提交 68b00b25 编写于 作者: Q qiuxuezhong

need nott bcast_param

上级 5b60a6f7
...@@ -404,7 +404,6 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler, ...@@ -404,7 +404,6 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
feed_dict[sum_cost.name + "@GRAD"] = 1. / total_num_token feed_dict[sum_cost.name + "@GRAD"] = 1. / total_num_token
outs = train_exe.run(fetch_list=[sum_cost.name, token_num.name], outs = train_exe.run(fetch_list=[sum_cost.name, token_num.name],
feed=feed_list) feed=feed_list)
train_exe.bcast_params()
sum_cost_val, token_num_val = np.array(outs[0]), np.array(outs[1]) sum_cost_val, token_num_val = np.array(outs[0]), np.array(outs[1])
total_sum_cost = sum_cost_val.sum( total_sum_cost = sum_cost_val.sum(
) # sum the cost from multi-devices ) # sum the cost from multi-devices
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册