未验证 提交 7f3b0877 编写于 作者: L Leo Chen 提交者: GitHub

[new-exec] support pten kernel (#38770)

上级 1b6e4664
...@@ -413,7 +413,23 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) { ...@@ -413,7 +413,23 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
if (op_with_kernel == nullptr) { if (op_with_kernel == nullptr) {
instr_node.OpBase()->Run(*local_scope, place_); instr_node.OpBase()->Run(*local_scope, place_);
} else { } else {
instr_node.KernelFunc()(*instr_node.InnerExecutionContext().get()); // fit for pten
if (instr_node.PtenKernel() && instr_node.PtenKernel()->IsValid()) {
VLOG(4) << "Run pten kernel: " << op->Type();
VLOG(4) << instr_node.InnerRuntimeContext().get() << " "
<< &instr_node.DeviceContext();
op_with_kernel->BuildPtenKernelContext(
*instr_node.InnerRuntimeContext().get(),
const_cast<platform::DeviceContext*>(&instr_node.DeviceContext()));
(*instr_node.PtenKernel())(instr_node.PtenKernelContext());
op_with_kernel->WriteBackToOutputs(
instr_node.InnerRuntimeContext().get());
instr_node.PtenKernelContext()->ClearData();
} else {
instr_node.KernelFunc()(*instr_node.InnerExecutionContext().get());
}
} }
} }
......
...@@ -19,10 +19,13 @@ ...@@ -19,10 +19,13 @@
#include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h" #include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h"
#include "paddle/fluid/operators/controlflow/recurrent_op_helper.h" #include "paddle/fluid/operators/controlflow/recurrent_op_helper.h"
#include "paddle/fluid/operators/controlflow/while_op_helper.h" #include "paddle/fluid/operators/controlflow/while_op_helper.h"
#include "paddle/pten/core/kernel_factory.h"
PADDLE_DEFINE_EXPORTED_bool( PADDLE_DEFINE_EXPORTED_bool(
new_executor_sequential_run, false, new_executor_sequential_run, false,
"Enable sequential execution for standalone executor, used for debug"); "Enable sequential execution for standalone executor, used for debug");
DECLARE_bool(run_pten_kernel);
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace interpreter { namespace interpreter {
...@@ -338,6 +341,8 @@ void build_op_func_list(const platform::Place& place, ...@@ -338,6 +341,8 @@ void build_op_func_list(const platform::Place& place,
// op is not a operatorwithkernel, so direcly run OperatorBase::Run() // op is not a operatorwithkernel, so direcly run OperatorBase::Run()
deal_operator_base(place, var_scope, ops[i], &op_func_node, local_scope); deal_operator_base(place, var_scope, ops[i], &op_func_node, local_scope);
} else { } else {
auto op_with_kernel =
static_cast<const framework::OperatorWithKernel*>(op);
// construct RuntimeContext and analysis KernelType // construct RuntimeContext and analysis KernelType
RuntimeContext runtime_context({}, {}); RuntimeContext runtime_context({}, {});
runtime_context.inputs.swap(ins_map); runtime_context.inputs.swap(ins_map);
...@@ -350,8 +355,7 @@ void build_op_func_list(const platform::Place& place, ...@@ -350,8 +355,7 @@ void build_op_func_list(const platform::Place& place,
// TODO(Aurelius84): In case of control flow ops, they are NOT // TODO(Aurelius84): In case of control flow ops, they are NOT
// inheritted // inheritted
// from OperatorWithKernel. // from OperatorWithKernel.
static_cast<const framework::OperatorWithKernel*>(op)->InferShape( op_with_kernel->InferShape(&infer_shape_ctx);
&infer_shape_ctx);
} }
auto kernels_iter = all_op_kernels.find(op->Type()); auto kernels_iter = all_op_kernels.find(op->Type());
...@@ -367,10 +371,8 @@ void build_op_func_list(const platform::Place& place, ...@@ -367,10 +371,8 @@ void build_op_func_list(const platform::Place& place,
platform::DeviceContextPool::Instance(); platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place); auto* dev_ctx = pool.Get(place);
Scope scope; Scope scope;
auto expected_kernel_key = auto expected_kernel_key = op_with_kernel->GetExpectedKernelType(
dynamic_cast<const framework::OperatorWithKernel*>(op) ExecutionContext(*op, scope, *dev_ctx, runtime_context));
->GetExpectedKernelType(
ExecutionContext(*op, scope, *dev_ctx, runtime_context));
// change device by the device_guard() // change device by the device_guard()
apply_device_guard(op, place, &expected_kernel_key); apply_device_guard(op, place, &expected_kernel_key);
...@@ -378,10 +380,16 @@ void build_op_func_list(const platform::Place& place, ...@@ -378,10 +380,16 @@ void build_op_func_list(const platform::Place& place,
// step 3. apply data transforms and insert data transfer ops // step 3. apply data transforms and insert data transfer ops
VariableValueMap& ins_map_temp = runtime_context.inputs; VariableValueMap& ins_map_temp = runtime_context.inputs;
// NOTE(zhiqiu): op_func_node->operator_base_ maybe changed in
// ApplyDataTransform
ApplyDataTransform(expected_kernel_key, place, &ins_map_temp, var_scope, ApplyDataTransform(expected_kernel_key, place, &ins_map_temp, var_scope,
&op_func_node, vec_func_list, use_local_scope); &op_func_node, vec_func_list, use_local_scope);
op_with_kernel = static_cast<const framework::OperatorWithKernel*>(
op_func_node.operator_base_.get());
// step 4. Run op kernel // step 4. Run op kernel
VLOG(3) << op->Type() VLOG(3) << op_with_kernel->Type()
<< " : expected_kernel_key : " << expected_kernel_key; << " : expected_kernel_key : " << expected_kernel_key;
if (platform::is_gpu_place(expected_kernel_key.place_)) { if (platform::is_gpu_place(expected_kernel_key.place_)) {
...@@ -397,7 +405,8 @@ void build_op_func_list(const platform::Place& place, ...@@ -397,7 +405,8 @@ void build_op_func_list(const platform::Place& place,
} }
op_func_node.dev_ctx_ = dev_ctx; op_func_node.dev_ctx_ = dev_ctx;
auto exec_ctx = ExecutionContext(*op, scope, *dev_ctx, runtime_context); auto exec_ctx =
ExecutionContext(*op_with_kernel, scope, *dev_ctx, runtime_context);
auto kernel_iter = kernels.find(expected_kernel_key); auto kernel_iter = kernels.find(expected_kernel_key);
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
...@@ -406,8 +415,27 @@ void build_op_func_list(const platform::Place& place, ...@@ -406,8 +415,27 @@ void build_op_func_list(const platform::Place& place,
"Operator (%s) does not have kernel for %s.", op->Type(), "Operator (%s) does not have kernel for %s.", op->Type(),
KernelTypeToString(expected_kernel_key))); KernelTypeToString(expected_kernel_key)));
op_func_node.kernel_func_ = OpKernelComputeFunc(kernel_iter->second); auto run_pten_kernel = false;
op_func_node.kernel_func_(exec_ctx);
if (FLAGS_run_pten_kernel &&
pten::KernelFactory::Instance().HasCompatiblePtenKernel(
op_with_kernel->Type())) {
op_with_kernel->ChoosePtenKernel(exec_ctx);
run_pten_kernel = op_with_kernel->PtenKernel()->IsValid();
}
if (run_pten_kernel) {
op_with_kernel->BuildPtenKernelContext(runtime_context, dev_ctx);
op_func_node.pt_kernel_ = op_with_kernel->PtenKernel();
op_func_node.pt_kernel_context_ = op_with_kernel->PtenKernelContext();
(*op_func_node.pt_kernel_)(op_func_node.pt_kernel_context_);
op_with_kernel->WriteBackToOutputs(&runtime_context);
op_func_node.pt_kernel_context_->ClearData();
} else {
op_func_node.kernel_func_ = OpKernelComputeFunc(kernel_iter->second);
op_func_node.kernel_func_(exec_ctx);
}
// post-process grad_op.outputs if need cast complex grad into real grad. // post-process grad_op.outputs if need cast complex grad into real grad.
// NOTE(Aurelius84): insert a transfer_dtype_op inplacely to cast it. // NOTE(Aurelius84): insert a transfer_dtype_op inplacely to cast it.
......
...@@ -673,6 +673,14 @@ OpKernelComputeFunc Instruction::KernelFunc() const { ...@@ -673,6 +673,14 @@ OpKernelComputeFunc Instruction::KernelFunc() const {
return op_func_node_.kernel_func_; return op_func_node_.kernel_func_;
} }
pten::Kernel* Instruction::PtenKernel() const {
return op_func_node_.pt_kernel_;
}
pten::KernelContext* Instruction::PtenKernelContext() const {
return op_func_node_.pt_kernel_context_;
}
OpFuncType Instruction::KernelType() const { return op_func_node_.type_; } OpFuncType Instruction::KernelType() const { return op_func_node_.type_; }
OperatorBase* Instruction::OpBase() const { OperatorBase* Instruction::OpBase() const {
......
...@@ -295,6 +295,11 @@ struct OpFuncNode { ...@@ -295,6 +295,11 @@ struct OpFuncNode {
OpKernelComputeFunc kernel_func_; OpKernelComputeFunc kernel_func_;
platform::DeviceContext* dev_ctx_; // not owned platform::DeviceContext* dev_ctx_; // not owned
// fit for pten kernel
pten::Kernel* pt_kernel_{nullptr}; // not owned
pten::KernelContext* pt_kernel_context_{nullptr}; // not onwed
OpFuncType type_; OpFuncType type_;
}; };
...@@ -313,6 +318,10 @@ class Instruction { ...@@ -313,6 +318,10 @@ class Instruction {
OpKernelComputeFunc KernelFunc() const; OpKernelComputeFunc KernelFunc() const;
pten::Kernel* PtenKernel() const;
pten::KernelContext* PtenKernelContext() const;
OpFuncType KernelType() const; OpFuncType KernelType() const;
OperatorBase* OpBase() const; OperatorBase* OpBase() const;
......
...@@ -1791,6 +1791,9 @@ KernelSignature OperatorWithKernel::GetExpectedPtenKernelArgs( ...@@ -1791,6 +1791,9 @@ KernelSignature OperatorWithKernel::GetExpectedPtenKernelArgs(
void OperatorWithKernel::BuildPtenKernelContext( void OperatorWithKernel::BuildPtenKernelContext(
const RuntimeContext& ctx, platform::DeviceContext* dev_ctx) const { const RuntimeContext& ctx, platform::DeviceContext* dev_ctx) const {
if (pt_kernel_context_ == nullptr) {
pt_kernel_context_.reset(new pten::KernelContext());
}
// TODO(chenweihang): now only work for very simple case, // TODO(chenweihang): now only work for very simple case,
// many cases need to be deal with later: // many cases need to be deal with later:
// 1. the input and output are not tensor // 1. the input and output are not tensor
......
...@@ -555,6 +555,20 @@ class OperatorWithKernel : public OperatorBase { ...@@ -555,6 +555,20 @@ class OperatorWithKernel : public OperatorBase {
virtual KernelSignature GetExpectedPtenKernelArgs( virtual KernelSignature GetExpectedPtenKernelArgs(
const ExecutionContext& ctx) const; const ExecutionContext& ctx) const;
/* member functions for adapting to pten lib */
void ChoosePtenKernel(const ExecutionContext& ctx) const;
void BuildPtenKernelContext(const RuntimeContext& ctx,
platform::DeviceContext* dev_ctx) const;
void WriteBackToOutputs(RuntimeContext* ctx) const;
pten::Kernel* PtenKernel() const { return pt_kernel_.get(); }
pten::KernelContext* PtenKernelContext() const {
return pt_kernel_context_.get();
}
private: private:
void RunImpl(const Scope& scope, const platform::Place& place) const final; void RunImpl(const Scope& scope, const platform::Place& place) const final;
void RunImpl(const Scope& scope, const platform::Place& place, void RunImpl(const Scope& scope, const platform::Place& place,
...@@ -595,14 +609,6 @@ class OperatorWithKernel : public OperatorBase { ...@@ -595,14 +609,6 @@ class OperatorWithKernel : public OperatorBase {
Tensor* GetTensorFormInputSafely(const ExecutionContext& ctx, Tensor* GetTensorFormInputSafely(const ExecutionContext& ctx,
const std::string& name) const; const std::string& name) const;
/* member functions for adapting to pten lib */
void ChoosePtenKernel(const ExecutionContext& ctx) const;
void BuildPtenKernelContext(const RuntimeContext& ctx,
platform::DeviceContext* dev_ctx) const;
void WriteBackToOutputs(RuntimeContext* ctx) const;
protected: protected:
mutable std::unique_ptr<OpKernelType> kernel_type_; mutable std::unique_ptr<OpKernelType> kernel_type_;
mutable std::unique_ptr<OpKernelFunc> kernel_func_; mutable std::unique_ptr<OpKernelFunc> kernel_func_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册