Byte-Pair Encoding tokenization and Tiktoken

2024. 7. 8. 22:08ใ†ArtificialIntelligence/DeepLearning

 

 

 

Byte-Pair Encoding tokenization

https://youtu.be/HEikzVL-lZU

์–ด๋–ป๊ฒŒ ํ† ํฐํ™”์˜ ๋‹จ์œ„๊ฐ€ ๊ฒฐ์ •๋˜๋Š”์ง€ ์•Œ ์ˆ˜ ์žˆ๋‹ค.

 

 

 

+ Byte-Pair Encoding (BPE) was initially developed as an algorithm to compress texts, and then used by OpenAI for tokenization when pretraining the GPT model. It’s used by a lot of Transformer models, including GPT, GPT-2, RoBERTa, BART, and DeBERTa.

 

1) ์บ๋ฆญํ„ฐ ๋ณ„๋กœ ๋ชจ๋‘ ๋ถ„๋ฆฌํ•˜๊ธฐ 

2) Pair ๋‹จ์œ„๋กœ ๋นˆ๋„ ์ˆ˜ count 

3) ๊ฐ€์žฅ ๋งŽ์€ ๋นˆ๋„๋ฅผ ๋ณด์—ฌ์ฃผ๋Š” ์Œ์„ voca์— ์ƒˆ๋กญ๊ฒŒ ์ถ”๊ฐ€ 

4) ํ•ฉ์ณ์ง„ ๋‘ char๋ฅผ ํ•˜๋‚˜์˜ ํ† ํฐ์œผ๋กœ merge, ๋‹ค์‹œ pair ๋‹จ์œ„ ๋นˆ๋„์ˆ˜ count 

5) ์œ„ ๊ณผ์ •์„ ๋ฐ˜๋ณตํ•˜๋ฉด์„œ voca๋ฅผ ์™„์„ฑ์‹œํ‚จ๋‹ค.

6) ์ƒˆ๋กœ์šด input word๊ฐ€ ๋“ค์–ด์™”์„ ๋•Œ, ๋ฏธ๋ฆฌ ์ •์˜๋œ voca์˜ ํ† ํฐ๋“ค์„ ํ•˜๋‚˜ํ•˜๋‚˜ ์ˆœ์ฐจํƒ์ƒ‰ํ•˜๋ฉด์„œ, ํ•ด๋‹น ๋‹จ์–ด๋ฅผ ํ† ํฐํ™” ์‹œํ‚จ๋‹ค. 

-> ๊ธฐ์กด cache ๋ฐฉ์‹์ด ๋น„ํšจ์œจ์ ์ธ ๋ฉ”์ปค๋‹ˆ์ฆ˜์œผ๋กœ ๊ฐ€์ •ํ•˜๊ธฐ

 

 

 

# https://huggingface.co/learn/nlp-course/en/chapter6/5

hugging face์—์„œ ๋‹ค์Œ๊ณผ ๊ฐ™์ด ๊ตฌ์ฒด์ ์œผ๋กœ ๊ตฌํ˜„๋œ ์ฝ”๋“œ๋ฅผ ํ™•์ธํ•  ์ˆ˜ ์žˆ๋‹ค.

 

 

 

Tiktoken by OpenAI

Tiktoken is a fast BPE tokenizer for use with OpenAI's models

https://github.com/openai/tiktoken

 

GitHub - openai/tiktoken: tiktoken is a fast BPE tokeniser for use with OpenAI's models.

tiktoken is a fast BPE tokeniser for use with OpenAI's models. - openai/tiktoken

github.com

 

 

 

# get_encoding func

https://github.com/openai/tiktoken/blob/main/tiktoken/registry.py

from __future__ import annotations

import functools
import importlib
import pkgutil
import threading
from typing import Any, Callable, Optional, Sequence

import tiktoken_ext

from tiktoken.core import Encoding

_lock = threading.RLock()
ENCODINGS: dict[str, Encoding] = {}
ENCODING_CONSTRUCTORS: Optional[dict[str, Callable[[], dict[str, Any]]]] = None


@functools.lru_cache()
def _available_plugin_modules() -> Sequence[str]:
    # tiktoken_ext is a namespace package
    # submodules inside tiktoken_ext will be inspected for ENCODING_CONSTRUCTORS attributes
    # - we use namespace package pattern so `pkgutil.iter_modules` is fast
    # - it's a separate top-level package because namespace subpackages of non-namespace
    #   packages don't quite do what you want with editable installs
    mods = []
    plugin_mods = pkgutil.iter_modules(tiktoken_ext.__path__, tiktoken_ext.__name__ + ".")
    for _, mod_name, _ in plugin_mods:
        mods.append(mod_name)
    return mods


def _find_constructors() -> None:
    global ENCODING_CONSTRUCTORS
    with _lock:
        if ENCODING_CONSTRUCTORS is not None:
            return
        ENCODING_CONSTRUCTORS = {}

        for mod_name in _available_plugin_modules():
            mod = importlib.import_module(mod_name)
            try:
                constructors = mod.ENCODING_CONSTRUCTORS
            except AttributeError as e:
                raise ValueError(
                    f"tiktoken plugin {mod_name} does not define ENCODING_CONSTRUCTORS"
                ) from e
            for enc_name, constructor in constructors.items():
                if enc_name in ENCODING_CONSTRUCTORS:
                    raise ValueError(
                        f"Duplicate encoding name {enc_name} in tiktoken plugin {mod_name}"
                    )
                ENCODING_CONSTRUCTORS[enc_name] = constructor


def get_encoding(encoding_name: str) -> Encoding:
    if encoding_name in ENCODINGS:
        return ENCODINGS[encoding_name]

    with _lock:
        if encoding_name in ENCODINGS:
            return ENCODINGS[encoding_name]

        if ENCODING_CONSTRUCTORS is None:
            _find_constructors()
            assert ENCODING_CONSTRUCTORS is not None

        if encoding_name not in ENCODING_CONSTRUCTORS:
            raise ValueError(
                f"Unknown encoding {encoding_name}. Plugins found: {_available_plugin_modules()}"
            )

        constructor = ENCODING_CONSTRUCTORS[encoding_name]
        enc = Encoding(**constructor())
        ENCODINGS[encoding_name] = enc
        return enc


def list_encoding_names() -> list[str]:
    with _lock:
        if ENCODING_CONSTRUCTORS is None:
            _find_constructors()
            assert ENCODING_CONSTRUCTORS is not None
        return list(ENCODING_CONSTRUCTORS)

 

 

 

 

# Encoding 

https://github.com/openai/tiktoken/blob/main/tiktoken/core.py

from __future__ import annotations

import functools
from concurrent.futures import ThreadPoolExecutor
from typing import AbstractSet, Collection, Literal, NoReturn, Optional, Union

import regex

from tiktoken import _tiktoken


class Encoding:
    def __init__(
        self,
        name: str,
        *,
        pat_str: str,
        mergeable_ranks: dict[bytes, int],
        special_tokens: dict[str, int],
        explicit_n_vocab: Optional[int] = None,
    ):
        """Creates an Encoding object.

        See openai_public.py for examples of how to construct an Encoding object.

        Args:
            name: The name of the encoding. It should be clear from the name of the encoding
                what behaviour to expect, in particular, encodings with different special tokens
                should have different names.
            pat_str: A regex pattern string that is used to split the input text.
            mergeable_ranks: A dictionary mapping mergeable token bytes to their ranks. The ranks
                must correspond to merge priority.
            special_tokens: A dictionary mapping special token strings to their token values.
            explicit_n_vocab: The number of tokens in the vocabulary. If provided, it is checked
                that the number of mergeable tokens and special tokens is equal to this number.
        """
        self.name = name

        self._pat_str = pat_str
        self._mergeable_ranks = mergeable_ranks
        self._special_tokens = special_tokens

        self.max_token_value = max(
            max(mergeable_ranks.values()), max(special_tokens.values(), default=0)
        )
        if explicit_n_vocab:
            assert len(mergeable_ranks) + len(special_tokens) == explicit_n_vocab
            assert self.max_token_value == explicit_n_vocab - 1

        self._core_bpe = _tiktoken.CoreBPE(mergeable_ranks, special_tokens, pat_str)

    def __repr__(self) -> str:
        return f"<Encoding {self.name!r}>"

    # ====================
    # Encoding
    # ====================

    def encode_ordinary(self, text: str) -> list[int]:
        """Encodes a string into tokens, ignoring special tokens.

        This is equivalent to `encode(text, disallowed_special=())` (but slightly faster).

        ```
        >>> enc.encode_ordinary("hello world")
        [31373, 995]
        """
        try:
            return self._core_bpe.encode_ordinary(text)
        except UnicodeEncodeError:
            # See comment in encode
            text = text.encode("utf-16", "surrogatepass").decode("utf-16", "replace")
            return self._core_bpe.encode_ordinary(text)

    def encode(
        self,
        text: str,
        *,
        allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),  # noqa: B006
        disallowed_special: Union[Literal["all"], Collection[str]] = "all",
    ) -> list[int]:
        """Encodes a string into tokens.

        Special tokens are artificial tokens used to unlock capabilities from a model,
        such as fill-in-the-middle. So we want to be careful about accidentally encoding special
        tokens, since they can be used to trick a model into doing something we don't want it to do.

        Hence, by default, encode will raise an error if it encounters text that corresponds
        to a special token. This can be controlled on a per-token level using the `allowed_special`
        and `disallowed_special` parameters. In particular:
        - Setting `disallowed_special` to () will prevent this function from raising errors and
          cause all text corresponding to special tokens to be encoded as natural text.
        - Setting `allowed_special` to "all" will cause this function to treat all text
          corresponding to special tokens to be encoded as special tokens.

        ```
        >>> enc.encode("hello world")
        [31373, 995]
        >>> enc.encode("<|endoftext|>", allowed_special={"<|endoftext|>"})
        [50256]
        >>> enc.encode("<|endoftext|>", allowed_special="all")
        [50256]
        >>> enc.encode("<|endoftext|>")
        # Raises ValueError
        >>> enc.encode("<|endoftext|>", disallowed_special=())
        [27, 91, 437, 1659, 5239, 91, 29]
        ```
        """
        if allowed_special == "all":
            allowed_special = self.special_tokens_set
        if disallowed_special == "all":
            disallowed_special = self.special_tokens_set - allowed_special
        if disallowed_special:
            if not isinstance(disallowed_special, frozenset):
                disallowed_special = frozenset(disallowed_special)
            if match := _special_token_regex(disallowed_special).search(text):
                raise_disallowed_special_token(match.group())

        # https://github.com/PyO3/pyo3/pull/3632
        if isinstance(allowed_special, frozenset):
            allowed_special = set(allowed_special)

        try:
            return self._core_bpe.encode(text, allowed_special)
        except UnicodeEncodeError:
            # BPE operates on bytes, but the regex operates on unicode. If we pass a str that is
            # invalid UTF-8 to Rust, it will rightfully complain. Here we do a quick and dirty
            # fixup for any surrogate pairs that may have sneaked their way into the text.
            # Technically, this introduces a place where encode + decode doesn't roundtrip a Python
            # string, but given that this is input we want to support, maybe that's okay.
            # Also we use errors="replace" to handle weird things like lone surrogates.
            text = text.encode("utf-16", "surrogatepass").decode("utf-16", "replace")
            return self._core_bpe.encode(text, allowed_special)

    def encode_ordinary_batch(self, text: list[str], *, num_threads: int = 8) -> list[list[int]]:
        """Encodes a list of strings into tokens, in parallel, ignoring special tokens.

        This is equivalent to `encode_batch(text, disallowed_special=())` (but slightly faster).

        ```
        >>> enc.encode_ordinary_batch(["hello world", "goodbye world"])
        [[31373, 995], [11274, 16390, 995]]
        ```
        """
        encoder = functools.partial(self.encode_ordinary)
        with ThreadPoolExecutor(num_threads) as e:
            return list(e.map(encoder, text))

    def encode_batch(
        self,
        text: list[str],
        *,
        num_threads: int = 8,
        allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),  # noqa: B006
        disallowed_special: Union[Literal["all"], Collection[str]] = "all",
    ) -> list[list[int]]:
        """Encodes a list of strings into tokens, in parallel.

        See `encode` for more details on `allowed_special` and `disallowed_special`.

        ```
        >>> enc.encode_batch(["hello world", "goodbye world"])
        [[31373, 995], [11274, 16390, 995]]
        ```
        """
        if allowed_special == "all":
            allowed_special = self.special_tokens_set
        if disallowed_special == "all":
            disallowed_special = self.special_tokens_set - allowed_special
        if not isinstance(disallowed_special, frozenset):
            disallowed_special = frozenset(disallowed_special)

        encoder = functools.partial(
            self.encode, allowed_special=allowed_special, disallowed_special=disallowed_special
        )
        with ThreadPoolExecutor(num_threads) as e:
            return list(e.map(encoder, text))

    def encode_with_unstable(
        self,
        text: str,
        *,
        allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),  # noqa: B006
        disallowed_special: Union[Literal["all"], Collection[str]] = "all",
    ) -> tuple[list[int], list[list[int]]]:
        """Encodes a string into stable tokens and possible completion sequences.

        Note that the stable tokens will only represent a substring of `text`.

        See `encode` for more details on `allowed_special` and `disallowed_special`.

        This API should itself be considered unstable.

        ```
        >>> enc.encode_with_unstable("hello fanta")
        ([31373], [(277, 4910), (5113, 265), ..., (8842,)])

        >>> text = "..."
        >>> stable_tokens, completions = enc.encode_with_unstable(text)
        >>> assert text.encode().startswith(enc.decode_bytes(stable_tokens))
        >>> assert all(enc.decode_bytes(stable_tokens + seq).startswith(text.encode()) for seq in completions)
        ```
        """
        if allowed_special == "all":
            allowed_special = self.special_tokens_set
        if disallowed_special == "all":
            disallowed_special = self.special_tokens_set - allowed_special
        if disallowed_special:
            if not isinstance(disallowed_special, frozenset):
                disallowed_special = frozenset(disallowed_special)
            if match := _special_token_regex(disallowed_special).search(text):
                raise_disallowed_special_token(match.group())

        return self._core_bpe.encode_with_unstable(text, allowed_special)

    def encode_single_token(self, text_or_bytes: Union[str, bytes]) -> int:
        """Encodes text corresponding to a single token to its token value.

        NOTE: this will encode all special tokens.

        Raises `KeyError` if the token is not in the vocabulary.

        ```
        >>> enc.encode_single_token("hello")
        31373
        ```
        """
        if isinstance(text_or_bytes, str):
            text_or_bytes = text_or_bytes.encode("utf-8")
        return self._core_bpe.encode_single_token(text_or_bytes)

    # ====================
    # Decoding
    # ====================

    def decode_bytes(self, tokens: list[int]) -> bytes:
        """Decodes a list of tokens into bytes.

        ```
        >>> enc.decode_bytes([31373, 995])
        b'hello world'
        ```
        """
        return self._core_bpe.decode_bytes(tokens)

    def decode(self, tokens: list[int], errors: str = "replace") -> str:
        """Decodes a list of tokens into a string.

        WARNING: the default behaviour of this function is lossy, since decoded bytes are not
        guaranteed to be valid UTF-8. You can control this behaviour using the `errors` parameter,
        for instance, setting `errors=strict`.

        ```
        >>> enc.decode([31373, 995])
        'hello world'
        ```
        """
        return self._core_bpe.decode_bytes(tokens).decode("utf-8", errors=errors)

    def decode_single_token_bytes(self, token: int) -> bytes:
        """Decodes a token into bytes.

        NOTE: this will decode all special tokens.

        Raises `KeyError` if the token is not in the vocabulary.

        ```
        >>> enc.decode_single_token_bytes(31373)
        b'hello'
        ```
        """
        return self._core_bpe.decode_single_token_bytes(token)

    def decode_tokens_bytes(self, tokens: list[int]) -> list[bytes]:
        """Decodes a list of tokens into a list of bytes.

        Useful for visualising tokenisation.
        >>> enc.decode_tokens_bytes([31373, 995])
        [b'hello', b' world']
        """
        return [self.decode_single_token_bytes(token) for token in tokens]

    def decode_with_offsets(self, tokens: list[int]) -> tuple[str, list[int]]:
        """Decodes a list of tokens into a string and a list of offsets.

        Each offset is the index into text corresponding to the start of each token.
        If UTF-8 character boundaries do not line up with token boundaries, the offset is the index
        of the first character that contains bytes from the token.

        This will currently raise if given tokens that decode to invalid UTF-8; this behaviour may
        change in the future to be more permissive.

        >>> enc.decode_with_offsets([31373, 995])
        ('hello world', [0, 5])
        """
        token_bytes = self.decode_tokens_bytes(tokens)

        text_len = 0
        offsets = []
        for token in token_bytes:
            offsets.append(max(0, text_len - (0x80 <= token[0] < 0xC0)))
            text_len += sum(1 for c in token if not 0x80 <= c < 0xC0)

        # TODO: assess correctness for errors="ignore" and errors="replace"
        text = b"".join(token_bytes).decode("utf-8", errors="strict")
        return text, offsets

    def decode_batch(
        self, batch: list[list[int]], *, errors: str = "replace", num_threads: int = 8
    ) -> list[str]:
        """Decodes a batch (list of lists of tokens) into a list of strings."""
        decoder = functools.partial(self.decode, errors=errors)
        with ThreadPoolExecutor(num_threads) as e:
            return list(e.map(decoder, batch))

    def decode_bytes_batch(self, batch: list[list[int]], *, num_threads: int = 8) -> list[bytes]:
        """Decodes a batch (list of lists of tokens) into a list of bytes."""
        with ThreadPoolExecutor(num_threads) as e:
            return list(e.map(self.decode_bytes, batch))

    # ====================
    # Miscellaneous
    # ====================

    def token_byte_values(self) -> list[bytes]:
        """Returns the list of all token byte values."""
        return self._core_bpe.token_byte_values()

    @property
    def eot_token(self) -> int:
        return self._special_tokens["<|endoftext|>"]

    @functools.cached_property
    def special_tokens_set(self) -> set[str]:
        return set(self._special_tokens.keys())

    @property
    def n_vocab(self) -> int:
        """For backwards compatibility. Prefer to use `enc.max_token_value + 1`."""
        return self.max_token_value + 1

    # ====================
    # Private
    # ====================

    def _encode_single_piece(self, text_or_bytes: Union[str, bytes]) -> list[int]:
        """Encodes text corresponding to bytes without a regex split.

        NOTE: this will not encode any special tokens.

        ```
        >>> enc.encode_single_piece("helloqqqq")
        [31373, 38227, 38227]
        ```
        """
        if isinstance(text_or_bytes, str):
            text_or_bytes = text_or_bytes.encode("utf-8")
        return self._core_bpe.encode_single_piece(text_or_bytes)

    def _encode_only_native_bpe(self, text: str) -> list[int]:
        """Encodes a string into tokens, but do regex splitting in Python."""
        _unused_pat = regex.compile(self._pat_str)
        ret = []
        for piece in regex.findall(_unused_pat, text):
            ret.extend(self._core_bpe.encode_single_piece(piece))
        return ret

    def _encode_bytes(self, text: bytes) -> list[int]:
        return self._core_bpe._encode_bytes(text)

    def __getstate__(self) -> object:
        import tiktoken.registry

        # As an optimisation, pickle registered encodings by reference
        if self is tiktoken.registry.ENCODINGS.get(self.name):
            return self.name
        return {
            "name": self.name,
            "pat_str": self._pat_str,
            "mergeable_ranks": self._mergeable_ranks,
            "special_tokens": self._special_tokens,
        }

    def __setstate__(self, value: object) -> None:
        import tiktoken.registry

        if isinstance(value, str):
            self.__dict__ = tiktoken.registry.get_encoding(value).__dict__
            return
        self.__init__(**value)


@functools.lru_cache(maxsize=128)
def _special_token_regex(tokens: frozenset[str]) -> "regex.Pattern[str]":
    inner = "|".join(regex.escape(token) for token in tokens)
    return regex.compile(f"({inner})")


def raise_disallowed_special_token(token: str) -> NoReturn:
    raise ValueError(
        f"Encountered text corresponding to disallowed special token {token!r}.\n"
        "If you want this text to be encoded as a special token, "
        f"pass it to `allowed_special`, e.g. `allowed_special={{{token!r}, ...}}`.\n"
        f"If you want this text to be encoded as normal text, disable the check for this token "
        f"by passing `disallowed_special=(enc.special_tokens_set - {{{token!r}}})`.\n"
        "To disable this check for all special tokens, pass `disallowed_special=()`.\n"
    )

 

 

 

Valgrind์„ ํ™œ์šฉํ•œ ๋ฉ”๋ชจ๋ฆฌ ์ถ”์  

pin์ด ARM Processor์„ ์ง€์›ํ•˜์ง€ ์•Š์•„, ๋Œ€์•ˆ์œผ๋กœ ํ™œ์šฉํ•œ valgrind 

pintrace์™€ ๊ฐ™์ด ๋ฉ”๋ชจ๋ฆฌ ์ถ”์  ๊ธฐ๋Šฅ์„ ํ™œ์šฉํ•  ์ˆ˜ ์žˆ๋‹ค. 

๊ฐ„๋‹จํ•œ ์˜ˆ์ œ๋ฅผ ์‹คํ–‰ํ•ด๋ณด์•˜๋‹ค.

https://valgrind.org

 

 

 

1) brew๋ฅผ ํ†ตํ•ด valgrind ์„ค์น˜

brew install valgrind

 

 

 

2) hello.c ์ƒ์„ฑ 

// hello.c ์˜ˆ์‹œ ํŒŒ์ผ์„ ์ €์žฅํ•ด์ค€๋‹ค. 

#include <stdio.h>
#include <stdlib.h>

void memory_access_example() {
    int *array = (int*)malloc(10 * sizeof(int)); // ๋™์  ๋ฉ”๋ชจ๋ฆฌ ํ• ๋‹น DMA 
    for (int i = 0; i < 10; i++) {
        array[i] = i;
    }
    for (int i = 0; i < 10; i++) {
        printf("%d ", array[i]);
    }
    printf("\n");
    free(array); // ๋ฉ”๋ชจ๋ฆฌ ํ•ด์ œ 
}

// main func 
int main() {
    memory_access_example();
    return 0;
}

 

 

 

3) ์ปดํŒŒ์ผ

gcc -o hello hello.c

 

 

 

4) valgrind memcheck 

valgrind --tool=memcheck --leak-check=yes ./hello

 

ARM ์ง€์›์„ ํ™•์ธํ•˜์˜€์œผ๋‚˜, MacOS M1์—์„œ ์—ฌ์ „ํžˆ valgrind ์‹คํ–‰ X 

-> ๊ธฐ๋ณธ์ ์œผ๋กœ ํ™œ์šฉ ๊ฐ€๋Šฅํ•œ LLDB๋Š” ๋””๋ฒ„๊น…์— ์ ํ•ฉ or ํŠน์ • ์ฃผ์†Œ ์žก๊ณ  ๋ถ„์„ ex) lldb ./hello

-> ๋” ์•Œ์•„๋ด์•ผ ํ•  ๊ฒƒ ๊ฐ™๋‹ค (๋ฉ”๋ชจ๋ฆฌ trace or cache miss ํŒŒ์•…ํ•˜๊ธฐ) 

 

 

 

 

 

์ฐธ๊ณ ๋ฌธํ—Œ 

  1. https://huggingface.co/learn/nlp-course/en/chapter6/5 (ํ† ํฐํ™” ๊ณผ์ •์„ ์ฝ”๋“œ๋กœ ๊ตฌํ˜„) 
  2. https://github.com/openai/tiktoken
  3. https://valgrind.org 
  4. https://hyeyoo.com/64
  5. https://www.intel.com/content/www/us/en/developer/articles/tool/pin-a-dynamic-binary-instrumentation-tool.html