From 7f3b08772273795fb6845d248603addc6adccfe8 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Fri, 7 Jan 2022 10:52:30 +0800 Subject: [PATCH] [new-exec] support pten kernel (#38770) --- .../framework/new_executor/interpretercore.cc | 18 ++++++- .../new_executor/interpretercore_util.cc | 48 +++++++++++++++---- .../new_executor/new_executor_defs.cc | 8 ++++ .../new_executor/new_executor_defs.h | 9 ++++ paddle/fluid/framework/operator.cc | 3 ++ paddle/fluid/framework/operator.h | 22 +++++---- 6 files changed, 89 insertions(+), 19 deletions(-) diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index 7c0bbac6180..950756c0394 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -413,7 +413,23 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) { if (op_with_kernel == nullptr) { instr_node.OpBase()->Run(*local_scope, place_); } else { - instr_node.KernelFunc()(*instr_node.InnerExecutionContext().get()); + // fit for pten + if (instr_node.PtenKernel() && instr_node.PtenKernel()->IsValid()) { + VLOG(4) << "Run pten kernel: " << op->Type(); + VLOG(4) << instr_node.InnerRuntimeContext().get() << " " + << &instr_node.DeviceContext(); + op_with_kernel->BuildPtenKernelContext( + *instr_node.InnerRuntimeContext().get(), + const_cast(&instr_node.DeviceContext())); + + (*instr_node.PtenKernel())(instr_node.PtenKernelContext()); + + op_with_kernel->WriteBackToOutputs( + instr_node.InnerRuntimeContext().get()); + instr_node.PtenKernelContext()->ClearData(); + } else { + instr_node.KernelFunc()(*instr_node.InnerExecutionContext().get()); + } } } diff --git a/paddle/fluid/framework/new_executor/interpretercore_util.cc b/paddle/fluid/framework/new_executor/interpretercore_util.cc index 3817a11b9af..41c4faa67fb 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_util.cc +++ b/paddle/fluid/framework/new_executor/interpretercore_util.cc @@ -19,10 +19,13 @@ #include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h" #include "paddle/fluid/operators/controlflow/recurrent_op_helper.h" #include "paddle/fluid/operators/controlflow/while_op_helper.h" +#include "paddle/pten/core/kernel_factory.h" PADDLE_DEFINE_EXPORTED_bool( new_executor_sequential_run, false, "Enable sequential execution for standalone executor, used for debug"); +DECLARE_bool(run_pten_kernel); + namespace paddle { namespace framework { namespace interpreter { @@ -338,6 +341,8 @@ void build_op_func_list(const platform::Place& place, // op is not a operatorwithkernel, so direcly run OperatorBase::Run() deal_operator_base(place, var_scope, ops[i], &op_func_node, local_scope); } else { + auto op_with_kernel = + static_cast(op); // construct RuntimeContext and analysis KernelType RuntimeContext runtime_context({}, {}); runtime_context.inputs.swap(ins_map); @@ -350,8 +355,7 @@ void build_op_func_list(const platform::Place& place, // TODO(Aurelius84): In case of control flow ops, they are NOT // inheritted // from OperatorWithKernel. - static_cast(op)->InferShape( - &infer_shape_ctx); + op_with_kernel->InferShape(&infer_shape_ctx); } auto kernels_iter = all_op_kernels.find(op->Type()); @@ -367,10 +371,8 @@ void build_op_func_list(const platform::Place& place, platform::DeviceContextPool::Instance(); auto* dev_ctx = pool.Get(place); Scope scope; - auto expected_kernel_key = - dynamic_cast(op) - ->GetExpectedKernelType( - ExecutionContext(*op, scope, *dev_ctx, runtime_context)); + auto expected_kernel_key = op_with_kernel->GetExpectedKernelType( + ExecutionContext(*op, scope, *dev_ctx, runtime_context)); // change device by the device_guard() apply_device_guard(op, place, &expected_kernel_key); @@ -378,10 +380,16 @@ void build_op_func_list(const platform::Place& place, // step 3. apply data transforms and insert data transfer ops VariableValueMap& ins_map_temp = runtime_context.inputs; + + // NOTE(zhiqiu): op_func_node->operator_base_ maybe changed in + // ApplyDataTransform ApplyDataTransform(expected_kernel_key, place, &ins_map_temp, var_scope, &op_func_node, vec_func_list, use_local_scope); + op_with_kernel = static_cast( + op_func_node.operator_base_.get()); + // step 4. Run op kernel - VLOG(3) << op->Type() + VLOG(3) << op_with_kernel->Type() << " : expected_kernel_key : " << expected_kernel_key; if (platform::is_gpu_place(expected_kernel_key.place_)) { @@ -397,7 +405,8 @@ void build_op_func_list(const platform::Place& place, } op_func_node.dev_ctx_ = dev_ctx; - auto exec_ctx = ExecutionContext(*op, scope, *dev_ctx, runtime_context); + auto exec_ctx = + ExecutionContext(*op_with_kernel, scope, *dev_ctx, runtime_context); auto kernel_iter = kernels.find(expected_kernel_key); PADDLE_ENFORCE_NE( @@ -406,8 +415,27 @@ void build_op_func_list(const platform::Place& place, "Operator (%s) does not have kernel for %s.", op->Type(), KernelTypeToString(expected_kernel_key))); - op_func_node.kernel_func_ = OpKernelComputeFunc(kernel_iter->second); - op_func_node.kernel_func_(exec_ctx); + auto run_pten_kernel = false; + + if (FLAGS_run_pten_kernel && + pten::KernelFactory::Instance().HasCompatiblePtenKernel( + op_with_kernel->Type())) { + op_with_kernel->ChoosePtenKernel(exec_ctx); + run_pten_kernel = op_with_kernel->PtenKernel()->IsValid(); + } + + if (run_pten_kernel) { + op_with_kernel->BuildPtenKernelContext(runtime_context, dev_ctx); + op_func_node.pt_kernel_ = op_with_kernel->PtenKernel(); + op_func_node.pt_kernel_context_ = op_with_kernel->PtenKernelContext(); + + (*op_func_node.pt_kernel_)(op_func_node.pt_kernel_context_); + op_with_kernel->WriteBackToOutputs(&runtime_context); + op_func_node.pt_kernel_context_->ClearData(); + } else { + op_func_node.kernel_func_ = OpKernelComputeFunc(kernel_iter->second); + op_func_node.kernel_func_(exec_ctx); + } // post-process grad_op.outputs if need cast complex grad into real grad. // NOTE(Aurelius84): insert a transfer_dtype_op inplacely to cast it. diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.cc b/paddle/fluid/framework/new_executor/new_executor_defs.cc index 73f16fe3e9c..4b9404fd178 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.cc +++ b/paddle/fluid/framework/new_executor/new_executor_defs.cc @@ -673,6 +673,14 @@ OpKernelComputeFunc Instruction::KernelFunc() const { return op_func_node_.kernel_func_; } +pten::Kernel* Instruction::PtenKernel() const { + return op_func_node_.pt_kernel_; +} + +pten::KernelContext* Instruction::PtenKernelContext() const { + return op_func_node_.pt_kernel_context_; +} + OpFuncType Instruction::KernelType() const { return op_func_node_.type_; } OperatorBase* Instruction::OpBase() const { diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.h b/paddle/fluid/framework/new_executor/new_executor_defs.h index d691a75a6d3..ca49e7f5670 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.h +++ b/paddle/fluid/framework/new_executor/new_executor_defs.h @@ -295,6 +295,11 @@ struct OpFuncNode { OpKernelComputeFunc kernel_func_; platform::DeviceContext* dev_ctx_; // not owned + + // fit for pten kernel + pten::Kernel* pt_kernel_{nullptr}; // not owned + pten::KernelContext* pt_kernel_context_{nullptr}; // not onwed + OpFuncType type_; }; @@ -313,6 +318,10 @@ class Instruction { OpKernelComputeFunc KernelFunc() const; + pten::Kernel* PtenKernel() const; + + pten::KernelContext* PtenKernelContext() const; + OpFuncType KernelType() const; OperatorBase* OpBase() const; diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 50e16920a67..2d2e198ef40 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1791,6 +1791,9 @@ KernelSignature OperatorWithKernel::GetExpectedPtenKernelArgs( void OperatorWithKernel::BuildPtenKernelContext( const RuntimeContext& ctx, platform::DeviceContext* dev_ctx) const { + if (pt_kernel_context_ == nullptr) { + pt_kernel_context_.reset(new pten::KernelContext()); + } // TODO(chenweihang): now only work for very simple case, // many cases need to be deal with later: // 1. the input and output are not tensor diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 842ef0457d7..59bc4813d98 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -555,6 +555,20 @@ class OperatorWithKernel : public OperatorBase { virtual KernelSignature GetExpectedPtenKernelArgs( const ExecutionContext& ctx) const; + /* member functions for adapting to pten lib */ + void ChoosePtenKernel(const ExecutionContext& ctx) const; + + void BuildPtenKernelContext(const RuntimeContext& ctx, + platform::DeviceContext* dev_ctx) const; + + void WriteBackToOutputs(RuntimeContext* ctx) const; + + pten::Kernel* PtenKernel() const { return pt_kernel_.get(); } + + pten::KernelContext* PtenKernelContext() const { + return pt_kernel_context_.get(); + } + private: void RunImpl(const Scope& scope, const platform::Place& place) const final; void RunImpl(const Scope& scope, const platform::Place& place, @@ -595,14 +609,6 @@ class OperatorWithKernel : public OperatorBase { Tensor* GetTensorFormInputSafely(const ExecutionContext& ctx, const std::string& name) const; - /* member functions for adapting to pten lib */ - void ChoosePtenKernel(const ExecutionContext& ctx) const; - - void BuildPtenKernelContext(const RuntimeContext& ctx, - platform::DeviceContext* dev_ctx) const; - - void WriteBackToOutputs(RuntimeContext* ctx) const; - protected: mutable std::unique_ptr kernel_type_; mutable std::unique_ptr kernel_func_; -- GitLab