提交 37f90f5a 编写于 作者: Y Yu Yang

Use var

上级 0fd337bf
......@@ -202,7 +202,7 @@ def main():
tensor = position_encoding_init(ModelHyperParams.max_length + 1, ModelHyperParams.d_model)
for place_id in xrange(dev_count):
local_scope = train_exe.executor.local_scope(place_id)
local_scope.find_var(pos_enc_param_name).get_tensor().set(tensor, fluid.CUDAPlace(place_id))
local_scope.var(pos_enc_param_name).get_tensor().set(tensor, fluid.CUDAPlace(place_id))
train_data = read_multiple(reader=train_data, count=dev_count)
for pass_id in xrange(TrainTaskConfig.pass_num):
......@@ -221,12 +221,12 @@ def main():
fluid.CUDAPlace(place_id))
for var_name in data_input_dict:
local_scope.find_var(var_name).get_tensor().set(data_input_dict[var_name],
local_scope.var(var_name).get_tensor().set(data_input_dict[var_name],
fluid.CUDAPlace(place_id))
for var_name in util_input_dict:
print var_name, local_scope.find_var(var_name)
local_scope.find_var(var_name).get_tensor().set(util_input_dict[var_name],
local_scope.var(var_name).get_tensor().set(util_input_dict[var_name],
fluid.CUDAPlace(place_id))
outs = train_exe.run(fetch_list=[sum_cost.name, token_num.name])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册