Refactor code structure for improved readability and maintainability
This commit is contained in:
parent
56e46f0e0c
commit
3fe2c85f6b
@ -6,24 +6,25 @@ import concurrent.futures
|
|||||||
import threading
|
import threading
|
||||||
|
|
||||||
# Set the file paths for your Google Drive
|
# Set the file paths for your Google Drive
|
||||||
dataset_path = './dataset.json'
|
dataset_path = "./dataset.json"
|
||||||
images_path = './images'
|
images_path = "./images"
|
||||||
download = 1 # Set to 0 if images are already downloaded
|
download = 1 # Set to 0 if images are already downloaded
|
||||||
|
|
||||||
# Load dataset json file
|
# Load dataset json file
|
||||||
with open(dataset_path, 'r') as fp:
|
with open(dataset_path, "r") as fp:
|
||||||
data = json.load(fp)
|
data = json.load(fp)
|
||||||
|
|
||||||
# Initialize a counter and a lock for thread-safe counting
|
# Initialize a counter and a lock for thread-safe counting
|
||||||
downloaded_count = 0
|
downloaded_count = 0
|
||||||
count_lock = threading.Lock()
|
count_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
# Function to download an image
|
# Function to download an image
|
||||||
def download_image(k):
|
def download_image(k):
|
||||||
global downloaded_count
|
global downloaded_count
|
||||||
imageURL = data[k]['imageURL']
|
imageURL = data[k]["imageURL"]
|
||||||
ext = os.path.splitext(imageURL)[1]
|
ext = os.path.splitext(imageURL)[1]
|
||||||
outputFile = os.path.join(images_path, f'{k}{ext}')
|
outputFile = os.path.join(images_path, f"{k}{ext}")
|
||||||
|
|
||||||
# Only download the image if it doesn't exist
|
# Only download the image if it doesn't exist
|
||||||
if not os.path.exists(outputFile):
|
if not os.path.exists(outputFile):
|
||||||
@ -33,9 +34,10 @@ def download_image(k):
|
|||||||
with count_lock:
|
with count_lock:
|
||||||
downloaded_count += 1
|
downloaded_count += 1
|
||||||
if downloaded_count % 100 == 0:
|
if downloaded_count % 100 == 0:
|
||||||
print(f'{downloaded_count} images downloaded.')
|
print(f"{downloaded_count} images downloaded.")
|
||||||
except urllib.error.URLError as e:
|
except urllib.error.URLError as e:
|
||||||
print(f'Error downloading {outputFile}: {e}')
|
print(f"Error downloading {outputFile}: {e}")
|
||||||
|
|
||||||
|
|
||||||
# Download images using multiple threads
|
# Download images using multiple threads
|
||||||
if download == 1:
|
if download == 1:
|
||||||
@ -45,5 +47,5 @@ if download == 1:
|
|||||||
# Create a thread pool and download the images in parallel
|
# Create a thread pool and download the images in parallel
|
||||||
# Increase max_workers to potentially speed up downloads for many small files.
|
# Increase max_workers to potentially speed up downloads for many small files.
|
||||||
# The optimal number may vary based on your network and the server's capacity.
|
# The optimal number may vary based on your network and the server's capacity.
|
||||||
with concurrent.futures.ThreadPoolExecutor(max_workers=50) as executor:
|
with concurrent.futures.ThreadPoolExecutor(max_workers=400) as executor:
|
||||||
executor.map(download_image, data.keys())
|
executor.map(download_image, data.keys())
|
||||||
|
@ -5,6 +5,7 @@ dependencies = [
|
|||||||
"datasets==3.2.0",
|
"datasets==3.2.0",
|
||||||
"deepspeed==0.16.2",
|
"deepspeed==0.16.2",
|
||||||
"evaluate==0.4.3",
|
"evaluate==0.4.3",
|
||||||
|
"huggingface-hub==0.30.1",
|
||||||
"librosa>=0.10.2.post1",
|
"librosa>=0.10.2.post1",
|
||||||
"markupsafe==2.1.5",
|
"markupsafe==2.1.5",
|
||||||
"ms-swift>=1.3.0",
|
"ms-swift>=1.3.0",
|
||||||
|
3
src/.gitignore
vendored
3
src/.gitignore
vendored
@ -1 +1,2 @@
|
|||||||
checkpoint/*
|
checkpoint/*
|
||||||
|
wandb/*
|
@ -2,7 +2,7 @@ compute_environment: LOCAL_MACHINE
|
|||||||
debug: false
|
debug: false
|
||||||
deepspeed_config:
|
deepspeed_config:
|
||||||
deepspeed_multinode_launcher: standard
|
deepspeed_multinode_launcher: standard
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 4
|
||||||
zero3_init_flag: false
|
zero3_init_flag: false
|
||||||
zero_stage: 1
|
zero_stage: 1
|
||||||
distributed_type: DEEPSPEED
|
distributed_type: DEEPSPEED
|
||||||
@ -11,7 +11,7 @@ machine_rank: 0
|
|||||||
main_training_function: main
|
main_training_function: main
|
||||||
mixed_precision: 'bf16'
|
mixed_precision: 'bf16'
|
||||||
num_machines: 1
|
num_machines: 1
|
||||||
num_processes: 8
|
num_processes: 4
|
||||||
rdzv_backend: static
|
rdzv_backend: static
|
||||||
same_network: true
|
same_network: true
|
||||||
tpu_env: []
|
tpu_env: []
|
||||||
|
@ -12,7 +12,7 @@ machine_rank: 0
|
|||||||
main_training_function: main
|
main_training_function: main
|
||||||
mixed_precision: 'bf16'
|
mixed_precision: 'bf16'
|
||||||
num_machines: 1
|
num_machines: 1
|
||||||
num_processes: 8
|
num_processes: 4
|
||||||
rdzv_backend: static
|
rdzv_backend: static
|
||||||
same_network: true
|
same_network: true
|
||||||
tpu_env: []
|
tpu_env: []
|
||||||
|
@ -38,12 +38,44 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
from model_library.factory import get_model
|
from model_library.factory import get_model
|
||||||
|
|
||||||
if model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct":
|
if model_args.model_name_or_path == "Qwen/Qwen2.5-VL-3B-Instruct":
|
||||||
torch_dtype = (
|
torch_dtype = (
|
||||||
model_args.torch_dtype
|
model_args.torch_dtype
|
||||||
if model_args.torch_dtype in ["auto", None]
|
if model_args.torch_dtype in ["auto", None]
|
||||||
else getattr(torch, model_args.torch_dtype)
|
else getattr(torch, model_args.torch_dtype)
|
||||||
)
|
)
|
||||||
|
quantization_config = get_quantization_config(model_args)
|
||||||
|
model_kwargs = dict(
|
||||||
|
attn_implementation=model_args.attn_implementation,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
quantization_config=quantization_config,
|
||||||
|
)
|
||||||
|
from transformers import Qwen2_5_VLProcessor, Qwen2_5_VLForConditionalGeneration
|
||||||
|
|
||||||
|
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||||
|
training_args.output_dir,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
processor = Qwen2_5_VLProcessor.from_pretrained(
|
||||||
|
model_args.model_name_or_path,
|
||||||
|
trust_remote_code=model_args.trust_remote_code,
|
||||||
|
padding_side="left",
|
||||||
|
)
|
||||||
|
|
||||||
|
from model_library.qwen2vl import collate_fn_for_train, collate_fn_for_evaluate
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
||||||
|
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
|
||||||
|
|
||||||
|
elif model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct":
|
||||||
|
torch_dtype = (
|
||||||
|
model_args.torch_dtype
|
||||||
|
if model_args.torch_dtype in ["auto", None]
|
||||||
|
else getattr(torch, model_args.torch_dtype)
|
||||||
|
)
|
||||||
|
|
||||||
quantization_config = get_quantization_config(model_args)
|
quantization_config = get_quantization_config(model_args)
|
||||||
model_kwargs = dict(
|
model_kwargs = dict(
|
||||||
attn_implementation=model_args.attn_implementation,
|
attn_implementation=model_args.attn_implementation,
|
||||||
|
@ -68,4 +68,25 @@ def get_model(model_args: ContinualModelConfig):
|
|||||||
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
||||||
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
|
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
|
||||||
|
|
||||||
|
if model_args.model_name_or_path == "Qwen/Qwen2.5-VL-3B-Instruct":
|
||||||
|
from transformers import Qwen2_5_VLProcessor, Qwen2_5_VLForConditionalGeneration
|
||||||
|
|
||||||
|
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||||
|
model_args.model_name_or_path,
|
||||||
|
trust_remote_code=model_args.trust_remote_code,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
processor = Qwen2_5_VLProcessor.from_pretrained(
|
||||||
|
model_args.model_name_or_path,
|
||||||
|
trust_remote_code=model_args.trust_remote_code,
|
||||||
|
padding_side="left",
|
||||||
|
)
|
||||||
|
|
||||||
|
from model_library.qwen2vl import collate_fn_for_train, collate_fn_for_evaluate
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
||||||
|
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
|
||||||
|
|
||||||
return model, processor, collate_fn_for_train, collate_fn_for_evaluate
|
return model, processor, collate_fn_for_train, collate_fn_for_evaluate
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from .collate_fn import collate_fn_for_evaluate, collate_fn_for_train
|
from .collate_fn import collate_fn_for_evaluate, collate_fn_for_train
|
||||||
from .model import Qwen2VLForConditionalGeneration_modified
|
|
||||||
|
# from .model import Qwen2VLForConditionalGeneration_modified
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"collate_fn_for_train",
|
"collate_fn_for_train",
|
||||||
|
@ -73,7 +73,6 @@ from peft.tuners import (
|
|||||||
from .tuners import MMOELoraModel, MMOELoraConfig
|
from .tuners import MMOELoraModel, MMOELoraConfig
|
||||||
from peft.tuners.tuners_utils import BaseTuner
|
from peft.tuners.tuners_utils import BaseTuner
|
||||||
from peft.utils import _prepare_prompt_learning_config
|
from peft.utils import _prepare_prompt_learning_config
|
||||||
from peft.utils.constants import PEFT_TYPE_TO_PREFIX_MAPPING
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -46,7 +46,7 @@ from transformers.modeling_outputs import (
|
|||||||
)
|
)
|
||||||
from transformers.utils import PushToHubMixin
|
from transformers.utils import PushToHubMixin
|
||||||
|
|
||||||
from peft.utils.constants import DUMMY_MODEL_CONFIG, PEFT_TYPE_TO_PREFIX_MAPPING
|
from peft.utils.constants import DUMMY_MODEL_CONFIG
|
||||||
|
|
||||||
from peft import __version__
|
from peft import __version__
|
||||||
from peft.config import PeftConfig
|
from peft.config import PeftConfig
|
||||||
|
@ -1 +1 @@
|
|||||||
Subproject commit 65c3c43cd195bd90b8cb339c1ba883b4c6c66b43
|
Subproject commit 83111347f3df66f04bd6759b1a3dcce719380628
|
10
src/train.sh
10
src/train.sh
@ -1,15 +1,15 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml train.py \
|
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero1.yaml train.py \
|
||||||
--dataset_name chem \
|
--dataset_name refcoco \
|
||||||
--use_peft \
|
--use_peft \
|
||||||
--peft_type LORA \
|
--peft_type LORA \
|
||||||
--model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
|
--model_name_or_path Qwen/Qwen2.5-VL-3B-Instruct \
|
||||||
--lora_target_modules .\*proj.\*\|.\*fc.\*\|.\*mlp\.0\|.\*mlp\.2 \
|
--lora_target_modules .\*proj.\*\|.\*fc.\*\|.\*mlp\.0\|.\*mlp\.2 \
|
||||||
--lora_r 8 \
|
--lora_r 8 \
|
||||||
--lora_alpha 32 \
|
--lora_alpha 32 \
|
||||||
--per_device_train_batch_size 2 \
|
--per_device_train_batch_size 1 \
|
||||||
--per_device_eval_batch_size 2 \
|
--per_device_eval_batch_size 1 \
|
||||||
--gradient_accumulation_steps 4 \
|
--gradient_accumulation_steps 4 \
|
||||||
--output_dir checkpoint/qwen2_alllinear/ \
|
--output_dir checkpoint/qwen2_alllinear/ \
|
||||||
--learning_rate 1e-4 \
|
--learning_rate 1e-4 \
|
||||||
|
@ -1 +1 @@
|
|||||||
Subproject commit 7961d291b338d568fa2160f7deac85baa21c49dc
|
Subproject commit 684f12be1c8f26c46b1eebad50ce21ce6e3378b3
|
@ -1,15 +1,6 @@
|
|||||||
# _________________________________________________________
|
# _________________________________________________________
|
||||||
|
|
||||||
|
|
||||||
from transformers.trainer import (
|
|
||||||
Trainer,
|
|
||||||
_is_peft_model,
|
|
||||||
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
|
|
||||||
tpu_spmd_dataloader,
|
|
||||||
logger,
|
|
||||||
has_length,
|
|
||||||
sys,
|
|
||||||
)
|
|
||||||
from transformers.trainer import *
|
from transformers.trainer import *
|
||||||
from transformers import (
|
from transformers import (
|
||||||
TrainingArguments,
|
TrainingArguments,
|
||||||
@ -32,12 +23,12 @@ class ContinualTrainer(Trainer):
|
|||||||
self.accelerator = accelerator
|
self.accelerator = accelerator
|
||||||
super().__init__(model, args, data_collator, train_dataset, eval_dataset)
|
super().__init__(model, args, data_collator, train_dataset, eval_dataset)
|
||||||
|
|
||||||
if regularization_args.ewc_enable:
|
# if regularization_args.ewc_enable:
|
||||||
self.ewc_lambda = regularization_args.ewc_lambda
|
# self.ewc_lambda = regularization_args.ewc_lambda
|
||||||
# fisher = t
|
# # fisher = t
|
||||||
|
|
||||||
if regularization_args.lwf_enable:
|
# if regularization_args.lwf_enable:
|
||||||
self.lwf_lambda = regularization_args.lwf_lambda
|
# self.lwf_lambda = regularization_args.lwf_lambda
|
||||||
|
|
||||||
def create_accelerator_and_postprocess(self):
|
def create_accelerator_and_postprocess(self):
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user