Weights and Biases คืออะไร
Weights and Biases (W&B หรือ wandb) เป็น ML platform สำหรับ experiment tracking, model versioning, dataset management และ collaboration เป็นเครื่องมือที่ ML engineers และ data scientists ใช้มากที่สุดสำหรับ track experiments
Features หลักของ W&B ได้แก่ Experiment Tracking บันทึก metrics, hyperparameters, system info อัตโนมัติ, Sweeps hyperparameter optimization, Artifacts versioned datasets และ models, Reports สร้าง interactive reports สำหรับ collaboration, Tables visualize และ query datasets, Model Registry จัดการ model lifecycle และ Launch distributed training orchestration
Observability Stack สำหรับ ML หมายถึงระบบที่ monitor ทั้ง training pipeline (experiment metrics, resource usage, data quality) และ production inference (prediction latency, model drift, data drift, accuracy degradation) W&B เป็นหนึ่งใน core components ของ ML Observability Stack
ติดตั้งและเริ่มต้นใช้งาน W&B
วิธีติดตั้งและตั้งค่า W&B
# === ติดตั้ง Weights & Biases ===
# 1. Install wandb
pip install wandb
# 2. Login (ต้อง API key จาก wandb.ai)
wandb login
# หรือ set environment variable
export WANDB_API_KEY="your-api-key-here"
# 3. Self-hosted W&B Server (optional)
# Docker Compose
cat > docker-compose.yml << 'EOF'
version: '3'
services:
wandb:
image: wandb/local:latest
ports:
- "8080:8080"
environment:
- MYSQL_HOST=db
- MYSQL_PORT=3306
- MYSQL_DATABASE=wandb
- MYSQL_USER=wandb
- MYSQL_PASSWORD=wandb_pass
- BUCKET=s3://wandb-artifacts
volumes:
- wandb_data:/vol
db:
image: mysql:8.0
environment:
MYSQL_ROOT_PASSWORD: root_pass
MYSQL_DATABASE: wandb
MYSQL_USER: wandb
MYSQL_PASSWORD: wandb_pass
volumes:
- mysql_data:/var/lib/mysql
volumes:
wandb_data:
mysql_data:
EOF
docker-compose up -d
# 4. Basic Usage Test
python3 << 'PYEOF'
import wandb
import random
# Initialize a run
run = wandb.init(
project="test-project",
config={
"learning_rate": 0.001,
"epochs": 10,
"batch_size": 32,
"model": "resnet50",
}
)
# Simulate training
for epoch in range(10):
loss = 2.0 * (0.9 ** epoch) + random.gauss(0, 0.1)
acc = 0.5 + 0.05 * epoch + random.gauss(0, 0.02)
wandb.log({
"epoch": epoch,
"train/loss": loss,
"train/accuracy": min(acc, 1.0),
"val/loss": loss * 1.1,
"val/accuracy": min(acc * 0.95, 1.0),
"learning_rate": 0.001 * (0.95 ** epoch),
})
wandb.finish()
print("W&B test run complete")
PYEOF
echo "W&B installed and tested"
Experiment Tracking ด้วย W&B
Track experiments อย่างเป็นระบบ
#!/usr/bin/env python3
# experiment_tracker.py — W&B Experiment Tracking
import json
import logging
import time
from datetime import datetime
from typing import Dict, List, Optional
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("tracker")
class ExperimentTracker:
def __init__(self, project_name, entity=None):
self.project = project_name
self.entity = entity
self.experiments = []
def create_experiment(self, name, config):
# import wandb
# run = wandb.init(
# project=self.project,
# entity=self.entity,
# name=name,
# config=config,
# tags=config.get("tags", []),
# )
experiment = {
"name": name,
"config": config,
"metrics": [],
"start_time": datetime.utcnow().isoformat(),
"status": "running",
}
self.experiments.append(experiment)
logger.info(f"Created experiment: {name}")
return experiment
def log_metrics(self, experiment, step, metrics):
metrics["step"] = step
metrics["timestamp"] = datetime.utcnow().isoformat()
experiment["metrics"].append(metrics)
# wandb.log(metrics, step=step)
def log_model(self, experiment, model_path, metadata=None):
# artifact = wandb.Artifact(
# name=f"model-{experiment['name']}",
# type="model",
# metadata=metadata,
# )
# artifact.add_file(model_path)
# wandb.log_artifact(artifact)
experiment["model_path"] = model_path
experiment["model_metadata"] = metadata
logger.info(f"Logged model: {model_path}")
def log_dataset(self, name, path, metadata=None):
# artifact = wandb.Artifact(name=name, type="dataset", metadata=metadata)
# artifact.add_dir(path)
# wandb.log_artifact(artifact)
logger.info(f"Logged dataset: {name} from {path}")
def finish_experiment(self, experiment):
experiment["status"] = "completed"
experiment["end_time"] = datetime.utcnow().isoformat()
if experiment["metrics"]:
losses = [m.get("train/loss", float('inf')) for m in experiment["metrics"]]
accs = [m.get("train/accuracy", 0) for m in experiment["metrics"]]
experiment["summary"] = {
"best_loss": min(losses),
"best_accuracy": max(accs),
"total_steps": len(experiment["metrics"]),
}
# wandb.finish()
logger.info(f"Experiment completed: {experiment['name']}")
def compare_experiments(self):
results = []
for exp in self.experiments:
if exp.get("summary"):
results.append({
"name": exp["name"],
"config": {k: v for k, v in exp["config"].items() if k != "tags"},
"best_loss": exp["summary"]["best_loss"],
"best_accuracy": exp["summary"]["best_accuracy"],
})
results.sort(key=lambda x: x["best_loss"])
return results
# Hyperparameter Sweep Configuration
SWEEP_CONFIG = {
"method": "bayes",
"metric": {"name": "val/loss", "goal": "minimize"},
"parameters": {
"learning_rate": {"distribution": "log_uniform_values", "min": 1e-5, "max": 1e-2},
"batch_size": {"values": [16, 32, 64, 128]},
"optimizer": {"values": ["adam", "adamw", "sgd"]},
"dropout": {"distribution": "uniform", "min": 0.0, "max": 0.5},
"hidden_size": {"values": [128, 256, 512]},
},
}
# Example
tracker = ExperimentTracker("ml-observability")
configs = [
{"learning_rate": 0.001, "batch_size": 32, "optimizer": "adam"},
{"learning_rate": 0.0005, "batch_size": 64, "optimizer": "adamw"},
]
import random
random.seed(42)
for i, config in enumerate(configs):
exp = tracker.create_experiment(f"run-{i}", config)
for step in range(50):
loss = 2.0 * (0.95 ** step) + random.gauss(0, 0.05) - i * 0.1
acc = 0.5 + 0.01 * step + random.gauss(0, 0.01)
tracker.log_metrics(exp, step, {
"train/loss": max(loss, 0.01),
"train/accuracy": min(max(acc, 0), 1),
})
tracker.finish_experiment(exp)
print(json.dumps(tracker.compare_experiments(), indent=2))
สร้าง Observability Stack สำหรับ ML
ระบบ observability ครบวงจร
#!/usr/bin/env python3
# ml_observability.py — ML Observability Stack
import json
import logging
import hashlib
from datetime import datetime
from typing import Dict, List, Optional
from dataclasses import dataclass, field
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("observability")
@dataclass
class PredictionLog:
request_id: str
model_name: str
model_version: str
input_features: Dict
prediction: float
latency_ms: float
timestamp: str
class MLObservabilityStack:
def __init__(self):
self.predictions: List[PredictionLog] = []
self.feature_baselines: Dict = {}
self.prediction_baselines: Dict = {}
def set_baseline(self, model_name, feature_stats, prediction_stats):
self.feature_baselines[model_name] = feature_stats
self.prediction_baselines[model_name] = prediction_stats
def log_prediction(self, log: PredictionLog):
self.predictions.append(log)
# wandb.log({
# "inference/latency_ms": log.latency_ms,
# "inference/prediction": log.prediction,
# "inference/model_version": log.model_version,
# })
def detect_data_drift(self, model_name, recent_window=100):
model_preds = [p for p in self.predictions if p.model_name == model_name]
if len(model_preds) < recent_window:
return {"status": "insufficient_data"}
recent = model_preds[-recent_window:]
baseline = self.feature_baselines.get(model_name, {})
if not baseline:
return {"status": "no_baseline"}
drift_results = {}
for feature_name, baseline_stats in baseline.items():
recent_values = [p.input_features.get(feature_name, 0) for p in recent]
if not recent_values:
continue
recent_mean = sum(recent_values) / len(recent_values)
recent_std = (sum((x - recent_mean)**2 for x in recent_values) / len(recent_values)) ** 0.5
baseline_mean = baseline_stats.get("mean", 0)
baseline_std = baseline_stats.get("std", 1)
# Simple drift detection: z-score of mean shift
if baseline_std > 0:
z_score = abs(recent_mean - baseline_mean) / baseline_std
else:
z_score = 0
drift_results[feature_name] = {
"baseline_mean": round(baseline_mean, 4),
"recent_mean": round(recent_mean, 4),
"z_score": round(z_score, 2),
"drifted": z_score > 2.0,
}
drifted_features = [f for f, r in drift_results.items() if r["drifted"]]
return {
"model": model_name,
"window_size": recent_window,
"features_checked": len(drift_results),
"features_drifted": len(drifted_features),
"drifted_features": drifted_features,
"status": "drift_detected" if drifted_features else "no_drift",
"details": drift_results,
}
def detect_prediction_drift(self, model_name, recent_window=100):
model_preds = [p for p in self.predictions if p.model_name == model_name]
if len(model_preds) < recent_window:
return {"status": "insufficient_data"}
recent = model_preds[-recent_window:]
baseline = self.prediction_baselines.get(model_name, {})
recent_preds = [p.prediction for p in recent]
recent_mean = sum(recent_preds) / len(recent_preds)
baseline_mean = baseline.get("mean", 0)
baseline_std = baseline.get("std", 1)
z_score = abs(recent_mean - baseline_mean) / max(baseline_std, 0.001)
return {
"model": model_name,
"baseline_mean": round(baseline_mean, 4),
"recent_mean": round(recent_mean, 4),
"z_score": round(z_score, 2),
"drifted": z_score > 2.0,
}
def get_performance_metrics(self, model_name, window=100):
model_preds = [p for p in self.predictions if p.model_name == model_name][-window:]
if not model_preds:
return {}
latencies = [p.latency_ms for p in model_preds]
latencies.sort()
return {
"model": model_name,
"total_predictions": len(model_preds),
"avg_latency_ms": round(sum(latencies) / len(latencies), 1),
"p50_latency_ms": round(latencies[len(latencies) // 2], 1),
"p95_latency_ms": round(latencies[int(len(latencies) * 0.95)], 1),
"p99_latency_ms": round(latencies[int(len(latencies) * 0.99)], 1),
"throughput_rps": round(len(latencies) / (sum(latencies) / 1000), 1),
}
# Example
import random
random.seed(42)
stack = MLObservabilityStack()
stack.set_baseline("fraud-detector",
{"amount": {"mean": 100, "std": 50}, "frequency": {"mean": 5, "std": 2}},
{"mean": 0.15, "std": 0.1})
for i in range(200):
drift = 20 if i > 150 else 0 # Inject drift after 150 predictions
stack.log_prediction(PredictionLog(
request_id=f"req-{i:04d}", model_name="fraud-detector",
model_version="v1.2",
input_features={"amount": 100 + drift + random.gauss(0, 50), "frequency": 5 + random.gauss(0, 2)},
prediction=random.random() * 0.3,
latency_ms=random.gauss(15, 3),
timestamp=datetime.utcnow().isoformat()))
print("Data Drift:", json.dumps(stack.detect_data_drift("fraud-detector"), indent=2))
print("Performance:", json.dumps(stack.get_performance_metrics("fraud-detector"), indent=2))
Model Registry และ Artifact Management
จัดการ models และ artifacts
# === W&B Model Registry & Artifacts ===
# 1. Log Model as Artifact
# ===================================
# import wandb
#
# run = wandb.init(project="my-project")
#
# # Log model artifact
# model_artifact = wandb.Artifact(
# name="fraud-detector",
# type="model",
# description="Fraud detection model v1.2",
# metadata={
# "framework": "pytorch",
# "accuracy": 0.95,
# "f1_score": 0.92,
# "training_data": "dataset-v3",
# "features": ["amount", "frequency", "location"],
# }
# )
# model_artifact.add_file("model.pt")
# model_artifact.add_file("config.json")
# run.log_artifact(model_artifact)
# 2. Link to Model Registry
# ===================================
# run.link_artifact(
# model_artifact,
# "model-registry/fraud-detector",
# aliases=["latest", "v1.2", "production"]
# )
# 3. Download Artifact for Inference
# ===================================
# run = wandb.init(project="my-project")
# artifact = run.use_artifact("fraud-detector:production")
# artifact_dir = artifact.download()
# # Load model from artifact_dir
# 4. Dataset Versioning
# ===================================
# dataset_artifact = wandb.Artifact(
# name="training-data",
# type="dataset",
# metadata={
# "rows": 100000,
# "features": 50,
# "source": "postgresql",
# "split": "train",
# }
# )
# dataset_artifact.add_dir("data/train/")
# run.log_artifact(dataset_artifact)
# 5. Model Promotion Pipeline
# ===================================
#!/usr/bin/env python3
# model_promotion.py
import json
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("promotion")
class ModelRegistry:
def __init__(self):
self.models = {}
self.promotion_history = []
def register(self, name, version, metrics, artifacts):
key = f"{name}:{version}"
self.models[key] = {
"name": name,
"version": version,
"metrics": metrics,
"artifacts": artifacts,
"stage": "staging",
"registered_at": "2025-01-15T10:00:00",
}
logger.info(f"Registered: {key}")
def promote(self, name, version, target_stage, approver=None):
key = f"{name}:{version}"
model = self.models.get(key)
if not model:
return {"error": f"Model not found: {key}"}
# Validation checks
checks = self._run_promotion_checks(model, target_stage)
if not all(c["passed"] for c in checks):
failed = [c for c in checks if not c["passed"]]
return {"error": "Promotion checks failed", "failed_checks": failed}
old_stage = model["stage"]
model["stage"] = target_stage
self.promotion_history.append({
"model": key,
"from_stage": old_stage,
"to_stage": target_stage,
"approver": approver,
"checks": checks,
})
logger.info(f"Promoted {key}: {old_stage} -> {target_stage}")
return {"status": "promoted", "model": key, "stage": target_stage}
def _run_promotion_checks(self, model, target_stage):
checks = []
metrics = model.get("metrics", {})
if target_stage == "production":
checks.append({
"check": "accuracy >= 0.90",
"passed": metrics.get("accuracy", 0) >= 0.90,
"value": metrics.get("accuracy"),
})
checks.append({
"check": "latency_p99 <= 100ms",
"passed": metrics.get("latency_p99_ms", 999) <= 100,
"value": metrics.get("latency_p99_ms"),
})
checks.append({
"check": "test_coverage >= 80%",
"passed": metrics.get("test_coverage", 0) >= 80,
"value": metrics.get("test_coverage"),
})
return checks
registry = ModelRegistry()
registry.register("fraud-detector", "v1.2",
{"accuracy": 0.95, "f1": 0.92, "latency_p99_ms": 45, "test_coverage": 85},
["model.pt", "config.json"])
result = registry.promote("fraud-detector", "v1.2", "production", approver="admin")
print(json.dumps(result, indent=2))
Production Monitoring และ Alerting
Monitor ML models ใน production
# === Production ML Monitoring ===
# 1. Prometheus Metrics for ML
# ===================================
# Custom metrics:
# ml_prediction_latency_seconds{model="fraud-detector", version="v1.2"}
# ml_prediction_total{model="fraud-detector", result="fraud"}
# ml_prediction_total{model="fraud-detector", result="legit"}
# ml_data_drift_score{model="fraud-detector", feature="amount"}
# ml_model_accuracy{model="fraud-detector"}
# 2. Grafana Dashboard Queries
# ===================================
# Prediction Rate
# rate(ml_prediction_total[5m])
# P99 Latency
# histogram_quantile(0.99, rate(ml_prediction_latency_seconds_bucket[5m]))
# Error Rate
# rate(ml_prediction_errors_total[5m]) / rate(ml_prediction_total[5m])
# Data Drift Score
# ml_data_drift_score > 2.0
# 3. Alert Rules
# ===================================
# groups:
# - name: ml-monitoring
# rules:
# - alert: HighPredictionLatency
# expr: histogram_quantile(0.99, rate(ml_prediction_latency_seconds_bucket[5m])) > 0.1
# for: 5m
# labels:
# severity: warning
#
# - alert: DataDriftDetected
# expr: ml_data_drift_score > 2.0
# for: 15m
# labels:
# severity: warning
# annotations:
# summary: "Data drift detected for {{ $labels.feature }}"
#
# - alert: ModelAccuracyDrop
# expr: ml_model_accuracy < 0.85
# for: 30m
# labels:
# severity: critical
#
# - alert: PredictionVolumeAnomaly
# expr: |
# abs(rate(ml_prediction_total[1h]) -
# avg_over_time(rate(ml_prediction_total[1h])[7d:1h]))
# > 2 * stddev_over_time(rate(ml_prediction_total[1h])[7d:1h])
# for: 15m
# labels:
# severity: info
# 4. W&B Integration for Production
# ===================================
# import wandb
#
# # Log production metrics to W&B
# wandb.init(project="production-monitoring", job_type="inference")
#
# # Periodic logging
# wandb.log({
# "production/latency_p99": 45,
# "production/request_rate": 150,
# "production/error_rate": 0.001,
# "production/drift_score": 1.2,
# "production/accuracy_estimate": 0.94,
# })
# 5. Automated Retraining Trigger
# ===================================
# if drift_score > threshold:
# trigger_retraining_pipeline()
# notify_team("Data drift detected, retraining initiated")
echo "ML production monitoring configured"
FAQ คำถามที่พบบ่อย
Q: W&B ฟรีไหม?
A: W&B มี free tier สำหรับ personal use ไม่จำกัด experiments, 100GB storage สำหรับ artifacts เพียงพอสำหรับ individual researchers และ small teams Team plan เริ่มที่ $50/user/month สำหรับ collaboration features Enterprise plan สำหรับ on-premises deployment และ advanced security สำหรับ open source projects W&B ให้ใช้ฟรี Self-hosted version (W&B Server) deploy ได้เองบน infrastructure ของตัวเอง
Q: W&B กับ MLflow ต่างกันอย่างไร?
A: MLflow เป็น open source สามารถ self-host ได้ฟรี มี experiment tracking, model registry, model serving W&B เป็น SaaS-first มี UI ที่ดีกว่ามาก collaboration features ดีกว่า (reports, teams), hyperparameter sweeps ใช้งานง่ายกว่า, integration กับ frameworks มากกว่า (PyTorch, TensorFlow, Hugging Face, LangChain) W&B ดีกว่าสำหรับ experiment tracking และ visualization MLflow ดีกว่าสำหรับ model serving และ deployment สำหรับ teams ที่ต้อง self-host ใช้ MLflow ถ้า budget มี W&B ดีกว่าสำหรับ productivity
Q: Model drift detection ทำอย่างไร?
A: มีหลายวิธี Data drift ตรวจ distribution shift ของ input features ด้วย statistical tests (KS test, PSI, Wasserstein distance) เปรียบเทียบ recent data กับ training data Concept drift ตรวจเมื่อ relationship ระหว่าง features และ target เปลี่ยน ใช้ ground truth labels (ถ้ามี) หรือ proxy metrics Prediction drift ตรวจเมื่อ distribution ของ predictions เปลี่ยน ง่ายที่สุดในการ implement W&B มี built-in monitoring สำหรับ drift detection ใน enterprise plan
Q: Observability Stack สำหรับ ML ต้องมีอะไรบ้าง?
A: ครบวงจรต้องมี Training Observability (W&B/MLflow สำหรับ experiment tracking), Data Quality (Great Expectations, dbt tests สำหรับ validate data), Model Registry (W&B Registry/MLflow สำหรับ version models), Inference Monitoring (Prometheus + Grafana สำหรับ latency, throughput, errors), Drift Detection (Evidently AI, NannyML สำหรับ data/prediction drift), Alerting (PagerDuty, Slack สำหรับ notifications) และ Logging (ELK Stack/Loki สำหรับ prediction logs)
