未验证 提交 7f3b0877 编写于 作者: L Leo Chen 提交者: GitHub

[new-exec] support pten kernel (#38770)

上级 1b6e4664
......@@ -413,7 +413,23 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
if (op_with_kernel == nullptr) {
instr_node.OpBase()->Run(*local_scope, place_);
} else {
instr_node.KernelFunc()(*instr_node.InnerExecutionContext().get());
// fit for pten
if (instr_node.PtenKernel() && instr_node.PtenKernel()->IsValid()) {
VLOG(4) << "Run pten kernel: " << op->Type();
VLOG(4) << instr_node.InnerRuntimeContext().get() << " "
<< &instr_node.DeviceContext();
op_with_kernel->BuildPtenKernelContext(
*instr_node.InnerRuntimeContext().get(),
const_cast<platform::DeviceContext*>(&instr_node.DeviceContext()));
(*instr_node.PtenKernel())(instr_node.PtenKernelContext());
op_with_kernel->WriteBackToOutputs(
instr_node.InnerRuntimeContext().get());
instr_node.PtenKernelContext()->ClearData();
} else {
instr_node.KernelFunc()(*instr_node.InnerExecutionContext().get());
}
}
}
......
......@@ -19,10 +19,13 @@
#include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h"
#include "paddle/fluid/operators/controlflow/recurrent_op_helper.h"
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
#include "paddle/pten/core/kernel_factory.h"
PADDLE_DEFINE_EXPORTED_bool(
new_executor_sequential_run, false,
"Enable sequential execution for standalone executor, used for debug");
DECLARE_bool(run_pten_kernel);
namespace paddle {
namespace framework {
namespace interpreter {
......@@ -338,6 +341,8 @@ void build_op_func_list(const platform::Place& place,
// op is not a operatorwithkernel, so direcly run OperatorBase::Run()
deal_operator_base(place, var_scope, ops[i], &op_func_node, local_scope);
} else {
auto op_with_kernel =
static_cast<const framework::OperatorWithKernel*>(op);
// construct RuntimeContext and analysis KernelType
RuntimeContext runtime_context({}, {});
runtime_context.inputs.swap(ins_map);
......@@ -350,8 +355,7 @@ void build_op_func_list(const platform::Place& place,
// TODO(Aurelius84): In case of control flow ops, they are NOT
// inheritted
// from OperatorWithKernel.
static_cast<const framework::OperatorWithKernel*>(op)->InferShape(
&infer_shape_ctx);
op_with_kernel->InferShape(&infer_shape_ctx);
}
auto kernels_iter = all_op_kernels.find(op->Type());
......@@ -367,10 +371,8 @@ void build_op_func_list(const platform::Place& place,
platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);
Scope scope;
auto expected_kernel_key =
dynamic_cast<const framework::OperatorWithKernel*>(op)
->GetExpectedKernelType(
ExecutionContext(*op, scope, *dev_ctx, runtime_context));
auto expected_kernel_key = op_with_kernel->GetExpectedKernelType(
ExecutionContext(*op, scope, *dev_ctx, runtime_context));
// change device by the device_guard()
apply_device_guard(op, place, &expected_kernel_key);
......@@ -378,10 +380,16 @@ void build_op_func_list(const platform::Place& place,
// step 3. apply data transforms and insert data transfer ops
VariableValueMap& ins_map_temp = runtime_context.inputs;
// NOTE(zhiqiu): op_func_node->operator_base_ maybe changed in
// ApplyDataTransform
ApplyDataTransform(expected_kernel_key, place, &ins_map_temp, var_scope,
&op_func_node, vec_func_list, use_local_scope);
op_with_kernel = static_cast<const framework::OperatorWithKernel*>(
op_func_node.operator_base_.get());
// step 4. Run op kernel
VLOG(3) << op->Type()
VLOG(3) << op_with_kernel->Type()
<< " : expected_kernel_key : " << expected_kernel_key;
if (platform::is_gpu_place(expected_kernel_key.place_)) {
......@@ -397,7 +405,8 @@ void build_op_func_list(const platform::Place& place,
}
op_func_node.dev_ctx_ = dev_ctx;
auto exec_ctx = ExecutionContext(*op, scope, *dev_ctx, runtime_context);
auto exec_ctx =
ExecutionContext(*op_with_kernel, scope, *dev_ctx, runtime_context);
auto kernel_iter = kernels.find(expected_kernel_key);
PADDLE_ENFORCE_NE(
......@@ -406,8 +415,27 @@ void build_op_func_list(const platform::Place& place,
"Operator (%s) does not have kernel for %s.", op->Type(),
KernelTypeToString(expected_kernel_key)));
op_func_node.kernel_func_ = OpKernelComputeFunc(kernel_iter->second);
op_func_node.kernel_func_(exec_ctx);
auto run_pten_kernel = false;
if (FLAGS_run_pten_kernel &&
pten::KernelFactory::Instance().HasCompatiblePtenKernel(
op_with_kernel->Type())) {
op_with_kernel->ChoosePtenKernel(exec_ctx);
run_pten_kernel = op_with_kernel->PtenKernel()->IsValid();
}
if (run_pten_kernel) {
op_with_kernel->BuildPtenKernelContext(runtime_context, dev_ctx);
op_func_node.pt_kernel_ = op_with_kernel->PtenKernel();
op_func_node.pt_kernel_context_ = op_with_kernel->PtenKernelContext();
(*op_func_node.pt_kernel_)(op_func_node.pt_kernel_context_);
op_with_kernel->WriteBackToOutputs(&runtime_context);
op_func_node.pt_kernel_context_->ClearData();
} else {
op_func_node.kernel_func_ = OpKernelComputeFunc(kernel_iter->second);
op_func_node.kernel_func_(exec_ctx);
}
// post-process grad_op.outputs if need cast complex grad into real grad.
// NOTE(Aurelius84): insert a transfer_dtype_op inplacely to cast it.
......
......@@ -673,6 +673,14 @@ OpKernelComputeFunc Instruction::KernelFunc() const {
return op_func_node_.kernel_func_;
}
pten::Kernel* Instruction::PtenKernel() const {
return op_func_node_.pt_kernel_;
}
pten::KernelContext* Instruction::PtenKernelContext() const {
return op_func_node_.pt_kernel_context_;
}
OpFuncType Instruction::KernelType() const { return op_func_node_.type_; }
OperatorBase* Instruction::OpBase() const {
......
......@@ -295,6 +295,11 @@ struct OpFuncNode {
OpKernelComputeFunc kernel_func_;
platform::DeviceContext* dev_ctx_; // not owned
// fit for pten kernel
pten::Kernel* pt_kernel_{nullptr}; // not owned
pten::KernelContext* pt_kernel_context_{nullptr}; // not onwed
OpFuncType type_;
};
......@@ -313,6 +318,10 @@ class Instruction {
OpKernelComputeFunc KernelFunc() const;
pten::Kernel* PtenKernel() const;
pten::KernelContext* PtenKernelContext() const;
OpFuncType KernelType() const;
OperatorBase* OpBase() const;
......
......@@ -1791,6 +1791,9 @@ KernelSignature OperatorWithKernel::GetExpectedPtenKernelArgs(
void OperatorWithKernel::BuildPtenKernelContext(
const RuntimeContext& ctx, platform::DeviceContext* dev_ctx) const {
if (pt_kernel_context_ == nullptr) {
pt_kernel_context_.reset(new pten::KernelContext());
}
// TODO(chenweihang): now only work for very simple case,
// many cases need to be deal with later:
// 1. the input and output are not tensor
......
......@@ -555,6 +555,20 @@ class OperatorWithKernel : public OperatorBase {
virtual KernelSignature GetExpectedPtenKernelArgs(
const ExecutionContext& ctx) const;
/* member functions for adapting to pten lib */
void ChoosePtenKernel(const ExecutionContext& ctx) const;
void BuildPtenKernelContext(const RuntimeContext& ctx,
platform::DeviceContext* dev_ctx) const;
void WriteBackToOutputs(RuntimeContext* ctx) const;
pten::Kernel* PtenKernel() const { return pt_kernel_.get(); }
pten::KernelContext* PtenKernelContext() const {
return pt_kernel_context_.get();
}
private:
void RunImpl(const Scope& scope, const platform::Place& place) const final;
void RunImpl(const Scope& scope, const platform::Place& place,
......@@ -595,14 +609,6 @@ class OperatorWithKernel : public OperatorBase {
Tensor* GetTensorFormInputSafely(const ExecutionContext& ctx,
const std::string& name) const;
/* member functions for adapting to pten lib */
void ChoosePtenKernel(const ExecutionContext& ctx) const;
void BuildPtenKernelContext(const RuntimeContext& ctx,
platform::DeviceContext* dev_ctx) const;
void WriteBackToOutputs(RuntimeContext* ctx) const;
protected:
mutable std::unique_ptr<OpKernelType> kernel_type_;
mutable std::unique_ptr<OpKernelFunc> kernel_func_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册