NumPyML 源码解析(五)

2024-02-17 10:05:25 浏览数 (1)

numpy-mlnumpy_mlpreprocessingnlp.py

代码语言:javascript复制
# 导入必要的库和模块
import re
import heapq
import os.path as op
from collections import Counter, OrderedDict, defaultdict
import numpy as np

# 定义英文停用词列表,来源于"Glasgow Information Retrieval Group"
_STOP_WORDS = set(
    ).split(" "),
)

# 定义用于匹配单词的正则表达式,用于分词
_WORD_REGEX = re.compile(r"(?u)bww b")  # sklearn默认
_WORD_REGEX_W_PUNC = re.compile(r"(?u)w |[^a-zA-Z0-9s]")
_WORD_REGEX_W_PUNC_AND_WHITESPACE = re.compile(r"(?u)s?w s?|s?[^a-zA-Z0-9s]s?")

# 定义用于匹配标点符号的正则表达式
_PUNC_BYTE_REGEX = re.compile(
    r"(33|34|35|36|37|38|39|40|41|42|43|44|45|"
    r"46|47|58|59|60|61|62|63|64|91|92|93|94|"
    r"95|96|123|124|125|126)",
)
# 定义标点符号
_PUNCTUATION = "!"#$%&'()* ,-./:;<=>?@[\]^_`{|}~"
# 创建用于去除标点符号的转换表
_PUNC_TABLE = str.maketrans("", "", _PUNCTUATION)

# 定义函数,返回指定长度的n-gram序列
def ngrams(sequence, N):
    """Return all `N`-grams of the elements in `sequence`"""
    assert N >= 1
    return list(zip(*[sequence[i:] for i in range(N)]))

# 定义函数,将字符串按空格分词,可选择是否转为小写、过滤停用词和标点符号
def tokenize_whitespace(
    line, lowercase=True, filter_stopwords=True, filter_punctuation=True, **kwargs,
):
    """
    Split a string at any whitespace characters, optionally removing
    punctuation and stop-words in the process.
    """
    line = line.lower() if lowercase else line
    words = line.split()
    line = [strip_punctuation(w) for w in words] if filter_punctuation else line
    return remove_stop_words(words) if filter_stopwords else words

# 定义函数,将字符串按单词分词,可选择是否转为小写、过滤停用词和标点符号
def tokenize_words(
    line, lowercase=True, filter_stopwords=True, filter_punctuation=True, **kwargs,
):
    """
    Split a string into individual words, optionally removing punctuation and
    stop-words in the process.
    """
    REGEX = _WORD_REGEX if filter_punctuation else _WORD_REGEX_W_PUNC
    words = REGEX.findall(line.lower() if lowercase else line)
    return remove_stop_words(words) if filter_stopwords else words

# 定义函数,将字符串按字节分词
def tokenize_words_bytes(
    line,
    # 设置是否将文本转换为小写
    lowercase=True,
    # 设置是否过滤停用词
    filter_stopwords=True,
    # 设置是否过滤标点符号
    filter_punctuation=True,
    # 设置文本编码格式为 UTF-8
    encoding="utf-8",
    # **kwargs 表示接受任意数量的关键字参数,这些参数会被传递给函数的其他部分进行处理
    **kwargs,
# 将字符串拆分为单词,并在此过程中选择性地删除标点符号和停用词。将每个单词转换为字节列表。
def tokenize_words(
    line,
    lowercase=lowercase,
    filter_stopwords=filter_stopwords,
    filter_punctuation=filter_punctuation,
    **kwargs,
):
    # 对单词进行分词处理,根据参数选择是否转换为小写、过滤停用词和标点符号
    words = tokenize_words(
        line,
        lowercase=lowercase,
        filter_stopwords=filter_stopwords,
        filter_punctuation=filter_punctuation,
        **kwargs,
    )
    # 将单词转换为字节列表,每个字节用空格分隔
    words = [" ".join([str(i) for i in w.encode(encoding)]) for w in words]
    # 返回字节列表
    return words


# 将字符串中的字符转换为字节集合。每个字节用0到255之间的整数表示。
def tokenize_bytes_raw(line, encoding="utf-8", splitter=None, **kwargs):
    # 将字符串中的字符编码为字节,每个字节用空格分隔
    byte_str = [" ".join([str(i) for i in line.encode(encoding)])
    # 如果指定了分隔符为标点符号,则在编码为字节之前在标点符号处进行分割
    if splitter == "punctuation":
        byte_str = _PUNC_BYTE_REGEX.sub(r"-1-", byte_str[0]).split("-")
    return byte_str


# 将字节(表示为0到255之间的整数)解码为指定编码的字符。
def bytes_to_chars(byte_list, encoding="utf-8"):
    # 将字节列表中的整数转换为十六进制字符串
    hex_array = [hex(a).replace("0x", "") for a in byte_list]
    # 将十六进制字符串连接起来,并在需要时在前面补0
    hex_array = " ".join([h if len(h) > 1 else f"0{h}" for h in hex_array])
    # 将十六进制字符串转换为字节数组,再根据指定编码解码为字符
    return bytearray.fromhex(hex_array).decode(encoding)


# 将字符串中的字符转换为小写,并根据参数选择是否过滤标点符号。
def tokenize_chars(line, lowercase=True, filter_punctuation=True, **kwargs):
    # 将字符串拆分为单个字符,可选择在此过程中删除标点符号和停用词
    """
    # 如果需要转换为小写,则将字符串转换为小写
    line = line.lower() if lowercase else line
    # 如果需要过滤标点符号,则调用函数去除标点符号
    line = strip_punctuation(line) if filter_punctuation else line
    # 使用正则表达式将连续多个空格替换为一个空格,并去除首尾空格,然后将结果转换为字符列表
    chars = list(re.sub(" {2,}", " ", line).strip())
    # 返回字符列表
    return chars
# 从单词字符串列表中移除停用词
def remove_stop_words(words):
    """Remove stop words from a list of word strings"""
    # 返回不在停用词列表中的单词
    return [w for w in words if w.lower() not in _STOP_WORDS]


# 从字符串中移除标点符号
def strip_punctuation(line):
    """Remove punctuation from a string"""
    # 使用_PUNC_TABLE来移除字符串中的标点符号,并去除首尾空格
    return line.translate(_PUNC_TABLE).strip()


#######################################################################
#                          Byte-Pair Encoder                          #
#######################################################################


# 定义一个Byte-Pair编码器类
class BytePairEncoder(object):
    def __init__(self, max_merges=3000, encoding="utf-8"):
        """
        A byte-pair encoder for sub-word embeddings.

        Notes
        -----
        Byte-pair encoding [1][2] is a compression algorithm that iteratively
        replaces the most frequently ocurring byte pairs in a set of documents
        with a new, single token. It has gained popularity as a preprocessing
        step for many NLP tasks due to its simplicity and expressiveness: using
        a base coebook of just 256 unique tokens (bytes), any string can be
        encoded.

        References
        ----------
        .. [1] Gage, P. (1994). A new algorithm for data compression. *C
           Users Journal, 12(2)*, 23–38.
        .. [2] Sennrich, R., Haddow, B., & Birch, A. (2015). Neural machine
           translation of rare words with subword units, *Proceedings of the
           54th Annual Meeting of the Association for Computational
           Linguistics,* 1715-1725.

        Parameters
        ----------
        max_merges : int
            The maximum number of byte pair merges to perform during the
            :meth:`fit` operation. Default is 3000.
        encoding : str
            The encoding scheme for the documents used to train the encoder.
            Default is `'utf-8'`.
        """
        # 初始化参数字典
        self.parameters = {
            "max_merges": max_merges,
            "encoding": encoding,
        }

        # 初始化字节到标记和标记到字节的有序字典。字节以十进制表示为0到255之间的整数。
        # 在255之前,标记和字节表示之间存在一对一的对应关系。
        self.byte2token = OrderedDict({i: i for i in range(256)})
        self.token2byte = OrderedDict({v: k for k, v in self.byte2token.items()})
    # 在给定语料库上训练一个字节对编码表
    def fit(self, corpus_fps, encoding="utf-8"):
        """
        Train a byte pair codebook on a set of documents.

        Parameters
        ----------
        corpus_fps : str or list of strs
            The filepath / list of filepaths for the document(s) to be used to
            learn the byte pair codebook.
        encoding : str
            The text encoding for documents. Common entries are either 'utf-8'
            (no header byte), or 'utf-8-sig' (header byte). Default is
            'utf-8'.
        """
        # 创建一个词汇表对象,用于存储字节对编码表
        vocab = (
            Vocabulary(
                lowercase=False,
                min_count=None,
                max_tokens=None,
                filter_stopwords=False,
                filter_punctuation=False,
                tokenizer="bytes",
            )
            # 在给定语料库上拟合词汇表
            .fit(corpus_fps, encoding=encoding)
            # 获取词汇表中的计数信息
            .counts
        )

        # 迭代地合并跨文档中最常见的字节二元组
        for _ in range(self.parameters["max_merges"]):
            # 获取词汇表中的字节二元组计数信息
            pair_counts = self._get_counts(vocab)
            # 找到出现次数最多的字节二元组
            most_common_bigram = max(pair_counts, key=pair_counts.get)
            # 合并最常见的字节二元组到词汇表中
            vocab = self._merge(most_common_bigram, vocab)

        # 初始化一个空集合,用于存储字节标记
        token_bytes = set()
        # 遍历词汇表中的键
        for k in vocab.keys():
            # 将键按空格分割,筛选包含"-"的字节标记
            token_bytes = token_bytes.union([w for w in k.split(" ") if "-" in w])

        # 遍历字节标记集合
        for i, t in enumerate(token_bytes):
            # 将字节标记转换为元组形式
            byte_tuple = tuple(int(j) for j in t.split("-"))
            # 将字节标记映射到对应的标记索引
            self.token2byte[256   i] = byte_tuple
            # 将字节标记索引映射到对应的字节标记
            self.byte2token[byte_tuple] = 256   i

        # 返回当前对象
        return self

    # 获取词汇表中的字节二元组计数信息
    def _get_counts(self, vocab):
        """Collect bigram counts for the tokens in vocab"""
        # 初始化一个默认字典,用于存储字节二元组计数
        pair_counts = defaultdict(int)
        # 遍历词汇表中的单词和计数信息
        for word, count in vocab.items():
            # 生成单词的二元组
            pairs = ngrams(word.split(" "), 2)
            # 遍历单词的二元组
            for p in pairs:
                # 更新字节二元组计数信息
                pair_counts[p]  = count
        # 返回字节二元组计数信息
        return pair_counts
    # 将给定的二元组替换为单个标记,并相应更新词汇表
    def _merge(self, bigram, vocab):
        v_out = {}
        # 转义二元组中的单词,用于正则表达式匹配
        bg = re.escape(" ".join(bigram))
        # 创建匹配二元组的正则表达式
        bigram_regex = re.compile(r"(?<!S)"   bg   r"(?!S)")
        # 遍历词汇表中的单词
        for word in vocab.keys():
            # 将匹配到的二元组替换为连接符"-"
            w_out = bigram_regex.sub("-".join(bigram), word)
            v_out[w_out] = vocab[word]
        return v_out

    # 将文本中的单词转换为其字节对编码的标记ID
    def transform(self, text):
        """
        Transform the words in `text` into their byte pair encoded token IDs.

        Parameters
        ----------
        text: str or list of `N` strings
            The list of strings to encode

        Returns
        -------
        codes : list of `N` lists
            A list of byte pair token IDs for each of the `N` strings in
            `text`.

        Examples
        --------
        >>> B = BytePairEncoder(max_merges=100).fit("./example.txt")
        >>> encoded_tokens = B.transform("Hello! How are you 


	

0 人点赞