test llama_factory
This commit is contained in:
@@ -0,0 +1,59 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import random
|
||||
|
||||
import pytest
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from llamafactory.extras.constants import IGNORE_INDEX
|
||||
from llamafactory.train.test_utils import load_train_dataset
|
||||
|
||||
|
||||
DEMO_DATA = os.getenv("DEMO_DATA", "llamafactory/demo_data")
|
||||
|
||||
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||
|
||||
TRAIN_ARGS = {
|
||||
"model_name_or_path": TINY_LLAMA,
|
||||
"stage": "kto",
|
||||
"do_train": True,
|
||||
"finetuning_type": "full",
|
||||
"dataset": "kto_en_demo",
|
||||
"dataset_dir": "REMOTE:" + DEMO_DATA,
|
||||
"template": "llama3",
|
||||
"cutoff_len": 8192,
|
||||
"overwrite_cache": True,
|
||||
"output_dir": "dummy_dir",
|
||||
"overwrite_output_dir": True,
|
||||
"fp16": True,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_samples", [16])
|
||||
def test_feedback_data(num_samples: int):
|
||||
train_dataset = load_train_dataset(**TRAIN_ARGS)
|
||||
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
|
||||
original_data = load_dataset(DEMO_DATA, name="kto_en_demo", split="train")
|
||||
indexes = random.choices(range(len(original_data)), k=num_samples)
|
||||
for index in indexes:
|
||||
messages = original_data["messages"][index]
|
||||
ref_input_ids = ref_tokenizer.apply_chat_template(messages)
|
||||
prompt_len = len(ref_tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=True))
|
||||
ref_labels = [IGNORE_INDEX] * prompt_len + ref_input_ids[prompt_len:]
|
||||
assert train_dataset["input_ids"][index] == ref_input_ids
|
||||
assert train_dataset["labels"][index] == ref_labels
|
||||
assert train_dataset["kto_tags"][index] == original_data["label"][index]
|
||||
@@ -0,0 +1,78 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import random
|
||||
from typing import Dict, List
|
||||
|
||||
import pytest
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from llamafactory.extras.constants import IGNORE_INDEX
|
||||
from llamafactory.train.test_utils import load_train_dataset
|
||||
|
||||
|
||||
DEMO_DATA = os.getenv("DEMO_DATA", "llamafactory/demo_data")
|
||||
|
||||
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||
|
||||
TRAIN_ARGS = {
|
||||
"model_name_or_path": TINY_LLAMA,
|
||||
"stage": "rm",
|
||||
"do_train": True,
|
||||
"finetuning_type": "full",
|
||||
"dataset": "dpo_en_demo",
|
||||
"dataset_dir": "REMOTE:" + DEMO_DATA,
|
||||
"template": "llama3",
|
||||
"cutoff_len": 8192,
|
||||
"overwrite_cache": True,
|
||||
"output_dir": "dummy_dir",
|
||||
"overwrite_output_dir": True,
|
||||
"fp16": True,
|
||||
}
|
||||
|
||||
|
||||
def _convert_sharegpt_to_openai(messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
|
||||
role_mapping = {"human": "user", "gpt": "assistant", "system": "system"}
|
||||
new_messages = []
|
||||
for message in messages:
|
||||
new_messages.append({"role": role_mapping[message["from"]], "content": message["value"]})
|
||||
|
||||
return new_messages
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_samples", [16])
|
||||
def test_pairwise_data(num_samples: int):
|
||||
train_dataset = load_train_dataset(**TRAIN_ARGS)
|
||||
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
|
||||
original_data = load_dataset(DEMO_DATA, name="dpo_en_demo", split="train")
|
||||
indexes = random.choices(range(len(original_data)), k=num_samples)
|
||||
for index in indexes:
|
||||
chosen_messages = original_data["conversations"][index] + [original_data["chosen"][index]]
|
||||
rejected_messages = original_data["conversations"][index] + [original_data["rejected"][index]]
|
||||
chosen_messages = _convert_sharegpt_to_openai(chosen_messages)
|
||||
rejected_messages = _convert_sharegpt_to_openai(rejected_messages)
|
||||
ref_chosen_input_ids = ref_tokenizer.apply_chat_template(chosen_messages)
|
||||
chosen_prompt_len = len(ref_tokenizer.apply_chat_template(chosen_messages[:-1], add_generation_prompt=True))
|
||||
ref_chosen_labels = [IGNORE_INDEX] * chosen_prompt_len + ref_chosen_input_ids[chosen_prompt_len:]
|
||||
ref_rejected_input_ids = ref_tokenizer.apply_chat_template(rejected_messages)
|
||||
rejected_prompt_len = len(
|
||||
ref_tokenizer.apply_chat_template(rejected_messages[:-1], add_generation_prompt=True)
|
||||
)
|
||||
ref_rejected_labels = [IGNORE_INDEX] * rejected_prompt_len + ref_rejected_input_ids[rejected_prompt_len:]
|
||||
assert train_dataset["chosen_input_ids"][index] == ref_chosen_input_ids
|
||||
assert train_dataset["chosen_labels"][index] == ref_chosen_labels
|
||||
assert train_dataset["rejected_input_ids"][index] == ref_rejected_input_ids
|
||||
assert train_dataset["rejected_labels"][index] == ref_rejected_labels
|
||||
@@ -0,0 +1,35 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from llamafactory.data.processors.processor_utils import infer_seqlen
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_input,test_output",
|
||||
[
|
||||
((3000, 2000, 1000), (600, 400)),
|
||||
((2000, 3000, 1000), (400, 600)),
|
||||
((1000, 100, 1000), (900, 100)),
|
||||
((100, 1000, 1000), (100, 900)),
|
||||
((100, 500, 1000), (100, 500)),
|
||||
((500, 100, 1000), (500, 100)),
|
||||
((10, 10, 1000), (10, 10)),
|
||||
],
|
||||
)
|
||||
def test_infer_seqlen(test_input: Tuple[int, int, int], test_output: Tuple[int, int]):
|
||||
assert test_output == infer_seqlen(*test_input)
|
||||
@@ -0,0 +1,104 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import random
|
||||
|
||||
import pytest
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from llamafactory.extras.constants import IGNORE_INDEX
|
||||
from llamafactory.train.test_utils import load_train_dataset
|
||||
|
||||
|
||||
DEMO_DATA = os.getenv("DEMO_DATA", "llamafactory/demo_data")
|
||||
|
||||
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||
|
||||
TINY_DATA = os.getenv("TINY_DATA", "llamafactory/tiny-supervised-dataset")
|
||||
|
||||
TRAIN_ARGS = {
|
||||
"model_name_or_path": TINY_LLAMA,
|
||||
"stage": "sft",
|
||||
"do_train": True,
|
||||
"finetuning_type": "full",
|
||||
"template": "llama3",
|
||||
"cutoff_len": 8192,
|
||||
"overwrite_cache": True,
|
||||
"output_dir": "dummy_dir",
|
||||
"overwrite_output_dir": True,
|
||||
"fp16": True,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_samples", [16])
|
||||
def test_supervised_single_turn(num_samples: int):
|
||||
train_dataset = load_train_dataset(dataset_dir="ONLINE", dataset=TINY_DATA, **TRAIN_ARGS)
|
||||
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
|
||||
original_data = load_dataset(TINY_DATA, split="train")
|
||||
indexes = random.choices(range(len(original_data)), k=num_samples)
|
||||
for index in indexes:
|
||||
prompt = original_data["instruction"][index]
|
||||
if original_data["input"][index]:
|
||||
prompt += "\n" + original_data["input"][index]
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
{"role": "assistant", "content": original_data["output"][index]},
|
||||
]
|
||||
ref_input_ids = ref_tokenizer.apply_chat_template(messages)
|
||||
assert train_dataset["input_ids"][index] == ref_input_ids
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_samples", [8])
|
||||
def test_supervised_multi_turn(num_samples: int):
|
||||
train_dataset = load_train_dataset(dataset_dir="REMOTE:" + DEMO_DATA, dataset="system_chat", **TRAIN_ARGS)
|
||||
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
|
||||
original_data = load_dataset(DEMO_DATA, name="system_chat", split="train")
|
||||
indexes = random.choices(range(len(original_data)), k=num_samples)
|
||||
for index in indexes:
|
||||
ref_input_ids = ref_tokenizer.apply_chat_template(original_data["messages"][index])
|
||||
assert train_dataset["input_ids"][index] == ref_input_ids
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_samples", [4])
|
||||
def test_supervised_train_on_prompt(num_samples: int):
|
||||
train_dataset = load_train_dataset(
|
||||
dataset_dir="REMOTE:" + DEMO_DATA, dataset="system_chat", train_on_prompt=True, **TRAIN_ARGS
|
||||
)
|
||||
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
|
||||
original_data = load_dataset(DEMO_DATA, name="system_chat", split="train")
|
||||
indexes = random.choices(range(len(original_data)), k=num_samples)
|
||||
for index in indexes:
|
||||
ref_ids = ref_tokenizer.apply_chat_template(original_data["messages"][index])
|
||||
assert train_dataset["input_ids"][index] == ref_ids
|
||||
assert train_dataset["labels"][index] == ref_ids
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_samples", [4])
|
||||
def test_supervised_mask_history(num_samples: int):
|
||||
train_dataset = load_train_dataset(
|
||||
dataset_dir="REMOTE:" + DEMO_DATA, dataset="system_chat", mask_history=True, **TRAIN_ARGS
|
||||
)
|
||||
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
|
||||
original_data = load_dataset(DEMO_DATA, name="system_chat", split="train")
|
||||
indexes = random.choices(range(len(original_data)), k=num_samples)
|
||||
for index in indexes:
|
||||
messages = original_data["messages"][index]
|
||||
ref_input_ids = ref_tokenizer.apply_chat_template(messages)
|
||||
prompt_len = len(ref_tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=True))
|
||||
ref_label_ids = [IGNORE_INDEX] * prompt_len + ref_input_ids[prompt_len:]
|
||||
assert train_dataset["input_ids"][index] == ref_input_ids
|
||||
assert train_dataset["labels"][index] == ref_label_ids
|
||||
@@ -0,0 +1,61 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import random
|
||||
|
||||
import pytest
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from llamafactory.train.test_utils import load_train_dataset
|
||||
|
||||
|
||||
DEMO_DATA = os.getenv("DEMO_DATA", "llamafactory/demo_data")
|
||||
|
||||
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||
|
||||
TINY_DATA = os.getenv("TINY_DATA", "llamafactory/tiny-supervised-dataset")
|
||||
|
||||
TRAIN_ARGS = {
|
||||
"model_name_or_path": TINY_LLAMA,
|
||||
"stage": "ppo",
|
||||
"do_train": True,
|
||||
"finetuning_type": "full",
|
||||
"reward_model": "",
|
||||
"reward_model_type": "full",
|
||||
"dataset": "system_chat",
|
||||
"dataset_dir": "REMOTE:" + DEMO_DATA,
|
||||
"template": "llama3",
|
||||
"cutoff_len": 8192,
|
||||
"overwrite_cache": True,
|
||||
"output_dir": "dummy_dir",
|
||||
"overwrite_output_dir": True,
|
||||
"fp16": True,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_samples", [16])
|
||||
def test_unsupervised_data(num_samples: int):
|
||||
train_dataset = load_train_dataset(**TRAIN_ARGS)
|
||||
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
|
||||
original_data = load_dataset(DEMO_DATA, name="system_chat", split="train")
|
||||
indexes = random.choices(range(len(original_data)), k=num_samples)
|
||||
for index in indexes:
|
||||
messages = original_data["messages"][index]
|
||||
ref_ids = ref_tokenizer.apply_chat_template(messages)
|
||||
ref_input_ids = ref_tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=True)
|
||||
ref_labels = ref_ids[len(ref_input_ids) :]
|
||||
assert train_dataset["input_ids"][index] == ref_input_ids
|
||||
assert train_dataset["labels"][index] == ref_labels
|
||||
@@ -0,0 +1,152 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from llamafactory.data import get_template_and_fix_tokenizer
|
||||
from llamafactory.data.collator import MultiModalDataCollatorForSeq2Seq, prepare_4d_attention_mask
|
||||
from llamafactory.extras.constants import IGNORE_INDEX
|
||||
from llamafactory.hparams import get_infer_args
|
||||
from llamafactory.model import load_tokenizer
|
||||
|
||||
|
||||
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||
|
||||
|
||||
def test_base_collator():
|
||||
model_args, data_args, *_ = get_infer_args({"model_name_or_path": TINY_LLAMA, "template": "default"})
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args)
|
||||
data_collator = MultiModalDataCollatorForSeq2Seq(
|
||||
template=template,
|
||||
pad_to_multiple_of=8,
|
||||
label_pad_token_id=IGNORE_INDEX,
|
||||
**tokenizer_module,
|
||||
)
|
||||
p = tokenizer_module["tokenizer"].pad_token_id
|
||||
q = IGNORE_INDEX
|
||||
features = [
|
||||
{
|
||||
"input_ids": [0, 1, 2, 3, 4, 5],
|
||||
"attention_mask": [1, 1, 1, 1, 1, 1],
|
||||
"labels": [q, q, 2, 3, 4, 5],
|
||||
},
|
||||
{
|
||||
"input_ids": [6, 7],
|
||||
"attention_mask": [1, 1],
|
||||
"labels": [q, 7],
|
||||
},
|
||||
]
|
||||
batch_input = data_collator(features)
|
||||
expected_input = {
|
||||
"input_ids": [
|
||||
[0, 1, 2, 3, 4, 5, p, p],
|
||||
[6, 7, p, p, p, p, p, p],
|
||||
],
|
||||
"attention_mask": [
|
||||
[1, 1, 1, 1, 1, 1, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0, 0, 0],
|
||||
],
|
||||
"labels": [
|
||||
[q, q, 2, 3, 4, 5, q, q],
|
||||
[q, 7, q, q, q, q, q, q],
|
||||
],
|
||||
}
|
||||
for k in batch_input.keys():
|
||||
assert batch_input[k].eq(torch.tensor(expected_input[k])).all()
|
||||
|
||||
|
||||
def test_multimodal_collator():
|
||||
model_args, data_args, *_ = get_infer_args(
|
||||
{"model_name_or_path": "Qwen/Qwen2-VL-7B-Instruct", "template": "qwen2_vl"}
|
||||
)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args)
|
||||
data_collator = MultiModalDataCollatorForSeq2Seq(
|
||||
template=template,
|
||||
pad_to_multiple_of=4,
|
||||
label_pad_token_id=IGNORE_INDEX,
|
||||
**tokenizer_module,
|
||||
)
|
||||
p = tokenizer_module["tokenizer"].pad_token_id
|
||||
q = IGNORE_INDEX
|
||||
s = tokenizer_module["tokenizer"].convert_tokens_to_ids("<|vision_start|>")
|
||||
e = tokenizer_module["tokenizer"].convert_tokens_to_ids("<|vision_end|>")
|
||||
m = tokenizer_module["tokenizer"].convert_tokens_to_ids("<|image_pad|>")
|
||||
fake_image = Image.new("RGB", (64, 64), (255, 255, 255))
|
||||
|
||||
features = [
|
||||
{
|
||||
"input_ids": [0, 1, 2, 3],
|
||||
"attention_mask": [1, 1, 1, 1],
|
||||
"labels": [0, 1, 2, 3],
|
||||
},
|
||||
]
|
||||
batch_input = data_collator(features)
|
||||
expected_input = {
|
||||
"input_ids": [
|
||||
[0, 1, 2, 3, s, m, m, m, m, e, p, p],
|
||||
],
|
||||
"attention_mask": [
|
||||
[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
],
|
||||
"labels": [
|
||||
[0, 1, 2, 3, q, q, q, q, q, q, q, q],
|
||||
],
|
||||
**tokenizer_module["processor"].image_processor(fake_image),
|
||||
}
|
||||
for k in batch_input.keys():
|
||||
assert batch_input[k].eq(torch.tensor(expected_input[k])).all()
|
||||
|
||||
|
||||
def test_4d_attention_mask():
|
||||
o = 0.0
|
||||
x = torch.finfo(torch.float16).min
|
||||
attention_mask_with_indices = torch.tensor(
|
||||
[
|
||||
[1, 1, 2, 2, 2, 0],
|
||||
[1, 2, 2, 3, 3, 3],
|
||||
]
|
||||
)
|
||||
attention_mask_computed = prepare_4d_attention_mask(attention_mask_with_indices, torch.float16)
|
||||
attention_mask_expected = torch.tensor(
|
||||
[
|
||||
[
|
||||
[
|
||||
[o, x, x, x, x, x],
|
||||
[o, o, x, x, x, x],
|
||||
[x, x, o, x, x, x],
|
||||
[x, x, o, o, x, x],
|
||||
[x, x, o, o, o, x],
|
||||
[x, x, x, x, x, x],
|
||||
]
|
||||
],
|
||||
[
|
||||
[
|
||||
[o, x, x, x, x, x],
|
||||
[x, o, x, x, x, x],
|
||||
[x, o, o, x, x, x],
|
||||
[x, x, x, o, x, x],
|
||||
[x, x, x, o, o, x],
|
||||
[x, x, x, o, o, o],
|
||||
]
|
||||
],
|
||||
],
|
||||
dtype=torch.float16,
|
||||
)
|
||||
assert list(attention_mask_computed.size()) == [2, 1, 6, 6]
|
||||
assert torch.all(attention_mask_computed == attention_mask_expected)
|
||||
@@ -0,0 +1,246 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
from llamafactory.data.formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
|
||||
|
||||
|
||||
FUNCTION = {"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}
|
||||
|
||||
TOOLS = [
|
||||
{
|
||||
"name": "test_tool",
|
||||
"description": "tool_desc",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"foo": {"type": "string", "description": "foo_desc"},
|
||||
"bar": {"type": "number", "description": "bar_desc"},
|
||||
},
|
||||
"required": ["foo"],
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_empty_formatter():
|
||||
formatter = EmptyFormatter(slots=["\n"])
|
||||
assert formatter.apply() == ["\n"]
|
||||
|
||||
|
||||
def test_string_formatter():
|
||||
formatter = StringFormatter(slots=["<s>", "Human: {{content}}\nAssistant:"])
|
||||
assert formatter.apply(content="Hi") == ["<s>", "Human: Hi\nAssistant:"]
|
||||
|
||||
|
||||
def test_function_formatter():
|
||||
formatter = FunctionFormatter(slots=["{{content}}", "</s>"], tool_format="default")
|
||||
tool_calls = json.dumps(FUNCTION)
|
||||
assert formatter.apply(content=tool_calls) == [
|
||||
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n""",
|
||||
"</s>",
|
||||
]
|
||||
|
||||
|
||||
def test_multi_function_formatter():
|
||||
formatter = FunctionFormatter(slots=["{{content}}", "</s>"], tool_format="default")
|
||||
tool_calls = json.dumps([FUNCTION] * 2)
|
||||
assert formatter.apply(content=tool_calls) == [
|
||||
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n"""
|
||||
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n""",
|
||||
"</s>",
|
||||
]
|
||||
|
||||
|
||||
def test_default_tool_formatter():
|
||||
formatter = ToolFormatter(tool_format="default")
|
||||
assert formatter.apply(content=json.dumps(TOOLS)) == [
|
||||
"You have access to the following tools:\n"
|
||||
"> Tool Name: test_tool\n"
|
||||
"Tool Description: tool_desc\n"
|
||||
"Tool Args:\n"
|
||||
" - foo (string, required): foo_desc\n"
|
||||
" - bar (number): bar_desc\n\n"
|
||||
"Use the following format if using a tool:\n"
|
||||
"```\n"
|
||||
"Action: tool name (one of [test_tool])\n"
|
||||
"Action Input: the input to the tool, in a JSON format representing the kwargs "
|
||||
"""(e.g. ```{"input": "hello world", "num_beams": 5}```)\n"""
|
||||
"```\n"
|
||||
]
|
||||
|
||||
|
||||
def test_default_tool_extractor():
|
||||
formatter = ToolFormatter(tool_format="default")
|
||||
result = """Action: test_tool\nAction Input: {"foo": "bar", "size": 10}\n"""
|
||||
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
|
||||
|
||||
|
||||
def test_default_multi_tool_extractor():
|
||||
formatter = ToolFormatter(tool_format="default")
|
||||
result = (
|
||||
"""Action: test_tool\nAction Input: {"foo": "bar", "size": 10}\n"""
|
||||
"""Action: another_tool\nAction Input: {"foo": "job", "size": 2}\n"""
|
||||
)
|
||||
assert formatter.extract(result) == [
|
||||
("test_tool", """{"foo": "bar", "size": 10}"""),
|
||||
("another_tool", """{"foo": "job", "size": 2}"""),
|
||||
]
|
||||
|
||||
|
||||
def test_glm4_function_formatter():
|
||||
formatter = FunctionFormatter(slots=["{{content}}"], tool_format="glm4")
|
||||
tool_calls = json.dumps(FUNCTION)
|
||||
assert formatter.apply(content=tool_calls) == ["""tool_name\n{"foo": "bar", "size": 10}"""]
|
||||
|
||||
|
||||
def test_glm4_tool_formatter():
|
||||
formatter = ToolFormatter(tool_format="glm4")
|
||||
assert formatter.apply(content=json.dumps(TOOLS)) == [
|
||||
"你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
|
||||
"你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具\n\n"
|
||||
f"## test_tool\n\n{json.dumps(TOOLS[0], indent=4, ensure_ascii=False)}\n在调用上述函数时,请使用 Json 格式表示调用的参数。"
|
||||
]
|
||||
|
||||
|
||||
def test_glm4_tool_extractor():
|
||||
formatter = ToolFormatter(tool_format="glm4")
|
||||
result = """test_tool\n{"foo": "bar", "size": 10}\n"""
|
||||
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
|
||||
|
||||
|
||||
def test_llama3_function_formatter():
|
||||
formatter = FunctionFormatter(slots=["{{content}}", "<|eot_id|>"], tool_format="llama3")
|
||||
tool_calls = json.dumps({"name": "tool_name", "arguments": {"foo": "bar", "size": 10}})
|
||||
assert formatter.apply(content=tool_calls) == [
|
||||
"""{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}""",
|
||||
"<|eot_id|>",
|
||||
]
|
||||
|
||||
|
||||
def test_llama3_tool_formatter():
|
||||
formatter = ToolFormatter(tool_format="llama3")
|
||||
date = datetime.now().strftime("%d %b %Y")
|
||||
wrapped_tool = {"type": "function", "function": TOOLS[0]}
|
||||
assert formatter.apply(content=json.dumps(TOOLS)) == [
|
||||
f"Cutting Knowledge Date: December 2023\nToday Date: {date}\n\n"
|
||||
"You have access to the following functions. To call a function, please respond with JSON for a function call. "
|
||||
"""Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. """
|
||||
f"Do not use variables.\n\n{json.dumps(wrapped_tool, indent=4, ensure_ascii=False)}\n\n"
|
||||
]
|
||||
|
||||
|
||||
def test_llama3_tool_extractor():
|
||||
formatter = ToolFormatter(tool_format="llama3")
|
||||
result = """{"name": "test_tool", "parameters": {"foo": "bar", "size": 10}}\n"""
|
||||
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
|
||||
|
||||
|
||||
def test_mistral_function_formatter():
|
||||
formatter = FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", "</s>"], tool_format="mistral")
|
||||
tool_calls = json.dumps(FUNCTION)
|
||||
assert formatter.apply(content=tool_calls) == [
|
||||
"[TOOL_CALLS] ",
|
||||
"""[{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}]""",
|
||||
"</s>",
|
||||
]
|
||||
|
||||
|
||||
def test_mistral_multi_function_formatter():
|
||||
formatter = FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", "</s>"], tool_format="mistral")
|
||||
tool_calls = json.dumps([FUNCTION] * 2)
|
||||
assert formatter.apply(content=tool_calls) == [
|
||||
"[TOOL_CALLS] ",
|
||||
"""[{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}, """
|
||||
"""{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}]""",
|
||||
"</s>",
|
||||
]
|
||||
|
||||
|
||||
def test_mistral_tool_formatter():
|
||||
formatter = ToolFormatter(tool_format="mistral")
|
||||
wrapped_tool = {"type": "function", "function": TOOLS[0]}
|
||||
assert formatter.apply(content=json.dumps(TOOLS)) == [
|
||||
"[AVAILABLE_TOOLS] " + json.dumps([wrapped_tool], ensure_ascii=False) + "[/AVAILABLE_TOOLS]"
|
||||
]
|
||||
|
||||
|
||||
def test_mistral_tool_extractor():
|
||||
formatter = ToolFormatter(tool_format="mistral")
|
||||
result = """{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}"""
|
||||
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
|
||||
|
||||
|
||||
def test_mistral_multi_tool_extractor():
|
||||
formatter = ToolFormatter(tool_format="mistral")
|
||||
result = (
|
||||
"""[{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}, """
|
||||
"""{"name": "another_tool", "arguments": {"foo": "job", "size": 2}}]"""
|
||||
)
|
||||
assert formatter.extract(result) == [
|
||||
("test_tool", """{"foo": "bar", "size": 10}"""),
|
||||
("another_tool", """{"foo": "job", "size": 2}"""),
|
||||
]
|
||||
|
||||
|
||||
def test_qwen_function_formatter():
|
||||
formatter = FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen")
|
||||
tool_calls = json.dumps(FUNCTION)
|
||||
assert formatter.apply(content=tool_calls) == [
|
||||
"""<tool_call>\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>""",
|
||||
"<|im_end|>",
|
||||
]
|
||||
|
||||
|
||||
def test_qwen_multi_function_formatter():
|
||||
formatter = FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen")
|
||||
tool_calls = json.dumps([FUNCTION] * 2)
|
||||
assert formatter.apply(content=tool_calls) == [
|
||||
"""<tool_call>\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>\n"""
|
||||
"""<tool_call>\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>""",
|
||||
"<|im_end|>",
|
||||
]
|
||||
|
||||
|
||||
def test_qwen_tool_formatter():
|
||||
formatter = ToolFormatter(tool_format="qwen")
|
||||
wrapped_tool = {"type": "function", "function": TOOLS[0]}
|
||||
assert formatter.apply(content=json.dumps(TOOLS)) == [
|
||||
"\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
|
||||
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>"
|
||||
f"\n{json.dumps(wrapped_tool, ensure_ascii=False)}"
|
||||
"\n</tools>\n\nFor each function call, return a json object with function name and arguments within "
|
||||
"""<tool_call></tool_call> XML tags:\n<tool_call>\n{"name": <function-name>, """
|
||||
""""arguments": <args-json-object>}\n</tool_call><|im_end|>\n"""
|
||||
]
|
||||
|
||||
|
||||
def test_qwen_tool_extractor():
|
||||
formatter = ToolFormatter(tool_format="qwen")
|
||||
result = """<tool_call>\n{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>"""
|
||||
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
|
||||
|
||||
|
||||
def test_qwen_multi_tool_extractor():
|
||||
formatter = ToolFormatter(tool_format="qwen")
|
||||
result = (
|
||||
"""<tool_call>\n{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>\n"""
|
||||
"""<tool_call>\n{"name": "another_tool", "arguments": {"foo": "job", "size": 2}}\n</tool_call>"""
|
||||
)
|
||||
assert formatter.extract(result) == [
|
||||
("test_tool", """{"foo": "bar", "size": 10}"""),
|
||||
("another_tool", """{"foo": "job", "size": 2}"""),
|
||||
]
|
||||
@@ -0,0 +1,236 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Sequence
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from llamafactory.data.mm_plugin import get_mm_plugin
|
||||
from llamafactory.hparams import get_infer_args
|
||||
from llamafactory.model import load_tokenizer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
|
||||
from llamafactory.data.mm_plugin import BasePlugin
|
||||
from llamafactory.model.loader import TokenizerModule
|
||||
|
||||
|
||||
HF_TOKEN = os.getenv("HF_TOKEN")
|
||||
|
||||
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||
|
||||
MM_MESSAGES = [
|
||||
{"role": "user", "content": "<image>What is in this image?"},
|
||||
{"role": "assistant", "content": "A cat."},
|
||||
]
|
||||
|
||||
TEXT_MESSAGES = [
|
||||
{"role": "user", "content": "How are you"},
|
||||
{"role": "assistant", "content": "I am fine!"},
|
||||
]
|
||||
|
||||
IMAGES = [Image.new("RGB", (32, 32), (255, 255, 255))]
|
||||
|
||||
NO_IMAGES = []
|
||||
|
||||
NO_VIDEOS = []
|
||||
|
||||
IMGLENS = [1]
|
||||
|
||||
NO_IMGLENS = [0]
|
||||
|
||||
NO_VIDLENS = [0]
|
||||
|
||||
INPUT_IDS = [0, 1, 2, 3, 4]
|
||||
|
||||
LABELS = [0, 1, 2, 3, 4]
|
||||
|
||||
BATCH_IDS = [[1] * 1024]
|
||||
|
||||
|
||||
def _get_mm_inputs(processor: "ProcessorMixin") -> Dict[str, "torch.Tensor"]:
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
return image_processor(images=IMAGES, return_tensors="pt")
|
||||
|
||||
|
||||
def _is_close(batch_a: Dict[str, Any], batch_b: Dict[str, Any]) -> None:
|
||||
assert batch_a.keys() == batch_b.keys()
|
||||
for key in batch_a.keys():
|
||||
if isinstance(batch_a[key], torch.Tensor):
|
||||
assert torch.allclose(batch_a[key], batch_b[key], rtol=1e-4, atol=1e-5)
|
||||
elif isinstance(batch_a[key], list) and all(isinstance(item, torch.Tensor) for item in batch_a[key]):
|
||||
assert len(batch_a[key]) == len(batch_b[key])
|
||||
for tensor_a, tensor_b in zip(batch_a[key], batch_b[key]):
|
||||
assert torch.allclose(tensor_a, tensor_b, rtol=1e-4, atol=1e-5)
|
||||
else:
|
||||
assert batch_a[key] == batch_b[key]
|
||||
|
||||
|
||||
def _load_tokenizer_module(model_name_or_path: str) -> "TokenizerModule":
|
||||
model_args, *_ = get_infer_args({"model_name_or_path": model_name_or_path, "template": "default"})
|
||||
return load_tokenizer(model_args)
|
||||
|
||||
|
||||
def _check_plugin(
|
||||
plugin: "BasePlugin",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: "ProcessorMixin",
|
||||
expected_mm_messages: Sequence[Dict[str, str]] = MM_MESSAGES,
|
||||
expected_input_ids: List[int] = INPUT_IDS,
|
||||
expected_labels: List[int] = LABELS,
|
||||
expected_mm_inputs: Dict[str, Any] = {},
|
||||
expected_no_mm_inputs: Dict[str, Any] = {},
|
||||
) -> None:
|
||||
# test mm_messages
|
||||
assert plugin.process_messages(MM_MESSAGES, IMAGES, NO_VIDEOS, processor) == expected_mm_messages
|
||||
assert plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, NO_VIDEOS, tokenizer, processor) == (
|
||||
expected_input_ids,
|
||||
expected_labels,
|
||||
)
|
||||
_is_close(
|
||||
plugin.get_mm_inputs(IMAGES, NO_VIDEOS, IMGLENS, NO_VIDLENS, BATCH_IDS, processor),
|
||||
expected_mm_inputs,
|
||||
)
|
||||
# test text_messages
|
||||
assert plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, NO_VIDEOS, processor) == TEXT_MESSAGES
|
||||
assert plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, NO_VIDEOS, tokenizer, processor) == (
|
||||
INPUT_IDS,
|
||||
LABELS,
|
||||
)
|
||||
_is_close(
|
||||
plugin.get_mm_inputs(NO_IMAGES, NO_VIDEOS, NO_IMGLENS, NO_VIDLENS, BATCH_IDS, processor),
|
||||
expected_no_mm_inputs,
|
||||
)
|
||||
|
||||
|
||||
def test_base_plugin():
|
||||
tokenizer_module = _load_tokenizer_module(model_name_or_path=TINY_LLAMA)
|
||||
base_plugin = get_mm_plugin(name="base", image_token="<image>")
|
||||
check_inputs = {"plugin": base_plugin, **tokenizer_module}
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
def test_llava_plugin():
|
||||
image_seqlen = 576
|
||||
tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf")
|
||||
llava_plugin = get_mm_plugin(name="llava", image_token="<image>")
|
||||
check_inputs = {"plugin": llava_plugin, **tokenizer_module}
|
||||
check_inputs["expected_mm_messages"] = [
|
||||
{key: value.replace("<image>", "<image>" * image_seqlen) for key, value in message.items()}
|
||||
for message in MM_MESSAGES
|
||||
]
|
||||
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
def test_llava_next_plugin():
|
||||
image_seqlen = 1176
|
||||
tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-v1.6-vicuna-7b-hf")
|
||||
llava_next_plugin = get_mm_plugin(name="llava_next", image_token="<image>")
|
||||
check_inputs = {"plugin": llava_next_plugin, **tokenizer_module}
|
||||
check_inputs["expected_mm_messages"] = [
|
||||
{key: value.replace("<image>", "<image>" * image_seqlen) for key, value in message.items()}
|
||||
for message in MM_MESSAGES
|
||||
]
|
||||
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
def test_llava_next_video_plugin():
|
||||
image_seqlen = 1176
|
||||
tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/LLaVA-NeXT-Video-7B-hf")
|
||||
llava_next_video_plugin = get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>")
|
||||
check_inputs = {"plugin": llava_next_video_plugin, **tokenizer_module}
|
||||
check_inputs["expected_mm_messages"] = [
|
||||
{key: value.replace("<image>", "<image>" * image_seqlen) for key, value in message.items()}
|
||||
for message in MM_MESSAGES
|
||||
]
|
||||
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
|
||||
def test_paligemma_plugin():
|
||||
image_seqlen = 256
|
||||
tokenizer_module = _load_tokenizer_module(model_name_or_path="google/paligemma-3b-pt-224")
|
||||
paligemma_plugin = get_mm_plugin(name="paligemma", image_token="<image>")
|
||||
check_inputs = {"plugin": paligemma_plugin, **tokenizer_module}
|
||||
check_inputs["expected_mm_messages"] = [
|
||||
{key: value.replace("<image>", "") for key, value in message.items()} for message in MM_MESSAGES
|
||||
]
|
||||
check_inputs["expected_input_ids"] = [
|
||||
tokenizer_module["tokenizer"].convert_tokens_to_ids(paligemma_plugin.image_token)
|
||||
] * image_seqlen + INPUT_IDS
|
||||
check_inputs["expected_labels"] = [-100] * image_seqlen + LABELS
|
||||
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
|
||||
check_inputs["expected_mm_inputs"]["token_type_ids"] = [[0] * image_seqlen + [1] * (1024 - image_seqlen)]
|
||||
check_inputs["expected_no_mm_inputs"] = {"token_type_ids": [[1] * 1024]}
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
def test_pixtral_plugin():
|
||||
image_slice_height, image_slice_width = 2, 2
|
||||
tokenizer_module = _load_tokenizer_module(model_name_or_path="mistral-community/pixtral-12b")
|
||||
pixtral_plugin = get_mm_plugin(name="pixtral", image_token="[IMG]")
|
||||
check_inputs = {"plugin": pixtral_plugin, **tokenizer_module}
|
||||
check_inputs["expected_mm_messages"] = [
|
||||
{
|
||||
key: value.replace(
|
||||
"<image>",
|
||||
("{}[IMG_BREAK]".format("[IMG]" * image_slice_width) * image_slice_height).rsplit("[IMG_BREAK]", 1)[0]
|
||||
+ "[IMG_END]",
|
||||
)
|
||||
for key, value in message.items()
|
||||
}
|
||||
for message in MM_MESSAGES
|
||||
]
|
||||
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
|
||||
check_inputs["expected_mm_inputs"].pop("image_sizes")
|
||||
check_inputs["expected_mm_inputs"]["pixel_values"] = check_inputs["expected_mm_inputs"]["pixel_values"][0]
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
def test_qwen2_vl_plugin():
|
||||
image_seqlen = 4
|
||||
tokenizer_module = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2-VL-7B-Instruct")
|
||||
qwen2_vl_plugin = get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>")
|
||||
check_inputs = {"plugin": qwen2_vl_plugin, **tokenizer_module}
|
||||
check_inputs["expected_mm_messages"] = [
|
||||
{
|
||||
key: value.replace("<image>", "<|vision_start|>{}<|vision_end|>".format("<|image_pad|>" * image_seqlen))
|
||||
for key, value in message.items()
|
||||
}
|
||||
for message in MM_MESSAGES
|
||||
]
|
||||
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
def test_video_llava_plugin():
|
||||
image_seqlen = 256
|
||||
tokenizer_module = _load_tokenizer_module(model_name_or_path="LanguageBind/Video-LLaVA-7B-hf")
|
||||
video_llava_plugin = get_mm_plugin(name="video_llava", image_token="<image>", video_token="<video>")
|
||||
check_inputs = {"plugin": video_llava_plugin, **tokenizer_module}
|
||||
check_inputs["expected_mm_messages"] = [
|
||||
{key: value.replace("<image>", "<image>" * image_seqlen) for key, value in message.items()}
|
||||
for message in MM_MESSAGES
|
||||
]
|
||||
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
|
||||
_check_plugin(**check_inputs)
|
||||
@@ -0,0 +1,172 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING, List, Sequence
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from llamafactory.data import get_template_and_fix_tokenizer
|
||||
from llamafactory.data.template import _get_jinja_template
|
||||
from llamafactory.hparams import DataArguments
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
|
||||
HF_TOKEN = os.getenv("HF_TOKEN")
|
||||
|
||||
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||
|
||||
MESSAGES = [
|
||||
{"role": "user", "content": "How are you"},
|
||||
{"role": "assistant", "content": "I am fine!"},
|
||||
{"role": "user", "content": "你好"},
|
||||
{"role": "assistant", "content": "很高兴认识你!"},
|
||||
]
|
||||
|
||||
|
||||
def _check_tokenization(
|
||||
tokenizer: "PreTrainedTokenizer", batch_input_ids: Sequence[Sequence[int]], batch_text: Sequence[str]
|
||||
) -> None:
|
||||
for input_ids, text in zip(batch_input_ids, batch_text):
|
||||
assert input_ids == tokenizer.encode(text, add_special_tokens=False)
|
||||
assert tokenizer.decode(input_ids) == text
|
||||
|
||||
|
||||
def _check_single_template(
|
||||
model_id: str, template_name: str, prompt_str: str, answer_str: str, extra_str: str, use_fast: bool
|
||||
) -> List[str]:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast, token=HF_TOKEN)
|
||||
content_str = tokenizer.apply_chat_template(MESSAGES, tokenize=False)
|
||||
content_ids = tokenizer.apply_chat_template(MESSAGES, tokenize=True)
|
||||
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template=template_name))
|
||||
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
|
||||
assert content_str == prompt_str + answer_str + extra_str
|
||||
assert content_ids == prompt_ids + answer_ids + tokenizer.encode(extra_str, add_special_tokens=False)
|
||||
_check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))
|
||||
return content_ids
|
||||
|
||||
|
||||
def _check_template(model_id: str, template_name: str, prompt_str: str, answer_str: str, extra_str: str = "") -> None:
|
||||
"""
|
||||
Checks template for both the slow tokenizer and the fast tokenizer.
|
||||
|
||||
Args:
|
||||
model_id: the model id on hugging face hub.
|
||||
template_name: the template name.
|
||||
prompt_str: the string corresponding to the prompt part.
|
||||
answer_str: the string corresponding to the answer part.
|
||||
extra_str: the extra string in the jinja template of the original tokenizer.
|
||||
"""
|
||||
slow_ids = _check_single_template(model_id, template_name, prompt_str, answer_str, extra_str, use_fast=False)
|
||||
fast_ids = _check_single_template(model_id, template_name, prompt_str, answer_str, extra_str, use_fast=True)
|
||||
assert slow_ids == fast_ids
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_fast", [True, False])
|
||||
def test_encode_oneturn(use_fast: bool):
|
||||
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast)
|
||||
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
|
||||
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
|
||||
prompt_str = (
|
||||
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\nI am fine!<|eot_id|>"
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n你好<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
answer_str = "很高兴认识你!<|eot_id|>"
|
||||
_check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_fast", [True, False])
|
||||
def test_encode_multiturn(use_fast: bool):
|
||||
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast)
|
||||
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
|
||||
encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES)
|
||||
prompt_str_1 = (
|
||||
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
answer_str_1 = "I am fine!<|eot_id|>"
|
||||
prompt_str_2 = (
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n你好<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
answer_str_2 = "很高兴认识你!<|eot_id|>"
|
||||
_check_tokenization(
|
||||
tokenizer,
|
||||
(encoded_pairs[0][0], encoded_pairs[0][1], encoded_pairs[1][0], encoded_pairs[1][1]),
|
||||
(prompt_str_1, answer_str_1, prompt_str_2, answer_str_2),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_fast", [True, False])
|
||||
def test_jinja_template(use_fast: bool):
|
||||
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast)
|
||||
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast)
|
||||
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
|
||||
tokenizer.chat_template = _get_jinja_template(template, tokenizer) # llama3 template no replace
|
||||
assert tokenizer.chat_template != ref_tokenizer.chat_template
|
||||
assert tokenizer.apply_chat_template(MESSAGES) == ref_tokenizer.apply_chat_template(MESSAGES)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
|
||||
def test_gemma_template():
|
||||
prompt_str = (
|
||||
"<bos><start_of_turn>user\nHow are you<end_of_turn>\n"
|
||||
"<start_of_turn>model\nI am fine!<end_of_turn>\n"
|
||||
"<start_of_turn>user\n你好<end_of_turn>\n"
|
||||
"<start_of_turn>model\n"
|
||||
)
|
||||
answer_str = "很高兴认识你!"
|
||||
_check_template("google/gemma-2-9b-it", "gemma", prompt_str, answer_str, extra_str="<end_of_turn>\n")
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
|
||||
def test_llama3_template():
|
||||
prompt_str = (
|
||||
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\nI am fine!<|eot_id|>"
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n你好<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
answer_str = "很高兴认识你!<|eot_id|>"
|
||||
_check_template("meta-llama/Meta-Llama-3-8B-Instruct", "llama3", prompt_str, answer_str)
|
||||
|
||||
|
||||
def test_qwen_template():
|
||||
prompt_str = (
|
||||
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
|
||||
"<|im_start|>user\nHow are you<|im_end|>\n"
|
||||
"<|im_start|>assistant\nI am fine!<|im_end|>\n"
|
||||
"<|im_start|>user\n你好<|im_end|>\n"
|
||||
"<|im_start|>assistant\n"
|
||||
)
|
||||
answer_str = "很高兴认识你!<|im_end|>"
|
||||
_check_template("Qwen/Qwen2-7B-Instruct", "qwen", prompt_str, answer_str, extra_str="\n")
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="The fast tokenizer of Yi model is corrupted.")
|
||||
def test_yi_template():
|
||||
prompt_str = (
|
||||
"<|im_start|>user\nHow are you<|im_end|>\n"
|
||||
"<|im_start|>assistant\nI am fine!<|im_end|>\n"
|
||||
"<|im_start|>user\n你好<|im_end|>\n"
|
||||
"<|im_start|>assistant\n"
|
||||
)
|
||||
answer_str = "很高兴认识你!<|im_end|>"
|
||||
_check_template("01-ai/Yi-1.5-6B-Chat", "yi", prompt_str, answer_str)
|
||||
Reference in New Issue
Block a user