| | import argparse |
| | import collections |
| | import datetime |
| | import json |
| | import os |
| |
|
| | import _jsonnet |
| | import attr |
| | import torch |
| |
|
| | |
| | from seq2struct import ast_util |
| | |
| | from seq2struct import datasets |
| | |
| | from seq2struct import models |
| | |
| | from seq2struct import optimizers |
| |
|
| | from seq2struct.utils import registry |
| | from seq2struct.utils import random_state |
| | from seq2struct.utils import saver as saver_mod |
| |
|
| | |
| | from seq2struct.utils import vocab |
| |
|
| |
|
| | @attr.s |
| | class TrainConfig: |
| | eval_every_n = attr.ib(default=100) |
| | report_every_n = attr.ib(default=100) |
| | save_every_n = attr.ib(default=100) |
| | keep_every_n = attr.ib(default=1000) |
| |
|
| | batch_size = attr.ib(default=32) |
| | eval_batch_size = attr.ib(default=32) |
| | max_steps = attr.ib(default=100000) |
| | num_eval_items = attr.ib(default=None) |
| | eval_on_train = attr.ib(default=True) |
| | eval_on_val = attr.ib(default=True) |
| |
|
| | |
| | data_seed = attr.ib(default=None) |
| | |
| | init_seed = attr.ib(default=None) |
| | |
| | |
| | model_seed = attr.ib(default=None) |
| |
|
| | num_batch_accumulated = attr.ib(default=1) |
| | clip_grad = attr.ib(default=None) |
| |
|
| |
|
| | class Logger: |
| | def __init__(self, log_path=None, reopen_to_flush=False): |
| | self.log_file = None |
| | self.reopen_to_flush = reopen_to_flush |
| | if log_path is not None: |
| | os.makedirs(os.path.dirname(log_path), exist_ok=True) |
| | self.log_file = open(log_path, 'a+') |
| |
|
| | def log(self, msg): |
| | formatted = '[{}] {}'.format( |
| | datetime.datetime.now().replace(microsecond=0).isoformat(), |
| | msg) |
| | print(formatted) |
| | if self.log_file: |
| | self.log_file.write(formatted + '\n') |
| | if self.reopen_to_flush: |
| | log_path = self.log_file.name |
| | self.log_file.close() |
| | self.log_file = open(log_path, 'a+') |
| | else: |
| | self.log_file.flush() |
| |
|
| | class Trainer: |
| | def __init__(self, logger, config): |
| | if torch.cuda.is_available(): |
| | self.device = torch.device('cuda') |
| | else: |
| | self.device = torch.device('cpu') |
| |
|
| | self.logger = logger |
| | self.train_config = registry.instantiate(TrainConfig, config['train']) |
| | self.data_random = random_state.RandomContext(self.train_config.data_seed) |
| | self.model_random = random_state.RandomContext(self.train_config.model_seed) |
| |
|
| | self.init_random = random_state.RandomContext(self.train_config.init_seed) |
| | with self.init_random: |
| | |
| | self.model_preproc = registry.instantiate( |
| | registry.lookup('model', config['model']).Preproc, |
| | config['model'], |
| | unused_keys=('name',)) |
| | self.model_preproc.load() |
| |
|
| | |
| | self.model = registry.construct('model', config['model'], |
| | unused_keys=('encoder_preproc', 'decoder_preproc'), preproc=self.model_preproc, device=self.device) |
| | self.model.to(self.device) |
| |
|
| | def train(self, config, modeldir): |
| | |
| | with self.init_random: |
| | |
| | |
| | |
| | |
| |
|
| | |
| | if config["optimizer"].get("name", None) == 'bertAdamw': |
| | bert_params = list(self.model.encoder.bert_model.parameters()) |
| | assert len(bert_params) > 0 |
| | non_bert_params = [] |
| | for name, _param in self.model.named_parameters(): |
| | if "bert" not in name: |
| | non_bert_params.append(_param) |
| | assert len(non_bert_params) + len(bert_params) == len(list(self.model.parameters())) |
| |
|
| | optimizer = registry.construct('optimizer', config['optimizer'], non_bert_params=non_bert_params, \ |
| | bert_params=bert_params) |
| | lr_scheduler = registry.construct( 'lr_scheduler', |
| | config.get('lr_scheduler', {'name': 'noop'}), |
| | param_groups=[optimizer.non_bert_param_group, \ |
| | optimizer.bert_param_group]) |
| | else: |
| | optimizer = registry.construct('optimizer', config['optimizer'], params=self.model.parameters()) |
| | lr_scheduler = registry.construct( 'lr_scheduler', |
| | config.get('lr_scheduler', {'name': 'noop'}), |
| | param_groups=optimizer.param_groups) |
| |
|
| | |
| | saver = saver_mod.Saver( |
| | {"model": self.model, "optimizer": optimizer}, keep_every_n=self.train_config.keep_every_n) |
| | last_step = saver.restore(modeldir, map_location=self.device) |
| |
|
| | if "pretrain" in config and last_step == 0: |
| | pretrain_config = config["pretrain"] |
| | _path = pretrain_config["pretrained_path"] |
| | _step = pretrain_config["checkpoint_step"] |
| | pretrain_step = saver.restore(_path, step=_step, map_location=self.device, item_keys=["model"]) |
| | saver.save(modeldir, pretrain_step) |
| | last_step = pretrain_step |
| |
|
| | |
| | with self.data_random: |
| | train_data = self.model_preproc.dataset('train') |
| | train_data_loader = self._yield_batches_from_epochs( |
| | torch.utils.data.DataLoader( |
| | train_data, |
| | batch_size=self.train_config.batch_size, |
| | shuffle=True, |
| | drop_last=True, |
| | collate_fn=lambda x: x)) |
| | train_eval_data_loader = torch.utils.data.DataLoader( |
| | train_data, |
| | batch_size=self.train_config.eval_batch_size, |
| | collate_fn=lambda x: x) |
| |
|
| | val_data = self.model_preproc.dataset('val') |
| | val_data_loader = torch.utils.data.DataLoader( |
| | val_data, |
| | batch_size=self.train_config.eval_batch_size, |
| | collate_fn=lambda x: x) |
| |
|
| | |
| | with self.data_random: |
| | for batch in train_data_loader: |
| | |
| | if last_step >= self.train_config.max_steps: |
| | break |
| |
|
| | |
| | if last_step % self.train_config.eval_every_n == 0: |
| | if self.train_config.eval_on_train: |
| | self._eval_model(self.logger, self.model, last_step, train_eval_data_loader, 'train', num_eval_items=self.train_config.num_eval_items) |
| | if self.train_config.eval_on_val: |
| | self._eval_model(self.logger, self.model, last_step, val_data_loader, 'val', num_eval_items=self.train_config.num_eval_items) |
| |
|
| | |
| | with self.model_random: |
| | for _i in range(self.train_config.num_batch_accumulated): |
| | if _i > 0: batch = next(train_data_loader) |
| | loss = self.model.compute_loss(batch) |
| | norm_loss = loss / self.train_config.num_batch_accumulated |
| | norm_loss.backward() |
| |
|
| | if self.train_config.clip_grad: |
| | torch.nn.utils.clip_grad_norm_(optimizer.bert_param_group["params"], \ |
| | self.train_config.clip_grad) |
| | optimizer.step() |
| | lr_scheduler.update_lr(last_step) |
| | optimizer.zero_grad() |
| |
|
| | |
| | if last_step % self.train_config.report_every_n == 0: |
| | self.logger.log('Step {}: loss={:.4f}'.format(last_step, loss.item())) |
| |
|
| | last_step += 1 |
| | |
| | if last_step % self.train_config.save_every_n == 0: |
| | saver.save(modeldir, last_step) |
| |
|
| | |
| | saver.save(modeldir, last_step) |
| |
|
| |
|
| |
|
| | @staticmethod |
| | def _yield_batches_from_epochs(loader): |
| | while True: |
| | for batch in loader: |
| | yield batch |
| | |
| | @staticmethod |
| | def _eval_model(logger, model, last_step, eval_data_loader, eval_section, num_eval_items=None): |
| | stats = collections.defaultdict(float) |
| | model.eval() |
| | with torch.no_grad(): |
| | for eval_batch in eval_data_loader: |
| | batch_res = model.eval_on_batch(eval_batch) |
| | for k, v in batch_res.items(): |
| | stats[k] += v |
| | if num_eval_items and stats['total'] > num_eval_items: |
| | break |
| | model.train() |
| |
|
| | |
| | for k in stats: |
| | if k != 'total': |
| | stats[k] /= stats['total'] |
| | if 'total' in stats: |
| | del stats['total'] |
| |
|
| | logger.log("Step {} stats, {}: {}".format( |
| | last_step, eval_section, ", ".join( |
| | "{} = {}".format(k, v) for k, v in stats.items()))) |
| |
|
| | def add_parser(): |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument('--logdir', required=True) |
| | parser.add_argument('--config', required=True) |
| | parser.add_argument('--config-args') |
| | args = parser.parse_args() |
| | return args |
| |
|
| | def main(args): |
| | if args.config_args: |
| | config = json.loads(_jsonnet.evaluate_file(args.config, tla_codes={'args': args.config_args})) |
| | else: |
| | config = json.loads(_jsonnet.evaluate_file(args.config)) |
| |
|
| | if 'model_name' in config: |
| | args.logdir = os.path.join(args.logdir, config['model_name']) |
| |
|
| | |
| | reopen_to_flush = config.get('log', {}).get('reopen_to_flush') |
| | logger = Logger(os.path.join(args.logdir, 'log.txt'), reopen_to_flush) |
| |
|
| | |
| | with open(os.path.join(args.logdir, |
| | 'config-{}.json'.format( |
| | datetime.datetime.now().strftime('%Y%m%dT%H%M%S%Z'))), 'w', encoding='utf8') as f: |
| | json.dump(config, f, sort_keys=True, indent=4, ensure_ascii=False) |
| |
|
| | logger.log('Logging to {}'.format(args.logdir)) |
| |
|
| | |
| | trainer = Trainer(logger, config) |
| | trainer.train(config, modeldir=args.logdir) |
| |
|
| | if __name__ == '__main__': |
| | args = add_parser() |
| | main(args) |
| |
|