提交 2d6bda5c 编写于 作者: P phlrain

add memcpy; test=develop

上级 6ee54b49
...@@ -17,12 +17,13 @@ ...@@ -17,12 +17,13 @@
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/init.h" #include "paddle/fluid/platform/init.h"
#include <chrono> #include <chrono>
#include <gperftools/profiler.h> #include <gperftools/profiler.h>
//USE_OP(fill_constant); //USE_OP(fill_constant);
//USE_OP(elementwise_add); //USE_OP(elementwise_add);
...@@ -524,7 +525,8 @@ void build_variable_scope( const framework::ProgramDesc& pdesc, VariableScope* v ...@@ -524,7 +525,8 @@ void build_variable_scope( const framework::ProgramDesc& pdesc, VariableScope* v
} }
auto v = new Variable(); auto v = new Variable();
v->GetMutable<LoDTensor>(); //v->GetMutable<LoDTensor>();
InitializeVariable(v, var->GetType());
var_scope->var_list.push_back(std::unique_ptr<Variable>(v)); var_scope->var_list.push_back(std::unique_ptr<Variable>(v));
} }
} }
...@@ -537,9 +539,9 @@ void build_op_func_list( const framework::ProgramDesc& pdesc, std::vector<Operat ...@@ -537,9 +539,9 @@ void build_op_func_list( const framework::ProgramDesc& pdesc, std::vector<Operat
for ( auto& op : global_block.AllOps() ) for ( auto& op : global_block.AllOps() )
{ {
//cerr << op->Type() << endl; cerr << op->Type() << endl;
//bool debug = op->Type() == "softmax_with_cross_entropy_grad"; //bool debug = op->Type() == "softmax_with_cross_entropy_grad";
bool debug = false; bool debug = true;
//cerr << "create op" << endl; //cerr << "create op" << endl;
//auto op_base_u = OpRegistry::CreateOp(*op); //auto op_base_u = OpRegistry::CreateOp(*op);
...@@ -634,16 +636,25 @@ void build_op_func_list( const framework::ProgramDesc& pdesc, std::vector<Operat ...@@ -634,16 +636,25 @@ void build_op_func_list( const framework::ProgramDesc& pdesc, std::vector<Operat
//cerr << "22" << endl; //cerr << "22" << endl;
// add transfer log // add transfer log
for( auto& var_name_item : ins_map ) //cerr << "in map size " << ins_map.size() << endl;
VariableValueMap& ins_map_temp = runtime_context.inputs;
cerr << "ins map siz" << ins_map_temp.size() << endl;
for( auto& var_name_item : ins_map_temp )
{ {
auto& vec_ids = ins_name2id[ var_name_item.first ]; cerr << "in name " << var_name_item.first << endl;
//auto& vec_ids = ins_name2id[ var_name_item.first ];
for( size_t i = 0; i < var_name_item.second.size(); ++i ) for( size_t i = 0; i < var_name_item.second.size(); ++i )
{ {
auto var = var_name_item.second[i]; auto var = var_name_item.second[i];
auto tensor_in = static_cast<const Tensor*>(&(var->Get<LoDTensor>())); auto tensor_in = static_cast<const Tensor*>(&(var->Get<LoDTensor>()));
cerr << "i " << i << "\t" << tensor_in->IsInitialized() << endl;
auto kernel_type_for_var = static_cast<const framework::OperatorWithKernel*>(op_base)->GetKernelTypeForVar( auto kernel_type_for_var = static_cast<const framework::OperatorWithKernel*>(op_base)->GetKernelTypeForVar(
var_name_item.first, *tensor_in, expected_kernel_key); var_name_item.first, *tensor_in, expected_kernel_key);
if( debug)
{
cerr << "var name " << var_name_item.first << endl;
cerr << expected_kernel_key.place_ << "\t" << kernel_type_for_var.place_ << endl;
}
if ( !platform::is_same_place(kernel_type_for_var.place_, if ( !platform::is_same_place(kernel_type_for_var.place_,
expected_kernel_key.place_) ) expected_kernel_key.place_) )
{ {
...@@ -658,7 +669,8 @@ void build_op_func_list( const framework::ProgramDesc& pdesc, std::vector<Operat ...@@ -658,7 +669,8 @@ void build_op_func_list( const framework::ProgramDesc& pdesc, std::vector<Operat
var_scope->var_list.push_back(std::unique_ptr<Variable>(v)); var_scope->var_list.push_back(std::unique_ptr<Variable>(v));
VariableNameMap copy_in_map; VariableNameMap copy_in_map;
copy_in_map["X"] = input_names[var_name_item.first]; cerr << "ints name is " << input_names[var_name_item.first][i] << endl;
copy_in_map["X"] = { input_names[var_name_item.first][i] };
VariableNameMap copy_out_map; VariableNameMap copy_out_map;
copy_out_map["Out"] = { new_var_name }; copy_out_map["Out"] = { new_var_name };
AttributeMap attr_map; AttributeMap attr_map;
...@@ -669,25 +681,32 @@ void build_op_func_list( const framework::ProgramDesc& pdesc, std::vector<Operat ...@@ -669,25 +681,32 @@ void build_op_func_list( const framework::ProgramDesc& pdesc, std::vector<Operat
std::map< std::string, std::vector<int> > copy_out_name2id; std::map< std::string, std::vector<int> > copy_out_name2id;
copy_out_name2id["Out"] = { var_scope->name2id[new_var_name]}; copy_out_name2id["Out"] = { var_scope->name2id[new_var_name]};
vec_ids[i] = var_scope->name2id[new_var_name]; //vec_ids[i] = var_scope->name2id[new_var_name];
// update out runtime_context
op_func_node.input_index[ var_name_item.first ][i] = var_scope->name2id[new_var_name];
VariableValueMap copy_ins_value_map; VariableValueMap copy_ins_value_map;
copy_ins_value_map["X"] = ins_map[ var_name_item.first ]; copy_ins_value_map["X"] = { var };
VariableValueMap copy_outs_value_map; VariableValueMap copy_outs_value_map;
copy_outs_value_map["Out"] = { v }; copy_outs_value_map["Out"] = { v };
auto copy_op = info.Creator()( "memcpy", copy_in_map, copy_out_map, attr_map);
auto& copy_info = OpInfoMap::Instance().Get( "memcpy" );
auto copy_op = copy_info.Creator()( "memcpy", copy_in_map, copy_out_map, attr_map);
if(debug) cerr << "create memcpy" << endl;
OpFuncNode copy_op_func_node; OpFuncNode copy_op_func_node;
copy_op_func_node.input_index = copy_ins_name2id; copy_op_func_node.input_index = copy_ins_name2id;
copy_op_func_node.output_index = copy_out_name2id; copy_op_func_node.output_index = copy_out_name2id;
RuntimeContext runtime_context( {}, {}); RuntimeContext copy_runtime_context( {}, {});
runtime_context.inputs.swap( copy_ins_value_map ); copy_runtime_context.inputs.swap( copy_ins_value_map );
runtime_context.outputs.swap( copy_outs_value_map ); copy_runtime_context.outputs.swap( copy_outs_value_map );
//cerr << "create runtime context" << endl; //cerr << "create runtime context" << endl;
RuntimeInferShapeContext infer_shape_ctx(*copy_op, runtime_context); RuntimeInferShapeContext copy_infer_shape_ctx(*copy_op, copy_runtime_context);
static_cast<const framework::OperatorWithKernel*>(copy_op)->InferShape( &infer_shape_ctx ); if(debug) cerr << "before infer shape" << endl;
static_cast<const framework::OperatorWithKernel*>(copy_op)->InferShape( &copy_infer_shape_ctx );
if(debug) cerr << "infer shape" << endl;
//cerr << "fin infer shape" << endl; //cerr << "fin infer shape" << endl;
auto& all_op_kernels = OperatorWithKernel::AllOpKernels(); auto& all_op_kernels = OperatorWithKernel::AllOpKernels();
auto kernels_iter = all_op_kernels.find( "memcpy" ); auto kernels_iter = all_op_kernels.find( "memcpy" );
...@@ -704,20 +723,25 @@ void build_op_func_list( const framework::ProgramDesc& pdesc, std::vector<Operat ...@@ -704,20 +723,25 @@ void build_op_func_list( const framework::ProgramDesc& pdesc, std::vector<Operat
OpKernelMap& kernels = kernels_iter->second; OpKernelMap& kernels = kernels_iter->second;
//auto place = platform::CPUPlace(); //auto place = platform::CPUPlace();
//auto place = platform::CUDAPlace(0); //auto place = platform::CUDAPlace(0);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place); auto* dev_ctx = pool.Get(place);
Scope scope; Scope scope;
auto exec_ctx = ExecutionContext(*copy_op, scope, *dev_ctx, runtime_context ); auto copy_exec_ctx = ExecutionContext(*copy_op, scope, *dev_ctx, copy_runtime_context );
if (debug ) cerr << "21" << endl; if (debug ) cerr << "21" << endl;
auto expected_kernel_key = dynamic_cast<const framework::OperatorWithKernel*>(copy_op)->GetExpectedKernelType( exec_ctx ); auto expected_kernel_key = dynamic_cast<const framework::OperatorWithKernel*>(copy_op)->GetExpectedKernelType( copy_exec_ctx );
if (debug ) cerr << "22" << endl; if (debug ) cerr << "22" << endl;
//cerr << "22" << endl; //cerr << "22" << endl;
auto kernel_iter = kernels.find(expected_kernel_key); auto kernel_iter = kernels.find(expected_kernel_key);
copy_op_func_node.kernel_func_ = OpKernelFunc( kernel_iter->second ); copy_op_func_node.kernel_func_ = OpKernelFunc( kernel_iter->second );
copy_op_func_node.kernel_func_( exec_ctx ); copy_op_func_node.kernel_func_( copy_exec_ctx );
if(debug) cerr << "run exe ctx" << endl;
op_list.push_back( copy_op ); op_list.push_back( copy_op );
vec_func_list.push_back( copy_op_func_node); vec_func_list.push_back( copy_op_func_node);
var_name_item.second[i] = v;
} }
} }
} }
...@@ -833,8 +857,10 @@ public: ...@@ -833,8 +857,10 @@ public:
paddle::framework::build_op_func_list( prog_, op_list, vec_func_list, &global_scope, place_); paddle::framework::build_op_func_list( prog_, op_list, vec_func_list, &global_scope, place_);
is_build = true; is_build = true;
} }
else
{
paddle::framework::exec_op_func_list( vec_func_list, op_list, global_scope, place_ ); paddle::framework::exec_op_func_list( vec_func_list, op_list, global_scope, place_ );
}
for( size_t i = 0; i < vec_fetch_name.size(); ++i ) for( size_t i = 0; i < vec_fetch_name.size(); ++i )
{ {
...@@ -845,10 +871,24 @@ public: ...@@ -845,10 +871,24 @@ public:
//cerr << "out " << fetch_tensor->data<float>()[0] << endl; //cerr << "out " << fetch_tensor->data<float>()[0] << endl;
if ( platform::is_gpu_place(fetch_tensor->place() ) )
{
cerr << "fetch gpu" << endl;
Tensor out;
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place_);
dev_ctx->Wait();
TensorCopySync(*fetch_tensor, platform::CPUPlace(), &out);
dev_ctx->Wait();
cerr << "out " << out << endl;
}
else
{
cerr << "out " << *fetch_tensor << endl; cerr << "out " << *fetch_tensor << endl;
} }
} }
}
private: private:
const platform::Place& place_; const platform::Place& place_;
const ProgramDesc& prog_; const ProgramDesc& prog_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册