提交 b91a7a9d 编写于 作者: X Xin Pan

clear operator changes

test=develop
上级 f52b514d
...@@ -179,11 +179,6 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) { ...@@ -179,11 +179,6 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
VLOG(3) << place << " " << DebugStringEx(&scope); VLOG(3) << place << " " << DebugStringEx(&scope);
} }
void OperatorBase::RunPrepared(const RuntimeContext& ctx,
const platform::Place& place) {
RunImplPrepared(ctx, place);
}
bool OperatorBase::HasInputs(const std::string& name) const { bool OperatorBase::HasInputs(const std::string& name) const {
return inputs_.find(name) != inputs_.end(); return inputs_.find(name) != inputs_.end();
} }
...@@ -958,51 +953,6 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -958,51 +953,6 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
} }
} }
void OperatorWithKernel::RunImplPrepared(const RuntimeContext& ctx,
const platform::Place& place) const {
Scope dummy_scope;
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);
// check if op[type] has kernel registered.
auto& all_op_kernels = AllOpKernels();
auto kernels_iter = all_op_kernels.find(type_);
if (kernels_iter == all_op_kernels.end()) {
PADDLE_THROW(
"There are no kernels which are registered in the %s operator.", type_);
}
OpKernelMap& kernels = kernels_iter->second;
auto expected_kernel_key = this->GetExpectedKernelType(
ExecutionContext(*this, dummy_scope, *dev_ctx, ctx));
VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
auto kernel_iter = kernels.find(expected_kernel_key);
#ifdef PADDLE_WITH_MKLDNN
// workaround for missing MKLDNN kernel when FLAGS_use_mkldnn env var is set
if (kernel_iter == kernels.end() &&
expected_kernel_key.library_type_ == LibraryType::kMKLDNN) {
VLOG(3) << "missing MKLDNN kernel: fallbacking to PLAIN one";
expected_kernel_key.library_type_ = LibraryType::kPlain;
expected_kernel_key.data_layout_ = DataLayout::kAnyLayout;
kernel_iter = kernels.find(expected_kernel_key);
}
#endif
if (kernel_iter == kernels.end()) {
PADDLE_THROW("op %s does not have kernel for %s", type_,
KernelTypeToString(expected_kernel_key));
}
if (!(expected_kernel_key.place_ == dev_ctx->GetPlace())) {
dev_ctx = pool.Get(expected_kernel_key.place_);
}
RuntimeInferShapeContext infer_shape_ctx(*this, dummy_scope, ctx);
this->InferShape(&infer_shape_ctx);
kernel_iter->second(ExecutionContext(*this, dummy_scope, *dev_ctx, ctx));
}
void OperatorWithKernel::TransferInplaceVarsBack( void OperatorWithKernel::TransferInplaceVarsBack(
const Scope& scope, const std::vector<std::string>& inplace_vars, const Scope& scope, const std::vector<std::string>& inplace_vars,
const Scope& transfer_scope) const { const Scope& transfer_scope) const {
......
...@@ -105,7 +105,6 @@ class OperatorBase { ...@@ -105,7 +105,6 @@ class OperatorBase {
/// Executor will call this interface function to Run an op. /// Executor will call this interface function to Run an op.
// The implementation should be written at RunImpl // The implementation should be written at RunImpl
void Run(const Scope& scope, const platform::Place& place); void Run(const Scope& scope, const platform::Place& place);
void RunPrepared(const RuntimeContext& ctx, const platform::Place& place);
// FIXME(typhoonzero): this is only used for recv_op to stop event_loop. // FIXME(typhoonzero): this is only used for recv_op to stop event_loop.
virtual void Stop() {} virtual void Stop() {}
...@@ -172,11 +171,6 @@ class OperatorBase { ...@@ -172,11 +171,6 @@ class OperatorBase {
void CheckAllInputOutputSet() const; void CheckAllInputOutputSet() const;
virtual void RunImpl(const Scope& scope, virtual void RunImpl(const Scope& scope,
const platform::Place& place) const = 0; const platform::Place& place) const = 0;
virtual void RunImplPrepared(const RuntimeContext& ctx,
const platform::Place& place) const {
PADDLE_THROW("%s doesn't support RunPreparedImpl", Type());
}
}; };
class ExecutionContext { class ExecutionContext {
...@@ -469,8 +463,6 @@ class OperatorWithKernel : public OperatorBase { ...@@ -469,8 +463,6 @@ class OperatorWithKernel : public OperatorBase {
// same. // same.
proto::VarType::Type IndicateDataType(const ExecutionContext& ctx) const; proto::VarType::Type IndicateDataType(const ExecutionContext& ctx) const;
void RunImpl(const Scope& scope, const platform::Place& place) const final; void RunImpl(const Scope& scope, const platform::Place& place) const final;
void RunImplPrepared(const RuntimeContext& ctx,
const platform::Place& place) const final;
/** /**
* Transfer data from scope to a transfered scope. If there is no data need to * Transfer data from scope to a transfered scope. If there is no data need to
......
...@@ -118,7 +118,6 @@ class Tracer { ...@@ -118,7 +118,6 @@ class Tracer {
VLOG(3) << "tracer running " << op_desc->Type(); VLOG(3) << "tracer running " << op_desc->Type();
framework::RuntimeContext ctx(invars_map, outvars_map); framework::RuntimeContext ctx(invars_map, outvars_map);
// op_base->RunPrepared(ctx, platform::CPUPlace());
// TODO(panyx0718): Cache p. // TODO(panyx0718): Cache p.
framework::OperatorWithKernel* op_kernel = framework::OperatorWithKernel* op_kernel =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册