import torch from trl import ( get_kbit_device_map, # get_peft_config, get_quantization_config, ) from utils.args import ContinualModelConfig def get_model(model_args: ContinualModelConfig): torch_dtype = ( model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) ) quantization_config = get_quantization_config(model_args) model_kwargs = dict( revision=model_args.model_revision, attn_implementation=model_args.attn_implementation, torch_dtype=torch_dtype, device_map=(get_kbit_device_map() if quantization_config is not None else None), quantization_config=quantization_config, ) if model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct": from transformers import Qwen2VLProcessor, Qwen2VLForConditionalGeneration # from .qwen2vl import Qwen2VLForConditionalGeneration_modified model = Qwen2VLForConditionalGeneration.from_pretrained( model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs, ) processor = Qwen2VLProcessor.from_pretrained( model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, padding_side="left", ) from .qwen2vl import ( collate_fn_for_train, collate_fn_for_evaluate, ) from functools import partial collate_fn_for_train = partial(collate_fn_for_train, processor=processor) collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor) if model_args.model_name_or_path == "Qwen/Qwen2-Audio-7B-Instruct": from transformers import Qwen2AudioProcessor, Qwen2AudioForConditionalGeneration model = Qwen2AudioForConditionalGeneration.from_pretrained( model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs, ) processor = Qwen2AudioProcessor.from_pretrained( model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, padding_side="left", ) from .qwen2audio import ( collate_fn_for_train, collate_fn_for_evaluate, ) from functools import partial collate_fn_for_train = partial(collate_fn_for_train, processor=processor) collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor) if model_args.model_name_or_path == "Qwen/Qwen2.5-VL-3B-Instruct": from transformers import Qwen2_5_VLProcessor, Qwen2_5_VLForConditionalGeneration model = Qwen2_5_VLForConditionalGeneration.from_pretrained( model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs, ) processor = Qwen2_5_VLProcessor.from_pretrained( model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, padding_side="left", ) from model_library.qwen2vl import collate_fn_for_train, collate_fn_for_evaluate from functools import partial collate_fn_for_train = partial(collate_fn_for_train, processor=processor) collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor) if model_args.model_name_or_path == "Qwen/Qwen2.5-Omni-3B": from transformers.models.qwen2_5_omni import ( Qwen2_5OmniThinkerForConditionalGeneration, Qwen2_5OmniProcessor ) model = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained( model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs, ) processor = Qwen2_5OmniProcessor.from_pretrained( model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, padding_side="left", ) from model_library.qwen2vl import collate_fn_for_train, collate_fn_for_evaluate from functools import partial collate_fn_for_train = partial(collate_fn_for_train, processor=processor) collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor) return model, processor, collate_fn_for_train, collate_fn_for_evaluate