未验证 提交 1460b761 编写于 作者: A Aurelius84 提交者: GitHub

Fix data transform bug in new executor (#37280)

上级 8c44ad47
...@@ -289,7 +289,6 @@ std::tuple<std::string, OpFuncNode> apply_place_transform_for_var( ...@@ -289,7 +289,6 @@ std::tuple<std::string, OpFuncNode> apply_place_transform_for_var(
const OpKernelType& expected_kernel_key, const platform::Place& place, const OpKernelType& expected_kernel_key, const platform::Place& place,
const std::string& var_name, const std::string& outer_name, const std::string& var_name, const std::string& outer_name,
const OpFuncNode& op_func_node, Variable* var, VariableScope* var_scope) { const OpFuncNode& op_func_node, Variable* var, VariableScope* var_scope) {
auto& ins_name2id = op_func_node.input_index;
auto& all_op_kernels = OperatorWithKernel::AllOpKernels(); auto& all_op_kernels = OperatorWithKernel::AllOpKernels();
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
std::string new_var_name = std::string new_var_name =
...@@ -307,7 +306,7 @@ std::tuple<std::string, OpFuncNode> apply_place_transform_for_var( ...@@ -307,7 +306,7 @@ std::tuple<std::string, OpFuncNode> apply_place_transform_for_var(
: is_gpu_place(expected_kernel_key.place_) ? 1 : -1; : is_gpu_place(expected_kernel_key.place_) ? 1 : -1;
std::map<std::string, std::vector<int>> copy_ins_name2id; std::map<std::string, std::vector<int>> copy_ins_name2id;
copy_ins_name2id["X"] = ins_name2id.at(outer_name); copy_ins_name2id["X"] = {var_scope->VarId(var_name)};
std::map<std::string, std::vector<int>> copy_out_name2id; std::map<std::string, std::vector<int>> copy_out_name2id;
copy_out_name2id["Out"] = {var_scope->VarId(new_var_name)}; copy_out_name2id["Out"] = {var_scope->VarId(new_var_name)};
......
...@@ -61,6 +61,8 @@ USE_OP(elementwise_max); ...@@ -61,6 +61,8 @@ USE_OP(elementwise_max);
USE_OP(elementwise_div); USE_OP(elementwise_div);
USE_OP(sgd); USE_OP(sgd);
USE_OP(squared_l2_norm); USE_OP(squared_l2_norm);
USE_OP(memcpy_h2d);
USE_OP(memcpy_d2h);
paddle::framework::ProgramDesc load_from_file(const std::string& file_name) { paddle::framework::ProgramDesc load_from_file(const std::string& file_name) {
std::ifstream fin(file_name, std::ios::in | std::ios::binary); std::ifstream fin(file_name, std::ios::in | std::ios::binary);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册