1456995462
Implemented a full-stack web application for fine-tuning LLMs on email data, optimized for Apple Silicon (M4 Pro with 24GB RAM). Features: - Mail import with drag & drop support (.mbox, .eml, .txt) - Automated mail cleaning and preprocessing - Interactive labeling interface with keyboard shortcuts - Training data export to JSONL format - MLX-based LoRA fine-tuning with live updates - Model evaluation and comparison interface - Server-Sent Events for real-time training progress - Dark theme UI optimized for extended use Technical Stack: - Backend: FastAPI with SQLite database - Frontend: Vanilla HTML/CSS/JavaScript (no external dependencies) - ML Framework: MLX for Apple Silicon optimization - Models: Support for Mistral 7B and Llama 3 8B via MLX Components: - data_manager.py: SQLite operations for mail storage and labeling - mail_parser.py: Parser for multiple mail formats with cleaning - training.py: MLX training wrapper with LoRA support - inference.py: Model loading and inference for evaluation - main.py: FastAPI backend with REST API and SSE - Frontend: Complete UI with all features Documentation: - Comprehensive README with installation and usage guide - Quick-start guide for rapid setup - Example mails for testing - Troubleshooting and best practices Ready for local deployment and fine-tuning workflows.
287 lines
8.5 KiB
Python
287 lines
8.5 KiB
Python
"""
|
|
Data Manager für Mail Fine-Tuning App
|
|
Verwaltet SQLite Datenbank für Mails und Labels
|
|
"""
|
|
|
|
import sqlite3
|
|
import json
|
|
from datetime import datetime
|
|
from typing import List, Dict, Optional
|
|
from pathlib import Path
|
|
|
|
|
|
class DataManager:
|
|
def __init__(self, db_path: str = "data/mails.db"):
|
|
self.db_path = Path(db_path)
|
|
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
self.init_db()
|
|
|
|
def init_db(self):
|
|
"""Initialisiert die Datenbank mit dem Schema"""
|
|
conn = sqlite3.connect(self.db_path)
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute("""
|
|
CREATE TABLE IF NOT EXISTS mails (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
subject TEXT,
|
|
sender TEXT,
|
|
recipient TEXT,
|
|
date TEXT,
|
|
body TEXT NOT NULL,
|
|
original_format TEXT,
|
|
task_type TEXT DEFAULT 'unlabeled',
|
|
expected_output TEXT,
|
|
status TEXT DEFAULT 'unlabeled',
|
|
created_at TEXT DEFAULT CURRENT_TIMESTAMP,
|
|
updated_at TEXT DEFAULT CURRENT_TIMESTAMP
|
|
)
|
|
""")
|
|
|
|
cursor.execute("""
|
|
CREATE TABLE IF NOT EXISTS training_runs (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
model_name TEXT NOT NULL,
|
|
start_time TEXT,
|
|
end_time TEXT,
|
|
config TEXT,
|
|
status TEXT,
|
|
final_train_loss REAL,
|
|
final_val_loss REAL,
|
|
checkpoint_path TEXT
|
|
)
|
|
""")
|
|
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
def add_mail(self, subject: str, sender: str, recipient: str,
|
|
date: str, body: str, original_format: str) -> int:
|
|
"""Fügt eine neue Mail hinzu"""
|
|
conn = sqlite3.connect(self.db_path)
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute("""
|
|
INSERT INTO mails (subject, sender, recipient, date, body, original_format)
|
|
VALUES (?, ?, ?, ?, ?, ?)
|
|
""", (subject, sender, recipient, date, body, original_format))
|
|
|
|
mail_id = cursor.lastrowid
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
return mail_id
|
|
|
|
def get_all_mails(self, status_filter: Optional[str] = None) -> List[Dict]:
|
|
"""Holt alle Mails, optional gefiltert nach Status"""
|
|
conn = sqlite3.connect(self.db_path)
|
|
conn.row_factory = sqlite3.Row
|
|
cursor = conn.cursor()
|
|
|
|
if status_filter:
|
|
cursor.execute("SELECT * FROM mails WHERE status = ? ORDER BY id", (status_filter,))
|
|
else:
|
|
cursor.execute("SELECT * FROM mails ORDER BY id")
|
|
|
|
rows = cursor.fetchall()
|
|
mails = [dict(row) for row in rows]
|
|
|
|
conn.close()
|
|
return mails
|
|
|
|
def get_mail(self, mail_id: int) -> Optional[Dict]:
|
|
"""Holt eine einzelne Mail"""
|
|
conn = sqlite3.connect(self.db_path)
|
|
conn.row_factory = sqlite3.Row
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute("SELECT * FROM mails WHERE id = ?", (mail_id,))
|
|
row = cursor.fetchone()
|
|
|
|
conn.close()
|
|
return dict(row) if row else None
|
|
|
|
def update_mail(self, mail_id: int, task_type: Optional[str] = None,
|
|
expected_output: Optional[str] = None,
|
|
status: Optional[str] = None,
|
|
body: Optional[str] = None) -> bool:
|
|
"""Aktualisiert eine Mail (Labeling)"""
|
|
conn = sqlite3.connect(self.db_path)
|
|
cursor = conn.cursor()
|
|
|
|
updates = []
|
|
params = []
|
|
|
|
if task_type is not None:
|
|
updates.append("task_type = ?")
|
|
params.append(task_type)
|
|
|
|
if expected_output is not None:
|
|
updates.append("expected_output = ?")
|
|
params.append(expected_output)
|
|
|
|
if status is not None:
|
|
updates.append("status = ?")
|
|
params.append(status)
|
|
|
|
if body is not None:
|
|
updates.append("body = ?")
|
|
params.append(body)
|
|
|
|
if not updates:
|
|
conn.close()
|
|
return False
|
|
|
|
updates.append("updated_at = ?")
|
|
params.append(datetime.now().isoformat())
|
|
params.append(mail_id)
|
|
|
|
query = f"UPDATE mails SET {', '.join(updates)} WHERE id = ?"
|
|
cursor.execute(query, params)
|
|
|
|
success = cursor.rowcount > 0
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
return success
|
|
|
|
def delete_mail(self, mail_id: int) -> bool:
|
|
"""Löscht eine Mail"""
|
|
conn = sqlite3.connect(self.db_path)
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute("DELETE FROM mails WHERE id = ?", (mail_id,))
|
|
success = cursor.rowcount > 0
|
|
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
return success
|
|
|
|
def get_statistics(self) -> Dict:
|
|
"""Berechnet Statistiken über die Daten"""
|
|
conn = sqlite3.connect(self.db_path)
|
|
cursor = conn.cursor()
|
|
|
|
# Gesamt-Anzahl
|
|
cursor.execute("SELECT COUNT(*) FROM mails")
|
|
total = cursor.fetchone()[0]
|
|
|
|
# Nach Status
|
|
cursor.execute("""
|
|
SELECT status, COUNT(*) as count
|
|
FROM mails
|
|
GROUP BY status
|
|
""")
|
|
status_counts = {row[0]: row[1] for row in cursor.fetchall()}
|
|
|
|
# Nach Task-Type
|
|
cursor.execute("""
|
|
SELECT task_type, COUNT(*) as count
|
|
FROM mails
|
|
WHERE status = 'labeled'
|
|
GROUP BY task_type
|
|
""")
|
|
task_counts = {row[0]: row[1] for row in cursor.fetchall()}
|
|
|
|
# Durchschnittliche Längen (nur gelabelte)
|
|
cursor.execute("""
|
|
SELECT
|
|
AVG(LENGTH(body)) as avg_input_length,
|
|
AVG(LENGTH(expected_output)) as avg_output_length
|
|
FROM mails
|
|
WHERE status = 'labeled'
|
|
""")
|
|
lengths = cursor.fetchone()
|
|
|
|
conn.close()
|
|
|
|
labeled_count = status_counts.get('labeled', 0)
|
|
|
|
return {
|
|
'total': total,
|
|
'labeled': labeled_count,
|
|
'unlabeled': status_counts.get('unlabeled', 0),
|
|
'skipped': status_counts.get('skip', 0),
|
|
'task_distribution': task_counts,
|
|
'avg_input_length': round(lengths[0]) if lengths[0] else 0,
|
|
'avg_output_length': round(lengths[1]) if lengths[1] else 0,
|
|
'sufficient_data': labeled_count >= 50
|
|
}
|
|
|
|
def export_training_data(self, train_split: float = 0.9) -> tuple[List[Dict], List[Dict]]:
|
|
"""Exportiert gelabelte Daten für Training"""
|
|
import random
|
|
|
|
conn = sqlite3.connect(self.db_path)
|
|
conn.row_factory = sqlite3.Row
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute("""
|
|
SELECT body, task_type, expected_output
|
|
FROM mails
|
|
WHERE status = 'labeled' AND expected_output IS NOT NULL
|
|
ORDER BY RANDOM()
|
|
""")
|
|
|
|
rows = cursor.fetchall()
|
|
conn.close()
|
|
|
|
if not rows:
|
|
return [], []
|
|
|
|
data = [dict(row) for row in rows]
|
|
|
|
# Shuffle
|
|
random.shuffle(data)
|
|
|
|
# Split
|
|
split_idx = int(len(data) * train_split)
|
|
train_data = data[:split_idx]
|
|
val_data = data[split_idx:]
|
|
|
|
return train_data, val_data
|
|
|
|
def save_training_run(self, model_name: str, config: Dict,
|
|
checkpoint_path: str) -> int:
|
|
"""Speichert einen Training-Run"""
|
|
conn = sqlite3.connect(self.db_path)
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute("""
|
|
INSERT INTO training_runs
|
|
(model_name, start_time, config, status, checkpoint_path)
|
|
VALUES (?, ?, ?, ?, ?)
|
|
""", (
|
|
model_name,
|
|
datetime.now().isoformat(),
|
|
json.dumps(config),
|
|
'running',
|
|
checkpoint_path
|
|
))
|
|
|
|
run_id = cursor.lastrowid
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
return run_id
|
|
|
|
def update_training_run(self, run_id: int, status: str,
|
|
train_loss: Optional[float] = None,
|
|
val_loss: Optional[float] = None):
|
|
"""Aktualisiert einen Training-Run"""
|
|
conn = sqlite3.connect(self.db_path)
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute("""
|
|
UPDATE training_runs
|
|
SET status = ?,
|
|
end_time = ?,
|
|
final_train_loss = COALESCE(?, final_train_loss),
|
|
final_val_loss = COALESCE(?, final_val_loss)
|
|
WHERE id = ?
|
|
""", (status, datetime.now().isoformat(), train_loss, val_loss, run_id))
|
|
|
|
conn.commit()
|
|
conn.close()
|