提交 1041e18c 编写于 作者: L Liu Yiqun

Refine codes.

test=develop
上级 d8a939d8
...@@ -922,62 +922,20 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -922,62 +922,20 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
auto* dev_ctx = pool.Get(place); auto* dev_ctx = pool.Get(place);
if (!kernel_type_) { if (!kernel_type_) {
// LOG(INFO) << "1, kernel_type is not set."; ChooseKernel(ctx, scope, place);
// check if op[type] has kernel registered.
auto& all_op_kernels = AllOpKernels();
auto kernels_iter = all_op_kernels.find(type_);
if (kernels_iter == all_op_kernels.end()) {
PADDLE_THROW(
"There are no kernels which are registered in the %s operator.",
type_);
} }
OpKernelMap& kernels = kernels_iter->second; std::vector<KernelConfig>* kernel_configs = GetKernelConfig(*kernel_type_);
auto expected_kernel_key = this->GetExpectedKernelType(
ExecutionContext(*this, scope, *dev_ctx, ctx, nullptr));
VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
auto kernel_iter = kernels.find(expected_kernel_key);
#ifdef PADDLE_WITH_MKLDNN
// workaround for missing MKLDNN kernel when FLAGS_use_mkldnn env var is set
if (kernel_iter == kernels.end() &&
expected_kernel_key.library_type_ == LibraryType::kMKLDNN) {
VLOG(3) << "missing MKLDNN kernel: fallbacking to PLAIN one";
expected_kernel_key.library_type_ = LibraryType::kPlain;
expected_kernel_key.data_layout_ = DataLayout::kAnyLayout;
kernel_iter = kernels.find(expected_kernel_key);
}
#endif
if (kernel_iter == kernels.end()) {
PADDLE_THROW("op %s does not have kernel for %s", type_,
KernelTypeToString(expected_kernel_key));
}
kernel_type_.reset(new OpKernelType(expected_kernel_key));
kernel_func_.reset(new OpKernelFunc(kernel_iter->second));
}
// std::shared_ptr<OpKernelType> kernel_type = kernel_type_;
// std::shared_ptr<OpKernelFunc> kernel_func = kernel_func_;
std::vector<KernelConfig>* kernel_configs =
// GetKernelConfig(expected_kernel_key);
GetKernelConfig(*kernel_type_);
// do data transformScope &transfer_scope; // do data transformScope &transfer_scope;
std::vector<std::string> transfered_inplace_vars; std::vector<std::string> transfered_inplace_vars;
auto* transfer_scope = auto* transfer_scope =
// PrepareData(scope, expected_kernel_key, &transfered_inplace_vars,
// &ctx);
PrepareData(scope, *kernel_type_, &transfered_inplace_vars, &ctx); PrepareData(scope, *kernel_type_, &transfered_inplace_vars, &ctx);
// exec scope is the scope that kernel actually executed on. // exec scope is the scope that kernel actually executed on.
const Scope& exec_scope = const Scope& exec_scope =
(transfer_scope == nullptr ? scope : *transfer_scope); (transfer_scope == nullptr ? scope : *transfer_scope);
// if (!(expected_kernel_key.place_ == dev_ctx->GetPlace())) {
// dev_ctx = pool.Get(expected_kernel_key.place_);
if (!(kernel_type_->place_ == dev_ctx->GetPlace())) { if (!(kernel_type_->place_ == dev_ctx->GetPlace())) {
dev_ctx = pool.Get(kernel_type_->place_); dev_ctx = pool.Get(kernel_type_->place_);
} }
...@@ -986,8 +944,6 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -986,8 +944,6 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
this->InferShape(&infer_shape_ctx); this->InferShape(&infer_shape_ctx);
// TODO(panyx0718): ExecutionContext should only depend on RuntimeContext // TODO(panyx0718): ExecutionContext should only depend on RuntimeContext
// not Scope. Imperative mode only pass inputs and get outputs. // not Scope. Imperative mode only pass inputs and get outputs.
// kernel_iter->second(
// ExecutionContext(*this, exec_scope, *dev_ctx, ctx, kernel_configs));
(*kernel_func_)( (*kernel_func_)(
ExecutionContext(*this, exec_scope, *dev_ctx, ctx, kernel_configs)); ExecutionContext(*this, exec_scope, *dev_ctx, ctx, kernel_configs));
...@@ -1015,6 +971,46 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1015,6 +971,46 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
} }
} }
void OperatorWithKernel::ChooseKernel(const RuntimeContext& ctx,
const Scope& scope,
const platform::Place& place) const {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);
// check if op[type] has kernel registered.
auto& all_op_kernels = AllOpKernels();
auto kernels_iter = all_op_kernels.find(type_);
if (kernels_iter == all_op_kernels.end()) {
PADDLE_THROW(
"There are no kernels which are registered in the %s operator.", type_);
}
OpKernelMap& kernels = kernels_iter->second;
auto expected_kernel_key = this->GetExpectedKernelType(
ExecutionContext(*this, scope, *dev_ctx, ctx, nullptr));
VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
auto kernel_iter = kernels.find(expected_kernel_key);
#ifdef PADDLE_WITH_MKLDNN
// workaround for missing MKLDNN kernel when FLAGS_use_mkldnn env var is set
if (kernel_iter == kernels.end() &&
expected_kernel_key.library_type_ == LibraryType::kMKLDNN) {
VLOG(3) << "missing MKLDNN kernel: fallbacking to PLAIN one";
expected_kernel_key.library_type_ = LibraryType::kPlain;
expected_kernel_key.data_layout_ = DataLayout::kAnyLayout;
kernel_iter = kernels.find(expected_kernel_key);
}
#endif
if (kernel_iter == kernels.end()) {
PADDLE_THROW("op %s does not have kernel for %s", type_,
KernelTypeToString(expected_kernel_key));
}
kernel_type_.reset(new OpKernelType(expected_kernel_key));
kernel_func_.reset(new OpKernelFunc(kernel_iter->second));
}
void OperatorWithKernel::TransferInplaceVarsBack( void OperatorWithKernel::TransferInplaceVarsBack(
const Scope& scope, const std::vector<std::string>& inplace_vars, const Scope& scope, const std::vector<std::string>& inplace_vars,
const Scope& transfer_scope) const { const Scope& transfer_scope) const {
......
...@@ -16,9 +16,11 @@ limitations under the License. */ ...@@ -16,9 +16,11 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <atomic> #include <atomic>
#include <memory>
#include <string> #include <string>
#include <tuple> #include <tuple>
#include <unordered_map> #include <unordered_map>
#include <utility>
#include <vector> #include <vector>
#include "glog/logging.h" // For VLOG #include "glog/logging.h" // For VLOG
...@@ -539,10 +541,13 @@ class OperatorWithKernel : public OperatorBase { ...@@ -539,10 +541,13 @@ class OperatorWithKernel : public OperatorBase {
const std::vector<std::string>& inplace_vars, const std::vector<std::string>& inplace_vars,
const Scope& exec_scope) const; const Scope& exec_scope) const;
void ChooseKernel(const RuntimeContext& ctx, const Scope& scope,
const platform::Place& place) const;
protected: protected:
mutable OpKernelConfigsMap kernel_configs_map_; mutable OpKernelConfigsMap kernel_configs_map_;
mutable std::shared_ptr<OpKernelType> kernel_type_; mutable std::unique_ptr<OpKernelType> kernel_type_;
mutable std::shared_ptr<OpKernelFunc> kernel_func_; mutable std::unique_ptr<OpKernelFunc> kernel_func_;
}; };
extern bool OpSupportGPU(const std::string& op_type); extern bool OpSupportGPU(const std::string& op_type);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册