diff --git a/paddle/fluid/framework/new_exec.h b/paddle/fluid/framework/new_exec.h index 755c88a53f39c0c495c19f86e4753e245b778cf4..39be163eb1eb494dc00e044b8d70c1a6cb89684a 100644 --- a/paddle/fluid/framework/new_exec.h +++ b/paddle/fluid/framework/new_exec.h @@ -17,12 +17,13 @@ #include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" - +#include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/platform/init.h" #include #include + //USE_OP(fill_constant); //USE_OP(elementwise_add); @@ -524,7 +525,8 @@ void build_variable_scope( const framework::ProgramDesc& pdesc, VariableScope* v } auto v = new Variable(); - v->GetMutable(); + //v->GetMutable(); + InitializeVariable(v, var->GetType()); var_scope->var_list.push_back(std::unique_ptr(v)); } } @@ -537,9 +539,9 @@ void build_op_func_list( const framework::ProgramDesc& pdesc, std::vectorType() << endl; + cerr << op->Type() << endl; //bool debug = op->Type() == "softmax_with_cross_entropy_grad"; - bool debug = false; + bool debug = true; //cerr << "create op" << endl; //auto op_base_u = OpRegistry::CreateOp(*op); @@ -634,16 +636,25 @@ void build_op_func_list( const framework::ProgramDesc& pdesc, std::vector(&(var->Get())); + cerr << "i " << i << "\t" << tensor_in->IsInitialized() << endl; auto kernel_type_for_var = static_cast(op_base)->GetKernelTypeForVar( var_name_item.first, *tensor_in, expected_kernel_key); - + if( debug) + { + cerr << "var name " << var_name_item.first << endl; + cerr << expected_kernel_key.place_ << "\t" << kernel_type_for_var.place_ << endl; + } if ( !platform::is_same_place(kernel_type_for_var.place_, expected_kernel_key.place_) ) { @@ -658,7 +669,8 @@ void build_op_func_list( const framework::ProgramDesc& pdesc, std::vectorvar_list.push_back(std::unique_ptr(v)); VariableNameMap copy_in_map; - copy_in_map["X"] = input_names[var_name_item.first]; + cerr << "ints name is " << input_names[var_name_item.first][i] << endl; + copy_in_map["X"] = { input_names[var_name_item.first][i] }; VariableNameMap copy_out_map; copy_out_map["Out"] = { new_var_name }; AttributeMap attr_map; @@ -669,25 +681,32 @@ void build_op_func_list( const framework::ProgramDesc& pdesc, std::vector > copy_out_name2id; copy_out_name2id["Out"] = { var_scope->name2id[new_var_name]}; - vec_ids[i] = var_scope->name2id[new_var_name]; + //vec_ids[i] = var_scope->name2id[new_var_name]; + // update out runtime_context + op_func_node.input_index[ var_name_item.first ][i] = var_scope->name2id[new_var_name]; - VariableValueMap copy_ins_value_map; - copy_ins_value_map["X"] = ins_map[ var_name_item.first ]; + VariableValueMap copy_ins_value_map; + copy_ins_value_map["X"] = { var }; VariableValueMap copy_outs_value_map; copy_outs_value_map["Out"] = { v }; - - auto copy_op = info.Creator()( "memcpy", copy_in_map, copy_out_map, attr_map); + + + auto& copy_info = OpInfoMap::Instance().Get( "memcpy" ); + auto copy_op = copy_info.Creator()( "memcpy", copy_in_map, copy_out_map, attr_map); + if(debug) cerr << "create memcpy" << endl; OpFuncNode copy_op_func_node; copy_op_func_node.input_index = copy_ins_name2id; copy_op_func_node.output_index = copy_out_name2id; - RuntimeContext runtime_context( {}, {}); - runtime_context.inputs.swap( copy_ins_value_map ); - runtime_context.outputs.swap( copy_outs_value_map ); + RuntimeContext copy_runtime_context( {}, {}); + copy_runtime_context.inputs.swap( copy_ins_value_map ); + copy_runtime_context.outputs.swap( copy_outs_value_map ); //cerr << "create runtime context" << endl; - RuntimeInferShapeContext infer_shape_ctx(*copy_op, runtime_context); - static_cast(copy_op)->InferShape( &infer_shape_ctx ); + RuntimeInferShapeContext copy_infer_shape_ctx(*copy_op, copy_runtime_context); + if(debug) cerr << "before infer shape" << endl; + static_cast(copy_op)->InferShape( ©_infer_shape_ctx ); + if(debug) cerr << "infer shape" << endl; //cerr << "fin infer shape" << endl; auto& all_op_kernels = OperatorWithKernel::AllOpKernels(); auto kernels_iter = all_op_kernels.find( "memcpy" ); @@ -704,20 +723,25 @@ void build_op_func_list( const framework::ProgramDesc& pdesc, std::vectorsecond; //auto place = platform::CPUPlace(); //auto place = platform::CUDAPlace(0); + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto* dev_ctx = pool.Get(place); Scope scope; - auto exec_ctx = ExecutionContext(*copy_op, scope, *dev_ctx, runtime_context ); + auto copy_exec_ctx = ExecutionContext(*copy_op, scope, *dev_ctx, copy_runtime_context ); if (debug ) cerr << "21" << endl; - auto expected_kernel_key = dynamic_cast(copy_op)->GetExpectedKernelType( exec_ctx ); + auto expected_kernel_key = dynamic_cast(copy_op)->GetExpectedKernelType( copy_exec_ctx ); if (debug ) cerr << "22" << endl; //cerr << "22" << endl; auto kernel_iter = kernels.find(expected_kernel_key); copy_op_func_node.kernel_func_ = OpKernelFunc( kernel_iter->second ); - copy_op_func_node.kernel_func_( exec_ctx ); + copy_op_func_node.kernel_func_( copy_exec_ctx ); + if(debug) cerr << "run exe ctx" << endl; op_list.push_back( copy_op ); vec_func_list.push_back( copy_op_func_node); + + + var_name_item.second[i] = v; } } } @@ -833,8 +857,10 @@ public: paddle::framework::build_op_func_list( prog_, op_list, vec_func_list, &global_scope, place_); is_build = true; } - - paddle::framework::exec_op_func_list( vec_func_list, op_list, global_scope, place_ ); + else + { + paddle::framework::exec_op_func_list( vec_func_list, op_list, global_scope, place_ ); + } for( size_t i = 0; i < vec_fetch_name.size(); ++i ) { @@ -845,8 +871,22 @@ public: //cerr << "out " << fetch_tensor->data()[0] << endl; + if ( platform::is_gpu_place(fetch_tensor->place() ) ) + { + cerr << "fetch gpu" << endl; + Tensor out; + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto* dev_ctx = pool.Get(place_); + dev_ctx->Wait(); + TensorCopySync(*fetch_tensor, platform::CPUPlace(), &out); + dev_ctx->Wait(); + cerr << "out " << out << endl; + } + else + { - cerr << "out " << *fetch_tensor << endl; + cerr << "out " << *fetch_tensor << endl; + } } } private: