diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index 7c0bbac61807eb1ff3259764a8e5d8a31288ae55..950756c0394a5e6e851644e431b990418abffaaa 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -413,7 +413,23 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) { if (op_with_kernel == nullptr) { instr_node.OpBase()->Run(*local_scope, place_); } else { - instr_node.KernelFunc()(*instr_node.InnerExecutionContext().get()); + // fit for pten + if (instr_node.PtenKernel() && instr_node.PtenKernel()->IsValid()) { + VLOG(4) << "Run pten kernel: " << op->Type(); + VLOG(4) << instr_node.InnerRuntimeContext().get() << " " + << &instr_node.DeviceContext(); + op_with_kernel->BuildPtenKernelContext( + *instr_node.InnerRuntimeContext().get(), + const_cast(&instr_node.DeviceContext())); + + (*instr_node.PtenKernel())(instr_node.PtenKernelContext()); + + op_with_kernel->WriteBackToOutputs( + instr_node.InnerRuntimeContext().get()); + instr_node.PtenKernelContext()->ClearData(); + } else { + instr_node.KernelFunc()(*instr_node.InnerExecutionContext().get()); + } } } diff --git a/paddle/fluid/framework/new_executor/interpretercore_util.cc b/paddle/fluid/framework/new_executor/interpretercore_util.cc index 3817a11b9afe4e6d6b3e2526db612b2c8893800d..41c4faa67fbebec194c86c1eecbe33ddbc2df9c2 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_util.cc +++ b/paddle/fluid/framework/new_executor/interpretercore_util.cc @@ -19,10 +19,13 @@ #include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h" #include "paddle/fluid/operators/controlflow/recurrent_op_helper.h" #include "paddle/fluid/operators/controlflow/while_op_helper.h" +#include "paddle/pten/core/kernel_factory.h" PADDLE_DEFINE_EXPORTED_bool( new_executor_sequential_run, false, "Enable sequential execution for standalone executor, used for debug"); +DECLARE_bool(run_pten_kernel); + namespace paddle { namespace framework { namespace interpreter { @@ -338,6 +341,8 @@ void build_op_func_list(const platform::Place& place, // op is not a operatorwithkernel, so direcly run OperatorBase::Run() deal_operator_base(place, var_scope, ops[i], &op_func_node, local_scope); } else { + auto op_with_kernel = + static_cast(op); // construct RuntimeContext and analysis KernelType RuntimeContext runtime_context({}, {}); runtime_context.inputs.swap(ins_map); @@ -350,8 +355,7 @@ void build_op_func_list(const platform::Place& place, // TODO(Aurelius84): In case of control flow ops, they are NOT // inheritted // from OperatorWithKernel. - static_cast(op)->InferShape( - &infer_shape_ctx); + op_with_kernel->InferShape(&infer_shape_ctx); } auto kernels_iter = all_op_kernels.find(op->Type()); @@ -367,10 +371,8 @@ void build_op_func_list(const platform::Place& place, platform::DeviceContextPool::Instance(); auto* dev_ctx = pool.Get(place); Scope scope; - auto expected_kernel_key = - dynamic_cast(op) - ->GetExpectedKernelType( - ExecutionContext(*op, scope, *dev_ctx, runtime_context)); + auto expected_kernel_key = op_with_kernel->GetExpectedKernelType( + ExecutionContext(*op, scope, *dev_ctx, runtime_context)); // change device by the device_guard() apply_device_guard(op, place, &expected_kernel_key); @@ -378,10 +380,16 @@ void build_op_func_list(const platform::Place& place, // step 3. apply data transforms and insert data transfer ops VariableValueMap& ins_map_temp = runtime_context.inputs; + + // NOTE(zhiqiu): op_func_node->operator_base_ maybe changed in + // ApplyDataTransform ApplyDataTransform(expected_kernel_key, place, &ins_map_temp, var_scope, &op_func_node, vec_func_list, use_local_scope); + op_with_kernel = static_cast( + op_func_node.operator_base_.get()); + // step 4. Run op kernel - VLOG(3) << op->Type() + VLOG(3) << op_with_kernel->Type() << " : expected_kernel_key : " << expected_kernel_key; if (platform::is_gpu_place(expected_kernel_key.place_)) { @@ -397,7 +405,8 @@ void build_op_func_list(const platform::Place& place, } op_func_node.dev_ctx_ = dev_ctx; - auto exec_ctx = ExecutionContext(*op, scope, *dev_ctx, runtime_context); + auto exec_ctx = + ExecutionContext(*op_with_kernel, scope, *dev_ctx, runtime_context); auto kernel_iter = kernels.find(expected_kernel_key); PADDLE_ENFORCE_NE( @@ -406,8 +415,27 @@ void build_op_func_list(const platform::Place& place, "Operator (%s) does not have kernel for %s.", op->Type(), KernelTypeToString(expected_kernel_key))); - op_func_node.kernel_func_ = OpKernelComputeFunc(kernel_iter->second); - op_func_node.kernel_func_(exec_ctx); + auto run_pten_kernel = false; + + if (FLAGS_run_pten_kernel && + pten::KernelFactory::Instance().HasCompatiblePtenKernel( + op_with_kernel->Type())) { + op_with_kernel->ChoosePtenKernel(exec_ctx); + run_pten_kernel = op_with_kernel->PtenKernel()->IsValid(); + } + + if (run_pten_kernel) { + op_with_kernel->BuildPtenKernelContext(runtime_context, dev_ctx); + op_func_node.pt_kernel_ = op_with_kernel->PtenKernel(); + op_func_node.pt_kernel_context_ = op_with_kernel->PtenKernelContext(); + + (*op_func_node.pt_kernel_)(op_func_node.pt_kernel_context_); + op_with_kernel->WriteBackToOutputs(&runtime_context); + op_func_node.pt_kernel_context_->ClearData(); + } else { + op_func_node.kernel_func_ = OpKernelComputeFunc(kernel_iter->second); + op_func_node.kernel_func_(exec_ctx); + } // post-process grad_op.outputs if need cast complex grad into real grad. // NOTE(Aurelius84): insert a transfer_dtype_op inplacely to cast it. diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.cc b/paddle/fluid/framework/new_executor/new_executor_defs.cc index 73f16fe3e9cc79c0ff3481f52bc45bcae80c9da9..4b9404fd178fd6b8073f7b9438701cd2ebe11627 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.cc +++ b/paddle/fluid/framework/new_executor/new_executor_defs.cc @@ -673,6 +673,14 @@ OpKernelComputeFunc Instruction::KernelFunc() const { return op_func_node_.kernel_func_; } +pten::Kernel* Instruction::PtenKernel() const { + return op_func_node_.pt_kernel_; +} + +pten::KernelContext* Instruction::PtenKernelContext() const { + return op_func_node_.pt_kernel_context_; +} + OpFuncType Instruction::KernelType() const { return op_func_node_.type_; } OperatorBase* Instruction::OpBase() const { diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.h b/paddle/fluid/framework/new_executor/new_executor_defs.h index d691a75a6d35b46563564bdabb2e7599b3d165aa..ca49e7f5670d6cd27e614a91563bcc0f608393c0 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.h +++ b/paddle/fluid/framework/new_executor/new_executor_defs.h @@ -295,6 +295,11 @@ struct OpFuncNode { OpKernelComputeFunc kernel_func_; platform::DeviceContext* dev_ctx_; // not owned + + // fit for pten kernel + pten::Kernel* pt_kernel_{nullptr}; // not owned + pten::KernelContext* pt_kernel_context_{nullptr}; // not onwed + OpFuncType type_; }; @@ -313,6 +318,10 @@ class Instruction { OpKernelComputeFunc KernelFunc() const; + pten::Kernel* PtenKernel() const; + + pten::KernelContext* PtenKernelContext() const; + OpFuncType KernelType() const; OperatorBase* OpBase() const; diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 50e16920a673703c0dc9774ce46629a83a8a8829..2d2e198ef40eca168cc5ae933805cb42447dc62b 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1791,6 +1791,9 @@ KernelSignature OperatorWithKernel::GetExpectedPtenKernelArgs( void OperatorWithKernel::BuildPtenKernelContext( const RuntimeContext& ctx, platform::DeviceContext* dev_ctx) const { + if (pt_kernel_context_ == nullptr) { + pt_kernel_context_.reset(new pten::KernelContext()); + } // TODO(chenweihang): now only work for very simple case, // many cases need to be deal with later: // 1. the input and output are not tensor diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 842ef0457d7bd247c339a66372376f520d7d4386..59bc4813d985b9c1b443793d0e106e803a4c5aff 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -555,6 +555,20 @@ class OperatorWithKernel : public OperatorBase { virtual KernelSignature GetExpectedPtenKernelArgs( const ExecutionContext& ctx) const; + /* member functions for adapting to pten lib */ + void ChoosePtenKernel(const ExecutionContext& ctx) const; + + void BuildPtenKernelContext(const RuntimeContext& ctx, + platform::DeviceContext* dev_ctx) const; + + void WriteBackToOutputs(RuntimeContext* ctx) const; + + pten::Kernel* PtenKernel() const { return pt_kernel_.get(); } + + pten::KernelContext* PtenKernelContext() const { + return pt_kernel_context_.get(); + } + private: void RunImpl(const Scope& scope, const platform::Place& place) const final; void RunImpl(const Scope& scope, const platform::Place& place, @@ -595,14 +609,6 @@ class OperatorWithKernel : public OperatorBase { Tensor* GetTensorFormInputSafely(const ExecutionContext& ctx, const std::string& name) const; - /* member functions for adapting to pten lib */ - void ChoosePtenKernel(const ExecutionContext& ctx) const; - - void BuildPtenKernelContext(const RuntimeContext& ctx, - platform::DeviceContext* dev_ctx) const; - - void WriteBackToOutputs(RuntimeContext* ctx) const; - protected: mutable std::unique_ptr kernel_type_; mutable std::unique_ptr kernel_func_;