| import os |
| import glob |
| import pandas as pd |
| import matplotlib.pyplot as plt |
| import numpy as np |
| import re |
|
|
| def extract_run_name(filename): |
| """Extract the run name from the filename.""" |
| basename = os.path.basename(filename) |
| |
| match = re.search(r'_([^_]+)(?:-loss)?_tensorboard\.csv$', basename) |
| if match: |
| return match.group(1) |
| return basename.split('_')[1].split('-')[0] |
|
|
| def setup_plot_style(): |
| """Apply publication-quality styling to plots.""" |
| plt.rcParams.update({ |
| 'font.family': 'serif', |
| 'font.size': 12, |
| 'axes.labelsize': 14, |
| 'axes.titlesize': 16, |
| 'legend.fontsize': 10, |
| 'figure.dpi': 300, |
| 'figure.figsize': (10, 6), |
| 'lines.linewidth': 2.5, |
| 'axes.grid': True, |
| 'grid.linestyle': '--', |
| 'grid.alpha': 0.6, |
| 'axes.spines.top': False, |
| 'axes.spines.right': False, |
| }) |
|
|
| def get_metric_label(metric_name): |
| """Return a human-readable label for the metric.""" |
| labels = { |
| 'loss_epoch': 'Loss', |
| 'perplexityval_epoch': 'Validation Perplexity', |
| 'topkacc_epoch': 'Top-K Accuracy', |
| 'acc_trainstep': 'Training Accuracy' |
| } |
| return labels.get(metric_name, metric_name.replace('_', ' ').title()) |
|
|
| def get_color_mapping(run_names): |
| """Create a consistent color mapping for all runs.""" |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
|
|
| |
|
|
| |
|
|
| |
|
|
| |
|
|
| |
|
|
| |
|
|
| |
|
|
| |
|
|
| |
| colors = [ |
| "#e6194b", |
| "#f58231", |
| "#ffe119", |
| "#bfef45", |
| "#3cb44b", |
| "#42d4f4", |
| "#4363d8", |
| "#911eb4", |
| "#f032e6", |
| "#a9a9a9" |
| ] |
| |
| |
| return {name: colors[i % len(colors)] for i, name in enumerate(sorted(run_names))} |
|
|
| def plot_metric(metric_dir, color_mapping, output_dir): |
| """Plot all runs for a specific metric.""" |
| metric_name = os.path.basename(metric_dir) |
| csv_files = glob.glob(os.path.join(metric_dir, '*.csv')) |
| |
| if not csv_files: |
| print(f"No CSV files found in {metric_dir}") |
| return |
| |
| plt.figure(figsize=(12, 7)) |
| |
| for csv_file in sorted(csv_files): |
| try: |
| |
| df = pd.read_csv(csv_file) |
| |
| |
| run_name = extract_run_name(csv_file) |
| |
| |
| color = color_mapping.get(run_name, 'gray') |
| plt.plot(df['Step'], df['Value'], label=run_name, color=color, alpha=0.9) |
| |
| |
| except Exception as e: |
| print(f"Error processing {csv_file}: {e}") |
| |
| |
| plt.xlabel('Step') |
| plt.ylabel(get_metric_label(metric_name)) |
|
|
| comparison = "Epoch" if "epoch" in metric_name else "Step" |
| plt.title(f'{get_metric_label(metric_name)} vs. {comparison}', fontweight='bold') |
| |
| |
| plt.legend(loc='best', frameon=True, fancybox=True, framealpha=0.9, |
| shadow=True, borderpad=1, ncol=2 if len(csv_files) > 5 else 1) |
| |
| |
| plt.grid(True, linestyle='--', alpha=0.7) |
| |
| |
| plt.tight_layout() |
| |
| |
| os.makedirs(output_dir, exist_ok=True) |
| output_path = os.path.join(output_dir, f'{metric_name}_plot.png') |
| plt.savefig(output_path, bbox_inches='tight') |
| print(f"Saved plot to {output_path}") |
| |
| |
| plt.close() |
|
|
| def main(): |
| |
| base_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'runs_jsons') |
| |
| |
| output_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'plots') |
| os.makedirs(output_dir, exist_ok=True) |
| |
| |
| setup_plot_style() |
| |
| |
| metric_dirs = [d for d in glob.glob(os.path.join(base_dir, '*')) if os.path.isdir(d)] |
| |
| |
| all_run_names = set() |
| for metric_dir in metric_dirs: |
| csv_files = glob.glob(os.path.join(metric_dir, '*.csv')) |
| for csv_file in csv_files: |
| run_name = extract_run_name(csv_file) |
| all_run_names.add(run_name) |
| |
| |
| color_mapping = get_color_mapping(all_run_names) |
| |
| |
| for metric_dir in metric_dirs: |
| plot_metric(metric_dir, color_mapping, output_dir) |
| |
| print(f"All plots have been generated in {output_dir}") |
|
|
| if __name__ == '__main__': |
| main() |
|
|