File size: 3,187 Bytes
bf5116f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import os
import sys
import numpy as np
import argparse
import h5py
import pickle
import matplotlib.pyplot as plt
from utilities import create_folder
def plot(args):
# Arguments & parameters
workspace = args.workspace
select = args.select
max_plot_iteration = 300001
iterations = np.arange(0, max_plot_iteration, 5000)
metric_types = ['frame_ap', 'reg_onset_mae', 'reg_offset_mae',
'velocity_mae', 'reg_pedal_onset_mae', 'reg_pedal_offset_mae']
save_out_path = 'results/{}.pdf'.format(select)
create_folder(os.path.dirname(save_out_path))
# Plot
fig, axes = plt.subplots(2, 3, figsize=(8, 5))
lines = []
def _load_metrics(filename, model_type, loss_type, augmentation,
max_note_shift, batch_size, data_type, metric_type):
statistics_path = os.path.join(workspace, 'statistics', filename,
model_type, 'loss_type={}'.format(loss_type),
'augmentation={}'.format(augmentation), 'max_note_shift={}'.format(max_note_shift),
'batch_size={}'.format(batch_size), 'statistics.pkl')
statistics_dict = pickle.load(open(statistics_path, 'rb'))
if metric_type in statistics_dict[data_type][0].keys():
metrics = np.array([statistics[metric_type] for statistics in statistics_dict[data_type]])
return metrics
else:
return None
ylims = [[0, 1], [0, 0.5], [0, 0.5], [0, 0.3], [0, 0.3], [0, 0.3]]
legend_locs = [4, 1, 1, 1, 1, 1]
if select == '1a':
for j, metric_type in enumerate(metric_types):
lines = []
for data_type in ['train', 'test']:
metrics = _load_metrics('main',
'Regress_onset_offset_frame_velocity_CRNN',
'regress_onset_offset_frame_velocity_bce', 'none', 0, 12,
data_type, metric_type)
if metrics is not None:
line, = axes[j // 3, j % 3].plot(metrics, label=data_type)
lines.append(line)
axes[j // 3, j % 3].set_title(metric_type)
axes[j // 3, j % 3].legend(handles=lines, loc=legend_locs[j])
axes[j // 3, j % 3].set_ylim(ylims[j][0], ylims[j][1])
axes[j // 3, j % 3].set_xlim(0, len(iterations))
axes[j // 3, j % 3].xaxis.set_ticks(np.arange(0, len(iterations), 20))
axes[j // 3, j % 3].xaxis.set_ticklabels(['0', '100k', '200k', '300k'])
axes[j // 3, j % 3].set_xlabel('Iterations')
plt.tight_layout(0, 1, 0)
plt.savefig(save_out_path)
print('Write out to {}'.format(save_out_path))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='')
subparsers = parser.add_subparsers(dest='mode')
parser_plot = subparsers.add_parser('plot')
parser_plot.add_argument('--workspace', type=str, required=True)
parser_plot.add_argument('--select', type=str, required=True)
args = parser.parse_args()
if args.mode == 'plot':
plot(args)
else:
raise Exception('Error argument!') |