117 lines
4.4 KiB
Python
117 lines
4.4 KiB
Python
import torch
|
|
from trl import (
|
|
get_kbit_device_map,
|
|
# get_peft_config,
|
|
get_quantization_config,
|
|
)
|
|
from utils.args import ContinualModelConfig
|
|
|
|
|
|
def get_model(model_args: ContinualModelConfig):
|
|
torch_dtype = (
|
|
model_args.torch_dtype
|
|
if model_args.torch_dtype in ["auto", None]
|
|
else getattr(torch, model_args.torch_dtype)
|
|
)
|
|
quantization_config = get_quantization_config(model_args)
|
|
model_kwargs = dict(
|
|
revision=model_args.model_revision,
|
|
attn_implementation=model_args.attn_implementation,
|
|
torch_dtype=torch_dtype,
|
|
device_map=(get_kbit_device_map() if quantization_config is not None else None),
|
|
quantization_config=quantization_config,
|
|
)
|
|
|
|
if model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct":
|
|
from transformers import Qwen2VLProcessor, Qwen2VLForConditionalGeneration
|
|
|
|
# from .qwen2vl import Qwen2VLForConditionalGeneration_modified
|
|
|
|
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
|
model_args.model_name_or_path,
|
|
trust_remote_code=model_args.trust_remote_code,
|
|
**model_kwargs,
|
|
)
|
|
processor = Qwen2VLProcessor.from_pretrained(
|
|
model_args.model_name_or_path,
|
|
trust_remote_code=model_args.trust_remote_code,
|
|
padding_side="left",
|
|
)
|
|
from .qwen2vl import (
|
|
collate_fn_for_train,
|
|
collate_fn_for_evaluate,
|
|
)
|
|
from functools import partial
|
|
|
|
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
|
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
|
|
|
|
if model_args.model_name_or_path == "Qwen/Qwen2-Audio-7B-Instruct":
|
|
from transformers import Qwen2AudioProcessor, Qwen2AudioForConditionalGeneration
|
|
|
|
model = Qwen2AudioForConditionalGeneration.from_pretrained(
|
|
model_args.model_name_or_path,
|
|
trust_remote_code=model_args.trust_remote_code,
|
|
**model_kwargs,
|
|
)
|
|
processor = Qwen2AudioProcessor.from_pretrained(
|
|
model_args.model_name_or_path,
|
|
trust_remote_code=model_args.trust_remote_code,
|
|
padding_side="left",
|
|
)
|
|
from .qwen2audio import (
|
|
collate_fn_for_train,
|
|
collate_fn_for_evaluate,
|
|
)
|
|
from functools import partial
|
|
|
|
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
|
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
|
|
|
|
if model_args.model_name_or_path == "Qwen/Qwen2.5-VL-3B-Instruct":
|
|
from transformers import Qwen2_5_VLProcessor, Qwen2_5_VLForConditionalGeneration
|
|
|
|
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
|
model_args.model_name_or_path,
|
|
trust_remote_code=model_args.trust_remote_code,
|
|
**model_kwargs,
|
|
)
|
|
|
|
processor = Qwen2_5_VLProcessor.from_pretrained(
|
|
model_args.model_name_or_path,
|
|
trust_remote_code=model_args.trust_remote_code,
|
|
padding_side="left",
|
|
)
|
|
|
|
from model_library.qwen2vl import collate_fn_for_train, collate_fn_for_evaluate
|
|
from functools import partial
|
|
|
|
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
|
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
|
|
|
|
if model_args.model_name_or_path == "Qwen/Qwen2.5-Omni-3B":
|
|
from transformers.models.qwen2_5_omni import (
|
|
Qwen2_5OmniThinkerForConditionalGeneration,
|
|
Qwen2_5OmniProcessor
|
|
)
|
|
|
|
model = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained(
|
|
model_args.model_name_or_path,
|
|
trust_remote_code=model_args.trust_remote_code,
|
|
**model_kwargs,
|
|
)
|
|
|
|
processor = Qwen2_5OmniProcessor.from_pretrained(
|
|
model_args.model_name_or_path,
|
|
trust_remote_code=model_args.trust_remote_code,
|
|
padding_side="left",
|
|
)
|
|
|
|
from model_library.qwen2vl import collate_fn_for_train, collate_fn_for_evaluate
|
|
from functools import partial
|
|
|
|
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
|
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
|
|
|
|
return model, processor, collate_fn_for_train, collate_fn_for_evaluate
|