Source code for src.train.utils.common

import os

import pathspec
import torch

from src.train.utils.dist_utils import is_main_process, DistSummaryWriter


[docs]def save_model(net, optimizer, epoch, save_path, distributed): if is_main_process(): model_state_dict = net.state_dict() state = {'model': model_state_dict, 'optimizer': optimizer.state_dict()} # state = {'model': model_state_dict} assert os.path.exists(save_path) model_path = os.path.join(save_path, 'ep%03d.pth' % epoch) torch.save(state, model_path)
[docs]def cp_projects(to_path): if is_main_process(): with open('./.gitignore', 'r') as fp: ign = fp.read() ign += '\n.git' spec = pathspec.PathSpec.from_lines(pathspec.patterns.GitWildMatchPattern, ign.splitlines()) all_files = {os.path.join(root, name) for root, dirs, files in os.walk('./') for name in files} matches = spec.match_files(all_files) matches = set(matches) to_cp_files = all_files - matches # to_cp_files = [f[2:] for f in to_cp_files] for f in to_cp_files: dirs = os.path.join(to_path, 'code', os.path.split(f[2:])[0]) if not os.path.exists(dirs): os.makedirs(dirs) os.system('cp "%s" "%s"' % (f, os.path.join(to_path, 'code', f[2:])))
[docs]def get_logger(work_dir, cfg): logger = DistSummaryWriter(work_dir) config_txt = os.path.join(work_dir, 'cfg.txt') if is_main_process(): with open(config_txt, 'w') as fp: fp.write(str(cfg)) return logger