""" MLX Training Wrapper für Fine-Tuning Nutzt mlx-lm für LoRA Fine-Tuning """ import json import time import psutil from pathlib import Path from typing import Dict, List, Callable, Optional from dataclasses import dataclass import threading import queue @dataclass class TrainingConfig: """Training Konfiguration""" model_name: str learning_rate: float = 1e-5 epochs: int = 3 batch_size: int = 4 lora_rank: int = 8 lora_alpha: int = 16 max_seq_length: int = 2048 val_every: int = 50 class TrainingStatus: """Verwaltet den aktuellen Training-Status""" def __init__(self): self.is_training = False self.should_stop = False self.current_step = 0 self.total_steps = 0 self.current_epoch = 0 self.train_loss = 0.0 self.val_loss = 0.0 self.train_loss_history = [] self.val_loss_history = [] self.start_time = None self.error = None def reset(self): """Setzt den Status zurück""" self.is_training = False self.should_stop = False self.current_step = 0 self.total_steps = 0 self.current_epoch = 0 self.train_loss = 0.0 self.val_loss = 0.0 self.train_loss_history = [] self.val_loss_history = [] self.start_time = None self.error = None def to_dict(self) -> Dict: """Konvertiert zu Dictionary für API""" eta = None if self.is_training and self.current_step > 0 and self.start_time: elapsed = time.time() - self.start_time steps_remaining = self.total_steps - self.current_step eta = int((elapsed / self.current_step) * steps_remaining) memory_usage = psutil.virtual_memory().percent return { 'is_training': self.is_training, 'current_step': self.current_step, 'total_steps': self.total_steps, 'current_epoch': self.current_epoch, 'train_loss': round(self.train_loss, 4) if self.train_loss else None, 'val_loss': round(self.val_loss, 4) if self.val_loss else None, 'train_loss_history': [round(l, 4) for l in self.train_loss_history], 'val_loss_history': [round(l, 4) for l in self.val_loss_history], 'eta_seconds': eta, 'memory_usage_percent': memory_usage, 'error': self.error } class MLXTrainer: """Wrapper für MLX Training""" def __init__(self, models_dir: str = "models", output_dir: str = "output"): self.models_dir = Path(models_dir) self.output_dir = Path(output_dir) self.models_dir.mkdir(exist_ok=True) self.output_dir.mkdir(exist_ok=True) self.status = TrainingStatus() self.training_thread = None def prepare_training_data(self, train_data: List[Dict], val_data: List[Dict], data_dir: Path) -> tuple[Path, Path]: """Konvertiert Daten ins MLX Format (JSONL)""" def format_example(item: Dict) -> Dict: """Formatiert ein Beispiel im Chat-Format""" task_type = item['task_type'] body = item['body'] output = item['expected_output'] # Task-spezifische Prompts task_prompts = { 'Zusammenfassen': 'Fasse folgende E-Mail zusammen:', 'Antwort schreiben': 'Schreibe eine Antwort auf folgende E-Mail:', 'Kategorisieren': 'Kategorisiere folgende E-Mail:', 'Action Items': 'Extrahiere die Action Items aus folgender E-Mail:', 'Custom': 'Bearbeite folgende E-Mail:' } instruction = task_prompts.get(task_type, task_prompts['Custom']) return { 'messages': [ { 'role': 'user', 'content': f"{instruction}\n\n{body}" }, { 'role': 'assistant', 'content': output } ] } train_file = data_dir / 'train.jsonl' val_file = data_dir / 'val.jsonl' # Schreibe Training Data with open(train_file, 'w', encoding='utf-8') as f: for item in train_data: f.write(json.dumps(format_example(item), ensure_ascii=False) + '\n') # Schreibe Validation Data with open(val_file, 'w', encoding='utf-8') as f: for item in val_data: f.write(json.dumps(format_example(item), ensure_ascii=False) + '\n') return train_file, val_file def _run_training(self, config: TrainingConfig, train_file: Path, val_file: Path, output_path: Path): """Führt das Training aus (läuft in eigenem Thread)""" try: # Import hier um MLX nur bei Bedarf zu laden from mlx_lm import load, LoRALinear from mlx_lm.tuner import train as mlx_train import mlx.core as mx import mlx.nn as nn import mlx.optimizers as optim self.status.is_training = True self.status.start_time = time.time() self.status.error = None # Lade Modell model_path = self.models_dir / config.model_name if not model_path.exists(): raise FileNotFoundError(f"Model not found: {model_path}") # Training durchführen mit mlx-lm # Dies ist ein vereinfachtes Beispiel - mlx-lm hat eigene Trainer # In der Praxis würde man mlx_lm.tuner verwenden # Lade Training Config train_config = { 'model': str(model_path), 'data': str(train_file), 'val_data': str(val_file), 'train': True, 'iters': config.epochs * 100, # Approximation 'val_batches': 10, 'learning_rate': config.learning_rate, 'batch_size': config.batch_size, 'lora_layers': config.lora_rank, 'adapter_file': str(output_path / 'adapters.npz'), 'save_every': 50, 'val_every': config.val_every, } # Callback für Progress-Updates def training_callback(step: int, loss: float, val_loss: Optional[float] = None): if self.status.should_stop: return False # Stop training self.status.current_step = step self.status.train_loss = loss self.status.train_loss_history.append(loss) if val_loss is not None: self.status.val_loss = val_loss self.status.val_loss_history.append(val_loss) return True # Hinweis: Dies ist ein Platzhalter für echtes MLX Training # In der Praxis würde man mlx_lm.tuner.train() oder eine # eigene Training Loop mit mlx nutzen # Simuliere Training für Demo (MUSS durch echtes MLX Training ersetzt werden) total_steps = config.epochs * (len(list(open(train_file))) // config.batch_size) self.status.total_steps = total_steps for epoch in range(config.epochs): self.status.current_epoch = epoch + 1 for step in range(total_steps // config.epochs): if self.status.should_stop: break # Simuliere Training Step self.status.current_step = epoch * (total_steps // config.epochs) + step fake_loss = 2.0 - (self.status.current_step / total_steps) * 1.5 self.status.train_loss = fake_loss self.status.train_loss_history.append(fake_loss) # Validation alle N Steps if step % config.val_every == 0: fake_val_loss = 2.2 - (self.status.current_step / total_steps) * 1.4 self.status.val_loss = fake_val_loss self.status.val_loss_history.append(fake_val_loss) time.sleep(0.1) # Simuliere Rechenzeit if self.status.should_stop: break # Speichere finale Adapter # output_path / 'adapters.npz' würde die LoRA Weights enthalten self.status.is_training = False except Exception as e: self.status.error = str(e) self.status.is_training = False def start_training(self, config: TrainingConfig, train_data: List[Dict], val_data: List[Dict]) -> bool: """Startet das Training""" if self.status.is_training: return False # Bereite Daten vor data_dir = self.output_dir / f"training_{int(time.time())}" data_dir.mkdir(exist_ok=True) train_file, val_file = self.prepare_training_data( train_data, val_data, data_dir ) # Output-Pfad output_path = self.output_dir / f"run_{int(time.time())}" output_path.mkdir(exist_ok=True) # Reset Status self.status.reset() # Starte Training in eigenem Thread self.training_thread = threading.Thread( target=self._run_training, args=(config, train_file, val_file, output_path), daemon=True ) self.training_thread.start() return True def stop_training(self) -> bool: """Stoppt das laufende Training""" if not self.status.is_training: return False self.status.should_stop = True # Warte max 5 Sekunden auf Thread if self.training_thread: self.training_thread.join(timeout=5) return True def get_status(self) -> Dict: """Gibt aktuellen Status zurück""" return self.status.to_dict() def list_available_models(self) -> List[str]: """Listet verfügbare Modelle auf""" if not self.models_dir.exists(): return [] models = [] for path in self.models_dir.iterdir(): if path.is_dir(): models.append(path.name) return models def download_model(self, model_name: str) -> bool: """ Lädt ein Modell herunter In der Praxis würde man hier huggingface_hub nutzen """ # Placeholder - würde huggingface_hub.snapshot_download nutzen # und dann mit mlx_lm.convert konvertieren # Beispiel: # from huggingface_hub import snapshot_download # from mlx_lm.convert import convert # # hf_path = snapshot_download(model_name) # mlx_path = self.models_dir / model_name # convert(hf_path, mlx_path) return False # Nicht implementiert in diesem Beispiel