| | import threading |
| | import torch |
| | import time |
| | import json |
| | import queue |
| | import uuid |
| | import matplotlib.pyplot as plt |
| | from functools import partial |
| | from typing import Generator, Optional, List, Dict, Any, Tuple |
| | from datasets import Dataset, load_dataset |
| | from trl import SFTConfig, SFTTrainer |
| | from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl |
| | from huggingface_hub import HfApi, model_info, metadata_update |
| |
|
| | from config import AppConfig |
| | from tools import DEFAULT_TOOLS |
| | from utils import ( |
| | authenticate_hf, |
| | load_model_and_tokenizer, |
| | create_conversation_format, |
| | parse_csv_dataset, |
| | zip_directory |
| | ) |
| |
|
| | class AbortCallback(TrainerCallback): |
| | def __init__(self, stop_event: threading.Event): |
| | self.stop_event = stop_event |
| |
|
| | def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): |
| | if self.stop_event.is_set(): |
| | control.should_training_stop = True |
| |
|
| | class LogStreamingCallback(TrainerCallback): |
| | def __init__(self, log_queue: queue.Queue): |
| | self.log_queue = log_queue |
| | |
| | def _get_string(self, value): |
| | if isinstance(value, float): |
| | return f"{value:.4f}" |
| | return str(value) |
| |
|
| | def on_log(self, args, state, control, logs=None, **kwargs): |
| | if not logs: |
| | return |
| |
|
| | metrics_map = { |
| | "loss": "Loss", |
| | "eval_loss": "Eval Loss", |
| | "learning_rate": "LR", |
| | "epoch": "Epoch" |
| | } |
| | log_parts = [f"π [Step {state.global_step}]"] |
| | |
| | for key, label in metrics_map.items(): |
| | if key in logs: |
| | val = logs[key] |
| | if isinstance(val, (float, int)): |
| | val_str = f"{val:.4f}" if val > 1e-4 else f"{val:.2e}" |
| | else: |
| | val_str = str(val) |
| | |
| | log_parts.append(f"{label}: {val_str}") |
| | |
| | log_payload = logs.copy() |
| | log_payload['step'] = state.global_step |
| | |
| | self.log_queue.put((" | ".join(log_parts), log_payload)) |
| |
|
| | class FunctionGemmaEngine: |
| | def __init__(self, config: AppConfig): |
| | self.config = config |
| | |
| | self.session_id = str(uuid.uuid4())[:8] |
| | self.output_dir = self.config.ARTIFACTS_DIR.joinpath(f"session_{self.session_id}") |
| | self.output_dir.mkdir(parents=True, exist_ok=True) |
| |
|
| | self.model = None |
| | self.tokenizer = None |
| | self.loaded_model_name = None |
| | self.imported_dataset = [] |
| | self.stop_event = threading.Event() |
| | self.current_tools = DEFAULT_TOOLS |
| | self.has_model_tuned = False |
| |
|
| | authenticate_hf(self.config.HF_TOKEN) |
| | try: |
| | self.refresh_model() |
| | except Exception as e: |
| | print(f"Initial load warning: {e}") |
| |
|
| | |
| | def get_tools_json(self) -> str: |
| | return json.dumps(self.current_tools, indent=2) |
| |
|
| | def update_tools(self, json_str: str) -> str: |
| | try: |
| | new_tools = json.loads(json_str) |
| | if not isinstance(new_tools, list): |
| | return "Error: Schema must be a list of tool definitions." |
| | self.current_tools = new_tools |
| | return "β
Tool Schema Updated successfully." |
| | except json.JSONDecodeError as e: |
| | return f"β JSON Error: {e}" |
| | except Exception as e: |
| | return f"β Error: {e}" |
| |
|
| | |
| | |
| | def _load_model_weights(self): |
| | print(f"[{self.session_id}] Loading model: {self.config.MODEL_NAME}...") |
| | self.model, self.tokenizer = load_model_and_tokenizer(self.config.MODEL_NAME) |
| | self.loaded_model_name = self.config.MODEL_NAME |
| |
|
| | def refresh_model(self) -> str: |
| | self.has_model_tuned = False |
| | try: |
| | self._load_model_weights() |
| | return f"Model loaded: {self.loaded_model_name}\nData cleared.\nReady (Session {self.session_id})." |
| | except Exception as e: |
| | self.model = None |
| | self.tokenizer = None |
| | self.loaded_model_name = None |
| | return f"CRITICAL ERROR: Model failed to load. {e}" |
| |
|
| | def load_csv(self, file_path: str) -> str: |
| | try: |
| | new_data = parse_csv_dataset(file_path) |
| | if not new_data: |
| | return "Error: File empty or format invalid." |
| | self.imported_dataset = new_data |
| | return f"Successfully imported {len(new_data)} samples." |
| | except Exception as e: |
| | return f"Import failed: {e}" |
| |
|
| | def trigger_stop(self): |
| | self.stop_event.set() |
| |
|
| | def _ensure_model_consistency(self) -> Generator[str, None, bool]: |
| | """Checks if the requested model matches the loaded one. Reloads if necessary.""" |
| | if self.config.MODEL_NAME != self.loaded_model_name: |
| | yield f"π Model changed. Switching from '{self.loaded_model_name}' to '{self.config.MODEL_NAME}'...\n" |
| | try: |
| | self._load_model_weights() |
| | yield "β
Model reloaded successfully.\n" |
| | return True |
| | except Exception as e: |
| | yield f"β Failed to load model '{self.config.MODEL_NAME}': {e}\n" |
| | return False |
| | if self.model is None: |
| | yield "β Error: No model loaded.\n" |
| | return False |
| | return True |
| |
|
| | |
| | |
| | def run_evaluation(self, test_size: float, shuffle_data: bool) -> Generator[str, None, None]: |
| | self.stop_event.clear() |
| | output_buffer = "" |
| | |
| | try: |
| | |
| | gen = self._ensure_model_consistency() |
| | try: |
| | while True: |
| | msg = next(gen) |
| | output_buffer += msg |
| | yield output_buffer |
| | except StopIteration as e: |
| | if not e.value: return |
| | |
| | |
| | output_buffer += f"β³ Preparing Dataset for Eval (Test Split: {test_size})...\n" |
| | yield output_buffer |
| |
|
| | dataset, log = self._prepare_dataset() |
| | output_buffer += log |
| | yield output_buffer |
| | |
| | if not dataset: |
| | output_buffer += "β Dataset creation failed.\n" |
| | yield output_buffer |
| | return |
| |
|
| | if len(dataset) > 1: |
| | dataset = dataset.train_test_split(test_size=test_size, shuffle=shuffle_data) |
| | else: |
| | dataset = {"train": dataset, "test": dataset} |
| | |
| | |
| | output_buffer += "\nπ Evaluating Model Success Rate on Test Split...\n" |
| | yield output_buffer |
| |
|
| | for update in self._evaluate_model(dataset["test"]): |
| | yield f"{output_buffer}{update}" |
| | if self.stop_event.is_set(): |
| | yield f"{output_buffer}{update}\n\nπ Evaluation interrupted by user." |
| | break |
| | finally: |
| | self.stop_event.set() |
| |
|
| | |
| |
|
| | def run_training_pipeline(self, epochs: int, learning_rate: float, test_size: float, shuffle_data: bool) -> Generator[Tuple[str, Any], None, None]: |
| | self.stop_event.clear() |
| | output_buffer = "" |
| | last_plot = None |
| |
|
| | try: |
| | |
| | gen = self._ensure_model_consistency() |
| | try: |
| | while True: |
| | msg = next(gen) |
| | output_buffer += f"{msg}" |
| | yield output_buffer, None |
| | except StopIteration as e: |
| | if not e.value: return |
| |
|
| | output_buffer += f"β³ Preparing Dataset (Test Split: {test_size}, Shuffle: {shuffle_data})...\n" |
| | yield output_buffer, None |
| |
|
| | dataset, log = self._prepare_dataset() |
| | if not dataset: |
| | yield "Dataset creation failed.", None |
| | return |
| |
|
| | output_buffer += log |
| | yield output_buffer, None |
| | |
| | if len(dataset) > 1: |
| | dataset = dataset.train_test_split(test_size=test_size, shuffle=shuffle_data) |
| | else: |
| | dataset = {"train": dataset, "test": dataset} |
| |
|
| | |
| | output_buffer += f"\nπ Starting Fine-tuning (Epochs: {epochs}, LR: {learning_rate})...\n" |
| | yield output_buffer, None |
| | |
| | log_queue = queue.Queue() |
| | training_error = None |
| | running_history = [] |
| | |
| | def train_wrapper(): |
| | nonlocal training_error |
| | try: |
| | self._execute_trainer(dataset, log_queue, epochs, learning_rate) |
| | except Exception as e: |
| | training_error = e |
| | |
| | train_thread = threading.Thread(target=train_wrapper) |
| | train_thread.start() |
| | |
| | while train_thread.is_alive(): |
| | while not log_queue.empty(): |
| | payload = log_queue.get() |
| | if isinstance(payload, tuple): |
| | msg, log_data = payload |
| | output_buffer += f"{msg}\n" |
| | running_history.append(log_data) |
| | try: |
| | last_plot = self._generate_loss_plot(running_history) |
| | yield output_buffer, last_plot |
| | except Exception: |
| | yield output_buffer, last_plot |
| | else: |
| | output_buffer += f"{payload}\n" |
| | yield output_buffer, last_plot |
| | |
| | if self.stop_event.is_set(): |
| | yield f"{output_buffer}π Stop signal sent. Waiting for trainer to wrap up...\n", last_plot |
| | |
| | time.sleep(0.1) |
| | |
| | train_thread.join() |
| | |
| | self.has_model_tuned = True |
| | |
| | while not log_queue.empty(): |
| | payload = log_queue.get() |
| | if isinstance(payload, tuple): |
| | msg, log_data = payload |
| | output_buffer += f"{msg}\n" |
| | running_history.append(log_data) |
| | last_plot = self._generate_loss_plot(running_history) |
| | else: |
| | output_buffer += f"{payload}\n" |
| | yield output_buffer, last_plot |
| | |
| | if training_error: |
| | output_buffer += f"β Error during training: {training_error}\n" |
| | yield output_buffer, last_plot |
| | return |
| |
|
| | if self.stop_event.is_set(): |
| | output_buffer += "π Training manually stopped.\n" |
| | yield output_buffer, last_plot |
| | return |
| | |
| | output_buffer += "β
Training finished.\n" |
| | yield output_buffer, last_plot |
| | |
| | finally: |
| | self.stop_event.set() |
| |
|
| | def _prepare_dataset(self): |
| | formatting_fn = partial(create_conversation_format, tools_list=self.current_tools) |
| |
|
| | if not self.imported_dataset: |
| | ds = load_dataset(self.config.DEFAULT_DATASET, split="train").map(formatting_fn) |
| | log = f" `-> using default dataset (size:{len(ds)})\n" |
| | else: |
| | dataset_as_dicts = [{ |
| | "user_content": row[0], "tool_name": row[1], "tool_arguments": row[2]} |
| | for row in self.imported_dataset |
| | ] |
| | ds = Dataset.from_list(dataset_as_dicts).map(formatting_fn) |
| | log = f" `-> using custom dataset (size:{len(ds)})\n" |
| | return ds, log |
| |
|
| | def _execute_trainer(self, dataset, log_queue: queue.Queue, epochs: int, learning_rate: float) -> List[Dict]: |
| | torch_dtype = self.model.dtype |
| | args = SFTConfig( |
| | output_dir=str(self.output_dir), |
| | max_length=512, |
| | packing=False, |
| | num_train_epochs=epochs, |
| | per_device_train_batch_size=4, |
| | logging_steps=1, |
| | save_strategy="no", |
| | eval_strategy="epoch", |
| | learning_rate=learning_rate, |
| | fp16=(torch_dtype == torch.float16), |
| | bf16=(torch_dtype == torch.bfloat16), |
| | report_to="none", |
| | dataset_kwargs={"add_special_tokens": False, "append_concat_token": True} |
| | ) |
| |
|
| | trainer = SFTTrainer( |
| | model=self.model, |
| | args=args, |
| | train_dataset=dataset['train'], |
| | eval_dataset=dataset['test'], |
| | processing_class=self.tokenizer, |
| | callbacks=[ |
| | AbortCallback(self.stop_event), |
| | LogStreamingCallback(log_queue) |
| | ] |
| | ) |
| | trainer.train() |
| | trainer.save_model() |
| | return trainer.state.log_history |
| | |
| | def _generate_loss_plot(self, history: list): |
| | if not history: return None |
| | plt.close('all') |
| | |
| | train_steps = [x['step'] for x in history if 'loss' in x] |
| | train_loss = [x['loss'] for x in history if 'loss' in x] |
| | eval_steps = [x['step'] for x in history if 'eval_loss' in x] |
| | eval_loss = [x['eval_loss'] for x in history if 'eval_loss' in x] |
| |
|
| | fig, ax = plt.subplots(figsize=(10, 5)) |
| | if train_steps: |
| | ax.plot(train_steps, train_loss, label='Training Loss', linestyle='-', marker=None) |
| | if eval_steps: |
| | ax.plot(eval_steps, eval_loss, label='Validation Loss', linestyle='--', marker='o') |
| |
|
| | ax.set_xlabel("Steps") |
| | ax.set_ylabel("Loss") |
| | ax.set_title("Training & Validation Loss") |
| | ax.legend() |
| | ax.grid(True, linestyle=':', alpha=0.6) |
| | plt.tight_layout() |
| | return fig |
| |
|
| | def _evaluate_model(self, test_dataset) -> Generator[str, None, None]: |
| | results = [] |
| | success_count = 0 |
| | for idx, item in enumerate(test_dataset): |
| | messages = item["messages"][:2] |
| | try: |
| | inputs = self.tokenizer.apply_chat_template( |
| | messages, tools=self.current_tools, add_generation_prompt=True, return_dict=True, return_tensors="pt" |
| | ) |
| | device = self.model.device |
| | inputs = {k: v.to(device) for k, v in inputs.items()} |
| | out = self.model.generate( |
| | **inputs, |
| | pad_token_id=self.tokenizer.eos_token_id, |
| | max_new_tokens=128 |
| | ) |
| | output = self.tokenizer.decode(out[0][len(inputs["input_ids"][0]):], skip_special_tokens=True) |
| | log_entry = f"{idx+1}. Prompt: {messages[1]['content']}\n Output: {output[:100]}..." |
| | expected_tool = item['messages'][2]['tool_calls'][0]['function']['name'] |
| | if expected_tool in output: |
| | log_entry += "\n -> β
Correct Tool" |
| | success_count += 1 |
| | else: |
| | log_entry += f"\n -> β Wrong Tool (Expected: {expected_tool})" |
| | results.append(log_entry) |
| | yield "\n".join(results) + f"\n\nRunning Success Rate: {success_count}/{idx+1}" |
| | except Exception as e: |
| | yield f"Error during inference: {e}" |
| |
|
| | def get_zip_path(self) -> Optional[str]: |
| | if not self.output_dir.exists(): return None |
| | base_name = str(self.config.ARTIFACTS_DIR.joinpath(f"functiongemma_finetuned_{self.session_id}")) |
| | return zip_directory(str(self.output_dir), base_name) |
| |
|
| | def upload_model_to_hub(self, repo_name: str, oauth_token: str) -> str: |
| | """Uploads the trained model to Hugging Face Hub.""" |
| | if not self.output_dir.exists() or not any(self.output_dir.iterdir()): |
| | return "β No trained model found in current session. Run training first." |
| | |
| | try: |
| | api = HfApi(token=oauth_token) |
| |
|
| | |
| | user_info = api.whoami() |
| | username = user_info['name'] |
| | |
| | |
| | repo_id = f"{username}/{repo_name}" |
| | print(f"Preparing to upload to: {repo_id}") |
| |
|
| | |
| | api.create_repo(repo_id=repo_id, exist_ok=True) |
| | |
| | |
| | print(f"Uploading to {repo_id}...") |
| | repo_url = api.upload_folder( |
| | folder_path=str(self.output_dir), |
| | repo_id=repo_id, |
| | repo_type="model" |
| | ) |
| |
|
| | info = model_info( |
| | repo_id=repo_id, |
| | token=oauth_token |
| | ) |
| | tags = ["functiongemma", "functiongemma-tuning-lab"] |
| | if info.card_data: |
| | tags = info.card_data.tags |
| | tags.append("functiongemma-tuning-lab") |
| |
|
| | metadata_update(repo_id, {"tags": tags}, overwrite=True, token=oauth_token) |
| |
|
| | return f"β
Success! Model uploaded to: {repo_url}" |
| | except Exception as e: |
| | return f"β Upload failed: {str(e)}" |