LLMs: Tokenlizer - BPE

17 minute read

Published:

Byte-pair encoding (BPE)

[TOC]

1. 概述

BPE 是一种 tokenize 的方法,其核心思想是通过合并最频繁出现的字符对来构建更大的子词单元,从而减少词汇表的大小,处理稀有词问题。它需要先在一个语料库上进行训练,得到词表后才能进行编码和解码。

2. 训练步骤

  1. 确定超参数:词表大小。对文本标准化,预分词。
  2. 初始化基础词汇表
    • 把预分词的 token 拆分为字符(character)序列,词末位加后缀’</w>’
  3. 统计相邻字符对 bi-gram 的频率。(不会出现以’</w>‘开头的情况)
  4. 迭代合并
    • 合并最高频次的字符对
    • 新组合加入词汇表 encoder_vocab,原有的不删
    • 在另一个词表 merges 中记录字符对合并的先后顺序,便于encoding
    • 按新的词汇表拆分单词,重复2-4步骤直到预设词汇量

GPT-1.0 中的 BPE 实现

由于 Huggingface 提供的 BookCorpus 数据集已经经过了细致的后处理,因此我们无法完全复现出原始 GPT 代码的结果。
本文仅基于原始实现 text_utils.py 完成了编码和解码部分的工作。\

  • encoder 词表:encoder_bpe_40000.json
  • 记录字符对合并先后顺序的词表:vocab_40000.bpe BPE 的训练流程可以参考文末 Karpathy 的 minbpe。

Training

  1. 预分词:用 ftfy 规范化 Unicode 字符,把非标准标点统一,并替换所有的空白字符为 \n,然后使用 spacy 的 en_core_web_sm 模型进行分词(见 bpe.py)。
  2. 初始化词汇表:将整个文本语料库拆分成单字符的子词单元,最后一个字符添加 </w>。在训练后的词表 encoder_bpe_40000.json 中可以看出,id 从 1-238 都为单个字符,239-476 为单个字符 + </w> 的形式。这里的 </w> 代表一个 token 的结尾。例如在单词 bamboo 中,最后一个 o 会被视作 o</w> 以与倒数第二个 o 区分。
  3. 统计 bi-gram 字符对的频率。get_stats 接收一个 list,遍历里面的每个 pair,统计出现次数。
  4. 合并最频繁出现的字符对,并形成一个新的子词单元。更新语料库中的词汇表 encoder_vocab,并在 merges 中记录该合并操作。
  5. 重复步骤 3-4 40000 次,于是在 476 个单个词元的基础上获得了 40000 个新的子词单元。再加上 <unk>\n</w> 共计 40478 个词元。

minpe 代码实现

统计字符对频次:

def get_stats(ids, counts=None):
    """
    Given a list of integers, return a dictionary of counts of consecutive pairs
    Example: [1, 2, 3, 1, 2] -> {(1, 2): 2, (2, 3): 1, (3, 1): 1}
    注:没考虑 </w>
    Optionally allows to update an existing dictionary of counts
    """
    counts = {} if counts is None else counts
    for pair in zip(ids, ids[1:]): # iterate consecutive elements
        counts[pair] = counts.get(pair, 0) + 1
    return counts
  1. 为什么是字节对?因为算法最开始用的是 utf-8 对 string 编码,一个字节8位数,能表示256个字符。因此字符序列其实是以 list of integers 字节列表的形式表示的。
    token = 'T Ss'
    text_bytes = token.encode("utf-8") # raw bytes
    ids = list(text_bytes) # list of integers in range 0..255
    print(ids)
    # [84, 32, 83, 115] 空格是熟悉的 32
    
  2. 初始化词典也只有 256 个条目。例如:输入”aaabdaaabac”,合并三次。由于词表初始只有 256 条,新的子词直接往后加。所以 256: 'aa', 257: 'ab', 258: 'aaab',编码输出为:[258, 100, 258, 97, 99]
vocab = {idx: bytes([idx]) for idx in range(256)} # int -> bytes
  1. 特殊token:
    特殊 token 注册为词表的最后一个,编号即词表大小。
    from minbpe import RegexTokenizer
    tokenizer = RegexTokenizer()
    tokenizer.train(very_long_training_string, vocab_size=32768)
    tokenizer.register_special_tokens({"<|endoftext|>": 32768})
    tokenizer.encode("<|endoftext|>hello world", allowed_special="all")
    

Encoding

  1. 加载训练好的词表 merges。
  2. 预分词:对输入的文本进行预分词,同训练阶段。
  3. 把预分词的 token 拆分为字符(character)序列,最后一个字符添加 </w>
  4. get_pairs 统计 bi-gram 字符对。
  5. 从这些 bi-gram 字符对中选择在词表 merges 中最早被合并的,记做 first 和 second。在原词中找到 first 的索引,保留之前的部分。如果 first 后就是 second,合并;否则保留这个单独的 first。合并会形成一个新的子词,用这个新的子词替换文本中的字符。
  6. 重复步骤 3-4 直到没有更多的有效字符对或者只剩一个字符单元。
  7. 缓存结果,将子词单元映射到词表中对应 token 的 id。

Decoding

  1. 加载训练好的词表 encoder_vocab。
  2. 根据词表建立反向映射,将给定 token id 映射回原子词即可。

Code

encoder:词表,键值对的形式。

"""
bpe is short for Byte Pair Encoder. It translates arbitrary utf-8 strings into
sequences of integers, where each integer represents small chunks of commonly
occuring characters. This implementation is based on openai's gpt text_utils.py:
https://github.com/openai/finetune-transformer-lm/blob/master/text_utils.py
"""

import re
from typing import List, Optional, Union
import ftfy
import json
import spacy

from tqdm import tqdm


def get_pairs(word):
    """
    Return all bigrams as a set of tuples, of consecutive elements in the iterable word.
    
    统计相邻字节对,demo:
    token = 'chat'
    word = tuple(token[:-1]) + (token[-1] + '</w>',)
    # ('c', 'h', 'a', 't</w>')
    pairs = get_pairs(word)
    # {('c', 'h'), ('h', 'a'), ('a', 't</w>')}

    bigram = min(pairs, key=lambda pair: bpe_ranks.get(pair, float("inf")))
    first, second = bigram
    # bigram: a t</w>

    1:  ['c', 'h']
    2:  ['c', 'h', 'at</w>']
    pairs:  {('h', 'at</w>'), ('c', 'h')}

    # bigram: c h
    1:  []
    2:  ['ch']
    pairs:  {('ch', 'at</w>')}

    # bigram: ch at</w>
    1:  []
    2:  ['chat</w>']
    ('chat</w>',)
    """
    pairs = set()
    prev_char = word[0]
    for char in word[1:]:
        pairs.add((prev_char, char))
        prev_char = char
    return pairs


def text_standardize(text):
    """
    fixes some issues the spacy tokenizer had on books corpus
    also does some whitespace standardization
    """
    text = text.replace("—", "-")
    text = text.replace("–", "-")
    text = text.replace("―", "-")
    text = text.replace("…", "...")
    text = text.replace("´", "'")

    # add spaces around all punctuations (-, ~, !, ", ;, ?, +, `,`, ), (, \, /, *, [, ], }, {, |, _)
    # example: "Hi!Kami-chanw" will be converted to "Hi ! Kami - chanw"
    text = re.sub(
        r"""(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)""",
        r" \1 ",
        text,
    )

    # shrink spaces (or add spaces if not space exists) around `\n`
    # exmaple: "Hi\nKamichanw    \n" will be converted to "Hi \n Kamichanw \n "
    text = re.sub(r"\s*\n\s*", " \n ", text)

    # replace all space characters (e.g. `\t`) except `\n` with a single space
    # exmaple: "Hi\tKamichanw   \n" will be converted to "Hi Kamichanw \n "
    text = re.sub(r"[^\S\n]+", " ", text)
    return text.strip()


class BPETokenizer(object):
    """
    mostly a wrapper for a public python bpe tokenizer
    """

    def __init__(self, encoder_path, bpe_path):
        self.nlp = spacy.load(
            "en_core_web_sm",
            disable=["parser", "tagger", "ner", "textcat", "lemmatizer"],
        )
        self.encoder = json.load(open(encoder_path))
        self.decoder = {v: k for k, v in self.encoder.items()}
        merges = open(bpe_path).read().split("\n")[1:-1]
        merges = [tuple(merge.split()) for merge in merges]
        self.bpe_ranks = dict(zip(merges, range(len(merges))))
        self.cache = {}
        self.special_tokens = {"</w>"}

    def add_special_tokens(self, new_tokens: List[str]):
        start_idx = len(self.encoder)

        for i, token in enumerate(new_tokens):
            if token in self.encoder:
                raise ValueError(f"Token '{token}' already exists in the encoder.")

            self.encoder[token] = start_idx + i
            self.decoder[start_idx + i] = token

            # no need to update BPE ranks for special tokens as they are not merged
            self.cache[token] = token
        self.special_tokens.update(new_tokens)

    def get_vocab_size(self):
        return len(self.encoder)

    def bpe(self, token):
        word = tuple(token[:-1]) + (token[-1] + "</w>",)
        if token in self.cache:
            return self.cache[token]
        pairs = get_pairs(word)

        if not pairs:
            return token + "</w>"

        while True:

            # find the next lowest rank bigram that can be merged
            # the lower rank means earlier be merged
            bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
            if bigram not in self.bpe_ranks:
                break  # no more bigrams are eligible to be merged
            first, second = bigram

            # we will now replace all occurences of (first, second) in the list of current
            # words into one merged token first_second, in the output list new_words
            new_word = []
            i = 0
            while i < len(word):

                # find the next occurence of first in the sequence of current words
                # 保留 first 之前的部分
                try:
                    j = word.index(first, i)
                    new_word.extend(word[i:j])
                    i = j
                except:
                    new_word.extend(word[i:])
                    break

                # if this occurence is also followed by second, then merge them into one
                # 如果 first 后就是 second,合并 
                if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
                    new_word.append(first + second)
                    i += 2
                # 否则,和其他字符一样,保留这个单独的 first
                else:
                    new_word.append(word[i])
                    i += 1

            # all occurences of (first, second) have been merged to first_second
            new_word = tuple(new_word)
            word = new_word
            if len(word) == 1:
                break
            else:
                pairs = get_pairs(word)
        word = " ".join(word)
        if word == "\n  </w>":
            word = "\n</w>"
        self.cache[token] = word
        return word

    def token_to_id(self, token: str) -> int:
        return self.encoder.get(token, 0)

    def encode(
        self,
        texts: Union[str, List[str]],
        verbose: bool = True,
    ) -> List[List[int]]:
        if not isinstance(texts, list):
            texts = [texts]

        texts_tokens = []
        bar = tqdm(texts, ncols=80, leave=False) if verbose else texts

        for text in bar:
            text = self.nlp(text_standardize(ftfy.fix_text(text)))
            text_tokens = []
            for token in text:
                text_tokens.extend(
                    [
                        self.encoder.get(t, 0)
                        for t in self.bpe(token.text.lower()).split(" ")
                    ]
                )
            texts_tokens.append(text_tokens)
            
        return texts_tokens

    def decode(
        self, bpe_idx: Union[List[List[int]], List[int]], skip_special_tokens=True
    ):
        """lists of integers comes in, a list of string comes out"""
        if not isinstance(bpe_idx[0], list):
            bpe_idx = [bpe_idx]

        texts = []
        for idx in bpe_idx:
            # inverse map the integers to get the tokens
            tokens_merged = [self.decoder[token] for token in idx]
            text = "".join(tokens_merged)
            if skip_special_tokens:
                text = re.sub("|".join(self.special_tokens), " ", text)
            texts.append(text)

        return texts

3. 例子: Train

corpus = ‘chat chatt cat chap’ vocab_size = 8

1. 首轮初始化词表:按字符分词,结束字符加后缀特殊标识,

1. "chat"  → ['c', 'h', 'a', 't</w>'] 
2. "chatt" → ['c', 'h', 'a', 't', 't</w>'] 
3. "cat"   → ['c', 'a', 't</w>'] 
4. "chap"  → ['c', 'h', 'a', 'p</w>'] 

词汇表:

{ 'c': 1, 'h': 2, 'a': 3, 't': 4, 'p</w>': 5, 't</w>: 6' } 

bpe 表:


2. 第一次合并

统计相邻字符对频率:

('c','h') → 3次 (来自"chat", "chatt", "chap")
('h','a') → 3次 (同上)
('a','t</w>'') → 2次 ("chat", "cat")
('a','t) → 1次 ("chatt")
('t','t</w>') → 1次 ("chatt")
('c','a') → 1次 ("cat")
('a','p</w>') → 1次 ("chap")

最高频对:(‘c’,’h’)、(‘h’,’a’) 各3次(任选其一,假设选(‘c’,’h’))

合并操作:将c和h合并为ch

新词汇表:

{ 'c': 1, 'h': 2, 'a': 3, 't': 4, 'p</w>': 5, 't</w>: 6', 'ch': 7 } 

bpe 表:

c h

词表更新后拆分单词:

chat  → ['ch', 'a', 't</w>']
chatt → ['ch', 'a', 't', 't'</w>']
cat   → ['c', 'a', 't</w>'] (未受影响)
chap  → ['ch', 'a', 'p</w>']

3. 第二次合并

统计相邻字符对频率:

('ch','a') → 3次 ("chat", "chatt", "chap")
('a','t</w>'') → 2次 ("chat", "cat")
('a','t) → 1次 ("chatt")
('t','t</w>') → 1次 ("chatt")
('c','a') → 1次 ("cat")
('a','p</w>') → 1次 ("chap")

最高频对:(‘ch’,’a’)

合并操作:cha

新词汇表:

{ 'c': 1, 'h': 2, 'a': 3, 't': 4, 'p</w>': 5, 't</w>: 6', 'ch': 7, 'cha': 8 } 

bpe 表:

c h
ch a

更新后的词拆分:

chat  → ['cha', 't</w>']
chatt → ['cha', 't', 't</w>']
cat   → ['c', 'a', 't</w>'] (未受影响)
chap  → ['cha', 'p</w>']

encoding

token = 'chat'
word = ('c', 'h', 'a', 't</w>')
pairs = get_pairs(word)
# {('c', 'h'), ('h', 'a'), ('a', 't</w>')}

bigram = min(pairs, key=lambda pair: bpe_ranks.get(pair, float("inf")))
first, second = bigram
# bigram: a t</w>

1:  ['c', 'h']
2:  ['c', 'h', 'at</w>']
pairs:  {('h', 'at</w>'), ('c', 'h')}

# bigram: c h
1:  []
2:  ['ch']
pairs:  {('ch', 'at</w>')}

# bigram: ch at</w>
1:  []
2:  ['chat</w>']
('chat</w>',)

minbpe

1. 概述

  1. 是当下最常用的大模型分词算法,是字节对,因为用于 UTF-8 encoded strings。因 GPT2 而大火,当下大模型都用BPE。
  2. Tokenizer 的三大主要功能:
    • 训练词典,合并 pre-token
    • encode from text to tokens
    • decode from token to text

2. 构建自己的 GPT-4 Tokenizer

  1. 写 BasicTokenizer 类,要有三个关键方法,train,encode,decode:
    • def train(self, text, vocab_size, verbose=False)
    • def encode(self, text)
    • def decode(self, ids)
  2. 把 Basic 类 转成 RegexTokenizer
  3. load merges,对比和 tiktoken 库的encode decode效果
    # match this
    import tiktoken
    enc = tiktoken.get_encoding("cl100k_base") # this is the GPT-4 tokenizer
    ids = enc.encode("hello world!!!? (안녕하세요!) lol123 😉")
    text = enc.decode(ids) # get the same text back
    
  4. 要有处理 special token 的能力,类比:
    import tiktoken
    enc = tiktoken.get_encoding("cl100k_base") # this is the GPT-4 tokenizer
    ids = enc.encode("<|endoftext|>hello world", allowed_special="all")
    
  5. Llama 类用的是 sentencepiece 做分词,区别在于,sentencepiece runs BPE directly on Unicode code points instead of on UTF-8 encoded bytes

code

  1. base.py
"""
Contains the base Tokenizer class and a few common helper functions.
The base class also contains the (common) save/load functionality.
It would be possible to be a lot more strict about the interface and
e.g. isolating all regex/pattern parts to the RegexTokenizer, but
some concessions are made for simplicity.
"""
import unicodedata

# -----------------------------------------------------------------------------
# a few helper functions useful for both BasicTokenizer and RegexTokenizer

def get_stats(ids, counts=None):
    """
    Given a list of integers, return a dictionary of counts of consecutive pairs
    Example: [1, 2, 3, 1, 2] -> {(1, 2): 2, (2, 3): 1, (3, 1): 1}
    Optionally allows to update an existing dictionary of counts
    """
    counts = {} if counts is None else counts
    for pair in zip(ids, ids[1:]): # iterate consecutive elements
        counts[pair] = counts.get(pair, 0) + 1
    return counts


def merge(ids, pair, idx):
    """
    In the list of integers (ids), replace all consecutive occurrences
    of pair with the new integer token idx
    Example: ids=[1, 2, 3, 1, 2], pair=(1, 2), idx=4 -> [4, 3, 4]
    """
    newids = []
    i = 0
    while i < len(ids):
        # if not at the very last position AND the pair matches, replace it
        if ids[i] == pair[0] and i < len(ids) - 1 and ids[i+1] == pair[1]:
            newids.append(idx)
            i += 2
        else:
            newids.append(ids[i])
            i += 1
    return newids

# first two helper functions...
def replace_control_characters(s: str) -> str:
    # we don't want to print control characters
    # which distort the output (e.g. \n or much worse)
    # https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python/19016117#19016117
    # http://www.unicode.org/reports/tr44/#GC_Values_Table
    chars = []
    for ch in s:
        if unicodedata.category(ch)[0] != "C":
            chars.append(ch) # this character is ok
        else:
            chars.append(f"\\u{ord(ch):04x}") # escape
    return "".join(chars)

def render_token(t: bytes) -> str:
    # pretty print a token, escaping control characters
    s = t.decode('utf-8', errors='replace')
    s = replace_control_characters(s)
    return s

# -----------------------------------------------------------------------------
# the base Tokenizer class

class Tokenizer:
    """Base class for Tokenizers"""

    def __init__(self):
        # default: vocab size of 256 (all bytes), no merges, no patterns
        self.merges = {} # (int, int) -> int
        self.pattern = "" # str
        self.special_tokens = {} # str -> int, e.g. {'<|endoftext|>': 100257}
        self.vocab = self._build_vocab() # int -> bytes

    def train(self, text, vocab_size, verbose=False):
        # Tokenizer can train a vocabulary of size vocab_size from text
        raise NotImplementedError

    def encode(self, text):
        # Tokenizer can encode a string into a list of integers
        raise NotImplementedError

    def decode(self, ids):
        # Tokenizer can decode a list of integers into a string
        raise NotImplementedError

    def _build_vocab(self):
        # vocab is simply and deterministically derived from merges
        vocab = {idx: bytes([idx]) for idx in range(256)}
        for (p0, p1), idx in self.merges.items():
            vocab[idx] = vocab[p0] + vocab[p1]
        for special, idx in self.special_tokens.items():
            vocab[idx] = special.encode("utf-8")
        return vocab

    def save(self, file_prefix):
        """
        Saves two files: file_prefix.vocab and file_prefix.model
        This is inspired (but not equivalent to!) sentencepiece's model saving:
        - model file is the critical one, intended for load()
        - vocab file is just a pretty printed version for human inspection only
        """
        # write the model: to be used in load() later
        model_file = file_prefix + ".model"
        with open(model_file, 'w') as f:
            # write the version, pattern and merges, that's all that's needed
            f.write("minbpe v1\n")
            f.write(f"{self.pattern}\n")
            # write the special tokens, first the number of them, then each one
            f.write(f"{len(self.special_tokens)}\n")
            for special, idx in self.special_tokens.items():
                f.write(f"{special} {idx}\n")
            # the merges dict
            for idx1, idx2 in self.merges:
                f.write(f"{idx1} {idx2}\n")
        # write the vocab: for the human to look at
        vocab_file = file_prefix + ".vocab"
        inverted_merges = {idx: pair for pair, idx in self.merges.items()}
        with open(vocab_file, "w", encoding="utf-8") as f:
            for idx, token in self.vocab.items():
                # note: many tokens may be partial utf-8 sequences
                # and cannot be decoded into valid strings. Here we're using
                # errors='replace' to replace them with the replacement char �.
                # this also means that we couldn't possibly use .vocab in load()
                # because decoding in this way is a lossy operation!
                s = render_token(token)
                # find the children of this token, if any
                if idx in inverted_merges:
                    # if this token has children, render it nicely as a merge
                    idx0, idx1 = inverted_merges[idx]
                    s0 = render_token(self.vocab[idx0])
                    s1 = render_token(self.vocab[idx1])
                    f.write(f"[{s0}][{s1}] -> [{s}] {idx}\n")
                else:
                    # otherwise this is leaf token, just print it
                    # (this should just be the first 256 tokens, the bytes)
                    f.write(f"[{s}] {idx}\n")

    def load(self, model_file):
        """Inverse of save() but only for the model file"""
        assert model_file.endswith(".model")
        # read the model file
        merges = {}
        special_tokens = {}
        idx = 256
        with open(model_file, 'r', encoding="utf-8") as f:
            # read the version
            version = f.readline().strip()
            assert version == "minbpe v1"
            # read the pattern
            self.pattern = f.readline().strip()
            # read the special tokens
            num_special = int(f.readline().strip())
            for _ in range(num_special):
                special, special_idx = f.readline().strip().split()
                special_tokens[special] = int(special_idx)
            # read the merges
            for line in f:
                idx1, idx2 = map(int, line.split())
                merges[(idx1, idx2)] = idx
                idx += 1
        self.merges = merges
        self.special_tokens = special_tokens
        self.vocab = self._build_vocab()
  1. basic.py
"""
Minimal (byte-level) Byte Pair Encoding tokenizer.

Algorithmically follows along the GPT tokenizer:
https://github.com/openai/gpt-2/blob/master/src/encoder.py

But:
- Does not handle the regular expression splitting pattern.
- Does not handle any special tokens.
"""

from .base import Tokenizer, get_stats, merge


class BasicTokenizer(Tokenizer):

    def __init__(self):
        super().__init__()

    def train(self, text, vocab_size, verbose=False):
        assert vocab_size >= 256
        num_merges = vocab_size - 256

        # input text preprocessing
        text_bytes = text.encode("utf-8") # raw bytes
        ids = list(text_bytes) # list of integers in range 0..255

        # iteratively merge the most common pairs to create new tokens
        merges = {} # (int, int) -> int
        vocab = {idx: bytes([idx]) for idx in range(256)} # int -> bytes
        for i in range(num_merges):
            # count up the number of times every consecutive pair appears
            stats = get_stats(ids)
            # find the pair with the highest count
            pair = max(stats, key=stats.get)
            # mint a new token: assign it the next available id
            idx = 256 + i
            # replace all occurrences of pair in ids with idx
            ids = merge(ids, pair, idx)
            # save the merge
            merges[pair] = idx
            vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
            # prints
            if verbose:
                print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")

        # save class variables
        self.merges = merges # used in encode()
        self.vocab = vocab   # used in decode()

    def decode(self, ids):
        # given ids (list of integers), return Python string
        text_bytes = b"".join(self.vocab[idx] for idx in ids)
        text = text_bytes.decode("utf-8", errors="replace")
        return text

    def encode(self, text):
        # given a string text, return the token ids
        text_bytes = text.encode("utf-8") # raw bytes
        ids = list(text_bytes) # list of integers in range 0..255
        while len(ids) >= 2:
            # find the pair with the lowest merge index
            stats = get_stats(ids)
            pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
            # subtle: if there are no more merges available, the key will
            # result in an inf for every single pair, and the min will be
            # just the first pair in the list, arbitrarily
            # we can detect this terminating case by a membership check
            if pair not in self.merges:
                break # nothing else can be merged anymore
            # otherwise let's merge the best pair (lowest merge index)
            idx = self.merges[pair]
            ids = merge(ids, pair, idx)
        return ids
  1. train.py
"""
Train our Tokenizers on some data, just to see them in action.
The whole thing runs in ~25 seconds on my laptop.
"""

import os
import time
from minbpe import BasicTokenizer, RegexTokenizer

# open some text and train a vocab of 512 tokens
text = open("tests/taylorswift.txt", "r", encoding="utf-8").read()

# create a directory for models, so we don't pollute the current directory
os.makedirs("models", exist_ok=True)

t0 = time.time()
for TokenizerClass, name in zip([BasicTokenizer, RegexTokenizer], ["basic", "regex"]):

    # construct the Tokenizer object and kick off verbose training
    tokenizer = TokenizerClass()
    tokenizer.train(text, 512, verbose=True)
    # writes two files in the models directory: name.model, and name.vocab
    prefix = os.path.join("models", name)
    tokenizer.save(prefix)
t1 = time.time()

print(f"Training took {t1 - t0:.2f} seconds")