Skip to content

Granite code support #1336

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions install/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ gguf
# Tiktoken tokenizer for Llama 3 and other advanced models
tiktoken

# Tokenizers and jinja2 for other non-llama models that use HF tokenizers
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added these here, but did not add pytest (yet). I think there's a pending conversation about introducing optional dependency sets, so it would make sense to add a test or dev set at that point, but I didn't want to accidentally carry pytest along as a runtime dependency.

tokenizers
jinja2

# Miscellaneous
snakeviz
sentencepiece
Expand Down
12 changes: 12 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""
Global pytest config, fixtures, and helpers go here!
"""

# Standard
import os
import sys

# Make sure tests can import torchchat
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be a lot cleaner if we move to having a pyproject.toml or setup.py to bundle torchchat as a package that could be installed with pip install -e.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on the list

sys.path.append(
os.path.realpath(os.path.join(os.path.dirname(__file__), ".."))
)
216 changes: 216 additions & 0 deletions tests/test_chat_formatters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
"""
Unit tests for chat formatters
"""

# Third Party
import pytest

# Local
from torchchat.generate import (
HFTokenizerChatFormatter,
Llama2ChatFormatter,
Llama3ChatFormatter,
)

## Helpers #####################################################################

class DummyTokenizer:
"""Dummy tokenizer that encodes as strings so it's easy to check formatting"""
def encode(self, text, *_, **__):
return text


class DummySPTokenizer(DummyTokenizer):
"""Emulated Sentencepiece tokenizer with bos/eos"""
bos = "<s>"
eos = "</s>"


class DummyLlama3Tokenizer(DummyTokenizer):
class _IdentityDict:
def __getitem__(self, key):
return key
special_tokens = _IdentityDict()


class DummyHFTokenizer(DummyTokenizer):
"""Dummy made up chat template scheme"""
# Sequence
bos = "<bos>"
# Turn
bot = "<bot>"
eot = "<eot>"
# Role
bor = "<bor>"
eor = "<eor>"
def apply_chat_template(self, messages, add_generation_prompt):
out = [self.bos]
role = None
for msg in messages:
role = msg["role"]
content = msg["content"]
out.append(f"{self.bot}{self.bor}{role}{self.eor}{content}{self.eot}")
if add_generation_prompt and role != "assistant":
out.append(f"{self.bot}{self.bor}assistant{self.eor}")
return "\n".join(out)


def check_rendering(fmt, messages, expected, add_generation_prompt):
"""Render messages and compare to expected output"""
assert "".join(fmt.encode_dialog_prompt(messages, add_generation_prompt)) == expected


def make_message(role, text):
return {"role": role, "content": text}


SYSTEM_PROMPT = "You are a helpful assistant, feel free to ask me anything."
USER1 = "Hello world!"
ASSISTANT1 = "Greetings! How can I help you?"
USER2 = "Why is the sky blue?"
ASSISTANT2 = "The sky appears blue because of a phenomenon called Rayleigh scattering."


# Stock sets of messages to test
MSGS_NO_SYS= [
make_message("user", USER1),
]
MSGS_SYS_USR = [
make_message("system", SYSTEM_PROMPT),
make_message("user", USER1),
]
MSGS_SYS_USR_ASST = [
make_message("system", SYSTEM_PROMPT),
make_message("user", USER1),
make_message("assistant", ASSISTANT1),
]
MSGS_MULTI_TURN = [
make_message("system", SYSTEM_PROMPT),
make_message("user", USER1),
make_message("assistant", ASSISTANT1),
make_message("user", USER2),
make_message("assistant", ASSISTANT2),
]

## Llama2ChatFormatter #########################################################

@pytest.mark.parametrize(
["messages", "expected"],
[
# single user message (no system prompt)
(MSGS_NO_SYS, f"<s>[INST] {USER1} [/INST]"),
# sys, usr
(MSGS_SYS_USR, f"""<s>[INST] <<SYS>>
{SYSTEM_PROMPT}
<</SYS>>

{USER1} [/INST]"""),
# sys, usr, asst
(MSGS_SYS_USR_ASST, f"""<s>[INST] <<SYS>>
{SYSTEM_PROMPT}
<</SYS>>

{USER1} [/INST] {ASSISTANT1} </s>
"""),
# sys, usr, asst, usr, asst
(MSGS_MULTI_TURN, f"""<s>[INST] <<SYS>>
{SYSTEM_PROMPT}
<</SYS>>

{USER1} [/INST] {ASSISTANT1} </s>
<s>[INST] {USER2} [/INST] {ASSISTANT2} </s>
"""),
]
)
def test_llama2_chat_formatter(messages, expected):
"""Tests for Llama2 following the official guide
https://www.llama.com/docs/model-cards-and-prompt-formats/meta-llama-2/
"""
tok = DummySPTokenizer()
fmt = Llama2ChatFormatter(tok)
# NOTE: add_generation_prompt not used by Llama2
check_rendering(fmt, messages, expected, True)

## Llama3ChatFormatter #########################################################

@pytest.mark.parametrize(
["messages", "expected"],
[
# single user message (no system prompt)
(MSGS_NO_SYS, f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>

{USER1}<|eot_id|>"""),
# sys, usr
(MSGS_SYS_USR, f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>

{SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|>

{USER1}<|eot_id|>"""),
# sys, usr, asst
(MSGS_SYS_USR_ASST, f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>

{SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|>

{USER1}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

{ASSISTANT1}<|eot_id|>"""),
# sys, usr, asst, usr, asst
(MSGS_MULTI_TURN, f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>

{SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|>

{USER1}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

{ASSISTANT1}<|eot_id|><|start_header_id|>user<|end_header_id|>

{USER2}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

{ASSISTANT2}<|eot_id|>"""),
]
)
@pytest.mark.parametrize("add_generation_prompt", [True, False])
def test_llama3_chat_formatter(messages, expected, add_generation_prompt):
"""Tests for Llama3 following the official guide
https://www.llama.com/docs/model-cards-and-prompt-formats/meta-llama-3/
"""
tok = DummyLlama3Tokenizer()
fmt = Llama3ChatFormatter(tok)
# No assistant prompt added if the last message is from the assistant
if add_generation_prompt and messages[-1]["role"] != "assistant":
expected += "<|start_header_id|>assistant<|end_header_id|>\n\n"
check_rendering(fmt, messages, expected, add_generation_prompt)

## HFTokenizerChatFormatter ####################################################

@pytest.mark.parametrize(
["messages", "expected"],
[
# single user message (no system prompt)
(MSGS_NO_SYS, f"""<bos>
<bot><bor>user<eor>{USER1}<eot>"""),
# sys, usr
(MSGS_SYS_USR, f"""<bos>
<bot><bor>system<eor>{SYSTEM_PROMPT}<eot>
<bot><bor>user<eor>{USER1}<eot>"""),
# sys, usr, asst
(MSGS_SYS_USR_ASST, f"""<bos>
<bot><bor>system<eor>{SYSTEM_PROMPT}<eot>
<bot><bor>user<eor>{USER1}<eot>
<bot><bor>assistant<eor>{ASSISTANT1}<eot>"""),
# sys, usr, asst, usr, asst
(MSGS_MULTI_TURN, f"""<bos>
<bot><bor>system<eor>{SYSTEM_PROMPT}<eot>
<bot><bor>user<eor>{USER1}<eot>
<bot><bor>assistant<eor>{ASSISTANT1}<eot>
<bot><bor>user<eor>{USER2}<eot>
<bot><bor>assistant<eor>{ASSISTANT2}<eot>"""),
]
)
@pytest.mark.parametrize("add_generation_prompt", [True, False])
def test_hf_chat_formatter(messages, expected, add_generation_prompt):
tok = DummyHFTokenizer()
fmt = HFTokenizerChatFormatter(tok)
# No assistant prompt added if the last message is from the assistant
if add_generation_prompt and messages[-1]["role"] != "assistant":
expected += f"\n{tok.bot}{tok.bor}assistant{tok.eor}"
check_rendering(fmt, messages, expected, add_generation_prompt)
28 changes: 27 additions & 1 deletion tokenizer/hf_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
# LICENSE file in the root directory of this source tree.

# Standard
from typing import List, Optional
from typing import Dict, List, Optional
import json
import os

# Third Party
import jinja2
from tokenizers import Tokenizer

# Local
Expand Down Expand Up @@ -37,6 +38,9 @@ def __init__(self, file_path: str):
# Load the tokenizer itself
self._tokenizer = Tokenizer.from_file(tokenizer_path)

# Load the chat template if we have a config path
self._chat_template: Optional[jinja2.Template] = None

# If available, parse bos/eos tokens from the tokenizer config
self._bos_id, self._eos_id = None, None
if tokenizer_config_path is not None:
Expand All @@ -48,6 +52,8 @@ def __init__(self, file_path: str):
self._bos_id = self._tokenizer.token_to_id(bos_token)
if eos_token is not None:
self._eos_id = self._tokenizer.token_to_id(eos_token)
if chat_template_str := tok_config.get("chat_template"):
self._chat_template = jinja2.Template(chat_template_str)

# If no eos/bos tokens found, go looking for them!
if None in [self._bos_id, self._eos_id]:
Expand All @@ -70,6 +76,8 @@ def _look_for_special_token(added_tokens: dict, search_strs: List[str]) -> Optio
if len(candidate_toks) == 1:
return candidate_toks[0]["id"]

## Interface ##

def encode(
self,
s: str,
Expand All @@ -90,3 +98,21 @@ def bos_id(self) -> int:

def eos_id(self) -> int:
return self._eos_id

## Additional Public Methods ##

def has_chat_template(self) -> bool:
return bool(self._chat_template)

def apply_chat_template(
self,
dialog: List[Dict[str, str]],
add_generation_prompt: bool = False,
) -> str:
"""If configured with a chat template, apply it to the list of messages
"""
if not self._chat_template:
raise ValueError("No chat template configured!")
return self._chat_template.render(
messages=dialog, add_generation_prompt=add_generation_prompt
)
10 changes: 9 additions & 1 deletion torchchat/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,15 @@
allowable_params_table,
)

logging.basicConfig(level=logging.INFO, format="%(message)s")
_log_level_env = os.getenv("LOG_LEVEL", "INFO")
try:
_log_level = getattr(logging, _log_level_env.upper())
except AttributeError:
print(f"Invalid log level: {_log_level_env}", file=sys.stderr)
_log_level = logging.INFO


logging.basicConfig(level=_log_level, format="%(message)s")
logger = logging.getLogger(__name__)

default_device = os.getenv("TORCHCHAT_DEVICE", "fast")
Expand Down
Loading
Loading