提交 c39c6e22 编写于 作者: S SunGaofeng

fix name mismatch with config txt

上级 b73e4832
...@@ -37,7 +37,7 @@ class AttentionLSTM(ModelBase): ...@@ -37,7 +37,7 @@ class AttentionLSTM(ModelBase):
# get mode configs # get mode configs
self.batch_size = self.get_config_from_sec(self.mode, 'batch_size', 1) self.batch_size = self.get_config_from_sec(self.mode, 'batch_size', 1)
self.gpu_num = self.get_config_from_sec(self.mode, 'gpu_num', 1) self.num_gpus = self.get_config_from_sec(self.mode, 'num_gpus', 1)
if self.mode == 'train': if self.mode == 'train':
self.learning_rate = self.get_config_from_sec('train', self.learning_rate = self.get_config_from_sec('train',
...@@ -134,7 +134,7 @@ class AttentionLSTM(ModelBase): ...@@ -134,7 +134,7 @@ class AttentionLSTM(ModelBase):
cost = fluid.layers.reduce_sum(cost, dim=-1) cost = fluid.layers.reduce_sum(cost, dim=-1)
sum_cost = fluid.layers.reduce_sum(cost) sum_cost = fluid.layers.reduce_sum(cost)
self.loss_ = fluid.layers.scale( self.loss_ = fluid.layers.scale(
sum_cost, scale=self.gpu_num, bias_after_scale=False) sum_cost, scale=self.num_gpus, bias_after_scale=False)
return self.loss_ return self.loss_
def outputs(self): def outputs(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册