AnnsKhan commited on
Commit
baa5edc
·
1 Parent(s): 7ae953d
Files changed (2) hide show
  1. mlflow_config.yaml +12 -0
  2. mlflow_tracker.py +66 -0
mlflow_config.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mlflow_config.yaml
2
+
3
+ experiment_name: "billion-row-analysis"
4
+ run_name: "benchmarking"
5
+ tracking_uri: "http://localhost:5000" # Optional: set if you have a remote tracking server
6
+ metrics:
7
+ - Library
8
+ - Load Time (s)
9
+ - CPU Load (%)
10
+ - Memory Load (%)
11
+ - Peak Memory (%)
12
+
mlflow_tracker.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import mlflow
2
+ import yaml
3
+ import os
4
+
5
+ class MLflowTracker:
6
+ """
7
+ A reusable MLflow tracking class that reads configuration from a YAML file.
8
+ This class sets up the MLflow experiment and run, and exposes methods to log parameters,
9
+ metrics, and artifacts.
10
+ """
11
+ def __init__(self, config_file="mlflow_config.yaml"):
12
+ # Load configuration from the YAML file.
13
+ if not os.path.exists(config_file):
14
+ raise FileNotFoundError(f"Config file '{config_file}' not found.")
15
+ with open(config_file, "r") as f:
16
+ self.config = yaml.safe_load(f)
17
+
18
+ # Set up configuration parameters
19
+ self.experiment_name = self.config.get("experiment_name", "Default_Experiment")
20
+ self.run_name = self.config.get("run_name", "Default_Run")
21
+ self.tracking_uri = self.config.get("tracking_uri", None)
22
+ self.metrics_to_track = self.config.get("metrics", [])
23
+
24
+ # Set tracking URI if provided
25
+ if self.tracking_uri:
26
+ mlflow.set_tracking_uri(self.tracking_uri)
27
+
28
+ # Set the experiment
29
+ mlflow.set_experiment(self.experiment_name)
30
+
31
+ # Start the run
32
+ self.run = mlflow.start_run(run_name=self.run_name)
33
+ print(f"MLflow run started: Experiment='{self.experiment_name}', Run='{self.run_name}'")
34
+
35
+ def log_param(self, key, value):
36
+ """Log a single parameter."""
37
+ mlflow.log_param(key, value)
38
+
39
+ def log_params(self, params: dict):
40
+ """Log multiple parameters from a dictionary."""
41
+ mlflow.log_params(params)
42
+
43
+ def log_metric(self, key, value, step=None):
44
+ """Log a single metric. Optionally include a step value."""
45
+ mlflow.log_metric(key, value, step=step)
46
+
47
+ def log_metrics(self, metrics: dict, step=None):
48
+ """Log multiple metrics from a dictionary."""
49
+ for key, value in metrics.items():
50
+ self.log_metric(key, value, step=step)
51
+
52
+ def log_artifact(self, file_path, artifact_path=None):
53
+ """Log an artifact (file) to MLflow."""
54
+ mlflow.log_artifact(file_path, artifact_path=artifact_path)
55
+
56
+ def end_run(self):
57
+ """End the current MLflow run."""
58
+ mlflow.end_run()
59
+ print("MLflow run ended.")
60
+
61
+ # Example usage (can be removed or placed in a separate test script):
62
+ if __name__ == "__main__":
63
+ tracker = MLflowTracker("mlflow_config.yaml")
64
+ tracker.log_param("example_param", 123)
65
+ tracker.log_metric("example_metric", 0.95)
66
+ tracker.end_run()