MLflow Experiment Feature Flag Management คืออะไร
MLflow เป็น open source platform สำหรับจัดการ machine learning lifecycle ครบวงจร ตั้งแต่ experiment tracking, model packaging, deployment จนถึง model registry Feature Flag Management คือการควบคุมการเปิด-ปิด features ใน production โดยไม่ต้อง deploy ใหม่ การรวมสองแนวคิดนี้ช่วยให้ทีม ML สามารถ A/B test models, canary deploy ML features และ rollback models ได้อย่างปลอดภัย ด้วย feature flags ที่ควบคุม model serving
MLflow Architecture
# mlflow_arch.py — MLflow architecture
import json
class MLflowArchitecture:
COMPONENTS = {
"tracking": {
"name": "MLflow Tracking",
"description": "บันทึก parameters, metrics, artifacts ของทุก experiment run",
"use": "เปรียบเทียบ models, reproduce experiments",
},
"projects": {
"name": "MLflow Projects",
"description": "Package ML code เป็น reproducible format (conda/docker)",
"use": "แชร์ experiments ข้าม team, reproduce results",
},
"models": {
"name": "MLflow Models",
"description": "Package models เป็น standard format — deploy ได้ทุก platform",
"flavors": ["sklearn", "pytorch", "tensorflow", "xgboost", "langchain", "custom"],
},
"registry": {
"name": "Model Registry",
"description": "Version control สำหรับ models — staging, production, archived",
"stages": ["None", "Staging", "Production", "Archived"],
},
}
SETUP = """
# MLflow setup
pip install mlflow
# Start tracking server
mlflow server --host 0.0.0.0 --port 5000 \\
--backend-store-uri postgresql://user:pass@db:5432/mlflow \\
--default-artifact-root s3://mlflow-artifacts/
# Or use MLflow with local file store
mlflow ui --port 5000
"""
def show_components(self):
print("=== MLflow Components ===\n")
for key, comp in self.COMPONENTS.items():
print(f"[{comp['name']}]")
print(f" {comp['description']}")
print()
def show_setup(self):
print("=== Quick Setup ===")
print(self.SETUP[:400])
arch = MLflowArchitecture()
arch.show_components()
arch.show_setup()
Experiment Tracking
# tracking.py — MLflow experiment tracking
import mlflow
import mlflow.sklearn
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, precision_score
import numpy as np
# Set tracking URI
mlflow.set_tracking_uri("http://localhost:5000")
mlflow.set_experiment("customer-churn-prediction")
# Generate sample data
np.random.seed(42)
X = np.random.randn(1000, 10)
y = (X[:, 0] + X[:, 1] * 0.5 + np.random.randn(1000) * 0.3 > 0).astype(int)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
# Hyperparameter search with MLflow tracking
params_list = [
{"n_estimators": 100, "max_depth": 5, "min_samples_split": 2},
{"n_estimators": 200, "max_depth": 10, "min_samples_split": 5},
{"n_estimators": 300, "max_depth": 15, "min_samples_split": 10},
{"n_estimators": 150, "max_depth": 8, "min_samples_split": 3},
]
best_run = None
best_f1 = 0
for params in params_list:
with mlflow.start_run(run_name=f"rf_d{params['max_depth']}_n{params['n_estimators']}"):
# Log parameters
mlflow.log_params(params)
mlflow.log_param("model_type", "RandomForest")
# Train model
model = RandomForestClassifier(**params, random_state=42)
model.fit(X_train, y_train)
# Evaluate
y_pred = model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred)
precision = precision_score(y_test, y_pred)
# Log metrics
mlflow.log_metrics({
"accuracy": accuracy,
"f1_score": f1,
"precision": precision,
})
# Log model
mlflow.sklearn.log_model(model, "model")
# Track best
if f1 > best_f1:
best_f1 = f1
best_run = mlflow.active_run().info.run_id
print(f" Params: {params} → F1: {f1:.4f}")
print(f"\nBest run: {best_run} (F1: {best_f1:.4f})")
Feature Flag Integration
# feature_flags.py — Feature flags for ML model serving
import json
import random
class MLFeatureFlags:
CODE = """
# ml_feature_flags.py — Feature flag controlled model serving
import mlflow
import json
class MLFeatureFlagManager:
def __init__(self, flag_store_url="http://flagsmith:8000"):
self.flags = {}
self.models = {}
def register_model_flag(self, flag_name, model_configs):
'''Register a feature flag for model selection'''
self.flags[flag_name] = {
"configs": model_configs,
"active": model_configs[0]["name"], # default to first
}
def set_active_model(self, flag_name, model_name):
'''Switch active model via feature flag'''
if flag_name in self.flags:
self.flags[flag_name]["active"] = model_name
def get_model(self, flag_name):
'''Get currently active model based on feature flag'''
if flag_name not in self.flags:
raise ValueError(f"Unknown flag: {flag_name}")
active = self.flags[flag_name]["active"]
for config in self.flags[flag_name]["configs"]:
if config["name"] == active:
if config["name"] not in self.models:
self.models[config["name"]] = mlflow.pyfunc.load_model(config["model_uri"])
return self.models[config["name"]]
raise ValueError(f"Model not found: {active}")
def predict_with_flag(self, flag_name, input_data):
'''Make prediction using flagged model'''
model = self.get_model(flag_name)
return model.predict(input_data)
def canary_predict(self, flag_name, input_data, canary_pct=0.1):
'''Canary deployment — send % traffic to new model'''
import random
configs = self.flags[flag_name]["configs"]
if random.random() < canary_pct and len(configs) > 1:
# Canary: use new model
model_uri = configs[1]["model_uri"]
model_name = configs[1]["name"]
else:
# Stable: use current model
model_uri = configs[0]["model_uri"]
model_name = configs[0]["name"]
if model_name not in self.models:
self.models[model_name] = mlflow.pyfunc.load_model(model_uri)
prediction = self.models[model_name].predict(input_data)
return {"prediction": prediction, "model": model_name}
# Usage
manager = MLFeatureFlagManager()
# Register models with feature flag
manager.register_model_flag("churn-model", [
{"name": "v1-stable", "model_uri": "models:/churn-model/Production"},
{"name": "v2-canary", "model_uri": "models:/churn-model/Staging"},
])
# Normal prediction (uses stable)
result = manager.predict_with_flag("churn-model", input_data)
# Canary (10% traffic to new model)
result = manager.canary_predict("churn-model", input_data, canary_pct=0.10)
# Instant rollback — switch back to v1
manager.set_active_model("churn-model", "v1-stable")
"""
def show_code(self):
print("=== ML Feature Flags ===")
print(self.CODE[:600])
def deployment_dashboard(self):
print(f"\n=== Model Deployment Dashboard ===")
models = [
{"name": "churn-model", "version": "v2.1", "stage": "Production", "traffic": "90%", "canary": "v2.2 (10%)"},
{"name": "recommendation", "version": "v3.0", "stage": "Production", "traffic": "100%", "canary": "None"},
{"name": "fraud-detection", "version": "v1.5", "stage": "Production", "traffic": "95%", "canary": "v1.6 (5%)"},
]
for m in models:
print(f" [{m['stage']:>10}] {m['name']:<20} {m['version']} | Traffic: {m['traffic']} | Canary: {m['canary']}")
ff = MLFeatureFlags()
ff.show_code()
ff.deployment_dashboard()
Model Registry & Promotion
# registry.py — Model Registry workflow
import json
class ModelRegistry:
WORKFLOW = """
# model_promotion.py — Model promotion workflow
import mlflow
from mlflow.tracking import MlflowClient
client = MlflowClient()
# 1. Register best model from experiment
best_run_id = "abc123..."
model_uri = f"runs:/{best_run_id}/model"
# Register model (creates version 1)
result = mlflow.register_model(model_uri, "churn-prediction")
print(f"Registered: {result.name} v{result.version}")
# 2. Promote to Staging (for testing)
client.transition_model_version_stage(
name="churn-prediction",
version=result.version,
stage="Staging",
)
print(f"Promoted to Staging: v{result.version}")
# 3. Run validation tests
def validate_model(model_name, version):
model = mlflow.pyfunc.load_model(f"models:/{model_name}/{version}")
# Run test predictions
test_accuracy = 0.92 # from test suite
test_latency = 15 # ms
checks = {
"accuracy >= 0.90": test_accuracy >= 0.90,
"latency <= 50ms": test_latency <= 50,
"no_errors": True,
}
all_passed = all(checks.values())
return all_passed, checks
passed, checks = validate_model("churn-prediction", result.version)
print(f"Validation: {'PASS' if passed else 'FAIL'}")
for check, status in checks.items():
print(f" [{('OK' if status else 'FAIL'):>4}] {check}")
# 4. Promote to Production (if validation passes)
if passed:
# Archive current production model
prod_versions = client.get_latest_versions("churn-prediction", stages=["Production"])
for pv in prod_versions:
client.transition_model_version_stage(
name="churn-prediction",
version=pv.version,
stage="Archived",
)
# Promote new version
client.transition_model_version_stage(
name="churn-prediction",
version=result.version,
stage="Production",
)
print(f"Promoted to Production: v{result.version}")
"""
def show_workflow(self):
print("=== Model Promotion Workflow ===")
print(self.WORKFLOW[:600])
def registry_status(self):
print(f"\n=== Model Registry Status ===")
models = [
{"name": "churn-prediction", "prod": "v2.1", "staging": "v2.2", "versions": 8},
{"name": "recommendation-engine", "prod": "v3.0", "staging": "-", "versions": 12},
{"name": "fraud-detector", "prod": "v1.5", "staging": "v1.6", "versions": 6},
{"name": "price-predictor", "prod": "v4.2", "staging": "v4.3", "versions": 15},
]
for m in models:
print(f" {m['name']:<30} Prod: {m['prod']:<5} | Staging: {m['staging']:<5} | Versions: {m['versions']}")
reg = ModelRegistry()
reg.show_workflow()
reg.registry_status()
Monitoring & Rollback
# monitoring.py — Model monitoring and rollback
import json
import random
class ModelMonitoring:
METRICS = {
"accuracy": "Model accuracy on live data (compare with baseline)",
"latency_p50": "Inference latency P50 (target: < 20ms)",
"latency_p99": "Inference latency P99 (target: < 100ms)",
"data_drift": "Feature distribution change (KS test, PSI)",
"prediction_drift": "Output distribution change",
"error_rate": "Prediction error rate (misclassification)",
}
ROLLBACK = """
# rollback.py — Instant model rollback
import mlflow
from mlflow.tracking import MlflowClient
client = MlflowClient()
def rollback_model(model_name, target_version=None):
'''Rollback to previous production version'''
if target_version:
# Rollback to specific version
version = target_version
else:
# Find last archived version (previous production)
archived = client.get_latest_versions(model_name, stages=["Archived"])
if not archived:
raise ValueError("No archived version to rollback to")
version = archived[-1].version
# Demote current production
prod = client.get_latest_versions(model_name, stages=["Production"])
for p in prod:
client.transition_model_version_stage(
name=model_name, version=p.version, stage="Archived"
)
# Promote rollback version
client.transition_model_version_stage(
name=model_name, version=version, stage="Production"
)
print(f"Rolled back {model_name} to v{version}")
return version
# Usage: rollback_model("churn-prediction")
"""
def show_metrics(self):
print("=== Monitoring Metrics ===\n")
for name, desc in self.METRICS.items():
print(f" [{name}] {desc}")
def show_rollback(self):
print(f"\n=== Rollback Script ===")
print(self.ROLLBACK[:500])
def dashboard(self):
print(f"\n=== Live Model Dashboard ===")
models = [
{"name": "churn-prediction", "accuracy": random.uniform(0.88, 0.95), "p50": random.uniform(8, 20), "drift": random.uniform(0, 0.15)},
{"name": "fraud-detector", "accuracy": random.uniform(0.92, 0.99), "p50": random.uniform(5, 15), "drift": random.uniform(0, 0.10)},
]
for m in models:
drift_status = "OK" if m["drift"] < 0.1 else "DRIFT"
print(f" [{drift_status:>5}] {m['name']:<25} Acc: {m['accuracy']:.3f} | P50: {m['p50']:.1f}ms | Drift: {m['drift']:.3f}")
mon = ModelMonitoring()
mon.show_metrics()
mon.dashboard()
FAQ - คำถามที่พบบ่อย
Q: MLflow กับ Weights & Biases (W&B) ต่างกัน?
A: MLflow: open source, self-hosted ได้, model registry built-in, ฟรี W&B: SaaS, UI สวยกว่า, collaboration ดี, hyperparameter sweep ใช้ MLflow: self-hosted, model serving, budget จำกัด, full lifecycle ใช้ W&B: experiment tracking focus, team collaboration, มี budget สามารถใช้ร่วมกันได้ — W&B track experiments, MLflow manage models
Q: Feature Flags จำเป็นสำหรับ ML ไหม?
A: จำเป็นมากสำหรับ production ML เพราะ: 1) Canary deployment — ทดสอบ model ใหม่กับ traffic จริง 2) Instant rollback — ถ้า model ใหม่ไม่ดี rollback ทันที 3) A/B testing — เปรียบเทียบ models กับ users จริง 4) Gradual rollout — ค่อยๆ เพิ่ม traffic ไป model ใหม่ Tools: LaunchDarkly, Flagsmith, Unleash, custom solution
Q: Model Registry ใช้ยังไง?
A: Workflow: Train → Register → Staging → Validate → Production Staging: ทดสอบกับ test data + integration tests Production: serve ให้ users จริง Archived: versions เก่า (สำหรับ rollback) ทุก transition ต้องผ่าน validation — อย่า promote โดยไม่ test
Q: Monitor model drift อย่างไร?
A: Data Drift: เปรียบเทียบ feature distributions (KS test, PSI) Prediction Drift: เปรียบเทียบ output distributions Accuracy Drift: เทียบ predictions กับ actual outcomes (delayed) Tools: Evidently AI, NannyML, WhyLabs, custom Grafana dashboards Alert: ถ้า drift > threshold → investigate → retrain ถ้าจำเป็น
