| | import os |
| | import csv |
| | import json |
| | import shutil |
| | from typing import Optional, List, Any |
| | from huggingface_hub import login |
| | from transformers import AutoTokenizer, AutoModelForCausalLM |
| | from tools import DEFAULT_SYSTEM_MSG |
| | |
| |
|
| | def authenticate_hf(token: Optional[str]) -> None: |
| | """Logs into the Hugging Face Hub.""" |
| | if token: |
| | print("Logging into Hugging Face Hub...") |
| | login(token=token) |
| | else: |
| | print("Skipping Hugging Face login: HF_TOKEN not set.") |
| |
|
| | def load_model_and_tokenizer(model_name: str): |
| | print(f"Loading Transformer model: {model_name}") |
| | try: |
| | target_model = model_name |
| | if model_name.startswith("..") and not os.path.exists(model_name): |
| | print(f"Warning: Local path {model_name} not found. Falling back to default hub model.") |
| | target_model = "google/gemma-2b-it" |
| |
|
| | tokenizer = AutoTokenizer.from_pretrained(target_model) |
| | model = AutoModelForCausalLM.from_pretrained(target_model) |
| | print("Model loaded successfully.") |
| | return model, tokenizer |
| | except Exception as e: |
| | print(f"Error loading Transformer model {target_model}: {e}") |
| | raise e |
| |
|
| | |
| | def create_conversation_format(sample, tools_list): |
| | """Formats a dataset row into the conversational format required for SFT.""" |
| | try: |
| | tool_args = json.loads(sample["tool_arguments"]) |
| | except (json.JSONDecodeError, TypeError): |
| | tool_args = {} |
| | |
| | return { |
| | "messages": [ |
| | {"role": "developer", "content": DEFAULT_SYSTEM_MSG}, |
| | {"role": "user", "content": sample["user_content"]}, |
| | {"role": "assistant", "tool_calls": [{"type": "function", "function": {"name": sample["tool_name"], "arguments": tool_args}}]}, |
| | ], |
| | "tools": tools_list |
| | } |
| |
|
| | def parse_csv_dataset(file_path: str) -> List[List[str]]: |
| | """Parses an uploaded CSV file.""" |
| | dataset = [] |
| | if not file_path: |
| | return dataset |
| | |
| | with open(file_path, 'r', newline='', encoding='utf-8') as f: |
| | reader = csv.reader(f) |
| | try: |
| | header = next(reader) |
| | if not (header and "user_content" in header[0].lower()): |
| | f.seek(0) |
| | except StopIteration: |
| | return dataset |
| |
|
| | for row in reader: |
| | if len(row) >= 3: |
| | dataset.append([s.strip() for s in row[:3]]) |
| | return dataset |
| |
|
| | def zip_directory(source_dir: str, output_name_base: str) -> str: |
| | """Zips a directory.""" |
| | return shutil.make_archive( |
| | base_name=output_name_base, |
| | format='zip', |
| | root_dir=source_dir, |
| | ) |
| |
|