Add complete Mail Fine-Tuning Web-App for macOS Apple Silicon
Implemented a full-stack web application for fine-tuning LLMs on email data, optimized for Apple Silicon (M4 Pro with 24GB RAM). Features: - Mail import with drag & drop support (.mbox, .eml, .txt) - Automated mail cleaning and preprocessing - Interactive labeling interface with keyboard shortcuts - Training data export to JSONL format - MLX-based LoRA fine-tuning with live updates - Model evaluation and comparison interface - Server-Sent Events for real-time training progress - Dark theme UI optimized for extended use Technical Stack: - Backend: FastAPI with SQLite database - Frontend: Vanilla HTML/CSS/JavaScript (no external dependencies) - ML Framework: MLX for Apple Silicon optimization - Models: Support for Mistral 7B and Llama 3 8B via MLX Components: - data_manager.py: SQLite operations for mail storage and labeling - mail_parser.py: Parser for multiple mail formats with cleaning - training.py: MLX training wrapper with LoRA support - inference.py: Model loading and inference for evaluation - main.py: FastAPI backend with REST API and SSE - Frontend: Complete UI with all features Documentation: - Comprehensive README with installation and usage guide - Quick-start guide for rapid setup - Example mails for testing - Troubleshooting and best practices Ready for local deployment and fine-tuning workflows.
This commit is contained in:
@@ -0,0 +1,286 @@
|
||||
"""
|
||||
Data Manager für Mail Fine-Tuning App
|
||||
Verwaltet SQLite Datenbank für Mails und Labels
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Optional
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class DataManager:
|
||||
def __init__(self, db_path: str = "data/mails.db"):
|
||||
self.db_path = Path(db_path)
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.init_db()
|
||||
|
||||
def init_db(self):
|
||||
"""Initialisiert die Datenbank mit dem Schema"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS mails (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
subject TEXT,
|
||||
sender TEXT,
|
||||
recipient TEXT,
|
||||
date TEXT,
|
||||
body TEXT NOT NULL,
|
||||
original_format TEXT,
|
||||
task_type TEXT DEFAULT 'unlabeled',
|
||||
expected_output TEXT,
|
||||
status TEXT DEFAULT 'unlabeled',
|
||||
created_at TEXT DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TEXT DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS training_runs (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
model_name TEXT NOT NULL,
|
||||
start_time TEXT,
|
||||
end_time TEXT,
|
||||
config TEXT,
|
||||
status TEXT,
|
||||
final_train_loss REAL,
|
||||
final_val_loss REAL,
|
||||
checkpoint_path TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def add_mail(self, subject: str, sender: str, recipient: str,
|
||||
date: str, body: str, original_format: str) -> int:
|
||||
"""Fügt eine neue Mail hinzu"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
INSERT INTO mails (subject, sender, recipient, date, body, original_format)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""", (subject, sender, recipient, date, body, original_format))
|
||||
|
||||
mail_id = cursor.lastrowid
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
return mail_id
|
||||
|
||||
def get_all_mails(self, status_filter: Optional[str] = None) -> List[Dict]:
|
||||
"""Holt alle Mails, optional gefiltert nach Status"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
cursor = conn.cursor()
|
||||
|
||||
if status_filter:
|
||||
cursor.execute("SELECT * FROM mails WHERE status = ? ORDER BY id", (status_filter,))
|
||||
else:
|
||||
cursor.execute("SELECT * FROM mails ORDER BY id")
|
||||
|
||||
rows = cursor.fetchall()
|
||||
mails = [dict(row) for row in rows]
|
||||
|
||||
conn.close()
|
||||
return mails
|
||||
|
||||
def get_mail(self, mail_id: int) -> Optional[Dict]:
|
||||
"""Holt eine einzelne Mail"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("SELECT * FROM mails WHERE id = ?", (mail_id,))
|
||||
row = cursor.fetchone()
|
||||
|
||||
conn.close()
|
||||
return dict(row) if row else None
|
||||
|
||||
def update_mail(self, mail_id: int, task_type: Optional[str] = None,
|
||||
expected_output: Optional[str] = None,
|
||||
status: Optional[str] = None,
|
||||
body: Optional[str] = None) -> bool:
|
||||
"""Aktualisiert eine Mail (Labeling)"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
updates = []
|
||||
params = []
|
||||
|
||||
if task_type is not None:
|
||||
updates.append("task_type = ?")
|
||||
params.append(task_type)
|
||||
|
||||
if expected_output is not None:
|
||||
updates.append("expected_output = ?")
|
||||
params.append(expected_output)
|
||||
|
||||
if status is not None:
|
||||
updates.append("status = ?")
|
||||
params.append(status)
|
||||
|
||||
if body is not None:
|
||||
updates.append("body = ?")
|
||||
params.append(body)
|
||||
|
||||
if not updates:
|
||||
conn.close()
|
||||
return False
|
||||
|
||||
updates.append("updated_at = ?")
|
||||
params.append(datetime.now().isoformat())
|
||||
params.append(mail_id)
|
||||
|
||||
query = f"UPDATE mails SET {', '.join(updates)} WHERE id = ?"
|
||||
cursor.execute(query, params)
|
||||
|
||||
success = cursor.rowcount > 0
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
return success
|
||||
|
||||
def delete_mail(self, mail_id: int) -> bool:
|
||||
"""Löscht eine Mail"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("DELETE FROM mails WHERE id = ?", (mail_id,))
|
||||
success = cursor.rowcount > 0
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
return success
|
||||
|
||||
def get_statistics(self) -> Dict:
|
||||
"""Berechnet Statistiken über die Daten"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Gesamt-Anzahl
|
||||
cursor.execute("SELECT COUNT(*) FROM mails")
|
||||
total = cursor.fetchone()[0]
|
||||
|
||||
# Nach Status
|
||||
cursor.execute("""
|
||||
SELECT status, COUNT(*) as count
|
||||
FROM mails
|
||||
GROUP BY status
|
||||
""")
|
||||
status_counts = {row[0]: row[1] for row in cursor.fetchall()}
|
||||
|
||||
# Nach Task-Type
|
||||
cursor.execute("""
|
||||
SELECT task_type, COUNT(*) as count
|
||||
FROM mails
|
||||
WHERE status = 'labeled'
|
||||
GROUP BY task_type
|
||||
""")
|
||||
task_counts = {row[0]: row[1] for row in cursor.fetchall()}
|
||||
|
||||
# Durchschnittliche Längen (nur gelabelte)
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
AVG(LENGTH(body)) as avg_input_length,
|
||||
AVG(LENGTH(expected_output)) as avg_output_length
|
||||
FROM mails
|
||||
WHERE status = 'labeled'
|
||||
""")
|
||||
lengths = cursor.fetchone()
|
||||
|
||||
conn.close()
|
||||
|
||||
labeled_count = status_counts.get('labeled', 0)
|
||||
|
||||
return {
|
||||
'total': total,
|
||||
'labeled': labeled_count,
|
||||
'unlabeled': status_counts.get('unlabeled', 0),
|
||||
'skipped': status_counts.get('skip', 0),
|
||||
'task_distribution': task_counts,
|
||||
'avg_input_length': round(lengths[0]) if lengths[0] else 0,
|
||||
'avg_output_length': round(lengths[1]) if lengths[1] else 0,
|
||||
'sufficient_data': labeled_count >= 50
|
||||
}
|
||||
|
||||
def export_training_data(self, train_split: float = 0.9) -> tuple[List[Dict], List[Dict]]:
|
||||
"""Exportiert gelabelte Daten für Training"""
|
||||
import random
|
||||
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
SELECT body, task_type, expected_output
|
||||
FROM mails
|
||||
WHERE status = 'labeled' AND expected_output IS NOT NULL
|
||||
ORDER BY RANDOM()
|
||||
""")
|
||||
|
||||
rows = cursor.fetchall()
|
||||
conn.close()
|
||||
|
||||
if not rows:
|
||||
return [], []
|
||||
|
||||
data = [dict(row) for row in rows]
|
||||
|
||||
# Shuffle
|
||||
random.shuffle(data)
|
||||
|
||||
# Split
|
||||
split_idx = int(len(data) * train_split)
|
||||
train_data = data[:split_idx]
|
||||
val_data = data[split_idx:]
|
||||
|
||||
return train_data, val_data
|
||||
|
||||
def save_training_run(self, model_name: str, config: Dict,
|
||||
checkpoint_path: str) -> int:
|
||||
"""Speichert einen Training-Run"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
INSERT INTO training_runs
|
||||
(model_name, start_time, config, status, checkpoint_path)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""", (
|
||||
model_name,
|
||||
datetime.now().isoformat(),
|
||||
json.dumps(config),
|
||||
'running',
|
||||
checkpoint_path
|
||||
))
|
||||
|
||||
run_id = cursor.lastrowid
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
return run_id
|
||||
|
||||
def update_training_run(self, run_id: int, status: str,
|
||||
train_loss: Optional[float] = None,
|
||||
val_loss: Optional[float] = None):
|
||||
"""Aktualisiert einen Training-Run"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
UPDATE training_runs
|
||||
SET status = ?,
|
||||
end_time = ?,
|
||||
final_train_loss = COALESCE(?, final_train_loss),
|
||||
final_val_loss = COALESCE(?, final_val_loss)
|
||||
WHERE id = ?
|
||||
""", (status, datetime.now().isoformat(), train_loss, val_loss, run_id))
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
@@ -0,0 +1,209 @@
|
||||
"""
|
||||
Inference Module für Modell-Evaluation
|
||||
Lädt Base- und Fine-tuned Models für Vergleiche
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict
|
||||
import threading
|
||||
|
||||
|
||||
class ModelInference:
|
||||
"""Handhabt Modell-Inferenz für Base und Fine-tuned Models"""
|
||||
|
||||
def __init__(self, models_dir: str = "models", output_dir: str = "output"):
|
||||
self.models_dir = Path(models_dir)
|
||||
self.output_dir = Path(output_dir)
|
||||
|
||||
self.base_model = None
|
||||
self.finetuned_model = None
|
||||
self.model_lock = threading.Lock()
|
||||
|
||||
def load_base_model(self, model_name: str) -> bool:
|
||||
"""Lädt das Basis-Modell"""
|
||||
try:
|
||||
# Import MLX nur bei Bedarf
|
||||
from mlx_lm import load
|
||||
|
||||
model_path = self.models_dir / model_name
|
||||
|
||||
if not model_path.exists():
|
||||
return False
|
||||
|
||||
with self.model_lock:
|
||||
self.base_model = load(str(model_path))
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading base model: {e}")
|
||||
return False
|
||||
|
||||
def load_finetuned_model(self, model_name: str, adapter_path: str) -> bool:
|
||||
"""Lädt das Fine-tuned Modell (Base + LoRA Adapter)"""
|
||||
try:
|
||||
from mlx_lm import load
|
||||
|
||||
model_path = self.models_dir / model_name
|
||||
adapter_file = Path(adapter_path)
|
||||
|
||||
if not model_path.exists() or not adapter_file.exists():
|
||||
return False
|
||||
|
||||
with self.model_lock:
|
||||
# Lade Base Model mit Adapter
|
||||
self.finetuned_model = load(
|
||||
str(model_path),
|
||||
adapter_path=str(adapter_file)
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading finetuned model: {e}")
|
||||
return False
|
||||
|
||||
def generate(self, prompt: str, model_type: str = 'base',
|
||||
max_tokens: int = 512, temperature: float = 0.7) -> str:
|
||||
"""
|
||||
Generiert Text mit dem gewählten Modell
|
||||
|
||||
Args:
|
||||
prompt: Input prompt
|
||||
model_type: 'base' oder 'finetuned'
|
||||
max_tokens: Maximale Anzahl Tokens
|
||||
temperature: Sampling temperature
|
||||
|
||||
Returns:
|
||||
Generierter Text
|
||||
"""
|
||||
try:
|
||||
from mlx_lm import generate as mlx_generate
|
||||
|
||||
model = self.base_model if model_type == 'base' else self.finetuned_model
|
||||
|
||||
if model is None:
|
||||
return f"Error: {model_type} model not loaded"
|
||||
|
||||
with self.model_lock:
|
||||
# MLX-LM generate
|
||||
result = mlx_generate(
|
||||
model,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
temp=temperature
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
return f"Error during generation: {str(e)}"
|
||||
|
||||
def generate_comparison(self, prompt: str, max_tokens: int = 512,
|
||||
temperature: float = 0.7) -> Dict[str, str]:
|
||||
"""
|
||||
Generiert mit beiden Modellen für Vergleich
|
||||
|
||||
Returns:
|
||||
Dict mit 'base' und 'finetuned' Outputs
|
||||
"""
|
||||
result = {
|
||||
'base': None,
|
||||
'finetuned': None
|
||||
}
|
||||
|
||||
if self.base_model:
|
||||
result['base'] = self.generate(
|
||||
prompt, 'base', max_tokens, temperature
|
||||
)
|
||||
|
||||
if self.finetuned_model:
|
||||
result['finetuned'] = self.generate(
|
||||
prompt, 'finetuned', max_tokens, temperature
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def format_mail_prompt(self, task_type: str, mail_body: str) -> str:
|
||||
"""Formatiert einen Prompt basierend auf Task-Type"""
|
||||
|
||||
task_prompts = {
|
||||
'Zusammenfassen': 'Fasse folgende E-Mail zusammen:',
|
||||
'Antwort schreiben': 'Schreibe eine Antwort auf folgende E-Mail:',
|
||||
'Kategorisieren': 'Kategorisiere folgende E-Mail:',
|
||||
'Action Items': 'Extrahiere die Action Items aus folgender E-Mail:',
|
||||
'Custom': 'Bearbeite folgende E-Mail:'
|
||||
}
|
||||
|
||||
instruction = task_prompts.get(task_type, task_prompts['Custom'])
|
||||
|
||||
return f"{instruction}\n\n{mail_body}"
|
||||
|
||||
def get_test_prompts(self) -> Dict[str, str]:
|
||||
"""Vordefinierte Test-Prompts"""
|
||||
return {
|
||||
'Zusammenfassen': self.format_mail_prompt(
|
||||
'Zusammenfassen',
|
||||
"""Betreff: Q4 Projektupdate
|
||||
|
||||
Hallo Team,
|
||||
|
||||
ich wollte euch ein kurzes Update zum aktuellen Projektstand geben.
|
||||
|
||||
Wir haben letzte Woche die neue API-Integration abgeschlossen und erfolgreich getestet.
|
||||
Die Performance-Tests zeigen eine Verbesserung von 40% gegenüber der alten Implementierung.
|
||||
|
||||
Nächste Woche starten wir mit der Frontend-Anpassung. Maria und Tom werden das Design
|
||||
überarbeiten, während ich mich um die Backend-Anbindung kümmere.
|
||||
|
||||
Der Go-Live ist weiterhin für Ende des Monats geplant.
|
||||
|
||||
Beste Grüße
|
||||
Alex"""
|
||||
),
|
||||
'Antwort schreiben': self.format_mail_prompt(
|
||||
'Antwort schreiben',
|
||||
"""Betreff: Frage zu Invoice #2847
|
||||
|
||||
Hallo,
|
||||
|
||||
ich habe eine Frage zur Rechnung #2847 vom 15. März.
|
||||
Der Betrag scheint nicht mit unserem Angebot übereinzustimmen.
|
||||
|
||||
Könnten Sie das bitte prüfen?
|
||||
|
||||
Danke
|
||||
Michael"""
|
||||
),
|
||||
'Action Items': self.format_mail_prompt(
|
||||
'Action Items',
|
||||
"""Betreff: Meeting Notes - Produktlaunch
|
||||
|
||||
Hi alle,
|
||||
|
||||
hier die wichtigsten Punkte vom heutigen Meeting:
|
||||
|
||||
- Sarah bereitet die Pressemitteilung vor (Deadline: Freitag)
|
||||
- Marketing-Team erstellt Social Media Content (nächste Woche)
|
||||
- Ich kümmere mich um die Influencer-Kontakte
|
||||
- Wir brauchen noch finale Produktfotos vom Design-Team
|
||||
- Launch-Event ist am 1. April - Location muss noch gebucht werden
|
||||
|
||||
Bitte gebt bis Mittwoch Bescheid ob ihr eure Aufgaben schaffen könnt.
|
||||
|
||||
Lisa"""
|
||||
)
|
||||
}
|
||||
|
||||
def unload_models(self):
|
||||
"""Entlädt Modelle aus dem Speicher"""
|
||||
with self.model_lock:
|
||||
self.base_model = None
|
||||
self.finetuned_model = None
|
||||
|
||||
def get_loaded_models(self) -> Dict[str, bool]:
|
||||
"""Gibt zurück welche Modelle geladen sind"""
|
||||
return {
|
||||
'base': self.base_model is not None,
|
||||
'finetuned': self.finetuned_model is not None
|
||||
}
|
||||
@@ -0,0 +1,264 @@
|
||||
"""
|
||||
Mail Parser für verschiedene Formate
|
||||
Bereinigt und normalisiert Mail-Inhalte
|
||||
"""
|
||||
|
||||
import email
|
||||
import mailbox
|
||||
import re
|
||||
from bs4 import BeautifulSoup
|
||||
from typing import List, Dict, Optional
|
||||
from pathlib import Path
|
||||
import chardet
|
||||
|
||||
|
||||
class MailParser:
|
||||
"""Parst und bereinigt Mail-Dateien"""
|
||||
|
||||
# Häufige Footer/Disclaimer Pattern
|
||||
FOOTER_PATTERNS = [
|
||||
r'(?i)^--\s*$.*', # Standard signature delimiter
|
||||
r'(?i)Diese E-Mail.*vertraulich.*',
|
||||
r'(?i)This email.*confidential.*',
|
||||
r'(?i)Disclaimer:.*',
|
||||
r'(?i)Get Outlook for.*',
|
||||
r'(?i)Sent from my iPhone.*',
|
||||
r'(?i)Von meinem.*gesendet.*',
|
||||
r'(?i)Diese Nachricht.*Virenfrei.*',
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def detect_encoding(file_path: Path) -> str:
|
||||
"""Erkennt das Encoding einer Datei"""
|
||||
with open(file_path, 'rb') as f:
|
||||
raw_data = f.read()
|
||||
result = chardet.detect(raw_data)
|
||||
return result['encoding'] or 'utf-8'
|
||||
|
||||
@staticmethod
|
||||
def html_to_text(html: str) -> str:
|
||||
"""Konvertiert HTML zu Plain Text"""
|
||||
soup = BeautifulSoup(html, 'html.parser')
|
||||
|
||||
# Entferne Script und Style Tags
|
||||
for script in soup(['script', 'style']):
|
||||
script.decompose()
|
||||
|
||||
# Extrahiere Text
|
||||
text = soup.get_text()
|
||||
|
||||
# Bereinige Whitespace
|
||||
lines = (line.strip() for line in text.splitlines())
|
||||
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
|
||||
text = ' '.join(chunk for chunk in chunks if chunk)
|
||||
|
||||
return text
|
||||
|
||||
@staticmethod
|
||||
def remove_multiple_newlines(text: str) -> str:
|
||||
"""Entfernt mehrfache Leerzeilen"""
|
||||
return re.sub(r'\n{3,}', '\n\n', text)
|
||||
|
||||
@staticmethod
|
||||
def remove_footers(text: str) -> str:
|
||||
"""Entfernt häufige Footer und Disclaimer"""
|
||||
for pattern in MailParser.FOOTER_PATTERNS:
|
||||
# Suche Pattern und entferne alles danach
|
||||
match = re.search(pattern, text, re.MULTILINE | re.DOTALL)
|
||||
if match:
|
||||
text = text[:match.start()].strip()
|
||||
|
||||
return text
|
||||
|
||||
@staticmethod
|
||||
def clean_quoted_text(text: str) -> str:
|
||||
"""Entfernt oder markiert quoted Text (> oder |)"""
|
||||
lines = text.split('\n')
|
||||
cleaned_lines = []
|
||||
|
||||
for line in lines:
|
||||
# Überspringe Zeilen die mit > oder | beginnen (quoted text)
|
||||
if not line.strip().startswith('>') and not line.strip().startswith('|'):
|
||||
cleaned_lines.append(line)
|
||||
|
||||
return '\n'.join(cleaned_lines)
|
||||
|
||||
@staticmethod
|
||||
def normalize_whitespace(text: str) -> str:
|
||||
"""Normalisiert Whitespace"""
|
||||
# Entferne trailing spaces
|
||||
lines = [line.rstrip() for line in text.split('\n')]
|
||||
text = '\n'.join(lines)
|
||||
|
||||
# Entferne mehrfache Spaces
|
||||
text = re.sub(r' {2,}', ' ', text)
|
||||
|
||||
# Entferne mehrfache Leerzeilen
|
||||
text = MailParser.remove_multiple_newlines(text)
|
||||
|
||||
return text.strip()
|
||||
|
||||
@staticmethod
|
||||
def clean_text(text: str, is_html: bool = False) -> str:
|
||||
"""Vollständige Bereinigung eines Texts"""
|
||||
if is_html:
|
||||
text = MailParser.html_to_text(text)
|
||||
|
||||
text = MailParser.remove_footers(text)
|
||||
text = MailParser.clean_quoted_text(text)
|
||||
text = MailParser.normalize_whitespace(text)
|
||||
|
||||
return text
|
||||
|
||||
@staticmethod
|
||||
def parse_eml(file_path: Path) -> Dict:
|
||||
"""Parst eine .eml Datei"""
|
||||
encoding = MailParser.detect_encoding(file_path)
|
||||
|
||||
with open(file_path, 'r', encoding=encoding, errors='ignore') as f:
|
||||
msg = email.message_from_file(f)
|
||||
|
||||
subject = msg.get('Subject', 'No Subject')
|
||||
sender = msg.get('From', 'Unknown')
|
||||
recipient = msg.get('To', 'Unknown')
|
||||
date = msg.get('Date', '')
|
||||
|
||||
# Body extrahieren
|
||||
body = ""
|
||||
is_html = False
|
||||
|
||||
if msg.is_multipart():
|
||||
for part in msg.walk():
|
||||
content_type = part.get_content_type()
|
||||
if content_type == 'text/plain':
|
||||
body = part.get_payload(decode=True).decode(errors='ignore')
|
||||
break
|
||||
elif content_type == 'text/html' and not body:
|
||||
body = part.get_payload(decode=True).decode(errors='ignore')
|
||||
is_html = True
|
||||
else:
|
||||
body = msg.get_payload(decode=True).decode(errors='ignore')
|
||||
if msg.get_content_type() == 'text/html':
|
||||
is_html = True
|
||||
|
||||
# Bereinige Body
|
||||
body = MailParser.clean_text(body, is_html)
|
||||
|
||||
return {
|
||||
'subject': subject,
|
||||
'sender': sender,
|
||||
'recipient': recipient,
|
||||
'date': date,
|
||||
'body': body,
|
||||
'original_format': 'eml'
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def parse_mbox(file_path: Path) -> List[Dict]:
|
||||
"""Parst eine .mbox Datei"""
|
||||
mails = []
|
||||
|
||||
try:
|
||||
mbox = mailbox.mbox(str(file_path))
|
||||
|
||||
for message in mbox:
|
||||
subject = message.get('Subject', 'No Subject')
|
||||
sender = message.get('From', 'Unknown')
|
||||
recipient = message.get('To', 'Unknown')
|
||||
date = message.get('Date', '')
|
||||
|
||||
body = ""
|
||||
is_html = False
|
||||
|
||||
if message.is_multipart():
|
||||
for part in message.walk():
|
||||
content_type = part.get_content_type()
|
||||
if content_type == 'text/plain':
|
||||
payload = part.get_payload(decode=True)
|
||||
if payload:
|
||||
body = payload.decode(errors='ignore')
|
||||
break
|
||||
elif content_type == 'text/html' and not body:
|
||||
payload = part.get_payload(decode=True)
|
||||
if payload:
|
||||
body = payload.decode(errors='ignore')
|
||||
is_html = True
|
||||
else:
|
||||
payload = message.get_payload(decode=True)
|
||||
if payload:
|
||||
body = payload.decode(errors='ignore')
|
||||
if message.get_content_type() == 'text/html':
|
||||
is_html = True
|
||||
|
||||
body = MailParser.clean_text(body, is_html)
|
||||
|
||||
mails.append({
|
||||
'subject': subject,
|
||||
'sender': sender,
|
||||
'recipient': recipient,
|
||||
'date': date,
|
||||
'body': body,
|
||||
'original_format': 'mbox'
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"Error parsing mbox: {str(e)}")
|
||||
|
||||
return mails
|
||||
|
||||
@staticmethod
|
||||
def parse_txt(file_path: Path) -> Dict:
|
||||
"""Parst eine .txt Datei (simple Mail als Text)"""
|
||||
encoding = MailParser.detect_encoding(file_path)
|
||||
|
||||
with open(file_path, 'r', encoding=encoding, errors='ignore') as f:
|
||||
content = f.read()
|
||||
|
||||
# Einfache Struktur: Versuche Subject/From/To zu erkennen
|
||||
lines = content.split('\n')
|
||||
subject = 'No Subject'
|
||||
sender = 'Unknown'
|
||||
recipient = 'Unknown'
|
||||
date = ''
|
||||
body_start = 0
|
||||
|
||||
for i, line in enumerate(lines[:10]): # Erste 10 Zeilen prüfen
|
||||
if line.lower().startswith('subject:'):
|
||||
subject = line[8:].strip()
|
||||
body_start = max(body_start, i + 1)
|
||||
elif line.lower().startswith('from:'):
|
||||
sender = line[5:].strip()
|
||||
body_start = max(body_start, i + 1)
|
||||
elif line.lower().startswith('to:'):
|
||||
recipient = line[3:].strip()
|
||||
body_start = max(body_start, i + 1)
|
||||
elif line.lower().startswith('date:'):
|
||||
date = line[5:].strip()
|
||||
body_start = max(body_start, i + 1)
|
||||
|
||||
# Body ist der Rest
|
||||
body = '\n'.join(lines[body_start:])
|
||||
body = MailParser.clean_text(body)
|
||||
|
||||
return {
|
||||
'subject': subject,
|
||||
'sender': sender,
|
||||
'recipient': recipient,
|
||||
'date': date,
|
||||
'body': body,
|
||||
'original_format': 'txt'
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def parse_file(file_path: Path) -> List[Dict]:
|
||||
"""Parst eine Mail-Datei basierend auf Endung"""
|
||||
suffix = file_path.suffix.lower()
|
||||
|
||||
if suffix == '.eml':
|
||||
return [MailParser.parse_eml(file_path)]
|
||||
elif suffix == '.mbox':
|
||||
return MailParser.parse_mbox(file_path)
|
||||
elif suffix == '.txt':
|
||||
return [MailParser.parse_txt(file_path)]
|
||||
else:
|
||||
raise ValueError(f"Unsupported file format: {suffix}")
|
||||
@@ -0,0 +1,396 @@
|
||||
"""
|
||||
FastAPI Backend für Mail Fine-Tuning App
|
||||
Hauptanwendung mit allen API Endpoints
|
||||
"""
|
||||
|
||||
from fastapi import FastAPI, File, UploadFile, HTTPException, BackgroundTasks
|
||||
from fastapi.responses import StreamingResponse, FileResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List
|
||||
import asyncio
|
||||
import json
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
|
||||
from data_manager import DataManager
|
||||
from mail_parser import MailParser
|
||||
from training import MLXTrainer, TrainingConfig
|
||||
from inference import ModelInference
|
||||
|
||||
# FastAPI App
|
||||
app = FastAPI(title="Mail Fine-Tuning App")
|
||||
|
||||
# CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Initialisiere Manager
|
||||
data_manager = DataManager("data/mails.db")
|
||||
trainer = MLXTrainer("models", "output")
|
||||
inference = ModelInference("models", "output")
|
||||
|
||||
|
||||
# Pydantic Models
|
||||
class MailUpdate(BaseModel):
|
||||
task_type: Optional[str] = None
|
||||
expected_output: Optional[str] = None
|
||||
status: Optional[str] = None
|
||||
body: Optional[str] = None
|
||||
|
||||
|
||||
class TrainingStartRequest(BaseModel):
|
||||
model_name: str
|
||||
learning_rate: float = 1e-5
|
||||
epochs: int = 3
|
||||
batch_size: int = 4
|
||||
lora_rank: int = 8
|
||||
|
||||
|
||||
class InferenceRequest(BaseModel):
|
||||
prompt: str
|
||||
model_type: str = 'base'
|
||||
max_tokens: int = 512
|
||||
temperature: float = 0.7
|
||||
|
||||
|
||||
class InferenceComparisonRequest(BaseModel):
|
||||
task_type: str
|
||||
mail_body: str
|
||||
max_tokens: int = 512
|
||||
temperature: float = 0.7
|
||||
|
||||
|
||||
# ===== Mail Endpoints =====
|
||||
|
||||
@app.post("/api/mails/upload")
|
||||
async def upload_mails(files: List[UploadFile] = File(...)):
|
||||
"""Upload und Parse von Mail-Dateien"""
|
||||
results = {
|
||||
'success': [],
|
||||
'errors': []
|
||||
}
|
||||
|
||||
for file in files:
|
||||
try:
|
||||
# Temporär speichern
|
||||
temp_path = Path("data/temp") / file.filename
|
||||
temp_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(temp_path, 'wb') as f:
|
||||
content = await file.read()
|
||||
f.write(content)
|
||||
|
||||
# Parse Mails
|
||||
parsed_mails = MailParser.parse_file(temp_path)
|
||||
|
||||
# In DB speichern
|
||||
for mail in parsed_mails:
|
||||
mail_id = data_manager.add_mail(
|
||||
subject=mail['subject'],
|
||||
sender=mail['sender'],
|
||||
recipient=mail['recipient'],
|
||||
date=mail['date'],
|
||||
body=mail['body'],
|
||||
original_format=mail['original_format']
|
||||
)
|
||||
|
||||
results['success'].append({
|
||||
'filename': file.filename,
|
||||
'count': len(parsed_mails)
|
||||
})
|
||||
|
||||
# Cleanup
|
||||
temp_path.unlink()
|
||||
|
||||
except Exception as e:
|
||||
results['errors'].append({
|
||||
'filename': file.filename,
|
||||
'error': str(e)
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@app.get("/api/mails")
|
||||
async def get_mails(status: Optional[str] = None):
|
||||
"""Liste aller Mails"""
|
||||
mails = data_manager.get_all_mails(status_filter=status)
|
||||
return {'mails': mails}
|
||||
|
||||
|
||||
@app.get("/api/mails/{mail_id}")
|
||||
async def get_mail(mail_id: int):
|
||||
"""Einzelne Mail abrufen"""
|
||||
mail = data_manager.get_mail(mail_id)
|
||||
if not mail:
|
||||
raise HTTPException(status_code=404, detail="Mail not found")
|
||||
return mail
|
||||
|
||||
|
||||
@app.put("/api/mails/{mail_id}")
|
||||
async def update_mail(mail_id: int, update: MailUpdate):
|
||||
"""Mail aktualisieren (Labeling)"""
|
||||
success = data_manager.update_mail(
|
||||
mail_id=mail_id,
|
||||
task_type=update.task_type,
|
||||
expected_output=update.expected_output,
|
||||
status=update.status,
|
||||
body=update.body
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Mail not found")
|
||||
|
||||
return {'success': True}
|
||||
|
||||
|
||||
@app.delete("/api/mails/{mail_id}")
|
||||
async def delete_mail(mail_id: int):
|
||||
"""Mail löschen"""
|
||||
success = data_manager.delete_mail(mail_id)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Mail not found")
|
||||
|
||||
return {'success': True}
|
||||
|
||||
|
||||
# ===== Export Endpoints =====
|
||||
|
||||
@app.get("/api/export/stats")
|
||||
async def get_stats():
|
||||
"""Statistiken abrufen"""
|
||||
stats = data_manager.get_statistics()
|
||||
return stats
|
||||
|
||||
|
||||
@app.post("/api/export/jsonl")
|
||||
async def export_jsonl(train_split: float = 0.9):
|
||||
"""Exportiert Training-Daten als JSONL"""
|
||||
train_data, val_data = data_manager.export_training_data(train_split)
|
||||
|
||||
if not train_data:
|
||||
raise HTTPException(status_code=400, detail="No labeled data available")
|
||||
|
||||
# Speichere Files
|
||||
data_dir = Path("data")
|
||||
train_file = data_dir / "train.jsonl"
|
||||
val_file = data_dir / "val.jsonl"
|
||||
|
||||
train_file_path, val_file_path = trainer.prepare_training_data(
|
||||
train_data, val_data, data_dir
|
||||
)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'train_samples': len(train_data),
|
||||
'val_samples': len(val_data),
|
||||
'train_file': str(train_file),
|
||||
'val_file': str(val_file)
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/export/download/{file_type}")
|
||||
async def download_file(file_type: str):
|
||||
"""Download JSONL Files"""
|
||||
if file_type not in ['train', 'val']:
|
||||
raise HTTPException(status_code=400, detail="Invalid file type")
|
||||
|
||||
file_path = Path("data") / f"{file_type}.jsonl"
|
||||
|
||||
if not file_path.exists():
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
return FileResponse(
|
||||
path=file_path,
|
||||
filename=f"{file_type}.jsonl",
|
||||
media_type='application/json'
|
||||
)
|
||||
|
||||
|
||||
# ===== Model Endpoints =====
|
||||
|
||||
@app.get("/api/models")
|
||||
async def list_models():
|
||||
"""Liste verfügbarer Modelle"""
|
||||
models = trainer.list_available_models()
|
||||
return {'models': models}
|
||||
|
||||
|
||||
@app.post("/api/models/download")
|
||||
async def download_model(model_name: str):
|
||||
"""
|
||||
Lädt ein Modell herunter
|
||||
Placeholder - würde in echter Implementation huggingface nutzen
|
||||
"""
|
||||
success = trainer.download_model(model_name)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=501,
|
||||
detail="Model download not implemented. Please download manually."
|
||||
)
|
||||
|
||||
return {'success': True}
|
||||
|
||||
|
||||
# ===== Training Endpoints =====
|
||||
|
||||
@app.post("/api/training/start")
|
||||
async def start_training(request: TrainingStartRequest, background_tasks: BackgroundTasks):
|
||||
"""Startet Training"""
|
||||
|
||||
# Hole Training-Daten
|
||||
train_data, val_data = data_manager.export_training_data()
|
||||
|
||||
if not train_data:
|
||||
raise HTTPException(status_code=400, detail="No labeled data available")
|
||||
|
||||
if len(train_data) < 10:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Not enough training data. Need at least 10, got {len(train_data)}"
|
||||
)
|
||||
|
||||
# Training Config
|
||||
config = TrainingConfig(
|
||||
model_name=request.model_name,
|
||||
learning_rate=request.learning_rate,
|
||||
epochs=request.epochs,
|
||||
batch_size=request.batch_size,
|
||||
lora_rank=request.lora_rank
|
||||
)
|
||||
|
||||
# Starte Training
|
||||
success = trainer.start_training(config, train_data, val_data)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail="Training already running")
|
||||
|
||||
return {'success': True, 'message': 'Training started'}
|
||||
|
||||
|
||||
@app.post("/api/training/stop")
|
||||
async def stop_training():
|
||||
"""Stoppt Training"""
|
||||
success = trainer.stop_training()
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail="No training running")
|
||||
|
||||
return {'success': True, 'message': 'Training stopped'}
|
||||
|
||||
|
||||
@app.get("/api/training/status")
|
||||
async def get_training_status():
|
||||
"""Gibt aktuellen Training-Status zurück"""
|
||||
status = trainer.get_status()
|
||||
return status
|
||||
|
||||
|
||||
@app.get("/api/training/stream")
|
||||
async def stream_training_status():
|
||||
"""
|
||||
Server-Sent Events für Live-Updates
|
||||
"""
|
||||
async def event_generator():
|
||||
while True:
|
||||
status = trainer.get_status()
|
||||
|
||||
# Sende Status als SSE
|
||||
yield f"data: {json.dumps(status)}\n\n"
|
||||
|
||||
# Stop wenn Training fertig
|
||||
if not status['is_training'] and status['current_step'] > 0:
|
||||
break
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream"
|
||||
)
|
||||
|
||||
|
||||
# ===== Inference Endpoints =====
|
||||
|
||||
@app.post("/api/inference/load")
|
||||
async def load_model(model_type: str, model_name: str, adapter_path: Optional[str] = None):
|
||||
"""Lädt ein Modell für Inference"""
|
||||
|
||||
if model_type == 'base':
|
||||
success = inference.load_base_model(model_name)
|
||||
elif model_type == 'finetuned':
|
||||
if not adapter_path:
|
||||
raise HTTPException(status_code=400, detail="adapter_path required for finetuned model")
|
||||
success = inference.load_finetuned_model(model_name, adapter_path)
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Invalid model_type")
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail="Failed to load model")
|
||||
|
||||
return {'success': True}
|
||||
|
||||
|
||||
@app.get("/api/inference/loaded")
|
||||
async def get_loaded_models():
|
||||
"""Gibt zurück welche Modelle geladen sind"""
|
||||
loaded = inference.get_loaded_models()
|
||||
return loaded
|
||||
|
||||
|
||||
@app.post("/api/inference/generate")
|
||||
async def generate_text(request: InferenceRequest):
|
||||
"""Generiert Text mit geladenem Modell"""
|
||||
result = inference.generate(
|
||||
prompt=request.prompt,
|
||||
model_type=request.model_type,
|
||||
max_tokens=request.max_tokens,
|
||||
temperature=request.temperature
|
||||
)
|
||||
|
||||
return {'result': result}
|
||||
|
||||
|
||||
@app.post("/api/inference/compare")
|
||||
async def compare_models(request: InferenceComparisonRequest):
|
||||
"""Vergleicht Base und Fine-tuned Model"""
|
||||
|
||||
prompt = inference.format_mail_prompt(
|
||||
request.task_type,
|
||||
request.mail_body
|
||||
)
|
||||
|
||||
result = inference.generate_comparison(
|
||||
prompt=prompt,
|
||||
max_tokens=request.max_tokens,
|
||||
temperature=request.temperature
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@app.get("/api/inference/test-prompts")
|
||||
async def get_test_prompts():
|
||||
"""Gibt vordefinierte Test-Prompts zurück"""
|
||||
prompts = inference.get_test_prompts()
|
||||
return prompts
|
||||
|
||||
|
||||
# ===== Static Files =====
|
||||
|
||||
# Serve Frontend
|
||||
app.mount("/", StaticFiles(directory="frontend", html=True), name="frontend")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
@@ -0,0 +1,321 @@
|
||||
"""
|
||||
MLX Training Wrapper für Fine-Tuning
|
||||
Nutzt mlx-lm für LoRA Fine-Tuning
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
import psutil
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Callable, Optional
|
||||
from dataclasses import dataclass
|
||||
import threading
|
||||
import queue
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingConfig:
|
||||
"""Training Konfiguration"""
|
||||
model_name: str
|
||||
learning_rate: float = 1e-5
|
||||
epochs: int = 3
|
||||
batch_size: int = 4
|
||||
lora_rank: int = 8
|
||||
lora_alpha: int = 16
|
||||
max_seq_length: int = 2048
|
||||
val_every: int = 50
|
||||
|
||||
|
||||
class TrainingStatus:
|
||||
"""Verwaltet den aktuellen Training-Status"""
|
||||
|
||||
def __init__(self):
|
||||
self.is_training = False
|
||||
self.should_stop = False
|
||||
self.current_step = 0
|
||||
self.total_steps = 0
|
||||
self.current_epoch = 0
|
||||
self.train_loss = 0.0
|
||||
self.val_loss = 0.0
|
||||
self.train_loss_history = []
|
||||
self.val_loss_history = []
|
||||
self.start_time = None
|
||||
self.error = None
|
||||
|
||||
def reset(self):
|
||||
"""Setzt den Status zurück"""
|
||||
self.is_training = False
|
||||
self.should_stop = False
|
||||
self.current_step = 0
|
||||
self.total_steps = 0
|
||||
self.current_epoch = 0
|
||||
self.train_loss = 0.0
|
||||
self.val_loss = 0.0
|
||||
self.train_loss_history = []
|
||||
self.val_loss_history = []
|
||||
self.start_time = None
|
||||
self.error = None
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Konvertiert zu Dictionary für API"""
|
||||
eta = None
|
||||
if self.is_training and self.current_step > 0 and self.start_time:
|
||||
elapsed = time.time() - self.start_time
|
||||
steps_remaining = self.total_steps - self.current_step
|
||||
eta = int((elapsed / self.current_step) * steps_remaining)
|
||||
|
||||
memory_usage = psutil.virtual_memory().percent
|
||||
|
||||
return {
|
||||
'is_training': self.is_training,
|
||||
'current_step': self.current_step,
|
||||
'total_steps': self.total_steps,
|
||||
'current_epoch': self.current_epoch,
|
||||
'train_loss': round(self.train_loss, 4) if self.train_loss else None,
|
||||
'val_loss': round(self.val_loss, 4) if self.val_loss else None,
|
||||
'train_loss_history': [round(l, 4) for l in self.train_loss_history],
|
||||
'val_loss_history': [round(l, 4) for l in self.val_loss_history],
|
||||
'eta_seconds': eta,
|
||||
'memory_usage_percent': memory_usage,
|
||||
'error': self.error
|
||||
}
|
||||
|
||||
|
||||
class MLXTrainer:
|
||||
"""Wrapper für MLX Training"""
|
||||
|
||||
def __init__(self, models_dir: str = "models", output_dir: str = "output"):
|
||||
self.models_dir = Path(models_dir)
|
||||
self.output_dir = Path(output_dir)
|
||||
self.models_dir.mkdir(exist_ok=True)
|
||||
self.output_dir.mkdir(exist_ok=True)
|
||||
|
||||
self.status = TrainingStatus()
|
||||
self.training_thread = None
|
||||
|
||||
def prepare_training_data(self, train_data: List[Dict],
|
||||
val_data: List[Dict],
|
||||
data_dir: Path) -> tuple[Path, Path]:
|
||||
"""Konvertiert Daten ins MLX Format (JSONL)"""
|
||||
|
||||
def format_example(item: Dict) -> Dict:
|
||||
"""Formatiert ein Beispiel im Chat-Format"""
|
||||
task_type = item['task_type']
|
||||
body = item['body']
|
||||
output = item['expected_output']
|
||||
|
||||
# Task-spezifische Prompts
|
||||
task_prompts = {
|
||||
'Zusammenfassen': 'Fasse folgende E-Mail zusammen:',
|
||||
'Antwort schreiben': 'Schreibe eine Antwort auf folgende E-Mail:',
|
||||
'Kategorisieren': 'Kategorisiere folgende E-Mail:',
|
||||
'Action Items': 'Extrahiere die Action Items aus folgender E-Mail:',
|
||||
'Custom': 'Bearbeite folgende E-Mail:'
|
||||
}
|
||||
|
||||
instruction = task_prompts.get(task_type, task_prompts['Custom'])
|
||||
|
||||
return {
|
||||
'messages': [
|
||||
{
|
||||
'role': 'user',
|
||||
'content': f"{instruction}\n\n{body}"
|
||||
},
|
||||
{
|
||||
'role': 'assistant',
|
||||
'content': output
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
train_file = data_dir / 'train.jsonl'
|
||||
val_file = data_dir / 'val.jsonl'
|
||||
|
||||
# Schreibe Training Data
|
||||
with open(train_file, 'w', encoding='utf-8') as f:
|
||||
for item in train_data:
|
||||
f.write(json.dumps(format_example(item), ensure_ascii=False) + '\n')
|
||||
|
||||
# Schreibe Validation Data
|
||||
with open(val_file, 'w', encoding='utf-8') as f:
|
||||
for item in val_data:
|
||||
f.write(json.dumps(format_example(item), ensure_ascii=False) + '\n')
|
||||
|
||||
return train_file, val_file
|
||||
|
||||
def _run_training(self, config: TrainingConfig,
|
||||
train_file: Path, val_file: Path,
|
||||
output_path: Path):
|
||||
"""Führt das Training aus (läuft in eigenem Thread)"""
|
||||
try:
|
||||
# Import hier um MLX nur bei Bedarf zu laden
|
||||
from mlx_lm import load, LoRALinear
|
||||
from mlx_lm.tuner import train as mlx_train
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
|
||||
self.status.is_training = True
|
||||
self.status.start_time = time.time()
|
||||
self.status.error = None
|
||||
|
||||
# Lade Modell
|
||||
model_path = self.models_dir / config.model_name
|
||||
if not model_path.exists():
|
||||
raise FileNotFoundError(f"Model not found: {model_path}")
|
||||
|
||||
# Training durchführen mit mlx-lm
|
||||
# Dies ist ein vereinfachtes Beispiel - mlx-lm hat eigene Trainer
|
||||
# In der Praxis würde man mlx_lm.tuner verwenden
|
||||
|
||||
# Lade Training Config
|
||||
train_config = {
|
||||
'model': str(model_path),
|
||||
'data': str(train_file),
|
||||
'val_data': str(val_file),
|
||||
'train': True,
|
||||
'iters': config.epochs * 100, # Approximation
|
||||
'val_batches': 10,
|
||||
'learning_rate': config.learning_rate,
|
||||
'batch_size': config.batch_size,
|
||||
'lora_layers': config.lora_rank,
|
||||
'adapter_file': str(output_path / 'adapters.npz'),
|
||||
'save_every': 50,
|
||||
'val_every': config.val_every,
|
||||
}
|
||||
|
||||
# Callback für Progress-Updates
|
||||
def training_callback(step: int, loss: float, val_loss: Optional[float] = None):
|
||||
if self.status.should_stop:
|
||||
return False # Stop training
|
||||
|
||||
self.status.current_step = step
|
||||
self.status.train_loss = loss
|
||||
self.status.train_loss_history.append(loss)
|
||||
|
||||
if val_loss is not None:
|
||||
self.status.val_loss = val_loss
|
||||
self.status.val_loss_history.append(val_loss)
|
||||
|
||||
return True
|
||||
|
||||
# Hinweis: Dies ist ein Platzhalter für echtes MLX Training
|
||||
# In der Praxis würde man mlx_lm.tuner.train() oder eine
|
||||
# eigene Training Loop mit mlx nutzen
|
||||
|
||||
# Simuliere Training für Demo (MUSS durch echtes MLX Training ersetzt werden)
|
||||
total_steps = config.epochs * (len(list(open(train_file))) // config.batch_size)
|
||||
self.status.total_steps = total_steps
|
||||
|
||||
for epoch in range(config.epochs):
|
||||
self.status.current_epoch = epoch + 1
|
||||
|
||||
for step in range(total_steps // config.epochs):
|
||||
if self.status.should_stop:
|
||||
break
|
||||
|
||||
# Simuliere Training Step
|
||||
self.status.current_step = epoch * (total_steps // config.epochs) + step
|
||||
fake_loss = 2.0 - (self.status.current_step / total_steps) * 1.5
|
||||
self.status.train_loss = fake_loss
|
||||
self.status.train_loss_history.append(fake_loss)
|
||||
|
||||
# Validation alle N Steps
|
||||
if step % config.val_every == 0:
|
||||
fake_val_loss = 2.2 - (self.status.current_step / total_steps) * 1.4
|
||||
self.status.val_loss = fake_val_loss
|
||||
self.status.val_loss_history.append(fake_val_loss)
|
||||
|
||||
time.sleep(0.1) # Simuliere Rechenzeit
|
||||
|
||||
if self.status.should_stop:
|
||||
break
|
||||
|
||||
# Speichere finale Adapter
|
||||
# output_path / 'adapters.npz' würde die LoRA Weights enthalten
|
||||
|
||||
self.status.is_training = False
|
||||
|
||||
except Exception as e:
|
||||
self.status.error = str(e)
|
||||
self.status.is_training = False
|
||||
|
||||
def start_training(self, config: TrainingConfig,
|
||||
train_data: List[Dict],
|
||||
val_data: List[Dict]) -> bool:
|
||||
"""Startet das Training"""
|
||||
|
||||
if self.status.is_training:
|
||||
return False
|
||||
|
||||
# Bereite Daten vor
|
||||
data_dir = self.output_dir / f"training_{int(time.time())}"
|
||||
data_dir.mkdir(exist_ok=True)
|
||||
|
||||
train_file, val_file = self.prepare_training_data(
|
||||
train_data, val_data, data_dir
|
||||
)
|
||||
|
||||
# Output-Pfad
|
||||
output_path = self.output_dir / f"run_{int(time.time())}"
|
||||
output_path.mkdir(exist_ok=True)
|
||||
|
||||
# Reset Status
|
||||
self.status.reset()
|
||||
|
||||
# Starte Training in eigenem Thread
|
||||
self.training_thread = threading.Thread(
|
||||
target=self._run_training,
|
||||
args=(config, train_file, val_file, output_path),
|
||||
daemon=True
|
||||
)
|
||||
self.training_thread.start()
|
||||
|
||||
return True
|
||||
|
||||
def stop_training(self) -> bool:
|
||||
"""Stoppt das laufende Training"""
|
||||
if not self.status.is_training:
|
||||
return False
|
||||
|
||||
self.status.should_stop = True
|
||||
|
||||
# Warte max 5 Sekunden auf Thread
|
||||
if self.training_thread:
|
||||
self.training_thread.join(timeout=5)
|
||||
|
||||
return True
|
||||
|
||||
def get_status(self) -> Dict:
|
||||
"""Gibt aktuellen Status zurück"""
|
||||
return self.status.to_dict()
|
||||
|
||||
def list_available_models(self) -> List[str]:
|
||||
"""Listet verfügbare Modelle auf"""
|
||||
if not self.models_dir.exists():
|
||||
return []
|
||||
|
||||
models = []
|
||||
for path in self.models_dir.iterdir():
|
||||
if path.is_dir():
|
||||
models.append(path.name)
|
||||
|
||||
return models
|
||||
|
||||
def download_model(self, model_name: str) -> bool:
|
||||
"""
|
||||
Lädt ein Modell herunter
|
||||
In der Praxis würde man hier huggingface_hub nutzen
|
||||
"""
|
||||
# Placeholder - würde huggingface_hub.snapshot_download nutzen
|
||||
# und dann mit mlx_lm.convert konvertieren
|
||||
|
||||
# Beispiel:
|
||||
# from huggingface_hub import snapshot_download
|
||||
# from mlx_lm.convert import convert
|
||||
#
|
||||
# hf_path = snapshot_download(model_name)
|
||||
# mlx_path = self.models_dir / model_name
|
||||
# convert(hf_path, mlx_path)
|
||||
|
||||
return False # Nicht implementiert in diesem Beispiel
|
||||
Reference in New Issue
Block a user