Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
7f3b0877
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7f3b0877
编写于
1月 07, 2022
作者:
L
Leo Chen
提交者:
GitHub
1月 07, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[new-exec] support pten kernel (#38770)
上级
1b6e4664
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
89 addition
and
19 deletion
+89
-19
paddle/fluid/framework/new_executor/interpretercore.cc
paddle/fluid/framework/new_executor/interpretercore.cc
+17
-1
paddle/fluid/framework/new_executor/interpretercore_util.cc
paddle/fluid/framework/new_executor/interpretercore_util.cc
+38
-10
paddle/fluid/framework/new_executor/new_executor_defs.cc
paddle/fluid/framework/new_executor/new_executor_defs.cc
+8
-0
paddle/fluid/framework/new_executor/new_executor_defs.h
paddle/fluid/framework/new_executor/new_executor_defs.h
+9
-0
paddle/fluid/framework/operator.cc
paddle/fluid/framework/operator.cc
+3
-0
paddle/fluid/framework/operator.h
paddle/fluid/framework/operator.h
+14
-8
未找到文件。
paddle/fluid/framework/new_executor/interpretercore.cc
浏览文件 @
7f3b0877
...
@@ -412,10 +412,26 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
...
@@ -412,10 +412,26 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
platform
::
RecordEvent
compute_event
(
"Compute"
);
platform
::
RecordEvent
compute_event
(
"Compute"
);
if
(
op_with_kernel
==
nullptr
)
{
if
(
op_with_kernel
==
nullptr
)
{
instr_node
.
OpBase
()
->
Run
(
*
local_scope
,
place_
);
instr_node
.
OpBase
()
->
Run
(
*
local_scope
,
place_
);
}
else
{
// fit for pten
if
(
instr_node
.
PtenKernel
()
&&
instr_node
.
PtenKernel
()
->
IsValid
())
{
VLOG
(
4
)
<<
"Run pten kernel: "
<<
op
->
Type
();
VLOG
(
4
)
<<
instr_node
.
InnerRuntimeContext
().
get
()
<<
" "
<<
&
instr_node
.
DeviceContext
();
op_with_kernel
->
BuildPtenKernelContext
(
*
instr_node
.
InnerRuntimeContext
().
get
(),
const_cast
<
platform
::
DeviceContext
*>
(
&
instr_node
.
DeviceContext
()));
(
*
instr_node
.
PtenKernel
())(
instr_node
.
PtenKernelContext
());
op_with_kernel
->
WriteBackToOutputs
(
instr_node
.
InnerRuntimeContext
().
get
());
instr_node
.
PtenKernelContext
()
->
ClearData
();
}
else
{
}
else
{
instr_node
.
KernelFunc
()(
*
instr_node
.
InnerExecutionContext
().
get
());
instr_node
.
KernelFunc
()(
*
instr_node
.
InnerExecutionContext
().
get
());
}
}
}
}
}
VLOG
(
4
)
<<
"End run "
<<
place
<<
" "
<<
op
->
DebugStringEx
(
global_scope_
);
VLOG
(
4
)
<<
"End run "
<<
place
<<
" "
<<
op
->
DebugStringEx
(
global_scope_
);
...
...
paddle/fluid/framework/new_executor/interpretercore_util.cc
浏览文件 @
7f3b0877
...
@@ -19,10 +19,13 @@
...
@@ -19,10 +19,13 @@
#include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h"
#include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h"
#include "paddle/fluid/operators/controlflow/recurrent_op_helper.h"
#include "paddle/fluid/operators/controlflow/recurrent_op_helper.h"
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
#include "paddle/pten/core/kernel_factory.h"
PADDLE_DEFINE_EXPORTED_bool
(
PADDLE_DEFINE_EXPORTED_bool
(
new_executor_sequential_run
,
false
,
new_executor_sequential_run
,
false
,
"Enable sequential execution for standalone executor, used for debug"
);
"Enable sequential execution for standalone executor, used for debug"
);
DECLARE_bool
(
run_pten_kernel
);
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
interpreter
{
namespace
interpreter
{
...
@@ -338,6 +341,8 @@ void build_op_func_list(const platform::Place& place,
...
@@ -338,6 +341,8 @@ void build_op_func_list(const platform::Place& place,
// op is not a operatorwithkernel, so direcly run OperatorBase::Run()
// op is not a operatorwithkernel, so direcly run OperatorBase::Run()
deal_operator_base
(
place
,
var_scope
,
ops
[
i
],
&
op_func_node
,
local_scope
);
deal_operator_base
(
place
,
var_scope
,
ops
[
i
],
&
op_func_node
,
local_scope
);
}
else
{
}
else
{
auto
op_with_kernel
=
static_cast
<
const
framework
::
OperatorWithKernel
*>
(
op
);
// construct RuntimeContext and analysis KernelType
// construct RuntimeContext and analysis KernelType
RuntimeContext
runtime_context
({},
{});
RuntimeContext
runtime_context
({},
{});
runtime_context
.
inputs
.
swap
(
ins_map
);
runtime_context
.
inputs
.
swap
(
ins_map
);
...
@@ -350,8 +355,7 @@ void build_op_func_list(const platform::Place& place,
...
@@ -350,8 +355,7 @@ void build_op_func_list(const platform::Place& place,
// TODO(Aurelius84): In case of control flow ops, they are NOT
// TODO(Aurelius84): In case of control flow ops, they are NOT
// inheritted
// inheritted
// from OperatorWithKernel.
// from OperatorWithKernel.
static_cast
<
const
framework
::
OperatorWithKernel
*>
(
op
)
->
InferShape
(
op_with_kernel
->
InferShape
(
&
infer_shape_ctx
);
&
infer_shape_ctx
);
}
}
auto
kernels_iter
=
all_op_kernels
.
find
(
op
->
Type
());
auto
kernels_iter
=
all_op_kernels
.
find
(
op
->
Type
());
...
@@ -367,9 +371,7 @@ void build_op_func_list(const platform::Place& place,
...
@@ -367,9 +371,7 @@ void build_op_func_list(const platform::Place& place,
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
::
Instance
();
auto
*
dev_ctx
=
pool
.
Get
(
place
);
auto
*
dev_ctx
=
pool
.
Get
(
place
);
Scope
scope
;
Scope
scope
;
auto
expected_kernel_key
=
auto
expected_kernel_key
=
op_with_kernel
->
GetExpectedKernelType
(
dynamic_cast
<
const
framework
::
OperatorWithKernel
*>
(
op
)
->
GetExpectedKernelType
(
ExecutionContext
(
*
op
,
scope
,
*
dev_ctx
,
runtime_context
));
ExecutionContext
(
*
op
,
scope
,
*
dev_ctx
,
runtime_context
));
// change device by the device_guard()
// change device by the device_guard()
...
@@ -378,10 +380,16 @@ void build_op_func_list(const platform::Place& place,
...
@@ -378,10 +380,16 @@ void build_op_func_list(const platform::Place& place,
// step 3. apply data transforms and insert data transfer ops
// step 3. apply data transforms and insert data transfer ops
VariableValueMap
&
ins_map_temp
=
runtime_context
.
inputs
;
VariableValueMap
&
ins_map_temp
=
runtime_context
.
inputs
;
// NOTE(zhiqiu): op_func_node->operator_base_ maybe changed in
// ApplyDataTransform
ApplyDataTransform
(
expected_kernel_key
,
place
,
&
ins_map_temp
,
var_scope
,
ApplyDataTransform
(
expected_kernel_key
,
place
,
&
ins_map_temp
,
var_scope
,
&
op_func_node
,
vec_func_list
,
use_local_scope
);
&
op_func_node
,
vec_func_list
,
use_local_scope
);
op_with_kernel
=
static_cast
<
const
framework
::
OperatorWithKernel
*>
(
op_func_node
.
operator_base_
.
get
());
// step 4. Run op kernel
// step 4. Run op kernel
VLOG
(
3
)
<<
op
->
Type
()
VLOG
(
3
)
<<
op
_with_kernel
->
Type
()
<<
" : expected_kernel_key : "
<<
expected_kernel_key
;
<<
" : expected_kernel_key : "
<<
expected_kernel_key
;
if
(
platform
::
is_gpu_place
(
expected_kernel_key
.
place_
))
{
if
(
platform
::
is_gpu_place
(
expected_kernel_key
.
place_
))
{
...
@@ -397,7 +405,8 @@ void build_op_func_list(const platform::Place& place,
...
@@ -397,7 +405,8 @@ void build_op_func_list(const platform::Place& place,
}
}
op_func_node
.
dev_ctx_
=
dev_ctx
;
op_func_node
.
dev_ctx_
=
dev_ctx
;
auto
exec_ctx
=
ExecutionContext
(
*
op
,
scope
,
*
dev_ctx
,
runtime_context
);
auto
exec_ctx
=
ExecutionContext
(
*
op_with_kernel
,
scope
,
*
dev_ctx
,
runtime_context
);
auto
kernel_iter
=
kernels
.
find
(
expected_kernel_key
);
auto
kernel_iter
=
kernels
.
find
(
expected_kernel_key
);
PADDLE_ENFORCE_NE
(
PADDLE_ENFORCE_NE
(
...
@@ -406,8 +415,27 @@ void build_op_func_list(const platform::Place& place,
...
@@ -406,8 +415,27 @@ void build_op_func_list(const platform::Place& place,
"Operator (%s) does not have kernel for %s."
,
op
->
Type
(),
"Operator (%s) does not have kernel for %s."
,
op
->
Type
(),
KernelTypeToString
(
expected_kernel_key
)));
KernelTypeToString
(
expected_kernel_key
)));
auto
run_pten_kernel
=
false
;
if
(
FLAGS_run_pten_kernel
&&
pten
::
KernelFactory
::
Instance
().
HasCompatiblePtenKernel
(
op_with_kernel
->
Type
()))
{
op_with_kernel
->
ChoosePtenKernel
(
exec_ctx
);
run_pten_kernel
=
op_with_kernel
->
PtenKernel
()
->
IsValid
();
}
if
(
run_pten_kernel
)
{
op_with_kernel
->
BuildPtenKernelContext
(
runtime_context
,
dev_ctx
);
op_func_node
.
pt_kernel_
=
op_with_kernel
->
PtenKernel
();
op_func_node
.
pt_kernel_context_
=
op_with_kernel
->
PtenKernelContext
();
(
*
op_func_node
.
pt_kernel_
)(
op_func_node
.
pt_kernel_context_
);
op_with_kernel
->
WriteBackToOutputs
(
&
runtime_context
);
op_func_node
.
pt_kernel_context_
->
ClearData
();
}
else
{
op_func_node
.
kernel_func_
=
OpKernelComputeFunc
(
kernel_iter
->
second
);
op_func_node
.
kernel_func_
=
OpKernelComputeFunc
(
kernel_iter
->
second
);
op_func_node
.
kernel_func_
(
exec_ctx
);
op_func_node
.
kernel_func_
(
exec_ctx
);
}
// post-process grad_op.outputs if need cast complex grad into real grad.
// post-process grad_op.outputs if need cast complex grad into real grad.
// NOTE(Aurelius84): insert a transfer_dtype_op inplacely to cast it.
// NOTE(Aurelius84): insert a transfer_dtype_op inplacely to cast it.
...
...
paddle/fluid/framework/new_executor/new_executor_defs.cc
浏览文件 @
7f3b0877
...
@@ -673,6 +673,14 @@ OpKernelComputeFunc Instruction::KernelFunc() const {
...
@@ -673,6 +673,14 @@ OpKernelComputeFunc Instruction::KernelFunc() const {
return
op_func_node_
.
kernel_func_
;
return
op_func_node_
.
kernel_func_
;
}
}
pten
::
Kernel
*
Instruction
::
PtenKernel
()
const
{
return
op_func_node_
.
pt_kernel_
;
}
pten
::
KernelContext
*
Instruction
::
PtenKernelContext
()
const
{
return
op_func_node_
.
pt_kernel_context_
;
}
OpFuncType
Instruction
::
KernelType
()
const
{
return
op_func_node_
.
type_
;
}
OpFuncType
Instruction
::
KernelType
()
const
{
return
op_func_node_
.
type_
;
}
OperatorBase
*
Instruction
::
OpBase
()
const
{
OperatorBase
*
Instruction
::
OpBase
()
const
{
...
...
paddle/fluid/framework/new_executor/new_executor_defs.h
浏览文件 @
7f3b0877
...
@@ -295,6 +295,11 @@ struct OpFuncNode {
...
@@ -295,6 +295,11 @@ struct OpFuncNode {
OpKernelComputeFunc
kernel_func_
;
OpKernelComputeFunc
kernel_func_
;
platform
::
DeviceContext
*
dev_ctx_
;
// not owned
platform
::
DeviceContext
*
dev_ctx_
;
// not owned
// fit for pten kernel
pten
::
Kernel
*
pt_kernel_
{
nullptr
};
// not owned
pten
::
KernelContext
*
pt_kernel_context_
{
nullptr
};
// not onwed
OpFuncType
type_
;
OpFuncType
type_
;
};
};
...
@@ -313,6 +318,10 @@ class Instruction {
...
@@ -313,6 +318,10 @@ class Instruction {
OpKernelComputeFunc
KernelFunc
()
const
;
OpKernelComputeFunc
KernelFunc
()
const
;
pten
::
Kernel
*
PtenKernel
()
const
;
pten
::
KernelContext
*
PtenKernelContext
()
const
;
OpFuncType
KernelType
()
const
;
OpFuncType
KernelType
()
const
;
OperatorBase
*
OpBase
()
const
;
OperatorBase
*
OpBase
()
const
;
...
...
paddle/fluid/framework/operator.cc
浏览文件 @
7f3b0877
...
@@ -1791,6 +1791,9 @@ KernelSignature OperatorWithKernel::GetExpectedPtenKernelArgs(
...
@@ -1791,6 +1791,9 @@ KernelSignature OperatorWithKernel::GetExpectedPtenKernelArgs(
void
OperatorWithKernel
::
BuildPtenKernelContext
(
void
OperatorWithKernel
::
BuildPtenKernelContext
(
const
RuntimeContext
&
ctx
,
platform
::
DeviceContext
*
dev_ctx
)
const
{
const
RuntimeContext
&
ctx
,
platform
::
DeviceContext
*
dev_ctx
)
const
{
if
(
pt_kernel_context_
==
nullptr
)
{
pt_kernel_context_
.
reset
(
new
pten
::
KernelContext
());
}
// TODO(chenweihang): now only work for very simple case,
// TODO(chenweihang): now only work for very simple case,
// many cases need to be deal with later:
// many cases need to be deal with later:
// 1. the input and output are not tensor
// 1. the input and output are not tensor
...
...
paddle/fluid/framework/operator.h
浏览文件 @
7f3b0877
...
@@ -555,6 +555,20 @@ class OperatorWithKernel : public OperatorBase {
...
@@ -555,6 +555,20 @@ class OperatorWithKernel : public OperatorBase {
virtual
KernelSignature
GetExpectedPtenKernelArgs
(
virtual
KernelSignature
GetExpectedPtenKernelArgs
(
const
ExecutionContext
&
ctx
)
const
;
const
ExecutionContext
&
ctx
)
const
;
/* member functions for adapting to pten lib */
void
ChoosePtenKernel
(
const
ExecutionContext
&
ctx
)
const
;
void
BuildPtenKernelContext
(
const
RuntimeContext
&
ctx
,
platform
::
DeviceContext
*
dev_ctx
)
const
;
void
WriteBackToOutputs
(
RuntimeContext
*
ctx
)
const
;
pten
::
Kernel
*
PtenKernel
()
const
{
return
pt_kernel_
.
get
();
}
pten
::
KernelContext
*
PtenKernelContext
()
const
{
return
pt_kernel_context_
.
get
();
}
private:
private:
void
RunImpl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
final
;
void
RunImpl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
final
;
void
RunImpl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
,
void
RunImpl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
,
...
@@ -595,14 +609,6 @@ class OperatorWithKernel : public OperatorBase {
...
@@ -595,14 +609,6 @@ class OperatorWithKernel : public OperatorBase {
Tensor
*
GetTensorFormInputSafely
(
const
ExecutionContext
&
ctx
,
Tensor
*
GetTensorFormInputSafely
(
const
ExecutionContext
&
ctx
,
const
std
::
string
&
name
)
const
;
const
std
::
string
&
name
)
const
;
/* member functions for adapting to pten lib */
void
ChoosePtenKernel
(
const
ExecutionContext
&
ctx
)
const
;
void
BuildPtenKernelContext
(
const
RuntimeContext
&
ctx
,
platform
::
DeviceContext
*
dev_ctx
)
const
;
void
WriteBackToOutputs
(
RuntimeContext
*
ctx
)
const
;
protected:
protected:
mutable
std
::
unique_ptr
<
OpKernelType
>
kernel_type_
;
mutable
std
::
unique_ptr
<
OpKernelType
>
kernel_type_
;
mutable
std
::
unique_ptr
<
OpKernelFunc
>
kernel_func_
;
mutable
std
::
unique_ptr
<
OpKernelFunc
>
kernel_func_
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录