feat✨: 添加Gigaspeech数据集支持,更新训练脚本以使用新数据集并优化模型加载逻辑,添加Qwen2audio模型
This commit is contained in:
parent
e4c4a7b0a0
commit
52b8952bdc
@ -4,11 +4,14 @@ dependencies = [
|
|||||||
"datasets==3.2.0",
|
"datasets==3.2.0",
|
||||||
"deepspeed==0.16.2",
|
"deepspeed==0.16.2",
|
||||||
"evaluate==0.4.3",
|
"evaluate==0.4.3",
|
||||||
|
"librosa>=0.10.2.post1",
|
||||||
"markupsafe==2.1.5",
|
"markupsafe==2.1.5",
|
||||||
|
"numba>=0.60.0",
|
||||||
"peft==0.14.0",
|
"peft==0.14.0",
|
||||||
"pip==24.3.1",
|
"pip==24.3.1",
|
||||||
"requests==2.32.3",
|
"requests==2.32.3",
|
||||||
"setuptools>=70.0.0",
|
"setuptools>=70.0.0",
|
||||||
|
"soundfile>=0.13.0",
|
||||||
"torch==2.5.1+cu124",
|
"torch==2.5.1+cu124",
|
||||||
"torchaudio==2.5.1+cu124",
|
"torchaudio==2.5.1+cu124",
|
||||||
"torchvision==0.20.1+cu124",
|
"torchvision==0.20.1+cu124",
|
||||||
|
97
src/dataset_library/GigaspeechDataset.py
Normal file
97
src/dataset_library/GigaspeechDataset.py
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
from PIL import Image
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
|
||||||
|
class GigaspeechDataset(Dataset):
|
||||||
|
def __init__(self, audio_processor=None, text_processor=None, split="train"):
|
||||||
|
"""
|
||||||
|
vis_root (string): Root directory of images (e.g. coco/images/)
|
||||||
|
ann_root (string): directory to store the annotation file
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.audio_processor = audio_processor
|
||||||
|
self.text_processor = text_processor
|
||||||
|
gs = load_dataset("speechcolab/gigaspeech", "xs")
|
||||||
|
self.data = gs[split]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
sample = self.data[index]
|
||||||
|
audio = sample["audio"]["array"]
|
||||||
|
sampling_rate = sample["audio"]["sampling_rate"]
|
||||||
|
text = sample["text"]
|
||||||
|
|
||||||
|
if self.audio_processor is not None:
|
||||||
|
audio = self.audio_processor(audio)
|
||||||
|
if self.text_processor is not None:
|
||||||
|
text = self.text_processor(text)
|
||||||
|
|
||||||
|
chat = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "audio", "audio_url": ""},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Please convert the audio to text",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{"role": "assistant", "content": [{"type": "text", "text": text}]},
|
||||||
|
]
|
||||||
|
return {
|
||||||
|
"audio": (audio, sampling_rate),
|
||||||
|
"chat": chat,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class GigaspeechDatasetForGeneration(GigaspeechDataset):
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
sample = self.data[index]
|
||||||
|
audio = sample["audio"]["array"]
|
||||||
|
sampling_rate = sample["audio"]["sampling_rate"]
|
||||||
|
text = sample["text"]
|
||||||
|
|
||||||
|
if self.audio_processor is not None:
|
||||||
|
audio = self.audio_processor(audio)
|
||||||
|
if self.text_processor is not None:
|
||||||
|
text = self.text_processor(text)
|
||||||
|
|
||||||
|
chat = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "audio", "audio_url": ""},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Please convert the audio to text",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"audio": (audio, sampling_rate),
|
||||||
|
"chat": chat,
|
||||||
|
"answer": text,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
dataset = GigaspeechDataset(
|
||||||
|
split="train",
|
||||||
|
)
|
||||||
|
print(len(dataset))
|
||||||
|
print(dataset[0])
|
||||||
|
dataset = GigaspeechDatasetForGeneration(
|
||||||
|
split="train",
|
||||||
|
)
|
||||||
|
print(len(dataset))
|
||||||
|
print(dataset[0])
|
||||||
|
pass
|
@ -48,4 +48,13 @@ def get_dataset(
|
|||||||
split="test",
|
split="test",
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if dataset_name == "gigaspeech":
|
||||||
|
from .GigaspeechDataset import GigaspeechDataset, GigaspeechDatasetForGeneration
|
||||||
|
|
||||||
|
dataset = {
|
||||||
|
"train": GigaspeechDataset(split="train"),
|
||||||
|
"test": GigaspeechDataset(split="test"),
|
||||||
|
"generation": GigaspeechDatasetForGeneration(split="test"),
|
||||||
|
}
|
||||||
return dataset
|
return dataset
|
||||||
|
105
src/model_library/factory.py
Normal file
105
src/model_library/factory.py
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
import torch
|
||||||
|
from trl import (
|
||||||
|
get_kbit_device_map,
|
||||||
|
# get_peft_config,
|
||||||
|
get_quantization_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_model(model_args):
|
||||||
|
if model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct":
|
||||||
|
torch_dtype = (
|
||||||
|
model_args.torch_dtype
|
||||||
|
if model_args.torch_dtype in ["auto", None]
|
||||||
|
else getattr(torch, model_args.torch_dtype)
|
||||||
|
)
|
||||||
|
quantization_config = get_quantization_config(model_args)
|
||||||
|
model_kwargs = dict(
|
||||||
|
revision=model_args.model_revision,
|
||||||
|
attn_implementation=model_args.attn_implementation,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
device_map=(
|
||||||
|
get_kbit_device_map() if quantization_config is not None else None
|
||||||
|
),
|
||||||
|
quantization_config=quantization_config,
|
||||||
|
)
|
||||||
|
from transformers import Qwen2VLProcessor
|
||||||
|
from model_library.qwen2vl import Qwen2VLForConditionalGeneration_modified
|
||||||
|
|
||||||
|
model = Qwen2VLForConditionalGeneration_modified.from_pretrained(
|
||||||
|
model_args.model_name_or_path,
|
||||||
|
trust_remote_code=model_args.trust_remote_code,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
processor = Qwen2VLProcessor.from_pretrained(
|
||||||
|
model_args.model_name_or_path,
|
||||||
|
trust_remote_code=model_args.trust_remote_code,
|
||||||
|
padding_side="left",
|
||||||
|
)
|
||||||
|
print(model)
|
||||||
|
from model_library.qwen2vl import (
|
||||||
|
collate_fn_for_train,
|
||||||
|
collate_fn_for_evaluate,
|
||||||
|
)
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
||||||
|
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
|
||||||
|
|
||||||
|
if model_args.model_name_or_path == "Qwen/Qwen2-Audio-7B-Instruct":
|
||||||
|
torch_dtype = (
|
||||||
|
model_args.torch_dtype
|
||||||
|
if model_args.torch_dtype in ["auto", None]
|
||||||
|
else getattr(torch, model_args.torch_dtype)
|
||||||
|
)
|
||||||
|
quantization_config = get_quantization_config(model_args)
|
||||||
|
model_kwargs = dict(
|
||||||
|
revision=model_args.model_revision,
|
||||||
|
attn_implementation=model_args.attn_implementation,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
device_map=(
|
||||||
|
get_kbit_device_map() if quantization_config is not None else None
|
||||||
|
),
|
||||||
|
quantization_config=quantization_config,
|
||||||
|
)
|
||||||
|
from transformers import Qwen2AudioProcessor, Qwen2AudioForConditionalGeneration
|
||||||
|
|
||||||
|
model = Qwen2AudioForConditionalGeneration.from_pretrained(
|
||||||
|
model_args.model_name_or_path,
|
||||||
|
trust_remote_code=model_args.trust_remote_code,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
processor = Qwen2AudioProcessor.from_pretrained(
|
||||||
|
model_args.model_name_or_path,
|
||||||
|
trust_remote_code=model_args.trust_remote_code,
|
||||||
|
padding_side="left",
|
||||||
|
)
|
||||||
|
print(model)
|
||||||
|
from model_library.qwen2audio import (
|
||||||
|
collate_fn_for_train,
|
||||||
|
collate_fn_for_evaluate,
|
||||||
|
)
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
||||||
|
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
|
||||||
|
|
||||||
|
if model_args.model_name_or_path == "VITA-MLLM/VITA-1.5":
|
||||||
|
# from transformers import
|
||||||
|
# from model_library.qwen2vl import Qwen2VLForConditionalGeneration_modified
|
||||||
|
|
||||||
|
model = Qwen2VLForConditionalGeneration_modified.from_pretrained(
|
||||||
|
model_args.model_name_or_path,
|
||||||
|
trust_remote_code=model_args.trust_remote_code,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
print(model)
|
||||||
|
from model_library.qwen2vl import (
|
||||||
|
collate_fn_for_train,
|
||||||
|
collate_fn_for_evaluate,
|
||||||
|
)
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
||||||
|
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
|
||||||
|
return model, processor, collate_fn_for_train, collate_fn_for_evaluate
|
4
src/model_library/qwen2audio/__init__.py
Normal file
4
src/model_library/qwen2audio/__init__.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
from .collate_fn import collate_fn_for_evaluate, collate_fn_for_train
|
||||||
|
from .model import Qwen2VLForConditionalGeneration_modified
|
||||||
|
|
||||||
|
__all__ = ["collate_fn_for_train", "collate_fn_for_evaluate", "Qwen2VLForConditionalGeneration_modified"]
|
87
src/model_library/qwen2audio/collate_fn.py
Normal file
87
src/model_library/qwen2audio/collate_fn.py
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
from transformers import Qwen2AudioProcessor
|
||||||
|
|
||||||
|
|
||||||
|
def collate_fn_for_train(examples, processor: Qwen2AudioProcessor):
|
||||||
|
# Get the texts and images, and apply the chat template
|
||||||
|
texts = [
|
||||||
|
processor.apply_chat_template(example["chat"], tokenize=False)
|
||||||
|
for example in examples
|
||||||
|
]
|
||||||
|
audios = [example["audio"][0] for example in examples]
|
||||||
|
# Tokenize the texts and process the images
|
||||||
|
batch = processor(
|
||||||
|
text=texts,
|
||||||
|
audios=audios,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
sampling_rate=examples[0]["audio"][1],
|
||||||
|
)
|
||||||
|
|
||||||
|
# The labels are the input_ids, and we mask the padding tokens in the loss computation
|
||||||
|
labels = batch["input_ids"].clone()
|
||||||
|
labels[labels == processor.tokenizer.pad_token_id] = -100 #
|
||||||
|
# Ignore the image token index in the loss computation (model specific)
|
||||||
|
# 对<|im_start|>system *** <|im_end|>\n加掩码
|
||||||
|
im_start_token_id = processor.tokenizer.convert_tokens_to_ids("<|im_start|>")
|
||||||
|
im_end_token_id = processor.tokenizer.convert_tokens_to_ids("<|im_end|>")
|
||||||
|
system_token_id = processor.tokenizer.convert_tokens_to_ids("system")
|
||||||
|
user_token_id = processor.tokenizer.convert_tokens_to_ids("user")
|
||||||
|
assistant_token_id = processor.tokenizer.convert_tokens_to_ids("assistant")
|
||||||
|
# enter_token_id = processor.tokenizer.convert_tokens_to_ids("\n")
|
||||||
|
# print(im_start_token_id, im_end_token_id, system_token_id, user_token_id, assistant_token_id, enter_token_id, processor.tokenizer.pad_token_id)
|
||||||
|
# 151644 151645 8948 872 77091 None 151643
|
||||||
|
|
||||||
|
for i, label in enumerate(labels):
|
||||||
|
now_index = 0
|
||||||
|
while now_index < len(label):
|
||||||
|
if label[now_index] == im_start_token_id:
|
||||||
|
label[now_index] = -100
|
||||||
|
now_index += 1
|
||||||
|
if (
|
||||||
|
label[now_index] == system_token_id
|
||||||
|
or label[now_index] == user_token_id
|
||||||
|
):
|
||||||
|
while label[now_index] != im_end_token_id:
|
||||||
|
label[now_index] = -100
|
||||||
|
now_index += 1
|
||||||
|
label[now_index] = -100
|
||||||
|
elif label[now_index] == assistant_token_id:
|
||||||
|
label[now_index] = -100
|
||||||
|
label[now_index + 1] = -100
|
||||||
|
now_index += 2
|
||||||
|
while (
|
||||||
|
now_index < len(label) and label[now_index] != im_end_token_id
|
||||||
|
):
|
||||||
|
now_index += 1
|
||||||
|
now_index += 1
|
||||||
|
batch["labels"] = labels
|
||||||
|
# batch["task_id"] = torch.tensor([0] * len(labels), dtype=torch.long)
|
||||||
|
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
def collate_fn_for_evaluate(examples, processor: Qwen2AudioProcessor):
|
||||||
|
# Get the texts and images, and apply the chat template
|
||||||
|
texts = [
|
||||||
|
processor.apply_chat_template(
|
||||||
|
example["chat"], tokenize=False, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
for example in examples
|
||||||
|
]
|
||||||
|
# print(texts)
|
||||||
|
audios = [example["audio"] for example in examples]
|
||||||
|
|
||||||
|
# Tokenize the texts and process the images
|
||||||
|
batch = processor(text=texts, audios=audios, return_tensors="pt", padding=True)
|
||||||
|
|
||||||
|
answers = [example["answer"] for example in examples]
|
||||||
|
answers = processor(text=answers, return_tensors="pt", padding=True)
|
||||||
|
batch["answers_ids"] = answers["input_ids"]
|
||||||
|
batch["answers_mask"] = answers["attention_mask"]
|
||||||
|
# input_ids torch.Size([3, 370])
|
||||||
|
# attention_mask torch.Size([3, 370])
|
||||||
|
# pixel_values torch.Size([3888, 1176])
|
||||||
|
# image_grid_thw torch.Size([3, 3])
|
||||||
|
# answers_ids torch.Size([3, 10])
|
||||||
|
# answers_mask torch.Size([3, 10])
|
||||||
|
return batch
|
856
src/model_library/qwen2audio/model.py
Normal file
856
src/model_library/qwen2audio/model.py
Normal file
@ -0,0 +1,856 @@
|
|||||||
|
# from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast, Qwen2VLModel, Qwen2VLForConditionalGeneration, logger, DynamicCache, Qwen2VLDecoderLayer, Qwen2VLConfig, Qwen2VLAttention,
|
||||||
|
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
|
||||||
|
is_flash_attn_2_available,
|
||||||
|
Qwen2VLAttention,
|
||||||
|
Qwen2VLFlashAttention2,
|
||||||
|
Qwen2VLDecoderLayer,
|
||||||
|
Qwen2VLModel,
|
||||||
|
Qwen2VLForConditionalGeneration,
|
||||||
|
Qwen2VLConfig,
|
||||||
|
logger,
|
||||||
|
apply_rotary_pos_emb_vision,
|
||||||
|
Cache,
|
||||||
|
apply_multimodal_rotary_pos_emb,
|
||||||
|
repeat_kv,
|
||||||
|
is_flash_attn_greater_or_equal_2_10,
|
||||||
|
math,
|
||||||
|
Qwen2VLCausalLMOutputWithPast,
|
||||||
|
Qwen2VisionTransformerPretrainedModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_flash_attn_2_available():
|
||||||
|
from flash_attn import flash_attn_varlen_func
|
||||||
|
|
||||||
|
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
||||||
|
else:
|
||||||
|
flash_attn_varlen_func = None
|
||||||
|
from transformers.cache_utils import DynamicCache
|
||||||
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
|
import torch
|
||||||
|
from typing import Optional, List, Union, Tuple
|
||||||
|
from torch.nn import CrossEntropyLoss
|
||||||
|
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class LinearLayer(nn.Linear):
|
||||||
|
def forward(self, input: Tensor, **kwargs) -> Tensor:
|
||||||
|
return F.linear(input, self.weight, self.bias)
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2VLAttention_modified(Qwen2VLAttention):
|
||||||
|
def __init__(self, config: Qwen2VLConfig, layer_idx: Optional[int] = None):
|
||||||
|
super().__init__(config, layer_idx)
|
||||||
|
self.q_proj = LinearLayer(
|
||||||
|
self.hidden_size, self.num_heads * self.head_dim, bias=True
|
||||||
|
)
|
||||||
|
self.k_proj = LinearLayer(
|
||||||
|
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True
|
||||||
|
)
|
||||||
|
self.v_proj = LinearLayer(
|
||||||
|
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True
|
||||||
|
)
|
||||||
|
self.o_proj = LinearLayer(
|
||||||
|
self.num_heads * self.head_dim, self.hidden_size, bias=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
position_embeddings: Optional[
|
||||||
|
Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
] = None, # will become mandatory in v4.46
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
query_states = query_states.view(
|
||||||
|
bsz, q_len, self.num_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
key_states = key_states.view(
|
||||||
|
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
value_states = value_states.view(
|
||||||
|
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
|
||||||
|
if position_embeddings is None:
|
||||||
|
logger.warning_once(
|
||||||
|
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||||
|
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||||
|
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||||
|
"removed and `position_embeddings` will be mandatory."
|
||||||
|
)
|
||||||
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||||
|
else:
|
||||||
|
cos, sin = position_embeddings
|
||||||
|
query_states, key_states = apply_multimodal_rotary_pos_emb(
|
||||||
|
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
|
||||||
|
)
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
cache_kwargs = {
|
||||||
|
"sin": sin,
|
||||||
|
"cos": cos,
|
||||||
|
"cache_position": cache_position,
|
||||||
|
} # Specific to RoPE models
|
||||||
|
key_states, value_states = past_key_value.update(
|
||||||
|
key_states, value_states, self.layer_idx, cache_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
attn_weights = torch.matmul(
|
||||||
|
query_states, key_states.transpose(2, 3)
|
||||||
|
) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
|
if attention_mask is not None: # no matter the length, we just slice it
|
||||||
|
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||||
|
attn_weights = attn_weights + causal_mask
|
||||||
|
|
||||||
|
# Fix precision issues in Qwen2-VL float16 inference
|
||||||
|
# Replace inf values with zeros in attention weights to prevent NaN propagation
|
||||||
|
if query_states.dtype == torch.float16:
|
||||||
|
attn_weights = torch.where(
|
||||||
|
torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights
|
||||||
|
)
|
||||||
|
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(
|
||||||
|
attn_weights, dim=-1, dtype=torch.float32
|
||||||
|
).to(query_states.dtype)
|
||||||
|
attn_weights = nn.functional.dropout(
|
||||||
|
attn_weights, p=self.attention_dropout, training=self.training
|
||||||
|
)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||||
|
f" {attn_output.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, -1)
|
||||||
|
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
if not output_attentions:
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2VLSdpaAttention_modified(Qwen2VLAttention_modified):
|
||||||
|
"""
|
||||||
|
Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
||||||
|
`Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
||||||
|
SDPA API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Adapted from Qwen2Attention.forward
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
position_embeddings: Optional[
|
||||||
|
Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
] = None, # will become mandatory in v4.46
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
if output_attentions:
|
||||||
|
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||||
|
logger.warning_once(
|
||||||
|
"Qwen2VLModel is using Qwen2VLSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
||||||
|
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||||
|
)
|
||||||
|
return super().forward(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states, **kwargs)
|
||||||
|
key_states = self.k_proj(hidden_states, **kwargs)
|
||||||
|
value_states = self.v_proj(hidden_states, **kwargs)
|
||||||
|
|
||||||
|
query_states = query_states.view(
|
||||||
|
bsz, q_len, self.num_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
key_states = key_states.view(
|
||||||
|
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
value_states = value_states.view(
|
||||||
|
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
|
||||||
|
if position_embeddings is None:
|
||||||
|
logger.warning_once(
|
||||||
|
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||||
|
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||||
|
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||||
|
"removed and `position_embeddings` will be mandatory."
|
||||||
|
)
|
||||||
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||||
|
else:
|
||||||
|
cos, sin = position_embeddings
|
||||||
|
query_states, key_states = apply_multimodal_rotary_pos_emb(
|
||||||
|
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
|
||||||
|
)
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
cache_kwargs = {
|
||||||
|
"sin": sin,
|
||||||
|
"cos": cos,
|
||||||
|
"cache_position": cache_position,
|
||||||
|
} # Specific to RoPE models
|
||||||
|
key_states, value_states = past_key_value.update(
|
||||||
|
key_states, value_states, self.layer_idx, cache_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
causal_mask = attention_mask
|
||||||
|
if attention_mask is not None: # no matter the length, we just slice it
|
||||||
|
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||||
|
|
||||||
|
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||||
|
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||||
|
if query_states.device.type == "cuda" and attention_mask is not None:
|
||||||
|
query_states = query_states.contiguous()
|
||||||
|
key_states = key_states.contiguous()
|
||||||
|
value_states = value_states.contiguous()
|
||||||
|
|
||||||
|
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||||
|
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||||
|
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
||||||
|
is_causal = True if causal_mask is None and q_len > 1 else False
|
||||||
|
|
||||||
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attn_mask=causal_mask,
|
||||||
|
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||||
|
is_causal=is_causal,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
||||||
|
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output, None, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2VLFlashAttention2_modified(Qwen2VLAttention_modified):
|
||||||
|
"""
|
||||||
|
Qwen2VL flash attention module, following Qwen2VL attention module. This module inherits from `Qwen2VLAttention`
|
||||||
|
as the weights of the module stays untouched. The only required change would be on the forward pass
|
||||||
|
where it needs to correctly call the public API of flash attention and deal with padding tokens
|
||||||
|
in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
|
||||||
|
config.max_window_layers layers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||||
|
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||||
|
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||||
|
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
position_embeddings: Optional[
|
||||||
|
Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
] = None, # will become mandatory in v4.46
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states, **kwargs)
|
||||||
|
key_states = self.k_proj(hidden_states, **kwargs)
|
||||||
|
value_states = self.v_proj(hidden_states, **kwargs)
|
||||||
|
|
||||||
|
query_states = query_states.view(
|
||||||
|
bsz, q_len, self.num_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
key_states = key_states.view(
|
||||||
|
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
value_states = value_states.view(
|
||||||
|
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
|
||||||
|
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
||||||
|
if position_embeddings is None:
|
||||||
|
logger.warning_once(
|
||||||
|
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||||
|
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||||
|
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||||
|
"removed and `position_embeddings` will be mandatory."
|
||||||
|
)
|
||||||
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||||
|
else:
|
||||||
|
cos, sin = position_embeddings
|
||||||
|
|
||||||
|
query_states, key_states = apply_multimodal_rotary_pos_emb(
|
||||||
|
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
|
||||||
|
)
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
cache_kwargs = {
|
||||||
|
"sin": sin,
|
||||||
|
"cos": cos,
|
||||||
|
"cache_position": cache_position,
|
||||||
|
} # Specific to RoPE models
|
||||||
|
key_states, value_states = past_key_value.update(
|
||||||
|
key_states, value_states, self.layer_idx, cache_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
dropout_rate = 0.0 if not self.training else self.attention_dropout
|
||||||
|
|
||||||
|
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||||
|
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||||
|
# cast them back in float16 just to be sure everything works as expected.
|
||||||
|
input_dtype = query_states.dtype
|
||||||
|
if input_dtype == torch.float32:
|
||||||
|
if torch.is_autocast_enabled():
|
||||||
|
target_dtype = torch.get_autocast_gpu_dtype()
|
||||||
|
# Handle the case where the model is quantized
|
||||||
|
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||||
|
target_dtype = self.config._pre_quantization_dtype
|
||||||
|
else:
|
||||||
|
target_dtype = self.q_proj.weight.dtype
|
||||||
|
|
||||||
|
logger.warning_once(
|
||||||
|
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||||
|
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||||
|
f" {target_dtype}."
|
||||||
|
)
|
||||||
|
|
||||||
|
query_states = query_states.to(target_dtype)
|
||||||
|
key_states = key_states.to(target_dtype)
|
||||||
|
value_states = value_states.to(target_dtype)
|
||||||
|
|
||||||
|
# Reashape to the expected shape for Flash Attention
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
key_states = key_states.transpose(1, 2)
|
||||||
|
value_states = value_states.transpose(1, 2)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.config.use_sliding_window
|
||||||
|
and getattr(self.config, "sliding_window", None) is not None
|
||||||
|
and self.layer_idx >= self.config.max_window_layers
|
||||||
|
):
|
||||||
|
sliding_window = self.config.sliding_window
|
||||||
|
else:
|
||||||
|
sliding_window = None
|
||||||
|
|
||||||
|
attn_output = _flash_attention_forward(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attention_mask,
|
||||||
|
q_len,
|
||||||
|
dropout=dropout_rate,
|
||||||
|
sliding_window=sliding_window,
|
||||||
|
is_causal=self.is_causal,
|
||||||
|
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
if not output_attentions:
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
QWEN2_VL_ATTENTION_CLASSES = {
|
||||||
|
"eager": Qwen2VLAttention,
|
||||||
|
"flash_attention_2": Qwen2VLFlashAttention2_modified,
|
||||||
|
"sdpa": Qwen2VLSdpaAttention_modified,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2VLDecoderLayer_modified(Qwen2VLDecoderLayer):
|
||||||
|
def __init__(self, config: Qwen2VLConfig, layer_idx: int):
|
||||||
|
super().__init__(config, layer_idx)
|
||||||
|
self.self_attn = QWEN2_VL_ATTENTION_CLASSES[config._attn_implementation](
|
||||||
|
config, layer_idx
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
use_cache: Optional[bool] = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
position_embeddings: Optional[
|
||||||
|
Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
] = None, # will become mandatory in v4.46
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[
|
||||||
|
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
|
||||||
|
]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||||
|
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
||||||
|
`(batch, sequence_length)` where padding elements are indicated by 0.
|
||||||
|
output_attentions (`bool`, *optional*):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||||
|
returned tensors for more detail.
|
||||||
|
use_cache (`bool`, *optional*):
|
||||||
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||||
|
(see `past_key_values`).
|
||||||
|
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||||
|
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||||
|
Indices depicting the position of the input sequence tokens in the sequence.
|
||||||
|
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
|
||||||
|
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
||||||
|
with `head_dim` being the embedding dimension of each attention head.
|
||||||
|
kwargs (`dict`, *optional*):
|
||||||
|
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
||||||
|
into the model
|
||||||
|
"""
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
outputs = (hidden_states,)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
outputs += (self_attn_weights,)
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
outputs += (present_key_value,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2VLModel_modified(Qwen2VLModel):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
Qwen2VLDecoderLayer_modified(config, layer_idx)
|
||||||
|
for layer_idx in range(config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions
|
||||||
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
|
||||||
|
return_dict = (
|
||||||
|
return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
|
raise ValueError(
|
||||||
|
"You must specify exactly one of input_ids or inputs_embeds"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
if use_cache:
|
||||||
|
logger.warning_once(
|
||||||
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
# torch.jit.trace() doesn't support cache objects in the output
|
||||||
|
if use_cache and past_key_values is None and not torch.jit.is_tracing():
|
||||||
|
past_key_values = DynamicCache()
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
if cache_position is None:
|
||||||
|
past_seen_tokens = (
|
||||||
|
past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
)
|
||||||
|
cache_position = torch.arange(
|
||||||
|
past_seen_tokens,
|
||||||
|
past_seen_tokens + inputs_embeds.shape[1],
|
||||||
|
device=inputs_embeds.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# the hard coded `3` is for temporal, height and width.
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = cache_position.view(1, 1, -1).expand(
|
||||||
|
3, inputs_embeds.shape[0], -1
|
||||||
|
)
|
||||||
|
elif position_ids.dim() == 2:
|
||||||
|
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
|
||||||
|
|
||||||
|
causal_mask = self._update_causal_mask(
|
||||||
|
attention_mask,
|
||||||
|
inputs_embeds,
|
||||||
|
cache_position,
|
||||||
|
past_key_values,
|
||||||
|
output_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
|
# create position embeddings to be shared across the decoder layers
|
||||||
|
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||||
|
|
||||||
|
# decoder layers
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attns = () if output_attentions else None
|
||||||
|
next_decoder_cache = None
|
||||||
|
|
||||||
|
for decoder_layer in self.layers:
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
|
decoder_layer.__call__,
|
||||||
|
hidden_states,
|
||||||
|
causal_mask,
|
||||||
|
position_ids,
|
||||||
|
past_key_values,
|
||||||
|
output_attentions,
|
||||||
|
use_cache,
|
||||||
|
cache_position,
|
||||||
|
position_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
layer_outputs = decoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=causal_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_values,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attns += (layer_outputs[1],)
|
||||||
|
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
# add hidden states from the last decoder layer
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(
|
||||||
|
v
|
||||||
|
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
|
||||||
|
if v is not None
|
||||||
|
)
|
||||||
|
return BaseModelOutputWithPast(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_cache,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attns,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2VLForConditionalGeneration_modified(Qwen2VLForConditionalGeneration):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.visual = Qwen2VisionTransformerPretrainedModel._from_config(
|
||||||
|
config.vision_config
|
||||||
|
)
|
||||||
|
self.model = Qwen2VLModel_modified(config)
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||||
|
self.rope_deltas = None # cache rope_deltas here
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
pixel_values: Optional[torch.Tensor] = None,
|
||||||
|
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
||||||
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
rope_deltas: Optional[torch.LongTensor] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from PIL import Image
|
||||||
|
>>> import requests
|
||||||
|
>>> from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
|
||||||
|
|
||||||
|
>>> model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
||||||
|
>>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
||||||
|
|
||||||
|
>>> messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image"},
|
||||||
|
{"type": "text", "text": "What is shown in this image?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
||||||
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
|
||||||
|
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||||
|
>>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||||
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
|
||||||
|
```"""
|
||||||
|
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions
|
||||||
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = (
|
||||||
|
return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.model.embed_tokens(input_ids)
|
||||||
|
if pixel_values is not None:
|
||||||
|
pixel_values = pixel_values.type(self.visual.get_dtype())
|
||||||
|
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
||||||
|
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
|
||||||
|
n_image_features = image_embeds.shape[0]
|
||||||
|
if n_image_tokens != n_image_features:
|
||||||
|
raise ValueError(
|
||||||
|
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||||
|
)
|
||||||
|
image_mask = (
|
||||||
|
(input_ids == self.config.image_token_id)
|
||||||
|
.unsqueeze(-1)
|
||||||
|
.expand_as(inputs_embeds)
|
||||||
|
.to(inputs_embeds.device)
|
||||||
|
)
|
||||||
|
image_embeds = image_embeds.to(
|
||||||
|
inputs_embeds.device, inputs_embeds.dtype
|
||||||
|
)
|
||||||
|
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
||||||
|
|
||||||
|
if pixel_values_videos is not None:
|
||||||
|
pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
|
||||||
|
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
||||||
|
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
||||||
|
n_video_features = video_embeds.shape[0]
|
||||||
|
if n_video_tokens != n_video_features:
|
||||||
|
raise ValueError(
|
||||||
|
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
||||||
|
)
|
||||||
|
video_mask = (
|
||||||
|
(input_ids == self.config.video_token_id)
|
||||||
|
.unsqueeze(-1)
|
||||||
|
.expand_as(inputs_embeds)
|
||||||
|
.to(inputs_embeds.device)
|
||||||
|
)
|
||||||
|
video_embeds = video_embeds.to(
|
||||||
|
inputs_embeds.device, inputs_embeds.dtype
|
||||||
|
)
|
||||||
|
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = attention_mask.to(inputs_embeds.device)
|
||||||
|
|
||||||
|
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
|
||||||
|
if (
|
||||||
|
position_ids is None
|
||||||
|
and input_ids is not None
|
||||||
|
and (attention_mask is None or attention_mask.ndim == 2)
|
||||||
|
):
|
||||||
|
# calculate RoPE index once per generation in the pre-fill stage only
|
||||||
|
if (
|
||||||
|
cache_position is not None and cache_position[0] == 0
|
||||||
|
) or self.rope_deltas is None:
|
||||||
|
position_ids, rope_deltas = self.get_rope_index(
|
||||||
|
input_ids, image_grid_thw, video_grid_thw, attention_mask
|
||||||
|
)
|
||||||
|
self.rope_deltas = rope_deltas
|
||||||
|
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
||||||
|
else:
|
||||||
|
batch_size, seq_length, _ = inputs_embeds.shape
|
||||||
|
delta = (
|
||||||
|
cache_position[0] + self.rope_deltas
|
||||||
|
if cache_position is not None
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
|
||||||
|
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
||||||
|
if cache_position is not None: # otherwise `deltas` is an int `0`
|
||||||
|
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
|
||||||
|
position_ids = position_ids.add(delta)
|
||||||
|
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
|
||||||
|
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=None,
|
||||||
|
position_ids=position_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||||
|
logits = logits.float()
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
# Flatten the tokens
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||||
|
shift_labels = shift_labels.view(-1)
|
||||||
|
# Enable model parallelism
|
||||||
|
shift_labels = shift_labels.to(shift_logits.device)
|
||||||
|
loss = loss_fct(shift_logits, shift_labels)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return Qwen2VLCausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
rope_deltas=self.rope_deltas,
|
||||||
|
)
|
67
src/train.py
67
src/train.py
@ -1,16 +1,10 @@
|
|||||||
import torch
|
|
||||||
from dataset_library.factory import get_dataset
|
from dataset_library.factory import get_dataset
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoModelForVision2Seq,
|
|
||||||
AutoProcessor,
|
|
||||||
TrainingArguments,
|
TrainingArguments,
|
||||||
)
|
)
|
||||||
|
|
||||||
from trl import (
|
from trl import (
|
||||||
TrlParser,
|
TrlParser,
|
||||||
get_kbit_device_map,
|
|
||||||
# get_peft_config,
|
|
||||||
get_quantization_config,
|
|
||||||
)
|
)
|
||||||
from peft_library import get_peft_model
|
from peft_library import get_peft_model
|
||||||
|
|
||||||
@ -32,64 +26,12 @@ if __name__ == "__main__":
|
|||||||
training_args.remove_unused_columns = False
|
training_args.remove_unused_columns = False
|
||||||
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
|
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
|
||||||
|
|
||||||
# peft_config = get_peft_config(dict(**vars(model_args)))
|
from model_library.factory import get_model
|
||||||
|
|
||||||
torch_dtype = (
|
model, processor, collate_fn_for_train, collate_fn_for_evaluate = get_model(
|
||||||
model_args.torch_dtype
|
model_args
|
||||||
if model_args.torch_dtype in ["auto", None]
|
|
||||||
else getattr(torch, model_args.torch_dtype)
|
|
||||||
)
|
|
||||||
quantization_config = get_quantization_config(model_args)
|
|
||||||
model_kwargs = dict(
|
|
||||||
revision=model_args.model_revision,
|
|
||||||
attn_implementation=model_args.attn_implementation,
|
|
||||||
torch_dtype=torch_dtype,
|
|
||||||
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
|
||||||
quantization_config=quantization_config,
|
|
||||||
)
|
|
||||||
processor = AutoProcessor.from_pretrained(
|
|
||||||
model_args.model_name_or_path,
|
|
||||||
trust_remote_code=model_args.trust_remote_code,
|
|
||||||
padding_side="left",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct":
|
|
||||||
# from transformers import Qwen2VLForConditionalGeneration
|
|
||||||
from model_library.qwen2vl import Qwen2VLForConditionalGeneration_modified
|
|
||||||
|
|
||||||
model = Qwen2VLForConditionalGeneration_modified.from_pretrained(
|
|
||||||
model_args.model_name_or_path,
|
|
||||||
trust_remote_code=model_args.trust_remote_code,
|
|
||||||
**model_kwargs,
|
|
||||||
)
|
|
||||||
print(model)
|
|
||||||
from model_library.qwen2vl import (
|
|
||||||
collate_fn_for_train,
|
|
||||||
collate_fn_for_evaluate,
|
|
||||||
)
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
|
||||||
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
|
|
||||||
|
|
||||||
if model_args.model_name_or_path == "VITA-MLLM/VITA-1.5":
|
|
||||||
# from transformers import
|
|
||||||
# from model_library.qwen2vl import Qwen2VLForConditionalGeneration_modified
|
|
||||||
|
|
||||||
model = Qwen2VLForConditionalGeneration_modified.from_pretrained(
|
|
||||||
model_args.model_name_or_path,
|
|
||||||
trust_remote_code=model_args.trust_remote_code,
|
|
||||||
**model_kwargs,
|
|
||||||
)
|
|
||||||
print(model)
|
|
||||||
from model_library.qwen2vl import (
|
|
||||||
collate_fn_for_train,
|
|
||||||
collate_fn_for_evaluate,
|
|
||||||
)
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
|
||||||
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
|
|
||||||
################
|
################
|
||||||
# Dataset
|
# Dataset
|
||||||
################
|
################
|
||||||
@ -110,7 +52,7 @@ if __name__ == "__main__":
|
|||||||
elif model_args.peft_type == "LORA":
|
elif model_args.peft_type == "LORA":
|
||||||
from peft.tuners.lora import LoraConfig
|
from peft.tuners.lora import LoraConfig
|
||||||
|
|
||||||
peft_config = LoraConfig(target_modules=model_args.lora_target_modules)
|
peft_config = LoraConfig(target_modules=model_args.lora_target_modules, r=2)
|
||||||
|
|
||||||
model = get_peft_model(model, peft_config)
|
model = get_peft_model(model, peft_config)
|
||||||
|
|
||||||
@ -122,7 +64,6 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
for dataset_name in script_args.dataset_name:
|
for dataset_name in script_args.dataset_name:
|
||||||
dataset = get_dataset(dataset_name)
|
dataset = get_dataset(dataset_name)
|
||||||
print(dataset)
|
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
trainer = ContinualTrainer(
|
trainer = ContinualTrainer(
|
||||||
|
@ -1,14 +1,13 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml train.py \
|
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml train.py \
|
||||||
--dataset_name OCR-VQA-200K \
|
--dataset_name gigaspeech \
|
||||||
--use_peft \
|
--use_peft \
|
||||||
--peft_type MMOELORA \
|
--peft_type LORA \
|
||||||
--model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
|
--model_name_or_path Qwen/Qwen2-Audio-7B-Instruct \
|
||||||
--lora_target_modules q_proj v_proj \
|
--lora_target_modules q_proj v_proj \
|
||||||
--per_device_train_batch_size 1 \
|
--per_device_train_batch_size 1 \
|
||||||
--per_device_eval_batch_size 2 \
|
--per_device_eval_batch_size 2 \
|
||||||
--attn_implementation flash_attention_2 \
|
|
||||||
--gradient_accumulation_steps 16 \
|
--gradient_accumulation_steps 16 \
|
||||||
--output_dir checkpoint/sft-llava-1.5-7b-hf \
|
--output_dir checkpoint/sft-llava-1.5-7b-hf \
|
||||||
--bf16 \
|
--bf16 \
|
||||||
|
@ -52,7 +52,6 @@ class ContinualTrainer(Trainer):
|
|||||||
# TODO: this needs to be fixed and made cleaner later.
|
# TODO: this needs to be fixed and made cleaner later.
|
||||||
if self.args.past_index >= 0:
|
if self.args.past_index >= 0:
|
||||||
self._past = outputs[self.args.past_index]
|
self._past = outputs[self.args.past_index]
|
||||||
print(labels)
|
|
||||||
|
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
unwrapped_model = self.accelerator.unwrap_model(model)
|
unwrapped_model = self.accelerator.unwrap_model(model)
|
||||||
|
376
uv.lock
generated
376
uv.lock
generated
@ -124,6 +124,15 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/89/aa/ab0f7891a01eeb2d2e338ae8fecbe57fcebea1a24dbb64d45801bfab481d/attrs-24.3.0-py3-none-any.whl", hash = "sha256:ac96cd038792094f438ad1f6ff80837353805ac950cd2aa0e0625ef19850c308", size = 63397 },
|
{ url = "https://files.pythonhosted.org/packages/89/aa/ab0f7891a01eeb2d2e338ae8fecbe57fcebea1a24dbb64d45801bfab481d/attrs-24.3.0-py3-none-any.whl", hash = "sha256:ac96cd038792094f438ad1f6ff80837353805ac950cd2aa0e0625ef19850c308", size = 63397 },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "audioread"
|
||||||
|
version = "3.0.1"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/db/d2/87016ca9f083acadffb2d8da59bfa3253e4da7eeb9f71fb8e7708dc97ecd/audioread-3.0.1.tar.gz", hash = "sha256:ac5460a5498c48bdf2e8e767402583a4dcd13f4414d286f42ce4379e8b35066d", size = 116513 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/57/8d/30aa32745af16af0a9a650115fbe81bde7c610ed5c21b381fca0196f3a7f/audioread-3.0.1-py3-none-any.whl", hash = "sha256:4cdce70b8adc0da0a3c9e0d85fb10b3ace30fbdf8d1670fd443929b61d117c33", size = 23492 },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "certifi"
|
name = "certifi"
|
||||||
version = "2024.12.14"
|
version = "2024.12.14"
|
||||||
@ -133,6 +142,51 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/a5/32/8f6669fc4798494966bf446c8c4a162e0b5d893dff088afddf76414f70e1/certifi-2024.12.14-py3-none-any.whl", hash = "sha256:1275f7a45be9464efc1173084eaa30f866fe2e47d389406136d332ed4967ec56", size = 164927 },
|
{ url = "https://files.pythonhosted.org/packages/a5/32/8f6669fc4798494966bf446c8c4a162e0b5d893dff088afddf76414f70e1/certifi-2024.12.14-py3-none-any.whl", hash = "sha256:1275f7a45be9464efc1173084eaa30f866fe2e47d389406136d332ed4967ec56", size = 164927 },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "cffi"
|
||||||
|
version = "1.17.1"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "pycparser" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/fc/97/c783634659c2920c3fc70419e3af40972dbaf758daa229a7d6ea6135c90d/cffi-1.17.1.tar.gz", hash = "sha256:1c39c6016c32bc48dd54561950ebd6836e1670f2ae46128f67cf49e789c52824", size = 516621 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/6b/f4/927e3a8899e52a27fa57a48607ff7dc91a9ebe97399b357b85a0c7892e00/cffi-1.17.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a45e3c6913c5b87b3ff120dcdc03f6131fa0065027d0ed7ee6190736a74cd401", size = 182264 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/6c/f5/6c3a8efe5f503175aaddcbea6ad0d2c96dad6f5abb205750d1b3df44ef29/cffi-1.17.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:30c5e0cb5ae493c04c8b42916e52ca38079f1b235c2f8ae5f4527b963c401caf", size = 178651 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/94/dd/a3f0118e688d1b1a57553da23b16bdade96d2f9bcda4d32e7d2838047ff7/cffi-1.17.1-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f75c7ab1f9e4aca5414ed4d8e5c0e303a34f4421f8a0d47a4d019ceff0ab6af4", size = 445259 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/2e/ea/70ce63780f096e16ce8588efe039d3c4f91deb1dc01e9c73a287939c79a6/cffi-1.17.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a1ed2dd2972641495a3ec98445e09766f077aee98a1c896dcb4ad0d303628e41", size = 469200 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/1c/a0/a4fa9f4f781bda074c3ddd57a572b060fa0df7655d2a4247bbe277200146/cffi-1.17.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:46bf43160c1a35f7ec506d254e5c890f3c03648a4dbac12d624e4490a7046cd1", size = 477235 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/62/12/ce8710b5b8affbcdd5c6e367217c242524ad17a02fe5beec3ee339f69f85/cffi-1.17.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a24ed04c8ffd54b0729c07cee15a81d964e6fee0e3d4d342a27b020d22959dc6", size = 459721 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/ff/6b/d45873c5e0242196f042d555526f92aa9e0c32355a1be1ff8c27f077fd37/cffi-1.17.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:610faea79c43e44c71e1ec53a554553fa22321b65fae24889706c0a84d4ad86d", size = 467242 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/1a/52/d9a0e523a572fbccf2955f5abe883cfa8bcc570d7faeee06336fbd50c9fc/cffi-1.17.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:a9b15d491f3ad5d692e11f6b71f7857e7835eb677955c00cc0aefcd0669adaf6", size = 477999 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/44/74/f2a2460684a1a2d00ca799ad880d54652841a780c4c97b87754f660c7603/cffi-1.17.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:de2ea4b5833625383e464549fec1bc395c1bdeeb5f25c4a3a82b5a8c756ec22f", size = 454242 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/f8/4a/34599cac7dfcd888ff54e801afe06a19c17787dfd94495ab0c8d35fe99fb/cffi-1.17.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:fc48c783f9c87e60831201f2cce7f3b2e4846bf4d8728eabe54d60700b318a0b", size = 478604 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/34/33/e1b8a1ba29025adbdcda5fb3a36f94c03d771c1b7b12f726ff7fef2ebe36/cffi-1.17.1-cp311-cp311-win32.whl", hash = "sha256:85a950a4ac9c359340d5963966e3e0a94a676bd6245a4b55bc43949eee26a655", size = 171727 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/3d/97/50228be003bb2802627d28ec0627837ac0bf35c90cf769812056f235b2d1/cffi-1.17.1-cp311-cp311-win_amd64.whl", hash = "sha256:caaf0640ef5f5517f49bc275eca1406b0ffa6aa184892812030f04c2abf589a0", size = 181400 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/5a/84/e94227139ee5fb4d600a7a4927f322e1d4aea6fdc50bd3fca8493caba23f/cffi-1.17.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:805b4371bf7197c329fcb3ead37e710d1bca9da5d583f5073b799d5c5bd1eee4", size = 183178 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/da/ee/fb72c2b48656111c4ef27f0f91da355e130a923473bf5ee75c5643d00cca/cffi-1.17.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:733e99bc2df47476e3848417c5a4540522f234dfd4ef3ab7fafdf555b082ec0c", size = 178840 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/cc/b6/db007700f67d151abadf508cbfd6a1884f57eab90b1bb985c4c8c02b0f28/cffi-1.17.1-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1257bdabf294dceb59f5e70c64a3e2f462c30c7ad68092d01bbbfb1c16b1ba36", size = 454803 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/1a/df/f8d151540d8c200eb1c6fba8cd0dfd40904f1b0682ea705c36e6c2e97ab3/cffi-1.17.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da95af8214998d77a98cc14e3a3bd00aa191526343078b530ceb0bd710fb48a5", size = 478850 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/28/c0/b31116332a547fd2677ae5b78a2ef662dfc8023d67f41b2a83f7c2aa78b1/cffi-1.17.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d63afe322132c194cf832bfec0dc69a99fb9bb6bbd550f161a49e9e855cc78ff", size = 485729 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/91/2b/9a1ddfa5c7f13cab007a2c9cc295b70fbbda7cb10a286aa6810338e60ea1/cffi-1.17.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f79fc4fc25f1c8698ff97788206bb3c2598949bfe0fef03d299eb1b5356ada99", size = 471256 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/b2/d5/da47df7004cb17e4955df6a43d14b3b4ae77737dff8bf7f8f333196717bf/cffi-1.17.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b62ce867176a75d03a665bad002af8e6d54644fad99a3c70905c543130e39d93", size = 479424 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/0b/ac/2a28bcf513e93a219c8a4e8e125534f4f6db03e3179ba1c45e949b76212c/cffi-1.17.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:386c8bf53c502fff58903061338ce4f4950cbdcb23e2902d86c0f722b786bbe3", size = 484568 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/d4/38/ca8a4f639065f14ae0f1d9751e70447a261f1a30fa7547a828ae08142465/cffi-1.17.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4ceb10419a9adf4460ea14cfd6bc43d08701f0835e979bf821052f1805850fe8", size = 488736 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/86/c5/28b2d6f799ec0bdecf44dced2ec5ed43e0eb63097b0f58c293583b406582/cffi-1.17.1-cp312-cp312-win32.whl", hash = "sha256:a08d7e755f8ed21095a310a693525137cfe756ce62d066e53f502a83dc550f65", size = 172448 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/50/b9/db34c4755a7bd1cb2d1603ac3863f22bcecbd1ba29e5ee841a4bc510b294/cffi-1.17.1-cp312-cp312-win_amd64.whl", hash = "sha256:51392eae71afec0d0c8fb1a53b204dbb3bcabcb3c9b807eedf3e1e6ccf2de903", size = 181976 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/8d/f8/dd6c246b148639254dad4d6803eb6a54e8c85c6e11ec9df2cffa87571dbe/cffi-1.17.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f3a2b4222ce6b60e2e8b337bb9596923045681d71e5a082783484d845390938e", size = 182989 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/8b/f1/672d303ddf17c24fc83afd712316fda78dc6fce1cd53011b839483e1ecc8/cffi-1.17.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:0984a4925a435b1da406122d4d7968dd861c1385afe3b45ba82b750f229811e2", size = 178802 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/0e/2d/eab2e858a91fdff70533cab61dcff4a1f55ec60425832ddfdc9cd36bc8af/cffi-1.17.1-cp313-cp313-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d01b12eeeb4427d3110de311e1774046ad344f5b1a7403101878976ecd7a10f3", size = 454792 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/75/b2/fbaec7c4455c604e29388d55599b99ebcc250a60050610fadde58932b7ee/cffi-1.17.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:706510fe141c86a69c8ddc029c7910003a17353970cff3b904ff0686a5927683", size = 478893 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/4f/b7/6e4a2162178bf1935c336d4da8a9352cccab4d3a5d7914065490f08c0690/cffi-1.17.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:de55b766c7aa2e2a3092c51e0483d700341182f08e67c63630d5b6f200bb28e5", size = 485810 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/c7/8a/1d0e4a9c26e54746dc08c2c6c037889124d4f59dffd853a659fa545f1b40/cffi-1.17.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c59d6e989d07460165cc5ad3c61f9fd8f1b4796eacbd81cee78957842b834af4", size = 471200 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/26/9f/1aab65a6c0db35f43c4d1b4f580e8df53914310afc10ae0397d29d697af4/cffi-1.17.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd398dbc6773384a17fe0d3e7eeb8d1a21c2200473ee6806bb5e6a8e62bb73dd", size = 479447 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/5f/e4/fb8b3dd8dc0e98edf1135ff067ae070bb32ef9d509d6cb0f538cd6f7483f/cffi-1.17.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:3edc8d958eb099c634dace3c7e16560ae474aa3803a5df240542b305d14e14ed", size = 484358 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/f1/47/d7145bf2dc04684935d57d67dff9d6d795b2ba2796806bb109864be3a151/cffi-1.17.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:72e72408cad3d5419375fc87d289076ee319835bdfa2caad331e377589aebba9", size = 488469 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/bf/ee/f94057fa6426481d663b88637a9a10e859e492c73d0384514a17d78ee205/cffi-1.17.1-cp313-cp313-win32.whl", hash = "sha256:e03eab0a8677fa80d646b5ddece1cbeaf556c313dcfac435ba11f107ba117b5d", size = 172475 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/7c/fc/6a8cb64e5f0324877d503c854da15d76c1e50eb722e320b15345c4d0c6de/cffi-1.17.1-cp313-cp313-win_amd64.whl", hash = "sha256:f6a16c31041f09ead72d69f583767292f750d24913dadacf5756b966aacb3f1a", size = 182009 },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "charset-normalizer"
|
name = "charset-normalizer"
|
||||||
version = "3.4.1"
|
version = "3.4.1"
|
||||||
@ -190,11 +244,14 @@ dependencies = [
|
|||||||
{ name = "datasets" },
|
{ name = "datasets" },
|
||||||
{ name = "deepspeed" },
|
{ name = "deepspeed" },
|
||||||
{ name = "evaluate" },
|
{ name = "evaluate" },
|
||||||
|
{ name = "librosa" },
|
||||||
{ name = "markupsafe" },
|
{ name = "markupsafe" },
|
||||||
|
{ name = "numba" },
|
||||||
{ name = "peft" },
|
{ name = "peft" },
|
||||||
{ name = "pip" },
|
{ name = "pip" },
|
||||||
{ name = "requests" },
|
{ name = "requests" },
|
||||||
{ name = "setuptools" },
|
{ name = "setuptools" },
|
||||||
|
{ name = "soundfile" },
|
||||||
{ name = "torch" },
|
{ name = "torch" },
|
||||||
{ name = "torchaudio" },
|
{ name = "torchaudio" },
|
||||||
{ name = "torchvision" },
|
{ name = "torchvision" },
|
||||||
@ -215,11 +272,14 @@ requires-dist = [
|
|||||||
{ name = "deepspeed", specifier = "==0.16.2" },
|
{ name = "deepspeed", specifier = "==0.16.2" },
|
||||||
{ name = "evaluate", specifier = "==0.4.3" },
|
{ name = "evaluate", specifier = "==0.4.3" },
|
||||||
{ name = "flash-attn", marker = "extra == 'compile'", specifier = ">=2.7.2.post1" },
|
{ name = "flash-attn", marker = "extra == 'compile'", specifier = ">=2.7.2.post1" },
|
||||||
|
{ name = "librosa", specifier = ">=0.10.2.post1" },
|
||||||
{ name = "markupsafe", specifier = "==2.1.5", index = "https://download.pytorch.org/whl/cu124" },
|
{ name = "markupsafe", specifier = "==2.1.5", index = "https://download.pytorch.org/whl/cu124" },
|
||||||
|
{ name = "numba", specifier = ">=0.60.0" },
|
||||||
{ name = "peft", specifier = "==0.14.0" },
|
{ name = "peft", specifier = "==0.14.0" },
|
||||||
{ name = "pip", specifier = "==24.3.1" },
|
{ name = "pip", specifier = "==24.3.1" },
|
||||||
{ name = "requests", specifier = "==2.32.3", index = "https://pypi.org/simple" },
|
{ name = "requests", specifier = "==2.32.3", index = "https://pypi.org/simple" },
|
||||||
{ name = "setuptools", specifier = ">=70.0.0" },
|
{ name = "setuptools", specifier = ">=70.0.0" },
|
||||||
|
{ name = "soundfile", specifier = ">=0.13.0" },
|
||||||
{ name = "torch", specifier = "==2.5.1+cu124", index = "https://download.pytorch.org/whl/cu124" },
|
{ name = "torch", specifier = "==2.5.1+cu124", index = "https://download.pytorch.org/whl/cu124" },
|
||||||
{ name = "torchaudio", specifier = "==2.5.1+cu124", index = "https://download.pytorch.org/whl/cu124" },
|
{ name = "torchaudio", specifier = "==2.5.1+cu124", index = "https://download.pytorch.org/whl/cu124" },
|
||||||
{ name = "torchvision", specifier = "==0.20.1+cu124", index = "https://download.pytorch.org/whl/cu124" },
|
{ name = "torchvision", specifier = "==0.20.1+cu124", index = "https://download.pytorch.org/whl/cu124" },
|
||||||
@ -262,6 +322,15 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/d7/84/0df6c5981f5fc722381662ff8cfbdf8aad64bec875f75d80b55bfef394ce/datasets-3.2.0-py3-none-any.whl", hash = "sha256:f3d2ba2698b7284a4518019658596a6a8bc79f31e51516524249d6c59cf0fe2a", size = 480647 },
|
{ url = "https://files.pythonhosted.org/packages/d7/84/0df6c5981f5fc722381662ff8cfbdf8aad64bec875f75d80b55bfef394ce/datasets-3.2.0-py3-none-any.whl", hash = "sha256:f3d2ba2698b7284a4518019658596a6a8bc79f31e51516524249d6c59cf0fe2a", size = 480647 },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "decorator"
|
||||||
|
version = "5.1.1"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/66/0c/8d907af351aa16b42caae42f9d6aa37b900c67308052d10fdce809f8d952/decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330", size = 35016 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/d5/50/83c593b07763e1161326b3b8c6686f0f4b0f24d5526546bee538c89837d6/decorator-5.1.1-py3-none-any.whl", hash = "sha256:b8c3f85900b9dc423225913c5aace94729fe1fa9763b38939a95226f02d37186", size = 9073 },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "deepspeed"
|
name = "deepspeed"
|
||||||
version = "0.16.2"
|
version = "0.16.2"
|
||||||
@ -456,6 +525,69 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/bd/0f/2ba5fbcd631e3e88689309dbe978c5769e883e4b84ebfe7da30b43275c5a/jinja2-3.1.5-py3-none-any.whl", hash = "sha256:aba0f4dc9ed8013c424088f68a5c226f7d6097ed89b246d7749c2ec4175c6adb", size = 134596 },
|
{ url = "https://files.pythonhosted.org/packages/bd/0f/2ba5fbcd631e3e88689309dbe978c5769e883e4b84ebfe7da30b43275c5a/jinja2-3.1.5-py3-none-any.whl", hash = "sha256:aba0f4dc9ed8013c424088f68a5c226f7d6097ed89b246d7749c2ec4175c6adb", size = 134596 },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "joblib"
|
||||||
|
version = "1.4.2"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/64/33/60135848598c076ce4b231e1b1895170f45fbcaeaa2c9d5e38b04db70c35/joblib-1.4.2.tar.gz", hash = "sha256:2382c5816b2636fbd20a09e0f4e9dad4736765fdfb7dca582943b9c1366b3f0e", size = 2116621 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/91/29/df4b9b42f2be0b623cbd5e2140cafcaa2bef0759a00b7b70104dcfe2fb51/joblib-1.4.2-py3-none-any.whl", hash = "sha256:06d478d5674cbc267e7496a410ee875abd68e4340feff4490bcb7afb88060ae6", size = 301817 },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "lazy-loader"
|
||||||
|
version = "0.4"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "packaging" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/6f/6b/c875b30a1ba490860c93da4cabf479e03f584eba06fe5963f6f6644653d8/lazy_loader-0.4.tar.gz", hash = "sha256:47c75182589b91a4e1a85a136c074285a5ad4d9f39c63e0d7fb76391c4574cd1", size = 15431 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/83/60/d497a310bde3f01cb805196ac61b7ad6dc5dcf8dce66634dc34364b20b4f/lazy_loader-0.4-py3-none-any.whl", hash = "sha256:342aa8e14d543a154047afb4ba8ef17f5563baad3fc610d7b15b213b0f119efc", size = 12097 },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "librosa"
|
||||||
|
version = "0.10.2.post1"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "audioread" },
|
||||||
|
{ name = "decorator" },
|
||||||
|
{ name = "joblib" },
|
||||||
|
{ name = "lazy-loader" },
|
||||||
|
{ name = "msgpack" },
|
||||||
|
{ name = "numba" },
|
||||||
|
{ name = "numpy" },
|
||||||
|
{ name = "pooch" },
|
||||||
|
{ name = "scikit-learn" },
|
||||||
|
{ name = "scipy" },
|
||||||
|
{ name = "soundfile" },
|
||||||
|
{ name = "soxr" },
|
||||||
|
{ name = "typing-extensions" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/0e/2d/77783a52641a21ff7e2230aa588e4fb4a61422a64673096a36776b7e5bd9/librosa-0.10.2.post1.tar.gz", hash = "sha256:cd99f16717cbcd1e0983e37308d1db46a6f7dfc2e396e5a9e61e6821e44bd2e7", size = 325533 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/8c/8a/2d231b35456506b7c98b3ab9bbf07917b205fed8615d2e59e976ab497fff/librosa-0.10.2.post1-py3-none-any.whl", hash = "sha256:dc882750e8b577a63039f25661b7e39ec4cfbacc99c1cffba666cd664fb0a7a0", size = 260089 },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "llvmlite"
|
||||||
|
version = "0.43.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/9f/3d/f513755f285db51ab363a53e898b85562e950f79a2e6767a364530c2f645/llvmlite-0.43.0.tar.gz", hash = "sha256:ae2b5b5c3ef67354824fb75517c8db5fbe93bc02cd9671f3c62271626bc041d5", size = 157069 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/95/8c/de3276d773ab6ce3ad676df5fab5aac19696b2956319d65d7dd88fb10f19/llvmlite-0.43.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3e8d0618cb9bfe40ac38a9633f2493d4d4e9fcc2f438d39a4e854f39cc0f5f98", size = 31064409 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/ee/e1/38deed89ced4cf378c61e232265cfe933ccde56ae83c901aa68b477d14b1/llvmlite-0.43.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e0a9a1a39d4bf3517f2af9d23d479b4175ead205c592ceeb8b89af48a327ea57", size = 28793149 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/2f/b2/4429433eb2dc8379e2cb582502dca074c23837f8fd009907f78a24de4c25/llvmlite-0.43.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c1da416ab53e4f7f3bc8d4eeba36d801cc1894b9fbfbf2022b29b6bad34a7df2", size = 42857277 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/6b/99/5d00a7d671b1ba1751fc9f19d3b36f3300774c6eebe2bcdb5f6191763eb4/llvmlite-0.43.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:977525a1e5f4059316b183fb4fd34fa858c9eade31f165427a3977c95e3ee749", size = 43871781 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/20/ab/ed5ed3688c6ba4f0b8d789da19fd8e30a9cf7fc5852effe311bc5aefe73e/llvmlite-0.43.0-cp311-cp311-win_amd64.whl", hash = "sha256:d5bd550001d26450bd90777736c69d68c487d17bf371438f975229b2b8241a91", size = 28107433 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/0b/67/9443509e5d2b6d8587bae3ede5598fa8bd586b1c7701696663ea8af15b5b/llvmlite-0.43.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f99b600aa7f65235a5a05d0b9a9f31150c390f31261f2a0ba678e26823ec38f7", size = 31064409 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/a2/9c/24139d3712d2d352e300c39c0e00d167472c08b3bd350c3c33d72c88ff8d/llvmlite-0.43.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:35d80d61d0cda2d767f72de99450766250560399edc309da16937b93d3b676e7", size = 28793145 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/bf/f1/4c205a48488e574ee9f6505d50e84370a978c90f08dab41a42d8f2c576b6/llvmlite-0.43.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eccce86bba940bae0d8d48ed925f21dbb813519169246e2ab292b5092aba121f", size = 42857276 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/00/5f/323c4d56e8401c50185fd0e875fcf06b71bf825a863699be1eb10aa2a9cb/llvmlite-0.43.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:df6509e1507ca0760787a199d19439cc887bfd82226f5af746d6977bd9f66844", size = 43871781 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/c6/94/dea10e263655ce78d777e78d904903faae39d1fc440762be4a9dc46bed49/llvmlite-0.43.0-cp312-cp312-win_amd64.whl", hash = "sha256:7a2872ee80dcf6b5dbdc838763d26554c2a18aa833d31a2635bff16aafefb9c9", size = 28107442 },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "markdown-it-py"
|
name = "markdown-it-py"
|
||||||
version = "3.0.0"
|
version = "3.0.0"
|
||||||
@ -650,51 +782,53 @@ wheels = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "numpy"
|
name = "numba"
|
||||||
version = "2.2.1"
|
version = "0.60.0"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { registry = "https://pypi.org/simple" }
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/f2/a5/fdbf6a7871703df6160b5cf3dd774074b086d278172285c52c2758b76305/numpy-2.2.1.tar.gz", hash = "sha256:45681fd7128c8ad1c379f0ca0776a8b0c6583d2f69889ddac01559dfe4390918", size = 20227662 }
|
dependencies = [
|
||||||
|
{ name = "llvmlite" },
|
||||||
|
{ name = "numpy" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/3c/93/2849300a9184775ba274aba6f82f303343669b0592b7bb0849ea713dabb0/numba-0.60.0.tar.gz", hash = "sha256:5df6158e5584eece5fc83294b949fd30b9f1125df7708862205217e068aabf16", size = 2702171 }
|
||||||
wheels = [
|
wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/59/14/645887347124e101d983e1daf95b48dc3e136bf8525cb4257bf9eab1b768/numpy-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:40f9e544c1c56ba8f1cf7686a8c9b5bb249e665d40d626a23899ba6d5d9e1484", size = 21217379 },
|
{ url = "https://files.pythonhosted.org/packages/98/ad/df18d492a8f00d29a30db307904b9b296e37507034eedb523876f3a2e13e/numba-0.60.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a17b70fc9e380ee29c42717e8cc0bfaa5556c416d94f9aa96ba13acb41bdece8", size = 2647254 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/9f/fd/2279000cf29f58ccfd3778cbf4670dfe3f7ce772df5e198c5abe9e88b7d7/numpy-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f9b57eaa3b0cd8db52049ed0330747b0364e899e8a606a624813452b8203d5f7", size = 14388520 },
|
{ url = "https://files.pythonhosted.org/packages/9a/51/a4dc2c01ce7a850b8e56ff6d5381d047a5daea83d12bad08aa071d34b2ee/numba-0.60.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3fb02b344a2a80efa6f677aa5c40cd5dd452e1b35f8d1c2af0dfd9ada9978e4b", size = 2649970 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/58/b0/034eb5d5ba12d66ab658ff3455a31f20add0b78df8203c6a7451bd1bee21/numpy-2.2.1-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:bc8a37ad5b22c08e2dbd27df2b3ef7e5c0864235805b1e718a235bcb200cf1cb", size = 5389286 },
|
{ url = "https://files.pythonhosted.org/packages/f9/4c/8889ac94c0b33dca80bed11564b8c6d9ea14d7f094e674c58e5c5b05859b/numba-0.60.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5f4fde652ea604ea3c86508a3fb31556a6157b2c76c8b51b1d45eb40c8598703", size = 3412492 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/5d/69/6f3cccde92e82e7835fdb475c2bf439761cbf8a1daa7c07338e1e132dfec/numpy-2.2.1-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:9036d6365d13b6cbe8f27a0eaf73ddcc070cae584e5ff94bb45e3e9d729feab5", size = 6930345 },
|
{ url = "https://files.pythonhosted.org/packages/57/03/2b4245b05b71c0cee667e6a0b51606dfa7f4157c9093d71c6b208385a611/numba-0.60.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4142d7ac0210cc86432b818338a2bc368dc773a2f5cf1e32ff7c5b378bd63ee8", size = 3705018 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/d1/72/1cd38e91ab563e67f584293fcc6aca855c9ae46dba42e6b5ff4600022899/numpy-2.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:51faf345324db860b515d3f364eaa93d0e0551a88d6218a7d61286554d190d73", size = 14335748 },
|
{ url = "https://files.pythonhosted.org/packages/79/89/2d924ca60dbf949f18a6fec223a2445f5f428d9a5f97a6b29c2122319015/numba-0.60.0-cp311-cp311-win_amd64.whl", hash = "sha256:cac02c041e9b5bc8cf8f2034ff6f0dbafccd1ae9590dc146b3a02a45e53af4e2", size = 2686920 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/f2/d4/f999444e86986f3533e7151c272bd8186c55dda554284def18557e013a2a/numpy-2.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:38efc1e56b73cc9b182fe55e56e63b044dd26a72128fd2fbd502f75555d92591", size = 16391057 },
|
{ url = "https://files.pythonhosted.org/packages/eb/5c/b5ec752c475e78a6c3676b67c514220dbde2725896bbb0b6ec6ea54b2738/numba-0.60.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d7da4098db31182fc5ffe4bc42c6f24cd7d1cb8a14b59fd755bfee32e34b8404", size = 2647866 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/99/7b/85cef6a3ae1b19542b7afd97d0b296526b6ef9e3c43ea0c4d9c4404fb2d0/numpy-2.2.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:31b89fa67a8042e96715c68e071a1200c4e172f93b0fbe01a14c0ff3ff820fc8", size = 15556943 },
|
{ url = "https://files.pythonhosted.org/packages/65/42/39559664b2e7c15689a638c2a38b3b74c6e69a04e2b3019b9f7742479188/numba-0.60.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:38d6ea4c1f56417076ecf8fc327c831ae793282e0ff51080c5094cb726507b1c", size = 2650208 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/69/7e/b83cc884c3508e91af78760f6b17ab46ad649831b1fa35acb3eb26d9e6d2/numpy-2.2.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:4c86e2a209199ead7ee0af65e1d9992d1dce7e1f63c4b9a616500f93820658d0", size = 18180785 },
|
{ url = "https://files.pythonhosted.org/packages/67/88/c4459ccc05674ef02119abf2888ccd3e2fed12a323f52255f4982fc95876/numba-0.60.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:62908d29fb6a3229c242e981ca27e32a6e606cc253fc9e8faeb0e48760de241e", size = 3466946 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/b2/9f/eb4a9a38867de059dcd4b6e18d47c3867fbd3795d4c9557bb49278f94087/numpy-2.2.1-cp311-cp311-win32.whl", hash = "sha256:b34d87e8a3090ea626003f87f9392b3929a7bbf4104a05b6667348b6bd4bf1cd", size = 6568983 },
|
{ url = "https://files.pythonhosted.org/packages/8b/41/ac11cf33524def12aa5bd698226ae196a1185831c05ed29dc0c56eaa308b/numba-0.60.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:0ebaa91538e996f708f1ab30ef4d3ddc344b64b5227b67a57aa74f401bb68b9d", size = 3761463 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/6d/1e/be3b9f3073da2f8c7fa361fcdc231b548266b0781029fdbaf75eeab997fd/numpy-2.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:360137f8fb1b753c5cde3ac388597ad680eccbbbb3865ab65efea062c4a1fd16", size = 12917260 },
|
{ url = "https://files.pythonhosted.org/packages/ca/bd/0fe29fcd1b6a8de479a4ed25c6e56470e467e3611c079d55869ceef2b6d1/numba-0.60.0-cp312-cp312-win_amd64.whl", hash = "sha256:f75262e8fe7fa96db1dca93d53a194a38c46da28b112b8a4aca168f0df860347", size = 2707588 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/62/12/b928871c570d4a87ab13d2cc19f8817f17e340d5481621930e76b80ffb7d/numpy-2.2.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:694f9e921a0c8f252980e85bce61ebbd07ed2b7d4fa72d0e4246f2f8aa6642ab", size = 20909861 },
|
]
|
||||||
{ url = "https://files.pythonhosted.org/packages/3d/c3/59df91ae1d8ad7c5e03efd63fd785dec62d96b0fe56d1f9ab600b55009af/numpy-2.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3683a8d166f2692664262fd4900f207791d005fb088d7fdb973cc8d663626faa", size = 14095776 },
|
|
||||||
{ url = "https://files.pythonhosted.org/packages/af/4e/8ed5868efc8e601fb69419644a280e9c482b75691466b73bfaab7d86922c/numpy-2.2.1-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:780077d95eafc2ccc3ced969db22377b3864e5b9a0ea5eb347cc93b3ea900315", size = 5126239 },
|
[[package]]
|
||||||
{ url = "https://files.pythonhosted.org/packages/1a/74/dd0bbe650d7bc0014b051f092f2de65e34a8155aabb1287698919d124d7f/numpy-2.2.1-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:55ba24ebe208344aa7a00e4482f65742969a039c2acfcb910bc6fcd776eb4355", size = 6659296 },
|
name = "numpy"
|
||||||
{ url = "https://files.pythonhosted.org/packages/7f/11/4ebd7a3f4a655764dc98481f97bd0a662fb340d1001be6050606be13e162/numpy-2.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9b1d07b53b78bf84a96898c1bc139ad7f10fda7423f5fd158fd0f47ec5e01ac7", size = 14047121 },
|
version = "2.0.2"
|
||||||
{ url = "https://files.pythonhosted.org/packages/7f/a7/c1f1d978166eb6b98ad009503e4d93a8c1962d0eb14a885c352ee0276a54/numpy-2.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5062dc1a4e32a10dc2b8b13cedd58988261416e811c1dc4dbdea4f57eea61b0d", size = 16096599 },
|
source = { registry = "https://pypi.org/simple" }
|
||||||
{ url = "https://files.pythonhosted.org/packages/3d/6d/0e22afd5fcbb4d8d0091f3f46bf4e8906399c458d4293da23292c0ba5022/numpy-2.2.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:fce4f615f8ca31b2e61aa0eb5865a21e14f5629515c9151850aa936c02a1ee51", size = 15243932 },
|
sdist = { url = "https://files.pythonhosted.org/packages/a9/75/10dd1f8116a8b796cb2c737b674e02d02e80454bda953fa7e65d8c12b016/numpy-2.0.2.tar.gz", hash = "sha256:883c987dee1880e2a864ab0dc9892292582510604156762362d9326444636e78", size = 18902015 }
|
||||||
{ url = "https://files.pythonhosted.org/packages/03/39/e4e5832820131ba424092b9610d996b37e5557180f8e2d6aebb05c31ae54/numpy-2.2.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:67d4cda6fa6ffa073b08c8372aa5fa767ceb10c9a0587c707505a6d426f4e046", size = 17861032 },
|
wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/5f/8a/3794313acbf5e70df2d5c7d2aba8718676f8d054a05abe59e48417fb2981/numpy-2.2.1-cp312-cp312-win32.whl", hash = "sha256:32cb94448be47c500d2c7a95f93e2f21a01f1fd05dd2beea1ccd049bb6001cd2", size = 6274018 },
|
{ url = "https://files.pythonhosted.org/packages/8b/cf/034500fb83041aa0286e0fb16e7c76e5c8b67c0711bb6e9e9737a717d5fe/numpy-2.0.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:49ca4decb342d66018b01932139c0961a8f9ddc7589611158cb3c27cbcf76448", size = 21169137 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/17/c1/c31d3637f2641e25c7a19adf2ae822fdaf4ddd198b05d79a92a9ce7cb63e/numpy-2.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:ba5511d8f31c033a5fcbda22dd5c813630af98c70b2661f2d2c654ae3cdfcfc8", size = 12613843 },
|
{ url = "https://files.pythonhosted.org/packages/4a/d9/32de45561811a4b87fbdee23b5797394e3d1504b4a7cf40c10199848893e/numpy-2.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:11a76c372d1d37437857280aa142086476136a8c0f373b2e648ab2c8f18fb195", size = 13703552 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/20/d6/91a26e671c396e0c10e327b763485ee295f5a5a7a48c553f18417e5a0ed5/numpy-2.2.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f1d09e520217618e76396377c81fba6f290d5f926f50c35f3a5f72b01a0da780", size = 20896464 },
|
{ url = "https://files.pythonhosted.org/packages/c1/ca/2f384720020c7b244d22508cb7ab23d95f179fcfff33c31a6eeba8d6c512/numpy-2.0.2-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:807ec44583fd708a21d4a11d94aedf2f4f3c3719035c76a2bbe1fe8e217bdc57", size = 5298957 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/8c/40/5792ccccd91d45e87d9e00033abc4f6ca8a828467b193f711139ff1f1cd9/numpy-2.2.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:3ecc47cd7f6ea0336042be87d9e7da378e5c7e9b3c8ad0f7c966f714fc10d821", size = 14111350 },
|
{ url = "https://files.pythonhosted.org/packages/0e/78/a3e4f9fb6aa4e6fdca0c5428e8ba039408514388cf62d89651aade838269/numpy-2.0.2-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:8cafab480740e22f8d833acefed5cc87ce276f4ece12fdaa2e8903db2f82897a", size = 6905573 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/c0/2a/fb0a27f846cb857cef0c4c92bef89f133a3a1abb4e16bba1c4dace2e9b49/numpy-2.2.1-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:f419290bc8968a46c4933158c91a0012b7a99bb2e465d5ef5293879742f8797e", size = 5111629 },
|
{ url = "https://files.pythonhosted.org/packages/a0/72/cfc3a1beb2caf4efc9d0b38a15fe34025230da27e1c08cc2eb9bfb1c7231/numpy-2.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a15f476a45e6e5a3a79d8a14e62161d27ad897381fecfa4a09ed5322f2085669", size = 13914330 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/eb/e5/8e81bb9d84db88b047baf4e8b681a3e48d6390bc4d4e4453eca428ecbb49/numpy-2.2.1-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:5b6c390bfaef8c45a260554888966618328d30e72173697e5cabe6b285fb2348", size = 6645865 },
|
{ url = "https://files.pythonhosted.org/packages/ba/a8/c17acf65a931ce551fee11b72e8de63bf7e8a6f0e21add4c937c83563538/numpy-2.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:13e689d772146140a252c3a28501da66dfecd77490b498b168b501835041f951", size = 19534895 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/7a/1a/a90ceb191dd2f9e2897c69dde93ccc2d57dd21ce2acbd7b0333e8eea4e8d/numpy-2.2.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:526fc406ab991a340744aad7e25251dd47a6720a685fa3331e5c59fef5282a59", size = 14043508 },
|
{ url = "https://files.pythonhosted.org/packages/ba/86/8767f3d54f6ae0165749f84648da9dcc8cd78ab65d415494962c86fac80f/numpy-2.0.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9ea91dfb7c3d1c56a0e55657c0afb38cf1eeae4544c208dc465c3c9f3a7c09f9", size = 19937253 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/f1/5a/e572284c86a59dec0871a49cd4e5351e20b9c751399d5f1d79628c0542cb/numpy-2.2.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f74e6fdeb9a265624ec3a3918430205dff1df7e95a230779746a6af78bc615af", size = 16094100 },
|
{ url = "https://files.pythonhosted.org/packages/df/87/f76450e6e1c14e5bb1eae6836478b1028e096fd02e85c1c37674606ab752/numpy-2.0.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c1c9307701fec8f3f7a1e6711f9089c06e6284b3afbbcd259f7791282d660a15", size = 14414074 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/0c/2c/a79d24f364788386d85899dd280a94f30b0950be4b4a545f4fa4ed1d4ca7/numpy-2.2.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:53c09385ff0b72ba79d8715683c1168c12e0b6e84fb0372e97553d1ea91efe51", size = 15239691 },
|
{ url = "https://files.pythonhosted.org/packages/5c/ca/0f0f328e1e59f73754f06e1adfb909de43726d4f24c6a3f8805f34f2b0fa/numpy-2.0.2-cp311-cp311-win32.whl", hash = "sha256:a392a68bd329eafac5817e5aefeb39038c48b671afd242710b451e76090e81f4", size = 6470640 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/cf/79/1e20fd1c9ce5a932111f964b544facc5bb9bde7865f5b42f00b4a6a9192b/numpy-2.2.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:f3eac17d9ec51be534685ba877b6ab5edc3ab7ec95c8f163e5d7b39859524716", size = 17856571 },
|
{ url = "https://files.pythonhosted.org/packages/eb/57/3a3f14d3a759dcf9bf6e9eda905794726b758819df4663f217d658a58695/numpy-2.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:286cd40ce2b7d652a6f22efdfc6d1edf879440e53e76a75955bc0c826c7e64dc", size = 15910230 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/be/5b/cc155e107f75d694f562bdc84a26cc930569f3dfdfbccb3420b626065777/numpy-2.2.1-cp313-cp313-win32.whl", hash = "sha256:9ad014faa93dbb52c80d8f4d3dcf855865c876c9660cb9bd7553843dd03a4b1e", size = 6270841 },
|
{ url = "https://files.pythonhosted.org/packages/45/40/2e117be60ec50d98fa08c2f8c48e09b3edea93cfcabd5a9ff6925d54b1c2/numpy-2.0.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:df55d490dea7934f330006d0f81e8551ba6010a5bf035a249ef61a94f21c500b", size = 20895803 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/44/be/0e5cd009d2162e4138d79a5afb3b5d2341f0fe4777ab6e675aa3d4a42e21/numpy-2.2.1-cp313-cp313-win_amd64.whl", hash = "sha256:164a829b6aacf79ca47ba4814b130c4020b202522a93d7bff2202bfb33b61c60", size = 12606618 },
|
{ url = "https://files.pythonhosted.org/packages/46/92/1b8b8dee833f53cef3e0a3f69b2374467789e0bb7399689582314df02651/numpy-2.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8df823f570d9adf0978347d1f926b2a867d5608f434a7cff7f7908c6570dcf5e", size = 13471835 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/a8/87/04ddf02dd86fb17c7485a5f87b605c4437966d53de1e3745d450343a6f56/numpy-2.2.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:4dfda918a13cc4f81e9118dea249e192ab167a0bb1966272d5503e39234d694e", size = 20921004 },
|
{ url = "https://files.pythonhosted.org/packages/7f/19/e2793bde475f1edaea6945be141aef6c8b4c669b90c90a300a8954d08f0a/numpy-2.0.2-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:9a92ae5c14811e390f3767053ff54eaee3bf84576d99a2456391401323f4ec2c", size = 5038499 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/6e/3e/d0e9e32ab14005425d180ef950badf31b862f3839c5b927796648b11f88a/numpy-2.2.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:733585f9f4b62e9b3528dd1070ec4f52b8acf64215b60a845fa13ebd73cd0712", size = 14119910 },
|
{ url = "https://files.pythonhosted.org/packages/e3/ff/ddf6dac2ff0dd50a7327bcdba45cb0264d0e96bb44d33324853f781a8f3c/numpy-2.0.2-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:a842d573724391493a97a62ebbb8e731f8a5dcc5d285dfc99141ca15a3302d0c", size = 6633497 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/b5/5b/aa2d1905b04a8fb681e08742bb79a7bddfc160c7ce8e1ff6d5c821be0236/numpy-2.2.1-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:89b16a18e7bba224ce5114db863e7029803c179979e1af6ad6a6b11f70545008", size = 5153612 },
|
{ url = "https://files.pythonhosted.org/packages/72/21/67f36eac8e2d2cd652a2e69595a54128297cdcb1ff3931cfc87838874bd4/numpy-2.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c05e238064fc0610c840d1cf6a13bf63d7e391717d247f1bf0318172e759e692", size = 13621158 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/ce/35/6831808028df0648d9b43c5df7e1051129aa0d562525bacb70019c5f5030/numpy-2.2.1-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:676f4eebf6b2d430300f1f4f4c2461685f8269f94c89698d832cdf9277f30b84", size = 6668401 },
|
{ url = "https://files.pythonhosted.org/packages/39/68/e9f1126d757653496dbc096cb429014347a36b228f5a991dae2c6b6cfd40/numpy-2.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0123ffdaa88fa4ab64835dcbde75dcdf89c453c922f18dced6e27c90d1d0ec5a", size = 19236173 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/b1/38/10ef509ad63a5946cc042f98d838daebfe7eaf45b9daaf13df2086b15ff9/numpy-2.2.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:27f5cdf9f493b35f7e41e8368e7d7b4bbafaf9660cba53fb21d2cd174ec09631", size = 14014198 },
|
{ url = "https://files.pythonhosted.org/packages/d1/e9/1f5333281e4ebf483ba1c888b1d61ba7e78d7e910fdd8e6499667041cc35/numpy-2.0.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:96a55f64139912d61de9137f11bf39a55ec8faec288c75a54f93dfd39f7eb40c", size = 19634174 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/df/f8/c80968ae01df23e249ee0a4487fae55a4c0fe2f838dfe9cc907aa8aea0fa/numpy-2.2.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c1ad395cf254c4fbb5b2132fee391f361a6e8c1adbd28f2cd8e79308a615fe9d", size = 16076211 },
|
{ url = "https://files.pythonhosted.org/packages/71/af/a469674070c8d8408384e3012e064299f7a2de540738a8e414dcfd639996/numpy-2.0.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:ec9852fb39354b5a45a80bdab5ac02dd02b15f44b3804e9f00c556bf24b4bded", size = 14099701 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/09/69/05c169376016a0b614b432967ac46ff14269eaffab80040ec03ae1ae8e2c/numpy-2.2.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:08ef779aed40dbc52729d6ffe7dd51df85796a702afbf68a4f4e41fafdc8bda5", size = 15220266 },
|
{ url = "https://files.pythonhosted.org/packages/d0/3d/08ea9f239d0e0e939b6ca52ad403c84a2bce1bde301a8eb4888c1c1543f1/numpy-2.0.2-cp312-cp312-win32.whl", hash = "sha256:671bec6496f83202ed2d3c8fdc486a8fc86942f2e69ff0e986140339a63bcbe5", size = 6174313 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/f1/ff/94a4ce67ea909f41cf7ea712aebbe832dc67decad22944a1020bb398a5ee/numpy-2.2.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:26c9c4382b19fcfbbed3238a14abf7ff223890ea1936b8890f058e7ba35e8d71", size = 17852844 },
|
{ url = "https://files.pythonhosted.org/packages/b2/b5/4ac39baebf1fdb2e72585c8352c56d063b6126be9fc95bd2bb5ef5770c20/numpy-2.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:cfd41e13fdc257aa5778496b8caa5e856dc4896d4ccf01841daee1d96465467a", size = 15606179 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/46/72/8a5dbce4020dfc595592333ef2fbb0a187d084ca243b67766d29d03e0096/numpy-2.2.1-cp313-cp313t-win32.whl", hash = "sha256:93cf4e045bae74c90ca833cba583c14b62cb4ba2cba0abd2b141ab52548247e2", size = 6326007 },
|
|
||||||
{ url = "https://files.pythonhosted.org/packages/7b/9c/4fce9cf39dde2562584e4cfd351a0140240f82c0e3569ce25a250f47037d/numpy-2.2.1-cp313-cp313t-win_amd64.whl", hash = "sha256:bff7d8ec20f5f42607599f9994770fa65d76edca264a87b5e4ea5629bce12268", size = 12693107 },
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -946,6 +1080,29 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/ef/7d/500c9ad20238fcfcb4cb9243eede163594d7020ce87bd9610c9e02771876/pip-24.3.1-py3-none-any.whl", hash = "sha256:3790624780082365f47549d032f3770eeb2b1e8bd1f7b2e02dace1afa361b4ed", size = 1822182 },
|
{ url = "https://files.pythonhosted.org/packages/ef/7d/500c9ad20238fcfcb4cb9243eede163594d7020ce87bd9610c9e02771876/pip-24.3.1-py3-none-any.whl", hash = "sha256:3790624780082365f47549d032f3770eeb2b1e8bd1f7b2e02dace1afa361b4ed", size = 1822182 },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "platformdirs"
|
||||||
|
version = "4.3.6"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/13/fc/128cc9cb8f03208bdbf93d3aa862e16d376844a14f9a0ce5cf4507372de4/platformdirs-4.3.6.tar.gz", hash = "sha256:357fb2acbc885b0419afd3ce3ed34564c13c9b95c89360cd9563f73aa5e2b907", size = 21302 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/3c/a6/bc1012356d8ece4d66dd75c4b9fc6c1f6650ddd5991e421177d9f8f671be/platformdirs-4.3.6-py3-none-any.whl", hash = "sha256:73e575e1408ab8103900836b97580d5307456908a03e92031bab39e4554cc3fb", size = 18439 },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pooch"
|
||||||
|
version = "1.8.2"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "packaging" },
|
||||||
|
{ name = "platformdirs" },
|
||||||
|
{ name = "requests" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/c6/77/b3d3e00c696c16cf99af81ef7b1f5fe73bd2a307abca41bd7605429fe6e5/pooch-1.8.2.tar.gz", hash = "sha256:76561f0de68a01da4df6af38e9955c4c9d1a5c90da73f7e40276a5728ec83d10", size = 59353 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/a8/87/77cc11c7a9ea9fd05503def69e3d18605852cd0d4b0d3b8f15bbeb3ef1d1/pooch-1.8.2-py3-none-any.whl", hash = "sha256:3529a57096f7198778a5ceefd5ac3ef0e4d06a6ddaf9fc2d609b806f25302c47", size = 64574 },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "propcache"
|
name = "propcache"
|
||||||
version = "0.2.1"
|
version = "0.2.1"
|
||||||
@ -1062,6 +1219,15 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/c8/11/fabf6ecabb1fe5b7d96889228ca2a9158c4c3bb732e3b8ee3f7f6d40b703/pyarrow-18.1.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:b76130d835261b38f14fc41fdfb39ad8d672afb84c447126b84d5472244cfaba", size = 40043567 },
|
{ url = "https://files.pythonhosted.org/packages/c8/11/fabf6ecabb1fe5b7d96889228ca2a9158c4c3bb732e3b8ee3f7f6d40b703/pyarrow-18.1.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:b76130d835261b38f14fc41fdfb39ad8d672afb84c447126b84d5472244cfaba", size = 40043567 },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pycparser"
|
||||||
|
version = "2.22"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/1d/b2/31537cf4b1ca988837256c910a668b553fceb8f069bedc4b1c826024b52c/pycparser-2.22.tar.gz", hash = "sha256:491c8be9c040f5390f5bf44a5b07752bd07f56edf992381b05c701439eec10f6", size = 172736 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/13/a3/a812df4e2dd5696d1f351d58b8fe16a405b234ad2886a0dab9183fb78109/pycparser-2.22-py3-none-any.whl", hash = "sha256:c3702b6d3dd8c7abc1afa565d7e63d53a1d0bd86cdc24edd75470f4de499cfcc", size = 117552 },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pydantic"
|
name = "pydantic"
|
||||||
version = "2.10.5"
|
version = "2.10.5"
|
||||||
@ -1297,6 +1463,81 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/86/ca/aa489392ec6fb59223ffce825461e1f811a3affd417121a2088be7a5758b/safetensors-0.5.2-cp38-abi3-win_amd64.whl", hash = "sha256:78abdddd03a406646107f973c7843276e7b64e5e32623529dc17f3d94a20f589", size = 303756 },
|
{ url = "https://files.pythonhosted.org/packages/86/ca/aa489392ec6fb59223ffce825461e1f811a3affd417121a2088be7a5758b/safetensors-0.5.2-cp38-abi3-win_amd64.whl", hash = "sha256:78abdddd03a406646107f973c7843276e7b64e5e32623529dc17f3d94a20f589", size = 303756 },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "scikit-learn"
|
||||||
|
version = "1.6.1"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "joblib" },
|
||||||
|
{ name = "numpy" },
|
||||||
|
{ name = "scipy" },
|
||||||
|
{ name = "threadpoolctl" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/9e/a5/4ae3b3a0755f7b35a280ac90b28817d1f380318973cff14075ab41ef50d9/scikit_learn-1.6.1.tar.gz", hash = "sha256:b4fc2525eca2c69a59260f583c56a7557c6ccdf8deafdba6e060f94c1c59738e", size = 7068312 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/6c/2a/e291c29670795406a824567d1dfc91db7b699799a002fdaa452bceea8f6e/scikit_learn-1.6.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:72abc587c75234935e97d09aa4913a82f7b03ee0b74111dcc2881cba3c5a7b33", size = 12102620 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/25/92/ee1d7a00bb6b8c55755d4984fd82608603a3cc59959245068ce32e7fb808/scikit_learn-1.6.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:b3b00cdc8f1317b5f33191df1386c0befd16625f49d979fe77a8d44cae82410d", size = 11116234 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/30/cd/ed4399485ef364bb25f388ab438e3724e60dc218c547a407b6e90ccccaef/scikit_learn-1.6.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dc4765af3386811c3ca21638f63b9cf5ecf66261cc4815c1db3f1e7dc7b79db2", size = 12592155 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/a8/f3/62fc9a5a659bb58a03cdd7e258956a5824bdc9b4bb3c5d932f55880be569/scikit_learn-1.6.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:25fc636bdaf1cc2f4a124a116312d837148b5e10872147bdaf4887926b8c03d8", size = 13497069 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/a1/a6/c5b78606743a1f28eae8f11973de6613a5ee87366796583fb74c67d54939/scikit_learn-1.6.1-cp311-cp311-win_amd64.whl", hash = "sha256:fa909b1a36e000a03c382aade0bd2063fd5680ff8b8e501660c0f59f021a6415", size = 11139809 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/0a/18/c797c9b8c10380d05616db3bfb48e2a3358c767affd0857d56c2eb501caa/scikit_learn-1.6.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:926f207c804104677af4857b2c609940b743d04c4c35ce0ddc8ff4f053cddc1b", size = 12104516 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/c4/b7/2e35f8e289ab70108f8cbb2e7a2208f0575dc704749721286519dcf35f6f/scikit_learn-1.6.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:2c2cae262064e6a9b77eee1c8e768fc46aa0b8338c6a8297b9b6759720ec0ff2", size = 11167837 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/a4/f6/ff7beaeb644bcad72bcfd5a03ff36d32ee4e53a8b29a639f11bcb65d06cd/scikit_learn-1.6.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1061b7c028a8663fb9a1a1baf9317b64a257fcb036dae5c8752b2abef31d136f", size = 12253728 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/29/7a/8bce8968883e9465de20be15542f4c7e221952441727c4dad24d534c6d99/scikit_learn-1.6.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2e69fab4ebfc9c9b580a7a80111b43d214ab06250f8a7ef590a4edf72464dd86", size = 13147700 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/62/27/585859e72e117fe861c2079bcba35591a84f801e21bc1ab85bce6ce60305/scikit_learn-1.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:70b1d7e85b1c96383f872a519b3375f92f14731e279a7b4c6cfd650cf5dffc52", size = 11110613 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/2e/59/8eb1872ca87009bdcdb7f3cdc679ad557b992c12f4b61f9250659e592c63/scikit_learn-1.6.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:2ffa1e9e25b3d93990e74a4be2c2fc61ee5af85811562f1288d5d055880c4322", size = 12010001 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/9d/05/f2fc4effc5b32e525408524c982c468c29d22f828834f0625c5ef3d601be/scikit_learn-1.6.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:dc5cf3d68c5a20ad6d571584c0750ec641cc46aeef1c1507be51300e6003a7e1", size = 11096360 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/c8/e4/4195d52cf4f113573fb8ebc44ed5a81bd511a92c0228889125fac2f4c3d1/scikit_learn-1.6.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c06beb2e839ecc641366000ca84f3cf6fa9faa1777e29cf0c04be6e4d096a348", size = 12209004 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/94/be/47e16cdd1e7fcf97d95b3cb08bde1abb13e627861af427a3651fcb80b517/scikit_learn-1.6.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8ca8cb270fee8f1f76fa9bfd5c3507d60c6438bbee5687f81042e2bb98e5a97", size = 13171776 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/34/b0/ca92b90859070a1487827dbc672f998da95ce83edce1270fc23f96f1f61a/scikit_learn-1.6.1-cp313-cp313-win_amd64.whl", hash = "sha256:7a1c43c8ec9fde528d664d947dc4c0789be4077a3647f232869f41d9bf50e0fb", size = 11071865 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/12/ae/993b0fb24a356e71e9a894e42b8a9eec528d4c70217353a1cd7a48bc25d4/scikit_learn-1.6.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:a17c1dea1d56dcda2fac315712f3651a1fea86565b64b48fa1bc090249cbf236", size = 11955804 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/d6/54/32fa2ee591af44507eac86406fa6bba968d1eb22831494470d0a2e4a1eb1/scikit_learn-1.6.1-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:6a7aa5f9908f0f28f4edaa6963c0a6183f1911e63a69aa03782f0d924c830a35", size = 11100530 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/3f/58/55856da1adec655bdce77b502e94a267bf40a8c0b89f8622837f89503b5a/scikit_learn-1.6.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0650e730afb87402baa88afbf31c07b84c98272622aaba002559b614600ca691", size = 12433852 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/ff/4f/c83853af13901a574f8f13b645467285a48940f185b690936bb700a50863/scikit_learn-1.6.1-cp313-cp313t-win_amd64.whl", hash = "sha256:3f59fe08dc03ea158605170eb52b22a105f238a5d512c4470ddeca71feae8e5f", size = 11337256 },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "scipy"
|
||||||
|
version = "1.15.1"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "numpy" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/76/c6/8eb0654ba0c7d0bb1bf67bf8fbace101a8e4f250f7722371105e8b6f68fc/scipy-1.15.1.tar.gz", hash = "sha256:033a75ddad1463970c96a88063a1df87ccfddd526437136b6ee81ff0312ebdf6", size = 59407493 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/8e/2e/7b71312da9c2dabff53e7c9a9d08231bc34d9d8fdabe88a6f1155b44591c/scipy-1.15.1-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:5bd8d27d44e2c13d0c1124e6a556454f52cd3f704742985f6b09e75e163d20d2", size = 41424362 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/81/8c/ab85f1aa1cc200c796532a385b6ebf6a81089747adc1da7482a062acc46c/scipy-1.15.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:be3deeb32844c27599347faa077b359584ba96664c5c79d71a354b80a0ad0ce0", size = 32535910 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/3b/9c/6f4b787058daa8d8da21ddff881b4320e28de4704a65ec147adb50cb2230/scipy-1.15.1-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:5eb0ca35d4b08e95da99a9f9c400dc9f6c21c424298a0ba876fdc69c7afacedf", size = 24809398 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/16/2b/949460a796df75fc7a1ee1becea202cf072edbe325ebe29f6d2029947aa7/scipy-1.15.1-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:74bb864ff7640dea310a1377d8567dc2cb7599c26a79ca852fc184cc851954ac", size = 27918045 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/5f/36/67fe249dd7ccfcd2a38b25a640e3af7e59d9169c802478b6035ba91dfd6d/scipy-1.15.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:667f950bf8b7c3a23b4199db24cb9bf7512e27e86d0e3813f015b74ec2c6e3df", size = 38332074 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/fc/da/452e1119e6f720df3feb588cce3c42c5e3d628d4bfd4aec097bd30b7de0c/scipy-1.15.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:395be70220d1189756068b3173853029a013d8c8dd5fd3d1361d505b2aa58fa7", size = 40588469 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/7f/71/5f94aceeac99a4941478af94fe9f459c6752d497035b6b0761a700f5f9ff/scipy-1.15.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ce3a000cd28b4430426db2ca44d96636f701ed12e2b3ca1f2b1dd7abdd84b39a", size = 42965214 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/af/25/caa430865749d504271757cafd24066d596217e83326155993980bc22f97/scipy-1.15.1-cp311-cp311-win_amd64.whl", hash = "sha256:3fe1d95944f9cf6ba77aa28b82dd6bb2a5b52f2026beb39ecf05304b8392864b", size = 43896034 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/d8/6e/a9c42d0d39e09ed7fd203d0ac17adfea759cba61ab457671fe66e523dbec/scipy-1.15.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c09aa9d90f3500ea4c9b393ee96f96b0ccb27f2f350d09a47f533293c78ea776", size = 41478318 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/04/ee/e3e535c81828618878a7433992fecc92fa4df79393f31a8fea1d05615091/scipy-1.15.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:0ac102ce99934b162914b1e4a6b94ca7da0f4058b6d6fd65b0cef330c0f3346f", size = 32596696 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/c4/5e/b1b0124be8e76f87115f16b8915003eec4b7060298117715baf13f51942c/scipy-1.15.1-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:09c52320c42d7f5c7748b69e9f0389266fd4f82cf34c38485c14ee976cb8cb04", size = 24870366 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/14/36/c00cb73eefda85946172c27913ab995c6ad4eee00fa4f007572e8c50cd51/scipy-1.15.1-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:cdde8414154054763b42b74fe8ce89d7f3d17a7ac5dd77204f0e142cdc9239e9", size = 28007461 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/68/94/aff5c51b3799349a9d1e67a056772a0f8a47db371e83b498d43467806557/scipy-1.15.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4c9d8fc81d6a3b6844235e6fd175ee1d4c060163905a2becce8e74cb0d7554ce", size = 38068174 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/b0/3c/0de11ca154e24a57b579fb648151d901326d3102115bc4f9a7a86526ce54/scipy-1.15.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0fb57b30f0017d4afa5fe5f5b150b8f807618819287c21cbe51130de7ccdaed2", size = 40249869 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/15/09/472e8d0a6b33199d1bb95e49bedcabc0976c3724edd9b0ef7602ccacf41e/scipy-1.15.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:491d57fe89927fa1aafbe260f4cfa5ffa20ab9f1435025045a5315006a91b8f5", size = 42629068 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/ff/ba/31c7a8131152822b3a2cdeba76398ffb404d81d640de98287d236da90c49/scipy-1.15.1-cp312-cp312-win_amd64.whl", hash = "sha256:900f3fa3db87257510f011c292a5779eb627043dd89731b9c461cd16ef76ab3d", size = 43621992 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/2b/bf/dd68965a4c5138a630eeed0baec9ae96e5d598887835bdde96cdd2fe4780/scipy-1.15.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:100193bb72fbff37dbd0bf14322314fc7cbe08b7ff3137f11a34d06dc0ee6b85", size = 41441136 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/ef/5e/4928581312922d7e4d416d74c416a660addec4dd5ea185401df2269ba5a0/scipy-1.15.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:2114a08daec64980e4b4cbdf5bee90935af66d750146b1d2feb0d3ac30613692", size = 32533699 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/32/90/03f99c43041852837686898c66767787cd41c5843d7a1509c39ffef683e9/scipy-1.15.1-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:6b3e71893c6687fc5e29208d518900c24ea372a862854c9888368c0b267387ab", size = 24807289 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/9d/52/bfe82b42ae112eaba1af2f3e556275b8727d55ac6e4932e7aef337a9d9d4/scipy-1.15.1-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:837299eec3d19b7e042923448d17d95a86e43941104d33f00da7e31a0f715d3c", size = 27929844 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/f6/77/54ff610bad600462c313326acdb035783accc6a3d5f566d22757ad297564/scipy-1.15.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:82add84e8a9fb12af5c2c1a3a3f1cb51849d27a580cb9e6bd66226195142be6e", size = 38031272 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/f1/26/98585cbf04c7cf503d7eb0a1966df8a268154b5d923c5fe0c1ed13154c49/scipy-1.15.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:070d10654f0cb6abd295bc96c12656f948e623ec5f9a4eab0ddb1466c000716e", size = 40210217 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/fd/3f/3d2285eb6fece8bc5dbb2f9f94d61157d61d155e854fd5fea825b8218f12/scipy-1.15.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:55cc79ce4085c702ac31e49b1e69b27ef41111f22beafb9b49fea67142b696c4", size = 42587785 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/48/7d/5b5251984bf0160d6533695a74a5fddb1fa36edd6f26ffa8c871fbd4782a/scipy-1.15.1-cp313-cp313-win_amd64.whl", hash = "sha256:c352c1b6d7cac452534517e022f8f7b8d139cd9f27e6fbd9f3cbd0bfd39f5bef", size = 43640439 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/e7/b8/0e092f592d280496de52e152582030f8a270b194f87f890e1a97c5599b81/scipy-1.15.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:0458839c9f873062db69a03de9a9765ae2e694352c76a16be44f93ea45c28d2b", size = 41619862 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/f6/19/0b6e1173aba4db9e0b7aa27fe45019857fb90d6904038b83927cbe0a6c1d/scipy-1.15.1-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:af0b61c1de46d0565b4b39c6417373304c1d4f5220004058bdad3061c9fa8a95", size = 32610387 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/e7/02/754aae3bd1fa0f2479ade3cfdf1732ecd6b05853f63eee6066a32684563a/scipy-1.15.1-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:71ba9a76c2390eca6e359be81a3e879614af3a71dfdabb96d1d7ab33da6f2364", size = 24883814 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/1f/ac/d7906201604a2ea3b143bb0de51b3966f66441ba50b7dc182c4505b3edf9/scipy-1.15.1-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:14eaa373c89eaf553be73c3affb11ec6c37493b7eaaf31cf9ac5dffae700c2e0", size = 27944865 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/84/9d/8f539002b5e203723af6a6f513a45e0a7671e9dabeedb08f417ac17e4edc/scipy-1.15.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f735bc41bd1c792c96bc426dece66c8723283695f02df61dcc4d0a707a42fc54", size = 39883261 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/97/c0/62fd3bab828bcccc9b864c5997645a3b86372a35941cdaf677565c25c98d/scipy-1.15.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:2722a021a7929d21168830790202a75dbb20b468a8133c74a2c0230c72626b6c", size = 42093299 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/e4/1f/5d46a8d94e9f6d2c913cbb109e57e7eed914de38ea99e2c4d69a9fc93140/scipy-1.15.1-cp313-cp313t-win_amd64.whl", hash = "sha256:bc7136626261ac1ed988dca56cfc4ab5180f75e0ee52e58f1e6aa74b5f3eacd5", size = 43181730 },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "setuptools"
|
name = "setuptools"
|
||||||
version = "75.8.0"
|
version = "75.8.0"
|
||||||
@ -1315,6 +1556,46 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050 },
|
{ url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050 },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "soundfile"
|
||||||
|
version = "0.13.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "cffi" },
|
||||||
|
{ name = "numpy" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/76/7c/4054082696ee09397eb88947f91cd2429bb7d20db0d4e82bfeb64c234589/soundfile-0.13.0.tar.gz", hash = "sha256:e83399d8bde7d73b117c33d6a1ec8571318338f89ce72f4c3d457e9768798355", size = 45971 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/d3/35/79f0e4934eba4807dfcdc5589de4733186e1d89ecc313708f1ac0e255c45/soundfile-0.13.0-py2.py3-none-any.whl", hash = "sha256:6a732002843d267267de52367cbdb16dfea6c1f4e3de5e4ddfb2bd0b9b65dddf", size = 25706 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/8d/d4/d21a409d4002b585e755fd0301dd9bc59c243e8be8013cade60506c8730c/soundfile-0.13.0-py2.py3-none-macosx_10_9_x86_64.whl", hash = "sha256:c28024e59ebf2e5b12f5d77a16eb3ef1527e7580d7bbebfb5645253368391385", size = 1142204 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/24/d9/be447d3eb12d0e0c432f3c3d9c6baa6f27be475bcfd82d4e9c4f8f7e1838/soundfile-0.13.0-py2.py3-none-macosx_11_0_arm64.whl", hash = "sha256:af3d06019040dd3db0e0b173baca76db9f7d59b5750a347892f2bf7763a8083c", size = 1101360 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/99/c8/cf31904a8c17f23fa2d31852b0cb509b5ea017783d6cbdf62d7c4ac0e631/soundfile-0.13.0-py2.py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:3adcf4b70d2c872f6cb5a1b3eeb916b90233860a7d13652652753c4e628ac9f5", size = 1235683 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/21/2e/93c9283e6d8a95900e390d888f9f3e2ae3ce274d3a11bd79117eeb8f0879/soundfile-0.13.0-py2.py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:5f2c7d389d6cde9af49ec73f5a32a06f7397a2add3808ceb94ebff4a116e3446", size = 1313600 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/ff/36/5b0758658194eaa4f90c126ebdbd56f6c35c9f15e1640afaaa584a476b4b/soundfile-0.13.0-py2.py3-none-win32.whl", hash = "sha256:217fd97b9a515b6b92d817c917bd7c3bc838e4ffe9b68c2a0659b70b9af1a5dd", size = 899836 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/03/69/eda7a076709ada14a11604347ea3b80b3888101c93bbc739071d2f553df5/soundfile-0.13.0-py2.py3-none-win_amd64.whl", hash = "sha256:9fd67b1867fb7ce4a1bf1fd6600dfe9bf2af26b7ae3671719196c1d5632fa462", size = 1019114 },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "soxr"
|
||||||
|
version = "0.5.0.post1"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "numpy" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/02/c0/4429bf9b3be10e749149e286aa5c53775399ec62891c6b970456c6dca325/soxr-0.5.0.post1.tar.gz", hash = "sha256:7092b9f3e8a416044e1fa138c8172520757179763b85dc53aa9504f4813cff73", size = 170853 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/29/28/dc62dae260a77603e8257e9b79078baa2ca4c0b4edc6f9f82c9113d6ef18/soxr-0.5.0.post1-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:6fb77b626773a966e3d8f6cb24f6f74b5327fa5dc90f1ff492450e9cdc03a378", size = 203648 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/0e/48/3e88329a695f6e0e38a3b171fff819d75d7cc055dae1ec5d5074f34d61e3/soxr-0.5.0.post1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:39e0f791ba178d69cd676485dbee37e75a34f20daa478d90341ecb7f6d9d690f", size = 159933 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/9c/a5/6b439164be6871520f3d199554568a7656e96a867adbbe5bac179caf5776/soxr-0.5.0.post1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f0b558f445ba4b64dbcb37b5f803052eee7d93b1dbbbb97b3ec1787cb5a28eb", size = 221010 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/9f/e5/400e3bf7f29971abad85cb877e290060e5ec61fccd2fa319e3d85709c1be/soxr-0.5.0.post1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ca6903671808e0a6078b0d146bb7a2952b118dfba44008b2aa60f221938ba829", size = 252471 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/86/94/6a7e91bea7e6ca193ee429869b8f18548cd79759e064021ecb5756024c7c/soxr-0.5.0.post1-cp311-cp311-win_amd64.whl", hash = "sha256:c4d8d5283ed6f5efead0df2c05ae82c169cfdfcf5a82999c2d629c78b33775e8", size = 166723 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/5d/e3/d422d279e51e6932e7b64f1170a4f61a7ee768e0f84c9233a5b62cd2c832/soxr-0.5.0.post1-cp312-abi3-macosx_10_14_x86_64.whl", hash = "sha256:fef509466c9c25f65eae0ce1e4b9ac9705d22c6038c914160ddaf459589c6e31", size = 199993 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/20/f1/88adaca3c52e03bcb66b63d295df2e2d35bf355d19598c6ce84b20be7fca/soxr-0.5.0.post1-cp312-abi3-macosx_11_0_arm64.whl", hash = "sha256:4704ba6b13a3f1e41d12acf192878384c1c31f71ce606829c64abdf64a8d7d32", size = 156373 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/b8/38/bad15a9e615215c8219652ca554b601663ac3b7ac82a284aca53ec2ff48c/soxr-0.5.0.post1-cp312-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bd052a66471a7335b22a6208601a9d0df7b46b8d087dce4ff6e13eed6a33a2a1", size = 216564 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/e1/1a/569ea0420a0c4801c2c8dd40d8d544989522f6014d51def689125f3f2935/soxr-0.5.0.post1-cp312-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a3f16810dd649ab1f433991d2a9661e9e6a116c2b4101039b53b3c3e90a094fc", size = 248455 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/bc/10/440f1ba3d4955e0dc740bbe4ce8968c254a3d644d013eb75eea729becdb8/soxr-0.5.0.post1-cp312-abi3-win_amd64.whl", hash = "sha256:b1be9fee90afb38546bdbd7bde714d1d9a8c5a45137f97478a83b65e7f3146f6", size = 164937 },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "sympy"
|
name = "sympy"
|
||||||
version = "1.13.1"
|
version = "1.13.1"
|
||||||
@ -1327,6 +1608,15 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/b2/fe/81695a1aa331a842b582453b605175f419fe8540355886031328089d840a/sympy-1.13.1-py3-none-any.whl", hash = "sha256:db36cdc64bf61b9b24578b6f7bab1ecdd2452cf008f34faa33776680c26d66f8", size = 6189177 },
|
{ url = "https://files.pythonhosted.org/packages/b2/fe/81695a1aa331a842b582453b605175f419fe8540355886031328089d840a/sympy-1.13.1-py3-none-any.whl", hash = "sha256:db36cdc64bf61b9b24578b6f7bab1ecdd2452cf008f34faa33776680c26d66f8", size = 6189177 },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "threadpoolctl"
|
||||||
|
version = "3.5.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/bd/55/b5148dcbf72f5cde221f8bfe3b6a540da7aa1842f6b491ad979a6c8b84af/threadpoolctl-3.5.0.tar.gz", hash = "sha256:082433502dd922bf738de0d8bcc4fdcbf0979ff44c42bd40f5af8a282f6fa107", size = 41936 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/4b/2c/ffbf7a134b9ab11a67b0cf0726453cedd9c5043a4fe7a35d1cefa9a1bcfb/threadpoolctl-3.5.0-py3-none-any.whl", hash = "sha256:56c1e26c150397e58c4926da8eeee87533b1e32bef131bd4bf6a2f45f3185467", size = 18414 },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tokenizers"
|
name = "tokenizers"
|
||||||
version = "0.21.0"
|
version = "0.21.0"
|
||||||
|
Loading…
Reference in New Issue
Block a user