diff --git a/paddle/fluid/framework/ngraph_operator.cc b/paddle/fluid/framework/ngraph_operator.cc index e37f0915c5dbea45c195f8d769baea2865a172e6..23f681ce886fd0d8c113ffe4e80e25e6a803e31b 100644 --- a/paddle/fluid/framework/ngraph_operator.cc +++ b/paddle/fluid/framework/ngraph_operator.cc @@ -278,19 +278,7 @@ std::shared_ptr NgraphEngine::backend_ = ngraph::runtime::Backend::create("CPU"); void NgraphEngine::GetNgInputShape(std::shared_ptr op) { - RuntimeContext ctx; - for (auto& var_name_item : op->Inputs()) { - std::vector input_vars = ctx.inputs[var_name_item.first]; - for (auto& var_name : var_name_item.second) { - input_vars.push_back(scope_.FindVar(var_name)); - } - } - for (auto& var_name_item : op->Outputs()) { - std::vector output_vars = ctx.outputs[var_name_item.first]; - for (auto& var_name : var_name_item.second) { - output_vars.push_back(scope_.FindVar(var_name)); - } - } + RuntimeContext ctx(op->Inputs(), op->Outputs(), scope_); op->RuntimeInferShape(scope_, place_, ctx); for (auto& var_name_item : op->Inputs()) { for (auto& var_name : var_name_item.second) { diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 79e3d29a63b7a9a097d63589aaa36fb3cf6bc9ec..461d35752743ac2d6bf1640667d624b86e9ff48a 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -137,6 +137,23 @@ static LoD GetLoD(const Scope& scope, const std::string& name) { } } +RuntimeContext::RuntimeContext(const VariableNameMap& innames, + const VariableNameMap& outnames, + const Scope& scope) { + for (auto& var_name_item : innames) { + std::vector& input_vars = inputs[var_name_item.first]; + for (auto& var_name : var_name_item.second) { + input_vars.push_back(scope.FindVar(var_name)); + } + } + for (auto& var_name_item : outnames) { + std::vector& output_vars = outputs[var_name_item.first]; + for (auto& var_name : var_name_item.second) { + output_vars.push_back(scope.FindVar(var_name)); + } + } +} + void OperatorBase::Run(const Scope& scope, const platform::Place& place) { VLOG(4) << place << " " << DebugStringEx(&scope); if (platform::is_gpu_place(place)) { @@ -704,6 +721,7 @@ void OperatorWithKernel::RuntimeInferShape(const Scope& scope, void OperatorWithKernel::RunImpl(const Scope& scope, const platform::Place& place) const { + RuntimeContext ctx(Inputs(), Outputs(), scope); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto* dev_ctx = pool.Get(place); @@ -717,15 +735,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope, OpKernelMap& kernels = kernels_iter->second; - // TODO(dzhwinter) : kernel fallback mechanism will be added when all the - // transform functions are ready. - - // for (auto& candidate : kKernelPriority) { - // Do selection - // } - - auto expected_kernel_key = - this->GetExpectedKernelType(ExecutionContext(*this, scope, *dev_ctx)); + auto expected_kernel_key = this->GetExpectedKernelType( + ExecutionContext(*this, scope, *dev_ctx, ctx)); VLOG(3) << "expected_kernel_key:" << expected_kernel_key; auto kernel_iter = kernels.find(expected_kernel_key); @@ -744,7 +755,6 @@ void OperatorWithKernel::RunImpl(const Scope& scope, KernelTypeToString(expected_kernel_key)); } - RuntimeContext ctx; // do data transformScope &transfer_scope; std::vector transfered_inplace_vars; auto* transfer_scope = @@ -760,7 +770,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, RuntimeInferShapeContext infer_shape_ctx(*this, exec_scope, ctx); this->InferShape(&infer_shape_ctx); - kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx)); + kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx, ctx)); if (!transfered_inplace_vars.empty()) { // there is inplace variable has been transfered. @@ -784,6 +794,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, } } } + void OperatorWithKernel::TransferInplaceVarsBack( const Scope& scope, const std::vector& inplace_vars, const Scope& transfer_scope) const { @@ -806,7 +817,6 @@ Scope* OperatorWithKernel::PrepareData( Scope* new_scope = nullptr; for (auto& var_name_item : Inputs()) { std::vector& input_vars = ctx->inputs[var_name_item.first]; - input_vars.resize(var_name_item.second.size()); for (size_t i = 0; i < var_name_item.second.size(); ++i) { auto& var_name = var_name_item.second[i]; @@ -869,8 +879,6 @@ Scope* OperatorWithKernel::PrepareData( } for (auto& var_name_item : Outputs()) { std::vector& output_vars = ctx->outputs[var_name_item.first]; - output_vars.resize(var_name_item.second.size()); - for (size_t i = 0; i < var_name_item.second.size(); ++i) { auto& var_name = var_name_item.second[i]; output_vars[i] = scope.FindVar(var_name); diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 438ae2539874dec8ae9cd357071fa61a67e6b268..e359414d151b90bc835caa174035e0baf8852f1b 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -72,7 +72,8 @@ class ExecutionContext; class RuntimeContext { public: - RuntimeContext() {} + RuntimeContext(const VariableNameMap& innames, + const VariableNameMap& outnames, const Scope& scope); VariableValueMap inputs; VariableValueMap outputs; @@ -165,8 +166,9 @@ class OperatorBase { class ExecutionContext { public: ExecutionContext(const OperatorBase& op, const Scope& scope, - const platform::DeviceContext& device_context) - : op_(op), scope_(scope), device_context_(device_context) {} + const platform::DeviceContext& device_context, + const RuntimeContext& ctx) + : op_(op), scope_(scope), device_context_(device_context), ctx_(ctx) {} const OperatorBase& op() const { return op_; } @@ -295,6 +297,7 @@ class ExecutionContext { const OperatorBase& op_; const Scope& scope_; const platform::DeviceContext& device_context_; + const RuntimeContext& ctx_; }; template <> diff --git a/paddle/fluid/operators/beam_search_decode_op.cc b/paddle/fluid/operators/beam_search_decode_op.cc index ae9765b76138a34935619b662a8ffb7f46c8300c..7f2bde55c98277b9fd4b3374657001c42d673d43 100644 --- a/paddle/fluid/operators/beam_search_decode_op.cc +++ b/paddle/fluid/operators/beam_search_decode_op.cc @@ -122,7 +122,8 @@ class BeamSearchDecodeOp : public framework::OperatorBase { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto& dev_ctx = *pool.Get(dev_place); - framework::ExecutionContext ctx(*this, scope, dev_ctx); + framework::RuntimeContext run_ctx(Inputs(), Outputs(), scope); + framework::ExecutionContext ctx(*this, scope, dev_ctx, run_ctx); const LoDTensorArray* ids = ctx.Input("Ids"); const LoDTensorArray* scores = ctx.Input("Scores");