【Rust 研学】 | LLM 入门之旅 2 : BPE 算法

2024-04-22 11:29:11 浏览数 (2)

「 Rust 与 LLM」 是本合集的主题系列之一,本文为正文第二篇。 本合集将优先完成该主题系列文章,所以其他主题的文章优先级将降低。 「 Rust 与 LLM」主题系列将专注于自然语言处理、 Transfomer 架构和大模型相关内容,依托 Rust 开源生态和 HuggingFace 的相关 Rust 库,探秘从模型训练到模型部署、模型量化与 WebAssembly 轻量化部署的技术原理。

我们的作品是基于大模型实现的一个代码转译可视化工具,完全由 Rust 实现,也可能是这次赛事唯一一个用 Rust 实现的作品吧。

传统转译工具,比如 c2rust,其实是基于 ast 的转译方式,无法保留原项目架构的抽象信息,并且转译出来都是 unsafe 代码,实际应用效果不好。

本工具借助 Rust 实现了一个 AI Agent ,可以借助大模型的能力对 C/Cpp 进行转译,并得到更加安全的 Rust 实现。一共包括两个组件:一个是完全自动化的 cargo 插件,另一个是可以让开发者和 ai agent 无缝交互的终端操作界面

这个工具的开发时间一共加起来差不多十天左右,我和另一个小伙伴一起开发,分工协作,其实 Rust 开发效率还是非常不错的。在计算开发效率的时候,不要把 Rust 学习曲线也计算在内。

目前我正在重构这个工具,并且准备完善更多功能,在合适的时间点,我会开源这个项目,大家一起来玩。

有时候,我还是非常享受编程创造或解决某个问题(也是一种创造)的这个过程,包括在准备比赛作品的时候,我精神亢奋地通宵了两个晚上。以至于我顿悟:“如果编程不是为了创造,那和 AI 有何区别?”。因为 AI 仅仅是一个无情的代码生成机器,它不懂创造的乐趣。

然而更多的时候,我享受深入了解技术的实现和事物运行原理。所以 Rust 研学 LLM 系列文章也该继续更新了。

自然语言处理背景

在自然语言处理(NLP)中,标记化过程是文本预处理的一个关键步骤,通常发生在模型训练或预测的最初阶段。用 Transformer 架构(后续文章再讲)来说明时,大概分为以下几个步骤:

  1. 原始文本输入:在任何 NLP 任务开始之前,首先我们有原始的文本数据,这可以是句子、段落或整个文档。
  2. 标记化(Tokenization):在将文本输入 Transformer 模型之前,我们需要将文本转换为模型能理解的形式。标记化(或分词)就是这个过程的一部分,其中原始文本被分解成更小的单元(或词元,token)。这些标记可以是单词、子词或字符等。在许多现代应用中,特别是使用 BPE(字节对编码)或其变体(如 SentencePiece 或 WordPiece)进行子词标记化,可以有效处理未知词汇和减少词汇表的大小。
  3. 标记转换为ID:标记化之后,每个标记会被转换为一个唯一的数字ID,这些ID对应于模型词汇表中的条目。这一步是必要的,因为模型无法直接处理文本数据,而是通过这些数字ID来理解和生成文本。
  4. 输入 Transformer 模型:转换为 ID 的标记序列随后被输入到 Transformer 模型中。在模型内部,这些 ID 首先会通过嵌入层被转换为密集的向量表示,这些向量随后被用于模型的自注意力和其他处理层。
  5. 模型处理:Transformer 模型通过其多层自注意力机制和前馈网络处理输入的标记向量,执行所需的任务,如文本分类、翻译、摘要等。
  6. 输出处理:模型输出通常也是标记的形式,这些标记表示模型的预测结果。在生成任务中,如文本生成或机器翻译,输出标记序列将被转换回文本形式,以供最终用户使用。

我们这个 LLM 系列遵循这个自然语言处理过程。前面的文章讲述了分词器,属于标记化环节。本文再详细解读一下 BPE 算法的 Rust 实现思路

minbpe Rust 实现

关于 BPE 算法的概念如果你忘记了,可以再翻看本合集前面发的文章 [【Rust 研学】LLM 入门之旅番外篇 1.3 (上):OpenAI 工程师 Andrej 权威解读 GPT 分词器 ]( "【Rust 研学】LLM 入门之旅番外篇 1.3 (上):OpenAI 工程师 Andrej 权威解读 GPT 分词器 ")。

前 OpenAI 工程师 Karpathy 用 70 行纯 Python 编写了 最小化 BPE 算法实现:minbpe[1]。社区有人用 Rust 重新实现了 gnp/minbpe-rs[2] ,所以我们直接解读这个项目。

阅读源码和看书其实效果一样,只是学习作者的实现思路。但要想真正掌握它,应该还需要自己亲自再动手实现一遍,才能了解更多细节。后面找个时间我也会自己实现一遍。

minbpe-rs 是对原始 Python 版本 minbpe 的 Rust 移植,与 Python 版的文件组织是一一对应的。

  1. base.rs:
    • 它实现了 Tokenizer 类的基本框架,包含 train, encode, 和 decode 的基本结构和存取功能,以及一些共通的辅助函数。
    • 在 Rust 版本中,这个模块包括基本的 Tokenizer trait 和一些实用函数,但主要提供了用于被其他具体实现依赖的基础代码。
    • 对应于 Python 的: minbpe/base.py
    • 功能描述:
  2. basic.rs:
    • 实现了 BPE 算法的最简单形式,直接对文本进行处理。
    • 在 Rust 版本中,这个文件提供了基于字节级的 BPE 算法实现,即 BasicTokenizer,它处理直接输入的文本,并能进行训练、编码和解码。
    • 对应于 Python 的: minbpe/basic.py
    • 功能描述:
  3. regex.rs:
    • 实现了一个进一步通过正则表达式分割输入文本的分词器。
    • 在 Rust 版本中,这个模块包含了对文本的预处理步骤,使用正则表达式按类别分割文本(如字母、数字、标点符号等),以确保在类别边界不会进行合并。
    • 对应于 Python 的: minbpe/regex.py
    • 功能描述:
  4. gpt4.rs:
    • 一个轻量级的封装器,围绕 RegexTokenizer 实现,用于复现 GPT-4 的标记化过程。
    • 在 Rust 版本中,这个模块实现了特定的 GPT-4 标记化逻辑,处理一些特定的细节,如确保能够正确恢复 GPT-4 使用的特定合并和标记转换。
    • 对应于 Python 的: minbpe/gpt4.py
    • 功能描述:

Base.rs

base.rs 中定义了基本的 Tokenizer trait 和 一些 子 trait:

代码语言:javascript复制
/// Base trait for Tokenizers to implement.
pub trait Tokenizer {
 // 提供对特殊标记的访问,这些特殊标记通常用于处理如句子开始、结束等特定功能。
    fn special_tokens(&self) -> &IndexMap<String, Token>;

 // 提供对合并规则的访问,这些规则定义了在训练过程中哪些标记被合并。
    fn merges(&self) -> &IndexMap<(Token, Token), Token>;

 // 提供对词汇表的访问,这是标记化过程中的关键数据结构。
    fn vocab(&self) -> &IndexMap<Token, Vec<u8>>;

 // 将文本转换成一系列标记ID。这是文本处理中的基础步骤,用于后续的处理如模型训练或文本生成。
    /// A Tokenizer can encode a string into a list of integers.
    fn encode(&self, text: &str) -> Vec<Token>;

 // 将标记ID序列转换回原始文本。这通常用于生成文本后的输出阶段,验证标记化过程的准确性或用户界面展示
    /// A Tokenizer can decode a list of integers into a string.
    fn decode(&self, ids: &[Token]) -> String;
}

/// A Tokenizer that can be trained.
pub trait Trainable: Tokenizer {
 // 允许从给定的文本中训练词汇表和合并规则,
 // `vocab_size` 指定了目标词汇表大小,
 // `verbose` 标志控制是否输出详细的训练信息。
    /// Train a vocabulary of size `vocab_size` in distinct Tokens from `text`.
    fn train(&mut self, text: &str, vocab_size: Token, verbose: bool);
}


pub trait Saveable: Tokenizer {
    fn pattern(&self) -> &str;
    /// Saves the tokenizer's model and vocabulary to two files
    fn save(&self, dir: &Path, prefix: &str) {
        // let dir = dir.as_ref();

        // Write the model file (used for loading the tokenizer later)
        // 省略实现
    }
}


pub trait Loadable: Tokenizer {
    fn set_pattern(&mut self, pattern: &str);
    fn set_special_tokens(&mut self, special_tokens: IndexMap<String, Token>);
    fn set_merges(&mut self, merges: IndexMap<(Token, Token), Token>);
    fn set_vocab(&mut self, vocab: IndexMap<Token, Vec<u8>>);

    /// Loads the tokenizer's model from a file.
    fn load(&mut self, model_file: &Path) {
  // 省略实现
 }
}


这些 trait 提供了一个清晰的框架,通过定义一系列的接口来指定处理文本数据的标记化、训练、保存和加载功能。

我为什么说 Rust 提升了普通程序员的架构思维,就是这个原因。你用 Rust 的时候,需要面向接口编程。所谓面向接口,就是你需要思考系统变化的地方是什么。

你可以对照一下 Python 版本的 minbpe 面向对象的设计。相比于继承而言,面向接口的系统耦合性更低。

在这个案例中,Tokenizer trait 是基础的接口,定义了所有分词器应具备的核心功能。这包括能够对文本进行编码和解码,以及访问分词器的内部数据结构如词汇表、合并规则和特殊标记。

Trainable trait 扩展了 Tokenizer,为需要进行训练的分词器提供了额外的功能。这允许分词器根据实际文本数据学习和优化其内部的词汇表和合并规则。

Saveable trait 为分词器添加了保存功能。当分词器配置或训练代价高昂时,能够保存和重新加载是必要的。将分词器的状态(包括模型和词汇表)保存到指定的文件中,以便未来重用或分发。

Saveable 相对应,Loadable trait 允许从文件中加载先前保存的分词器状态。以便在需要时,复现实验结果或部署训练好的模型。

注意到在 Tokenizer trait 中使用了 IndexMap crate。这是因为 BPE 算法需要依赖元素插入的顺序,相比于 HashMapIndexMap 可以在保持插入顺序的同时,还提供了接近 HashMap 的性能。IndexMap 同时允许在更新时保持键的顺序,简化了合并字典和访问最频繁元素的逻辑。

除了接口之外,还有一些辅助函数:

  1. **get_statsupdate_stats**:用于计算和更新给定序列中连续标记对的出现次数。这对于 train 方法中的合并决策至关重要。
  2. **get_max_entry**:从统计数据中找到出现次数最多的标记对。这是选择合并操作的基础。
  3. **merge**:将序列中连续出现的标记对合并为一个新的标记。这是 BPE 算法中核心的合并步骤。
  4. **build_vocab**:根据特殊标记和合并历史构建词汇表。这个函数是在加载模型后重建词汇表的关键。
  5. **replace_control_charactersrender_token**:这些函数用于处理和格式化输出,特别是在创建可供人类阅读的词汇表文件时。

Basic.rs

basic.rs 定义了 BasicTokenizer 结构体,实现了字节级字节对编码(Byte Pair Encoding, BPE)算法的分词器。它直接操作文本,不处理正则表达式拆分模式或特殊标记。此实现主要参照了 GPT 分词器的算法。

代码语言:javascript复制
pub struct BasicTokenizer {
    special_tokens: IndexMap<String, Token>,
    merges: IndexMap<(Token, Token), Token>,
    vocab: IndexMap<Token, Vec<u8>>,
}


impl Tokenizer for BasicTokenizer {
    fn special_tokens(&self) -> &IndexMap<String, Token> {
        &self.special_tokens
    }

    fn merges(&self) -> &IndexMap<(Token, Token), Token> {
        &self.merges
    }

    fn vocab(&self) -> &IndexMap<Token, Vec<u8>> {
        &self.vocab
    }

    fn decode(&self, ids: &[Token]) -> String {
        // 将输入的标记ID序列转换成字符串
        // 通过遍历每个标记ID,从 `vocab` 映射中查找对应的字节序列
        // 然后将这些序列合并成一个完整的 UTF-8 字符串
        let text_bytes: Vec<u8> = ids
            .iter()
            .flat_map(|&idx| self.vocab[&idx].clone())
            .collect();
        // 将字节向量转换为字符串,
        // 这个方法会用特殊字符替换任何无效的 UTF-8 序列。
        String::from_utf8_lossy(&text_bytes).into_owned()
    }

    fn encode(&self, text: &str) -> Vec<Token> {
        // 将输入文本转换为其字节表示形式的序列
        let text_bytes = text.as_bytes();
        let mut ids: Vec<Token> = text_bytes.iter().map(|&b| b as Token).collect();
        while ids.len() >= 2 {
            // 找出序列中最常见的相邻标记对
            // 根据 `merges` 选择最低的合并索引来合并标记
            let stats = get_stats(&ids);

            let pair_opt = stats
                .keys()
                .filter_map(|&pair| self.merges.get(&pair).map(|_| pair))
                .min_by_key(|&pair| self.merges[&pair]);

   // 循环进行直到没有可合并的标记对为止
            match pair_opt {
                None => break, // If there are no more merges available, break
                Some(pair) => {
                    // Otherwise, merge the best pair (lowest merge index)
                    let idx = self.merges[&pair];
                    ids = merge(&ids, pair, idx);
                }
            };
        }
        ids
    }
}

impl Trainable for BasicTokenizer {
    /// 根据提供的文本训练一个大小为 `vocab_size` 的词汇表。
    /// 该方法实现了字节对编码(Byte Pair Encoding, BPE)算法,迭代地找出并合并最频繁的相邻标记对,
    /// 直到词汇表达到所需的大小。
    fn train(&mut self, text: &str, vocab_size: Token, verbose: bool) {
        // 确保请求的词汇表大小至少为256,以容纳所有单字节字符。
        assert!(vocab_size >= 256, "词汇表大小必须至少为256");

        // 计算需要创建的新标记数量,词汇表大小减去256(基础单字节字符的数量)。
        let num_merges = vocab_size - 256;

        // 将输入文本预处理为字节序列,每个字节视为一个初始标记。
        let text_bytes = text.as_bytes();
        let mut ids: Vec<Token> = text_bytes.iter().map(|&b| b as Token).collect();

        // 初始化合并记录和词汇表。
        let mut merges: IndexMap<(Token, Token), Token> = IndexMap::new();
        let mut vocab: IndexMap<Token, Vec<u8>> = (0..256).map(|idx| (idx, vec![idx as u8])).collect();

        // 迭代合并最常见的标记对,直到达到预设的词汇表大小或没有可合并的标记对。
        for i in 0..num_merges {
            // 计算当前所有相邻标记对的出现频率。
            let stats = get_stats(&ids);
            // 找到出现次数最多的标记对。
            let pair = get_max_entry(&stats).unwrap().0;

            // 为合并后的新标记分配一个新的ID(从256开始)。
            let idx = 256   i;
            // 在标记序列中替换所有出现的该标记对为新标记。
            ids = merge(&ids, *pair, idx);
            
            // 保存合并规则和更新词汇表。
            merges.insert(*pair, idx);
            vocab.insert(
                idx,
                [vocab[&pair.0].clone(), vocab[&pair.1].clone()].concat(),
            );

            // 如果设置为详细模式,则打印每次合并的详细信息。
            if verbose {
                println!(
                    "merge {}/{}: {:?} -> {} ({:?}) had {} occurrences",
                    i   1,
                    num_merges,
                    pair,
                    idx,
                    vocab[&idx],
                    stats[pair]
                );
            }
        }

        // 保存实例变量。
        self.merges = merges;
        self.vocab = vocab;
    }
}

代码说明可以参考文本注释。

为什么要实现字节级的 BPE ?有如下几个好处:

  • 处理未知词(OOV, Out-Of-Vocabulary)问题。字节级 BPE 通过将文本分解为更小的单位(字节而不是字符或单词),有效减少了未知词的问题。即使是未见过的词汇,也可以通过已知的字节组合来表示,这在处理多样化或专业领域的文本时尤其重要。
  • 词汇表大小可控。字节级 BPE 允许通过合并频繁出现的字节对来动态构建词汇表,最终词汇表的大小是可控的,这对模型的效率和性能都有积极影响。
  • 语言无关性。字节级处理意味着算法不依赖于任何特定语言的语法或词汇,使得同一模型能够应用于多种语言,增强了模型的通用性。
  • 简化模型复杂性。使用字节级的标记减少了模型需要学习的语言规则的复杂性,因为它主要关注于如何最有效地组合这些基本单元,而非解析高级语法结构。

为什么要训练词汇表

  • 这种通过统计大量文本数据中的字节对频率来确定哪些字节对应当合并,这种基于数据的方法可以自动发现最有效的标记策略,而不是依赖人工预定义。
  • 不同的文本集可能有不同的用语习惯和专业术语。通过在特定数据集上训练 BPE 模型,可以定制化词汇表以最好地反映该数据集的特点,从而提高模型的预测性能和准确性。
  • 适当的训练可以减少模型运行时对内存和其他计算资源的需求。合理的词汇表大小可以平衡模型的表达能力和资源消耗之间的关系。
  • 训练可以帮助确定合并操作的优先级,优化编码过程。这意味着常用的词或短语可以用更少的标记来表示,从而在使用模型处理实际任务时减少计算量和提高速度。

Regex.rs

这部分代码涵盖了正则表达式分词器 RegexTokenizerStruct 的实现,这种分词器可以处理更复杂的文本模式,包括特殊标记和正则表达式分割。

代码语言:javascript复制
pub const GPT2_SPLIT_PATTERN: &str = r"'(?:[sdmt]|ll|ve|re)| ?p{L} | ?p{N} | ?[^sp{L}p{N}] |s (?!S)|s ";
pub const GPT4_SPLIT_PATTERN: &str = r"'(?i:[sdmt]|ll|ve|re)|[^rnp{L}p{N}]? p{L} |p{N}{1,3}| ?[^sp{L}p{N}]  [rn]*|s*[rn]|s (?!S)|s ";

这部分定义了用于 GPT-2 和 GPT-4 的文本分割正则表达式,这些表达式用来分割输入文本以提取出适合处理的单元(tokens)。

代码语言:javascript复制
pub enum AllowedSpecial {
    All, // 允许在编码中使用所有特殊标记
    None, // 忽略所有特殊标记,将其视为普通文本进行编码
    NoneRaise, // 如果在编码过程中遇到特殊标记则引发错误
    Set(HashSet<String>), // 仅允许指定的特殊标记集合
}

该枚举定义了在编码过程中如何处理特殊标记的不同方式。

代码语言:javascript复制
pub trait RegexTokenizerTrait: Tokenizer {
 // 获取编译后的正则表达式对象
    fn compiled_pattern(&self) -> &Regex; 
 // 获取反向特殊标记映射
    fn inverse_special_tokens(&self) -> &IndexMap<Token, String>; 
 // 将文本编码成标记ID序列
    fn encode(&self, text: &str) -> Vec<Token> { 
     // 如果遇到特殊标记默认情况下引发错误
        self.encode_special(text, AllowedSpecial::NoneRaise) 
    }
 // 将标记ID序列解码成字符串
    fn decode(&self, ids: &[Token]) -> String { 
        // 默认实现,解码逻辑,包括处理特殊标记
    }
 // 编码任意忽略的特殊标记
 fn encode_ordinary(&self, text: &str) -> Vec<Token> {
   // 默认实现
  }
 // 根据 allowed_special 参数处理特殊标记的编码逻辑
    fn encode_special(&self, text: &str, allowed_special: AllowedSpecial) -> Vec<Token> {
       // 默认实现 
    }
}

RegexTokenizerTrait 也是 Tokenizer 的子 trait 定义了使用正则表达式处理文本的分词器应有的功能。

代码语言:javascript复制
pub struct RegexTokenizerStruct {
 // 使用的正则表达式模式字符串
    pattern: String, 
    // 编译后的正则表达式对象
    compiled_pattern: Regex, 
    // 特殊标记映射
    special_tokens: IndexMap<String, Token>,
    // 反向特殊标记映射 
    inverse_special_tokens: IndexMap<Token, String>, 
    // 合并规则映射
    merges: IndexMap<(Token, Token), Token>, 
    // 词汇表
    vocab: IndexMap<Token, Vec<u8>>, 
}
impl Tokenizer for RegexTokenizerStruct {
 // 实现
}
impl Trainable for RegexTokenizerStruct {
 // 实现
}

RegexTokenizerStruct 结构体定义了正则表达式分词器的具体数据结构。

特殊标记与 Prompt 模版

说到特殊标记,这里得说一下 Prompt 模版。在做大模型应用时候,Prompt 很重要。一个好的 Prompt 不仅仅是提升大模型输出的准确率,更重要的是,它也许能大幅地降低你的 token 成本。

当你遇到 LLM 问题时,最好的方法是首先使用提示(Prompt)。只有在你的提示达到最佳状态时,再考虑微调或更智能、更昂贵的模型。

0 人点赞