import os
import pathspec
import torch
from src.train.utils.dist_utils import is_main_process, DistSummaryWriter
[docs]def save_model(net, optimizer, epoch, save_path, distributed):
if is_main_process():
model_state_dict = net.state_dict()
state = {'model': model_state_dict, 'optimizer': optimizer.state_dict()}
# state = {'model': model_state_dict}
assert os.path.exists(save_path)
model_path = os.path.join(save_path, 'ep%03d.pth' % epoch)
torch.save(state, model_path)
[docs]def cp_projects(to_path):
if is_main_process():
with open('./.gitignore', 'r') as fp:
ign = fp.read()
ign += '\n.git'
spec = pathspec.PathSpec.from_lines(pathspec.patterns.GitWildMatchPattern, ign.splitlines())
all_files = {os.path.join(root, name) for root, dirs, files in os.walk('./') for name in files}
matches = spec.match_files(all_files)
matches = set(matches)
to_cp_files = all_files - matches
# to_cp_files = [f[2:] for f in to_cp_files]
for f in to_cp_files:
dirs = os.path.join(to_path, 'code', os.path.split(f[2:])[0])
if not os.path.exists(dirs):
os.makedirs(dirs)
os.system('cp "%s" "%s"' % (f, os.path.join(to_path, 'code', f[2:])))
[docs]def get_logger(work_dir, cfg):
logger = DistSummaryWriter(work_dir)
config_txt = os.path.join(work_dir, 'cfg.txt')
if is_main_process():
with open(config_txt, 'w') as fp:
fp.write(str(cfg))
return logger