From 1041e18c4785159838c1531983432c507c291b2e Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Tue, 5 Mar 2019 15:19:20 +0800 Subject: [PATCH] Refine codes. test=develop --- paddle/fluid/framework/operator.cc | 88 ++++++++++++++---------------- paddle/fluid/framework/operator.h | 9 ++- 2 files changed, 49 insertions(+), 48 deletions(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index fb99d2a2bc5..644af92799b 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -922,62 +922,20 @@ void OperatorWithKernel::RunImpl(const Scope& scope, auto* dev_ctx = pool.Get(place); if (!kernel_type_) { - // LOG(INFO) << "1, kernel_type is not set."; - // check if op[type] has kernel registered. - auto& all_op_kernels = AllOpKernels(); - auto kernels_iter = all_op_kernels.find(type_); - if (kernels_iter == all_op_kernels.end()) { - PADDLE_THROW( - "There are no kernels which are registered in the %s operator.", - type_); - } - - OpKernelMap& kernels = kernels_iter->second; - - auto expected_kernel_key = this->GetExpectedKernelType( - ExecutionContext(*this, scope, *dev_ctx, ctx, nullptr)); - VLOG(3) << "expected_kernel_key:" << expected_kernel_key; - - auto kernel_iter = kernels.find(expected_kernel_key); -#ifdef PADDLE_WITH_MKLDNN - // workaround for missing MKLDNN kernel when FLAGS_use_mkldnn env var is set - if (kernel_iter == kernels.end() && - expected_kernel_key.library_type_ == LibraryType::kMKLDNN) { - VLOG(3) << "missing MKLDNN kernel: fallbacking to PLAIN one"; - expected_kernel_key.library_type_ = LibraryType::kPlain; - expected_kernel_key.data_layout_ = DataLayout::kAnyLayout; - kernel_iter = kernels.find(expected_kernel_key); - } -#endif - if (kernel_iter == kernels.end()) { - PADDLE_THROW("op %s does not have kernel for %s", type_, - KernelTypeToString(expected_kernel_key)); - } - - kernel_type_.reset(new OpKernelType(expected_kernel_key)); - kernel_func_.reset(new OpKernelFunc(kernel_iter->second)); + ChooseKernel(ctx, scope, place); } - // std::shared_ptr kernel_type = kernel_type_; - // std::shared_ptr kernel_func = kernel_func_; - - std::vector* kernel_configs = - // GetKernelConfig(expected_kernel_key); - GetKernelConfig(*kernel_type_); + std::vector* kernel_configs = GetKernelConfig(*kernel_type_); // do data transformScope &transfer_scope; std::vector transfered_inplace_vars; auto* transfer_scope = - // PrepareData(scope, expected_kernel_key, &transfered_inplace_vars, - // &ctx); PrepareData(scope, *kernel_type_, &transfered_inplace_vars, &ctx); // exec scope is the scope that kernel actually executed on. const Scope& exec_scope = (transfer_scope == nullptr ? scope : *transfer_scope); - // if (!(expected_kernel_key.place_ == dev_ctx->GetPlace())) { - // dev_ctx = pool.Get(expected_kernel_key.place_); if (!(kernel_type_->place_ == dev_ctx->GetPlace())) { dev_ctx = pool.Get(kernel_type_->place_); } @@ -986,8 +944,6 @@ void OperatorWithKernel::RunImpl(const Scope& scope, this->InferShape(&infer_shape_ctx); // TODO(panyx0718): ExecutionContext should only depend on RuntimeContext // not Scope. Imperative mode only pass inputs and get outputs. - // kernel_iter->second( - // ExecutionContext(*this, exec_scope, *dev_ctx, ctx, kernel_configs)); (*kernel_func_)( ExecutionContext(*this, exec_scope, *dev_ctx, ctx, kernel_configs)); @@ -1015,6 +971,46 @@ void OperatorWithKernel::RunImpl(const Scope& scope, } } +void OperatorWithKernel::ChooseKernel(const RuntimeContext& ctx, + const Scope& scope, + const platform::Place& place) const { + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto* dev_ctx = pool.Get(place); + + // check if op[type] has kernel registered. + auto& all_op_kernels = AllOpKernels(); + auto kernels_iter = all_op_kernels.find(type_); + if (kernels_iter == all_op_kernels.end()) { + PADDLE_THROW( + "There are no kernels which are registered in the %s operator.", type_); + } + + OpKernelMap& kernels = kernels_iter->second; + + auto expected_kernel_key = this->GetExpectedKernelType( + ExecutionContext(*this, scope, *dev_ctx, ctx, nullptr)); + VLOG(3) << "expected_kernel_key:" << expected_kernel_key; + + auto kernel_iter = kernels.find(expected_kernel_key); +#ifdef PADDLE_WITH_MKLDNN + // workaround for missing MKLDNN kernel when FLAGS_use_mkldnn env var is set + if (kernel_iter == kernels.end() && + expected_kernel_key.library_type_ == LibraryType::kMKLDNN) { + VLOG(3) << "missing MKLDNN kernel: fallbacking to PLAIN one"; + expected_kernel_key.library_type_ = LibraryType::kPlain; + expected_kernel_key.data_layout_ = DataLayout::kAnyLayout; + kernel_iter = kernels.find(expected_kernel_key); + } +#endif + if (kernel_iter == kernels.end()) { + PADDLE_THROW("op %s does not have kernel for %s", type_, + KernelTypeToString(expected_kernel_key)); + } + + kernel_type_.reset(new OpKernelType(expected_kernel_key)); + kernel_func_.reset(new OpKernelFunc(kernel_iter->second)); +} + void OperatorWithKernel::TransferInplaceVarsBack( const Scope& scope, const std::vector& inplace_vars, const Scope& transfer_scope) const { diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 11ef729706b..e2383312038 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -16,9 +16,11 @@ limitations under the License. */ #include #include +#include #include #include #include +#include #include #include "glog/logging.h" // For VLOG @@ -539,10 +541,13 @@ class OperatorWithKernel : public OperatorBase { const std::vector& inplace_vars, const Scope& exec_scope) const; + void ChooseKernel(const RuntimeContext& ctx, const Scope& scope, + const platform::Place& place) const; + protected: mutable OpKernelConfigsMap kernel_configs_map_; - mutable std::shared_ptr kernel_type_; - mutable std::shared_ptr kernel_func_; + mutable std::unique_ptr kernel_type_; + mutable std::unique_ptr kernel_func_; }; extern bool OpSupportGPU(const std::string& op_type); -- GitLab