提交 1b2ca6e6 编写于 作者: T tink2123

polish code

上级 c9e1077d
...@@ -326,6 +326,4 @@ class STN_ON(nn.Layer): ...@@ -326,6 +326,4 @@ class STN_ON(nn.Layer):
image, self.tps_inputsize, mode="bilinear", align_corners=True) image, self.tps_inputsize, mode="bilinear", align_corners=True)
stn_img_feat, ctrl_points = self.stn_head(stn_input) stn_img_feat, ctrl_points = self.stn_head(stn_input)
x, _ = self.tps(image, ctrl_points) x, _ = self.tps(image, ctrl_points)
#print("x:", np.sum(x.numpy()))
# print(x.shape)
return x return x
...@@ -215,9 +215,6 @@ def train(config, ...@@ -215,9 +215,6 @@ def train(config,
preds = model(images, data=batch[1:]) preds = model(images, data=batch[1:])
else: else:
preds = model(images) preds = model(images)
state_dict = model.state_dict()
# for key in state_dict:
# print(key)
loss = loss_class(preds, batch) loss = loss_class(preds, batch)
avg_loss = loss['loss'] avg_loss = loss['loss']
avg_loss.backward() avg_loss.backward()
...@@ -414,7 +411,6 @@ def preprocess(is_train=False): ...@@ -414,7 +411,6 @@ def preprocess(is_train=False):
yaml.dump( yaml.dump(
dict(config), f, default_flow_style=False, sort_keys=False) dict(config), f, default_flow_style=False, sort_keys=False)
log_file = '{}/train.log'.format(save_model_dir) log_file = '{}/train.log'.format(save_model_dir)
print("log has save in {}/train.log".format(save_model_dir))
else: else:
log_file = None log_file = None
logger = get_logger(name='root', log_file=log_file) logger = get_logger(name='root', log_file=log_file)
......
...@@ -72,8 +72,6 @@ def main(config, device, logger, vdl_writer): ...@@ -72,8 +72,6 @@ def main(config, device, logger, vdl_writer):
# for rec algorithm # for rec algorithm
if hasattr(post_process_class, 'character'): if hasattr(post_process_class, 'character'):
char_num = len(getattr(post_process_class, 'character')) char_num = len(getattr(post_process_class, 'character'))
character = getattr(post_process_class, 'character')
print("getattr character:", character)
if config['Architecture']["algorithm"] in ["Distillation", if config['Architecture']["algorithm"] in ["Distillation",
]: # distillation model ]: # distillation model
for key in config['Architecture']["Models"]: for key in config['Architecture']["Models"]:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册