-
-
Notifications
You must be signed in to change notification settings - Fork 88
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support new OS models: Zephyr and Yi #392
Draft
rmitsch
wants to merge
58
commits into
develop
Choose a base branch
from
feat/new-os-models
base: develop
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
58 commits
Select commit
Hold shift + click to select a range
3aba660
Add context length info. Refactor BuiltinTask and models to facilitat…
rmitsch 5699773
Merge branch 'develop' into feat/inf-doc-len
rmitsch 4213372
Add token count estimator plumbing.
rmitsch f440ca4
Add plumbing for mapper and reducer.
rmitsch e47f762
Add ShardMapper prototype.
rmitsch 89a5510
Integrating mapping into prompt generation workflow.
rmitsch 086dec9
Update response parsing and component to support sharding (WIP).
rmitsch 23718fc
Fix shard & prompt flow.
rmitsch 7ce670d
Fix shard & prompt flow.
rmitsch 0d75ea8
Remove todo comments.
rmitsch 9da7098
Fix Anthropic, Cohere, NoOp model tests.
rmitsch 0cb9afd
Merge branch 'develop' into feat/inf-doc-len
rmitsch f368412
Fix test_llm_pipe().
rmitsch b1f111d
Fix type checking test.
rmitsch 44a2787
Fix span parsing tests.
rmitsch 6d8cdc7
Fix internal tests.
rmitsch e712f41
Fix _CountTask.
rmitsch 985fd68
Fix sentiment and summarization tasks and tests.
rmitsch 98842a2
Fix Azure connection URL. Fix Model test pings.
rmitsch b54a3d9
Fix Lemma parsing.
rmitsch 9bf365d
Start work on doc-to-shard property copying.
rmitsch dddfaab
Fix REL doc preprocessing.
rmitsch 3af21b5
Remove comment on doc attribute handling during sharding, as this is …
rmitsch fee9ca7
Add reducer implementations.
rmitsch e508499
Implement outstanding task reducers.
rmitsch 3218541
Resolve merge conflicts.
rmitsch c104387
Add shardable/non-shardable LLM task typing distinction. Add support …
rmitsch 2c6d899
Merge branch 'develop' into feat/inf-doc-len
rmitsch 2502c4d
Fix EL task.
rmitsch 03055c5
Fix EL tokenization and highlighting partially.
rmitsch 4e4a2cd
Fix tokenization and whitespaces for EL task.
rmitsch 865acec
Fix merge conflicts.
rmitsch 694d5da
Add new registry handlers (with context length and arbitrary model na…
rmitsch 5295400
Add sharding test with simple count task.
rmitsch 70e3643
Fix sharding algorithm.
rmitsch 4321483
Add test with simple count task.
rmitsch ef6e738
Add context length as init arg in HF models.
rmitsch e3ff37d
Fix tests. Don't stringify IO lists if sharded.
rmitsch 056730a
Fix tests.
rmitsch 196c235
Add NER sharding test.
rmitsch 1f51a4a
Add REL and sentiment sharding tests.
rmitsch e18b302
Add summary sharding tests.
rmitsch 7c092ca
Add EL sharding task. Fix bug in shard mapper.
rmitsch 358ba72
Fix REL error with RELExample parsing.
rmitsch 0c96fb6
Use regex for punctuation in REL conversion.
rmitsch dc926bd
Maintain custom doc attributes, incl. test.
rmitsch 5585174
Filter merge warnings in textcat reduction.
rmitsch 6d3a4c8
Add Zephyr and Yi classes.
rmitsch 57acfe4
Fix Yi model.
rmitsch 2f1a905
Fix Yi model.
rmitsch 9821063
Fix Yi and Zephyr processing.
rmitsch 98e3e6c
Remove deprecated comment.
rmitsch 513c2fb
Fix merge conflicts.
rmitsch 482af35
Merge branch 'develop' into feat/new-os-models
rmitsch 3747a2f
Change model used for Yi tests.
rmitsch b2dff8f
Incorporate feedback.
rmitsch dfe89ee
Skip Yi test failing in CI, but suceeding locally.
rmitsch 69c3c76
Extend readme with links for Zephyr and Yi.
rmitsch File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
from typing import Any, Dict, Iterable, List, Optional, Tuple | ||
|
||
from confection import SimpleFrozenDict | ||
|
||
from ...compat import Literal, transformers | ||
from ...registry.util import registry | ||
from .base import HuggingFace | ||
|
||
|
||
class Yi(HuggingFace): | ||
MODEL_NAMES = Literal[ # noqa: F722 | ||
"Yi-34B", | ||
"Yi-34B-chat-8bits", | ||
"Yi-6B-chat", | ||
"Yi-6B", | ||
"Yi-6B-200K", | ||
"Yi-34B-chat", | ||
"Yi-34B-chat-4bits", | ||
"Yi-34B-200K", | ||
] | ||
|
||
def __init__( | ||
self, | ||
name: MODEL_NAMES, | ||
config_init: Optional[Dict[str, Any]], | ||
config_run: Optional[Dict[str, Any]], | ||
context_length: int, | ||
): | ||
self._tokenizer: Optional["transformers.AutoTokenizer"] = None | ||
self._is_instruct = "instruct" in name | ||
super().__init__( | ||
name=name, | ||
config_init=config_init, | ||
config_run=config_run, | ||
context_length=context_length, | ||
) | ||
|
||
assert isinstance(self._tokenizer, transformers.PreTrainedTokenizerBase) | ||
|
||
# Instantiate GenerationConfig object from config dict. | ||
self._hf_config_run = transformers.GenerationConfig.from_pretrained( | ||
self._name, **self._config_run | ||
) | ||
# To avoid deprecation warning regarding usage of `max_length`. | ||
self._hf_config_run.max_new_tokens = self._hf_config_run.max_length | ||
|
||
def init_model(self) -> Any: | ||
self._tokenizer = transformers.AutoTokenizer.from_pretrained( | ||
self._name, use_fast=False | ||
) | ||
init_cfg = self._config_init | ||
device: Optional[str] = None | ||
if "device" in init_cfg: | ||
device = init_cfg.pop("device") | ||
|
||
model = transformers.AutoModelForCausalLM.from_pretrained( | ||
self._name, **init_cfg, resume_download=True | ||
).eval() | ||
if device: | ||
model.to(device) | ||
|
||
return model | ||
|
||
@property | ||
def hf_account(self) -> str: | ||
return "01-ai" | ||
|
||
def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: # type: ignore[override] | ||
assert hasattr(self._model, "generate") | ||
assert hasattr(self._tokenizer, "apply_chat_template") | ||
assert self._tokenizer | ||
|
||
responses: List[List[str]] = [] | ||
|
||
for prompts_for_doc in prompts: | ||
prompts_for_doc = list(prompts_for_doc) | ||
|
||
tokenized_input_ids = [ | ||
self._tokenizer.apply_chat_template( | ||
conversation=[{"role": "user", "content": prompt}], | ||
tokenize=True, | ||
add_generation_prompt=True, | ||
return_tensors="pt", | ||
) | ||
for prompt in prompts_for_doc | ||
] | ||
tokenized_input_ids = [ | ||
tp.to(self._model.device) for tp in tokenized_input_ids | ||
] | ||
|
||
responses.append( | ||
[ | ||
self._tokenizer.decode( | ||
self._model.generate( | ||
input_ids=tok_ii, generation_config=self._hf_config_run | ||
)[:, tok_ii.shape[1] :][0], | ||
skip_special_tokens=True, | ||
).strip("\n") | ||
for tok_ii in tokenized_input_ids | ||
] | ||
) | ||
|
||
return responses | ||
|
||
@staticmethod | ||
def compile_default_configs() -> Tuple[Dict[str, Any], Dict[str, Any]]: | ||
default_cfg_init, default_cfg_run = HuggingFace.compile_default_configs() | ||
return {**default_cfg_init, **{"torch_dtype": "auto"}}, default_cfg_run | ||
|
||
|
||
@registry.llm_models("spacy.Yi.v1") | ||
def yi_hf( | ||
name: Yi.MODEL_NAMES, | ||
config_init: Optional[Dict[str, Any]] = SimpleFrozenDict(), | ||
config_run: Optional[Dict[str, Any]] = SimpleFrozenDict(), | ||
) -> Yi: | ||
"""Generates Yi instance that can execute a set of prompts and return the raw responses. | ||
name (Literal): Name of the Yi model. Has to be one of Yi.get_model_names(). | ||
config_init (Optional[Dict[str, Any]]): HF config for initializing the model. | ||
config_run (Optional[Dict[str, Any]]): HF config for running the model. | ||
RETURNS (Yi): Yi instance that can execute a set of prompts and return the raw responses. | ||
""" | ||
return Yi( | ||
name=name, | ||
config_init=config_init, | ||
config_run=config_run, | ||
context_length=200000 if "200K" in name else 32000, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
from typing import Any, Dict, Iterable, List, Optional, Tuple | ||
|
||
from confection import SimpleFrozenDict | ||
|
||
from ...compat import Literal, transformers | ||
from ...registry.util import registry | ||
from .base import HuggingFace | ||
|
||
|
||
class Zephyr(HuggingFace): | ||
MODEL_NAMES = Literal["zephyr-7b-beta"] # noqa: F722 | ||
|
||
def __init__( | ||
self, | ||
name: MODEL_NAMES, | ||
config_init: Optional[Dict[str, Any]], | ||
config_run: Optional[Dict[str, Any]], | ||
context_length: int, | ||
): | ||
super().__init__( | ||
name=name, | ||
config_init=config_init, | ||
config_run=config_run, | ||
context_length=context_length, | ||
) | ||
|
||
# Instantiate GenerationConfig object from config dict. | ||
self._hf_config_run = transformers.GenerationConfig.from_pretrained( | ||
self._name, **self._config_run | ||
) | ||
# To avoid deprecation warning regarding usage of `max_length`. | ||
self._hf_config_run.max_new_tokens = self._hf_config_run.max_length | ||
|
||
def init_model(self) -> Any: | ||
return transformers.pipeline( | ||
"text-generation", | ||
model=self._name, | ||
return_full_text=False, | ||
**self._config_init | ||
) | ||
|
||
@property | ||
def hf_account(self) -> str: | ||
return "HuggingFaceH4" | ||
|
||
def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: # type: ignore[override] | ||
responses: List[List[str]] = [] | ||
|
||
for prompts_for_doc in prompts: | ||
formatted_prompts_for_doc = [ | ||
self._model.tokenizer.apply_chat_template( | ||
[{"role": "user", "content": prompt}], | ||
tokenize=False, | ||
add_generation_prompt=False, | ||
) | ||
for prompt in prompts_for_doc | ||
] | ||
|
||
responses.append( | ||
[ | ||
self._model(prompt, generation_config=self._hf_config_run)[0][ | ||
"generated_text" | ||
] | ||
.replace("<|assistant|>", "") | ||
.strip("\n") | ||
for prompt in formatted_prompts_for_doc | ||
] | ||
) | ||
|
||
return responses | ||
|
||
@staticmethod | ||
def compile_default_configs() -> Tuple[Dict[str, Any], Dict[str, Any]]: | ||
default_cfg_init, default_cfg_run = HuggingFace.compile_default_configs() | ||
return default_cfg_init, { | ||
**default_cfg_run, | ||
**{ | ||
"max_new_tokens": 256, | ||
"do_sample": True, | ||
"temperature": 0.7, | ||
"top_k": 50, | ||
"top_p": 0.95, | ||
}, | ||
} | ||
|
||
|
||
@registry.llm_models("spacy.Zephyr.v1") | ||
def zephyr_hf( | ||
name: Zephyr.MODEL_NAMES, | ||
config_init: Optional[Dict[str, Any]] = SimpleFrozenDict(), | ||
config_run: Optional[Dict[str, Any]] = SimpleFrozenDict(), | ||
) -> Zephyr: | ||
"""Generates Zephyr instance that can execute a set of prompts and return the raw responses. | ||
name (Literal): Name of the Zephyr model. Has to be one of Zephyr.get_model_names(). | ||
config_init (Optional[Dict[str, Any]]): HF config for initializing the model. | ||
config_run (Optional[Dict[str, Any]]): HF config for running the model. | ||
RETURNS (Zephyr): Zephyr instance that can execute a set of prompts and return the raw responses. | ||
""" | ||
return Zephyr( | ||
name=name, config_init=config_init, config_run=config_run, context_length=8000 | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,6 +27,7 @@ | |
|
||
[components.llm] | ||
factory = "llm" | ||
save_io = True | ||
|
||
[components.llm.task] | ||
@llm_tasks = "spacy.NoOp.v1" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import copy | ||
|
||
import pytest | ||
import spacy | ||
from confection import Config # type: ignore[import] | ||
from thinc.compat import has_torch_cuda_gpu | ||
|
||
from ...compat import torch | ||
|
||
_PIPE_CFG = { | ||
"model": { | ||
"@llm_models": "spacy.Yi.v1", | ||
"name": "Yi-6B-chat", | ||
}, | ||
"task": {"@llm_tasks": "spacy.NoOp.v1"}, | ||
} | ||
|
||
_NLP_CONFIG = """ | ||
|
||
[nlp] | ||
lang = "en" | ||
pipeline = ["llm"] | ||
batch_size = 128 | ||
|
||
[components] | ||
|
||
[components.llm] | ||
factory = "llm" | ||
|
||
[components.llm.task] | ||
@llm_tasks = "spacy.NoOp.v1" | ||
|
||
[components.llm.model] | ||
@llm_models = "spacy.Yi.v1" | ||
name = "Yi-6B" | ||
""" | ||
|
||
|
||
@pytest.mark.gpu | ||
@pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs GPU & CUDA") | ||
@pytest.mark.skip( | ||
reason="CI runner fails with 'cutlassF: no kernel found to launch!' - to be investigated" | ||
) | ||
def test_init(): | ||
"""Test initialization and simple run.""" | ||
nlp = spacy.blank("en") | ||
cfg = copy.deepcopy(_PIPE_CFG) | ||
nlp.add_pipe("llm", config=cfg) | ||
nlp("This is a test.") | ||
torch.cuda.empty_cache() | ||
|
||
|
||
@pytest.mark.gpu | ||
@pytest.mark.skip(reason="CI runner needs more GPU memory") | ||
@pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs GPU & CUDA") | ||
def test_init_from_config(): | ||
orig_config = Config().from_str(_NLP_CONFIG) | ||
nlp = spacy.util.load_model_from_config(orig_config, auto_fill=True) | ||
assert nlp.pipe_names == ["llm"] | ||
torch.cuda.empty_cache() | ||
|
||
|
||
@pytest.mark.gpu | ||
@pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs GPU & CUDA") | ||
def test_invalid_model(): | ||
orig_config = Config().from_str(_NLP_CONFIG) | ||
config = copy.deepcopy(orig_config) | ||
config["components"]["llm"]["model"]["name"] = "x" | ||
with pytest.raises(ValueError, match="unexpected value; permitted"): | ||
spacy.util.load_model_from_config(config, auto_fill=True) | ||
torch.cuda.empty_cache() |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure where this edit is coming from?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's just a drive-by because I noticed the warnings filter is missing here 🙃 I can move this into a separate PR, if you mind having it in here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's all just a bit confusing with the huge (mostly unrelated) git history etc - I do in general appreciate more "atomic" PRs ;-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know 🫣
Yeah, I don't know why that's the case. The branches should all be updated.