Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
16e40513
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
16e40513
编写于
9月 14, 2021
作者:
W
wanghuancoder
提交者:
GitHub
9月 14, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
skip sharelod, test=develop (#35625)
上级
eb810c1b
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
462 addition
and
436 deletion
+462
-436
paddle/fluid/framework/new_executor/interpretercore.cc
paddle/fluid/framework/new_executor/interpretercore.cc
+24
-2
paddle/fluid/framework/new_executor/interpretercore.h
paddle/fluid/framework/new_executor/interpretercore.h
+2
-0
paddle/fluid/framework/new_executor/interpretercore_util.cc
paddle/fluid/framework/new_executor/interpretercore_util.cc
+3
-3
paddle/fluid/framework/new_executor/interpretercore_util.h
paddle/fluid/framework/new_executor/interpretercore_util.h
+0
-428
paddle/fluid/framework/new_executor/new_executor_defs.h
paddle/fluid/framework/new_executor/new_executor_defs.h
+433
-3
未找到文件。
paddle/fluid/framework/new_executor/interpretercore.cc
浏览文件 @
16e40513
...
@@ -188,6 +188,8 @@ void InterpreterCore::Convert() {
...
@@ -188,6 +188,8 @@ void InterpreterCore::Convert() {
BuildAndCacheInstructionCtx
(
&
vec_instruction_
[
i
],
*
global_scope_
,
place_
);
BuildAndCacheInstructionCtx
(
&
vec_instruction_
[
i
],
*
global_scope_
,
place_
);
}
}
BuildSkipShareLoDInfo
();
for
(
size_t
i
=
0
;
i
<
vec_instruction_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
vec_instruction_
.
size
();
++
i
)
{
gc_event_
.
emplace_back
(
vec_instruction_
[
i
].
execution_ctx_
.
get
()
->
GetPlace
(),
gc_event_
.
emplace_back
(
vec_instruction_
[
i
].
execution_ctx_
.
get
()
->
GetPlace
(),
platform
::
GenerateDeviceEventFlag
());
platform
::
GenerateDeviceEventFlag
());
...
@@ -225,8 +227,8 @@ void InterpreterCore::BuildAndCacheInstructionCtx(
...
@@ -225,8 +227,8 @@ void InterpreterCore::BuildAndCacheInstructionCtx(
instr_node
->
runtime_ctx_
->
inputs
.
swap
(
ins_map
);
instr_node
->
runtime_ctx_
->
inputs
.
swap
(
ins_map
);
instr_node
->
runtime_ctx_
->
outputs
.
swap
(
outs_map
);
instr_node
->
runtime_ctx_
->
outputs
.
swap
(
outs_map
);
instr_node
->
infershape_ctx_
.
reset
(
instr_node
->
infershape_ctx_
.
reset
(
new
InterpretercoreInferShapeContext
(
new
RuntimeInferShapeContext
(
*
op_base
,
*
instr_node
->
runtime_ctx_
.
get
()));
*
op_base
,
*
instr_node
->
runtime_ctx_
.
get
()));
auto
*
dev_ctx
=
instr_node
->
dev_ctx_
;
auto
*
dev_ctx
=
instr_node
->
dev_ctx_
;
Scope
scope
;
Scope
scope
;
...
@@ -235,6 +237,26 @@ void InterpreterCore::BuildAndCacheInstructionCtx(
...
@@ -235,6 +237,26 @@ void InterpreterCore::BuildAndCacheInstructionCtx(
*
op_base
,
scope
,
*
dev_ctx
,
*
instr_node
->
runtime_ctx_
.
get
()));
*
op_base
,
scope
,
*
dev_ctx
,
*
instr_node
->
runtime_ctx_
.
get
()));
}
}
void
InterpreterCore
::
BuildSkipShareLoDInfo
()
{
for
(
size_t
i
=
0
;
i
<
vec_instruction_
.
size
();
++
i
)
{
bool
can_skip_lod
=
true
;
for
(
auto
&
input
:
vec_instruction_
[
i
].
runtime_ctx_
.
get
()
->
inputs
)
{
for
(
auto
&
var
:
input
.
second
)
{
if
(
var
->
IsType
<
LoDTensor
>
())
{
if
(
var
->
Get
<
LoDTensor
>
().
lod
().
size
()
!=
0
)
{
can_skip_lod
=
false
;
break
;
}
}
else
{
can_skip_lod
=
false
;
break
;
}
}
}
vec_instruction_
[
i
].
infershape_ctx_
.
get
()
->
SetSkipLoD
(
can_skip_lod
);
}
}
void
InterpreterCore
::
RunInstruction
(
const
Instruction
&
instr_node
)
{
void
InterpreterCore
::
RunInstruction
(
const
Instruction
&
instr_node
)
{
VLOG
(
3
)
<<
"RunInstruction: "
VLOG
(
3
)
<<
"RunInstruction: "
<<
instr_node
.
kernel_func_
.
operator_base_
->
Type
();
<<
instr_node
.
kernel_func_
.
operator_base_
->
Type
();
...
...
paddle/fluid/framework/new_executor/interpretercore.h
浏览文件 @
16e40513
...
@@ -68,6 +68,8 @@ class InterpreterCore {
...
@@ -68,6 +68,8 @@ class InterpreterCore {
void
AddFetch
(
const
std
::
vector
<
std
::
string
>&
fetch_names
);
void
AddFetch
(
const
std
::
vector
<
std
::
string
>&
fetch_names
);
void
BuildSkipShareLoDInfo
();
bool
is_build_
;
bool
is_build_
;
const
platform
::
Place
&
place_
;
const
platform
::
Place
&
place_
;
...
...
paddle/fluid/framework/new_executor/interpretercore_util.cc
浏览文件 @
16e40513
...
@@ -206,7 +206,7 @@ void build_op_func_list(const platform::Place& place,
...
@@ -206,7 +206,7 @@ void build_op_func_list(const platform::Place& place,
RuntimeContext
runtime_context
({},
{});
RuntimeContext
runtime_context
({},
{});
runtime_context
.
inputs
.
swap
(
ins_map
);
runtime_context
.
inputs
.
swap
(
ins_map
);
runtime_context
.
outputs
.
swap
(
outs_map
);
runtime_context
.
outputs
.
swap
(
outs_map
);
Runtim
eInferShapeContext
infer_shape_ctx
(
*
op_base
,
runtime_context
);
Interpretercor
eInferShapeContext
infer_shape_ctx
(
*
op_base
,
runtime_context
);
static_cast
<
const
framework
::
OperatorWithKernel
*>
(
op_base
)
->
InferShape
(
static_cast
<
const
framework
::
OperatorWithKernel
*>
(
op_base
)
->
InferShape
(
&
infer_shape_ctx
);
&
infer_shape_ctx
);
auto
kernels_iter
=
all_op_kernels
.
find
(
op
->
Type
());
auto
kernels_iter
=
all_op_kernels
.
find
(
op
->
Type
());
...
@@ -320,8 +320,8 @@ void build_op_func_list(const platform::Place& place,
...
@@ -320,8 +320,8 @@ void build_op_func_list(const platform::Place& place,
RuntimeContext
copy_runtime_context
({},
{});
RuntimeContext
copy_runtime_context
({},
{});
copy_runtime_context
.
inputs
.
swap
(
copy_ins_value_map
);
copy_runtime_context
.
inputs
.
swap
(
copy_ins_value_map
);
copy_runtime_context
.
outputs
.
swap
(
copy_outs_value_map
);
copy_runtime_context
.
outputs
.
swap
(
copy_outs_value_map
);
RuntimeInferShapeContext
copy_infer_shape_ctx
(
*
copy_op
,
InterpretercoreInferShapeContext
copy_infer_shape_ctx
(
copy_runtime_context
);
*
copy_op
,
copy_runtime_context
);
static_cast
<
const
framework
::
OperatorWithKernel
*>
(
copy_op
)
static_cast
<
const
framework
::
OperatorWithKernel
*>
(
copy_op
)
->
InferShape
(
&
copy_infer_shape_ctx
);
->
InferShape
(
&
copy_infer_shape_ctx
);
...
...
paddle/fluid/framework/new_executor/interpretercore_util.h
浏览文件 @
16e40513
...
@@ -46,434 +46,6 @@
...
@@ -46,434 +46,6 @@
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
class
RuntimeInferShapeContext
:
public
InferShapeContext
{
public:
RuntimeInferShapeContext
(
const
OperatorBase
&
op
,
const
RuntimeContext
&
ctx
)
:
op_
(
op
),
ctx_
(
ctx
)
{}
bool
HasInput
(
const
std
::
string
&
name
)
const
override
{
// has only one input
const
auto
&
ins
=
ctx_
.
inputs
;
auto
it
=
ins
.
find
(
name
);
if
(
it
==
ins
.
end
())
{
return
false
;
}
const
auto
&
in
=
it
->
second
;
if
(
in
.
size
()
==
0
)
return
false
;
PADDLE_ENFORCE_EQ
(
in
.
size
(),
1UL
,
platform
::
errors
::
InvalidArgument
(
"Input %s should not contain more than one inputs."
,
name
));
return
in
[
0
]
!=
nullptr
;
}
bool
HasOutput
(
const
std
::
string
&
name
)
const
override
{
// has only one output
const
auto
&
outs
=
ctx_
.
outputs
;
auto
it
=
outs
.
find
(
name
);
if
(
it
==
outs
.
end
())
{
return
false
;
}
const
auto
&
out
=
it
->
second
;
if
(
out
.
size
()
==
0
)
{
return
false
;
}
PADDLE_ENFORCE_EQ
(
out
.
size
(),
1UL
,
platform
::
errors
::
InvalidArgument
(
"Output %s should not contain more than one outputs."
,
name
));
return
out
[
0
]
!=
nullptr
;
}
bool
HasInputs
(
const
std
::
string
&
name
)
const
override
{
const
auto
&
ins
=
ctx_
.
inputs
;
auto
it
=
ins
.
find
(
name
);
if
(
it
==
ins
.
end
()
||
it
->
second
.
empty
())
{
return
false
;
}
for
(
auto
&
input
:
it
->
second
)
{
if
(
input
==
nullptr
)
{
return
false
;
}
}
return
true
;
}
bool
HasOutputs
(
const
std
::
string
&
name
)
const
override
{
const
auto
&
outs
=
ctx_
.
outputs
;
auto
it
=
outs
.
find
(
name
);
if
(
it
==
outs
.
end
()
||
it
->
second
.
empty
())
{
return
false
;
}
for
(
auto
&
output
:
it
->
second
)
{
if
(
output
==
nullptr
)
{
return
false
;
}
}
return
true
;
}
AttrReader
Attrs
()
const
override
{
return
AttrReader
(
op_
.
Attrs
());
}
std
::
vector
<
std
::
string
>
Inputs
(
const
std
::
string
&
name
)
const
override
{
return
op_
.
Inputs
(
name
);
}
std
::
vector
<
std
::
string
>
Outputs
(
const
std
::
string
&
name
)
const
override
{
return
op_
.
Outputs
(
name
);
}
std
::
string
GetInputNameByIdx
(
size_t
idx
)
const
override
{
auto
&
op_proto
=
paddle
::
framework
::
OpInfoMap
::
Instance
().
Get
(
op_
.
Type
()).
proto_
;
PADDLE_ENFORCE_LT
(
idx
,
op_proto
->
inputs
().
size
(),
platform
::
errors
::
OutOfRange
(
"The index should be less than the size of inputs of "
"operator %s, but got index is %d and size is %d"
,
op_
.
Type
(),
idx
,
op_proto
->
inputs
().
size
()));
return
op_proto
->
inputs
()[
idx
].
name
();
}
std
::
string
GetOutputNameByIdx
(
size_t
idx
)
const
override
{
auto
&
op_proto
=
paddle
::
framework
::
OpInfoMap
::
Instance
().
Get
(
op_
.
Type
()).
proto_
;
PADDLE_ENFORCE_LT
(
idx
,
op_proto
->
outputs
().
size
(),
platform
::
errors
::
OutOfRange
(
"The index should be less than the size of outputs of "
"operator %s, but got index is %d and size is %d"
,
op_
.
Type
(),
idx
,
op_proto
->
outputs
().
size
()));
return
op_proto
->
outputs
()[
idx
].
name
();
}
void
ShareDim
(
const
std
::
string
&
in
,
const
std
::
string
&
out
,
size_t
i
=
0
,
size_t
j
=
0
)
override
{
auto
in_it
=
ctx_
.
inputs
.
find
(
in
);
auto
out_it
=
ctx_
.
outputs
.
find
(
out
);
PADDLE_ENFORCE_NE
(
in_it
,
ctx_
.
inputs
.
end
(),
platform
::
errors
::
NotFound
(
"Input %s does not exist."
,
in
));
PADDLE_ENFORCE_NE
(
out_it
,
ctx_
.
outputs
.
end
(),
platform
::
errors
::
NotFound
(
"Output %s does not exist."
,
out
));
PADDLE_ENFORCE_LT
(
i
,
in_it
->
second
.
size
(),
platform
::
errors
::
InvalidArgument
(
"The index of input dimension is out of range, "
"excepted index less than %zu, but received %zu."
,
in_it
->
second
.
size
(),
i
));
PADDLE_ENFORCE_LT
(
j
,
out_it
->
second
.
size
(),
platform
::
errors
::
InvalidArgument
(
"The index of output dimension is out of range, "
"excepted index less than %zu, but received %zu."
,
out_it
->
second
.
size
(),
j
));
Variable
*
in_var
=
in_it
->
second
[
i
];
Variable
*
out_var
=
out_it
->
second
[
j
];
PADDLE_ENFORCE_EQ
(
in_var
->
Type
(),
out_var
->
Type
(),
platform
::
errors
::
InvalidArgument
(
"The type of input (%s) and output (%s) are inconsistent."
,
in
,
out
));
if
(
in_var
->
IsType
<
framework
::
SelectedRows
>
())
{
auto
&
in_sele_rows
=
in_var
->
Get
<
framework
::
SelectedRows
>
();
auto
out_sele_rows
=
out_var
->
GetMutable
<
framework
::
SelectedRows
>
();
out_sele_rows
->
mutable_value
()
->
Resize
(
in_sele_rows
.
value
().
dims
());
out_sele_rows
->
set_rows
(
in_sele_rows
.
rows
());
out_sele_rows
->
set_height
(
in_sele_rows
.
height
());
}
else
if
(
in_var
->
IsType
<
framework
::
LoDTensor
>
())
{
auto
&
in_lod_tensor
=
in_var
->
Get
<
framework
::
LoDTensor
>
();
auto
*
out_lod_tensor
=
out_var
->
GetMutable
<
framework
::
LoDTensor
>
();
out_lod_tensor
->
Resize
(
in_lod_tensor
.
dims
());
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Currently, the input type of ShareDim only can be LoDTensor "
"or SelectedRows."
));
}
}
void
ShareAllLoD
(
const
std
::
string
&
in
,
const
std
::
string
&
out
)
const
override
{
auto
in_it
=
ctx_
.
inputs
.
find
(
in
);
auto
out_it
=
ctx_
.
outputs
.
find
(
out
);
PADDLE_ENFORCE_NE
(
in_it
,
ctx_
.
inputs
.
end
(),
platform
::
errors
::
NotFound
(
"Input [%s] found error in Op [%s]"
,
in
,
op_
.
Type
()));
PADDLE_ENFORCE_NE
(
out_it
,
ctx_
.
outputs
.
end
(),
platform
::
errors
::
NotFound
(
"Output [%s] found error in Op [%s]"
,
out
,
op_
.
Type
()));
auto
&
in_var_list
=
in_it
->
second
;
auto
&
out_var_list
=
out_it
->
second
;
PADDLE_ENFORCE_EQ
(
in_var_list
.
size
(),
out_var_list
.
size
(),
platform
::
errors
::
PreconditionNotMet
(
"Op [%s]: Input var size should be equal with output var size"
,
op_
.
Type
()));
auto
&
out_var_names
=
op_
.
Outputs
(
out
);
for
(
size_t
i
=
0
;
i
<
in_var_list
.
size
();
++
i
)
{
if
(
out_var_names
[
i
]
==
framework
::
kEmptyVarName
)
{
continue
;
}
Variable
*
in_var
=
in_var_list
[
i
];
if
(
!
in_var
->
IsType
<
LoDTensor
>
())
return
;
Variable
*
out_var
=
out_var_list
[
i
];
PADDLE_ENFORCE_EQ
(
out_var
->
IsType
<
LoDTensor
>
(),
true
,
platform
::
errors
::
PreconditionNotMet
(
"The %d-th output of Output(%s) must be LoDTensor."
,
i
,
out_var_names
[
i
]));
auto
&
in_tensor
=
in_var
->
Get
<
LoDTensor
>
();
auto
*
out_tensor
=
out_var
->
GetMutable
<
LoDTensor
>
();
out_tensor
->
set_lod
(
in_tensor
.
lod
());
#ifdef PADDLE_WITH_MKLDNN
if
(
in_tensor
.
layout
()
!=
DataLayout
::
kMKLDNN
)
#endif
out_tensor
->
set_layout
(
in_tensor
.
layout
());
}
}
void
ShareLoD
(
const
std
::
string
&
in
,
const
std
::
string
&
out
,
size_t
i
=
0
,
size_t
j
=
0
)
const
override
{
auto
in_it
=
ctx_
.
inputs
.
find
(
in
);
PADDLE_ENFORCE_NE
(
in_it
,
ctx_
.
inputs
.
end
(),
platform
::
errors
::
NotFound
(
"Input %s does not exist."
,
in
));
PADDLE_ENFORCE_LT
(
i
,
in_it
->
second
.
size
(),
platform
::
errors
::
InvalidArgument
(
"The index of input dimension is out of range, "
"excepted index less than %zu, but received %zu."
,
in_it
->
second
.
size
(),
i
));
Variable
*
in_var
=
in_it
->
second
.
at
(
i
);
if
(
!
in_var
->
IsType
<
LoDTensor
>
())
return
;
auto
out_it
=
ctx_
.
outputs
.
find
(
out
);
PADDLE_ENFORCE_NE
(
out_it
,
ctx_
.
outputs
.
end
(),
platform
::
errors
::
NotFound
(
"Output %s does not exist."
,
out
));
PADDLE_ENFORCE_LT
(
j
,
out_it
->
second
.
size
(),
platform
::
errors
::
InvalidArgument
(
"The index of output dimension is out of range, "
"excepted index less than %zu, but received %zu."
,
out_it
->
second
.
size
(),
j
));
Variable
*
out_var
=
out_it
->
second
.
at
(
j
);
PADDLE_ENFORCE_EQ
(
out_var
->
IsType
<
LoDTensor
>
(),
true
,
platform
::
errors
::
InvalidArgument
(
"The %zu-th output of Output(%s) must be LoDTensor."
,
j
,
out
));
auto
&
in_tensor
=
in_var
->
Get
<
LoDTensor
>
();
auto
*
out_tensor
=
out_var
->
GetMutable
<
LoDTensor
>
();
out_tensor
->
set_lod
(
in_tensor
.
lod
());
// TODO(dzhwinter) : reuse ShareLoD in most operators.
// Need to call ShareLayout explicitly in sequence related ops.
// Shall we have a better method to shared info between in/out Tensor?
#ifdef PADDLE_WITH_MKLDNN
// Fix me: ugly workaround below
// Correct solution:
// set_layout() should NOT be called here (i.e. ShareLoD). Instead,
// layout of output tensor should be set "manually" in Compute()
// of each OPKernel. The reason layout should NOT be shared between
// input and output "automatically" (now by InferShape()->ShareLoD())
// is that layout transform may occur after InferShape().
// Workaround:
// Skip set_layout() when input layout is kMKLDNN
// This is to avoid kMKLDNN is populated wrongly into a non-MKLDNN
// OPKernel. In all MKLDNN OPkernel, set_layout(kMKLDNN) should be called
// in Compute()
if
(
in_tensor
.
layout
()
!=
DataLayout
::
kMKLDNN
)
#endif
out_tensor
->
set_layout
(
in_tensor
.
layout
());
}
int32_t
GetLoDLevel
(
const
std
::
string
&
in
,
size_t
i
=
0
)
const
override
{
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"GetLoDLevel is only used in compile time. The calculation of "
"output's actual lod is different among operators so that should be "
"set in the runtime kernel."
));
}
void
SetLoDLevel
(
const
std
::
string
&
out
,
int32_t
lod_level
,
size_t
j
=
0
)
const
override
{
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"SetLoDLevel is only used in compile time. The calculation of "
"output's actual lod is different among operators so that should be "
"set in the runtime kernel."
));
}
bool
IsRuntime
()
const
override
{
return
true
;
}
// TODO(paddle-dev): Can this be template?
std
::
vector
<
InferShapeVarPtr
>
GetInputVarPtrs
(
const
std
::
string
&
name
)
override
{
const
std
::
vector
<
Variable
*>&
vars
=
InputVars
(
name
);
std
::
vector
<
InferShapeVarPtr
>
res
;
res
.
reserve
(
vars
.
size
());
res
.
insert
(
res
.
begin
(),
vars
.
begin
(),
vars
.
end
());
return
res
;
}
std
::
vector
<
InferShapeVarPtr
>
GetOutputVarPtrs
(
const
std
::
string
&
name
)
override
{
const
std
::
vector
<
Variable
*>&
vars
=
OutputVars
(
name
);
std
::
vector
<
InferShapeVarPtr
>
res
;
res
.
reserve
(
vars
.
size
());
res
.
insert
(
res
.
begin
(),
vars
.
begin
(),
vars
.
end
());
return
res
;
}
DDim
GetInputDim
(
const
std
::
string
&
name
)
const
override
{
const
std
::
vector
<
Variable
*>&
vars
=
InputVars
(
name
);
PADDLE_ENFORCE_EQ
(
vars
.
size
(),
1UL
,
platform
::
errors
::
InvalidArgument
(
"Input(%s) should hold one element, but now it holds %zu elements."
,
name
,
vars
.
size
()));
return
this
->
GetDim
(
vars
[
0
]);
}
std
::
vector
<
DDim
>
GetInputsDim
(
const
std
::
string
&
name
)
const
override
{
const
std
::
vector
<
Variable
*>&
vars
=
InputVars
(
name
);
return
GetDims
(
vars
);
}
std
::
vector
<
proto
::
VarType
::
Type
>
GetInputsVarType
(
const
std
::
string
&
name
)
const
override
{
return
GetVarTypes
(
InputVars
(
name
));
}
std
::
vector
<
proto
::
VarType
::
Type
>
GetOutputsVarType
(
const
std
::
string
&
name
)
const
override
{
return
GetVarTypes
(
OutputVars
(
name
));
}
void
SetOutputDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
)
override
{
auto
&
vars
=
OutputVars
(
name
);
PADDLE_ENFORCE_EQ
(
vars
.
size
(),
1UL
,
platform
::
errors
::
InvalidArgument
(
"Output(%s) should hold one element, "
"but now it holds %zu elements."
,
name
,
vars
.
size
()));
SetDim
(
vars
[
0
],
dim
);
}
void
SetOutputsDim
(
const
std
::
string
&
name
,
const
std
::
vector
<
DDim
>&
dims
)
override
{
auto
&
vars
=
OutputVars
(
name
);
SetDims
(
vars
,
dims
);
}
protected:
DDim
GetDim
(
Variable
*
var
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
var
,
platform
::
errors
::
InvalidArgument
(
"Input variable is nullptr."
));
if
(
var
->
IsType
<
LoDTensor
>
())
{
return
var
->
Get
<
LoDTensor
>
().
dims
();
}
else
if
(
var
->
IsType
<
SelectedRows
>
())
{
return
var
->
Get
<
SelectedRows
>
().
GetCompleteDims
();
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Only LoDTensor or SelectedRows support 'GetDim', but input "
"Variable's type is %s."
,
ToTypeName
(
var
->
Type
())));
}
}
std
::
vector
<
DDim
>
GetDims
(
const
std
::
vector
<
Variable
*>&
vars
)
const
{
std
::
vector
<
DDim
>
ret
;
ret
.
reserve
(
vars
.
size
());
std
::
transform
(
vars
.
begin
(),
vars
.
end
(),
std
::
back_inserter
(
ret
),
[
this
](
Variable
*
var
)
{
return
this
->
GetDim
(
var
);
});
return
ret
;
}
std
::
vector
<
DDim
>
GetRepeatedDims
(
const
std
::
string
&
name
)
const
override
{
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"GetRepeatedDims method only ban be used in compile time."
));
}
void
SetDim
(
Variable
*
var
,
const
DDim
&
dim
)
{
if
(
var
->
IsType
<
LoDTensor
>
())
{
var
->
GetMutable
<
LoDTensor
>
()
->
Resize
(
dim
);
}
else
if
(
var
->
IsType
<
SelectedRows
>
())
{
var
->
GetMutable
<
SelectedRows
>
()
->
set_height
(
dim
[
0
]);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Variable type error, expect LoDTensor or SelectedRows, but received "
"(%s)."
,
ToTypeName
(
var
->
Type
())));
}
}
void
SetDims
(
const
std
::
vector
<
Variable
*>&
vars
,
const
std
::
vector
<
DDim
>&
dims
)
{
size_t
length
=
vars
.
size
();
PADDLE_ENFORCE_EQ
(
length
,
dims
.
size
(),
platform
::
errors
::
InvalidArgument
(
"The number of input variables do not match the "
"number of input dimensions, the number of variables "
"is %zu, the number of dimensions is %zu."
,
length
,
dims
.
size
()));
for
(
size_t
i
=
0
;
i
<
length
;
++
i
)
{
if
(
vars
[
i
]
==
nullptr
)
{
continue
;
}
SetDim
(
vars
[
i
],
dims
[
i
]);
}
}
void
SetRepeatedDims
(
const
std
::
string
&
name
,
const
std
::
vector
<
DDim
>&
dims
)
override
{
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"SetRepeatedDims method only can be used in compile time."
));
}
std
::
vector
<
proto
::
VarType
::
Type
>
GetVarTypes
(
const
std
::
vector
<
Variable
*>&
vars
)
const
{
std
::
vector
<
proto
::
VarType
::
Type
>
retv
;
retv
.
resize
(
vars
.
size
());
std
::
transform
(
vars
.
begin
(),
vars
.
end
(),
retv
.
begin
(),
std
::
bind
(
std
::
mem_fn
(
&
RuntimeInferShapeContext
::
GetVarType
),
this
,
std
::
placeholders
::
_1
));
return
retv
;
}
proto
::
VarType
::
Type
GetVarType
(
Variable
*
var
)
const
{
return
ToVarType
(
var
->
Type
());
}
private:
const
std
::
vector
<
Variable
*>&
InputVars
(
const
std
::
string
&
name
)
const
{
auto
it
=
ctx_
.
inputs
.
find
(
name
);
PADDLE_ENFORCE_NE
(
it
,
ctx_
.
inputs
.
end
(),
platform
::
errors
::
NotFound
(
"Operator (%s) does not have the input (%s)."
,
op_
.
Type
(),
name
));
return
it
->
second
;
}
const
std
::
vector
<
Variable
*>&
OutputVars
(
const
std
::
string
&
name
)
const
{
auto
it
=
ctx_
.
outputs
.
find
(
name
);
PADDLE_ENFORCE_NE
(
it
,
ctx_
.
outputs
.
end
(),
platform
::
errors
::
NotFound
(
"Operator (%s) does not have the outputs (%s)."
,
op_
.
Type
(),
name
));
return
it
->
second
;
}
const
OperatorBase
&
op_
;
const
RuntimeContext
&
ctx_
;
};
namespace
interpretercore
{
namespace
interpretercore
{
std
::
string
get_memcpy_type
(
const
platform
::
Place
&
src_place
,
std
::
string
get_memcpy_type
(
const
platform
::
Place
&
src_place
,
...
...
paddle/fluid/framework/new_executor/new_executor_defs.h
浏览文件 @
16e40513
...
@@ -34,6 +34,438 @@ using OpKernelComputeFunc = std::function<void(const ExecutionContext&)>;
...
@@ -34,6 +34,438 @@ using OpKernelComputeFunc = std::function<void(const ExecutionContext&)>;
using
OpKernelMap
=
using
OpKernelMap
=
std
::
unordered_map
<
OpKernelType
,
OpKernelComputeFunc
,
OpKernelType
::
Hash
>
;
std
::
unordered_map
<
OpKernelType
,
OpKernelComputeFunc
,
OpKernelType
::
Hash
>
;
class
InterpretercoreInferShapeContext
:
public
InferShapeContext
{
public:
InterpretercoreInferShapeContext
(
const
OperatorBase
&
op
,
const
RuntimeContext
&
ctx
)
:
op_
(
op
),
ctx_
(
ctx
),
can_skip_lod_
(
false
)
{}
bool
HasInput
(
const
std
::
string
&
name
)
const
override
{
// has only one input
const
auto
&
ins
=
ctx_
.
inputs
;
auto
it
=
ins
.
find
(
name
);
if
(
it
==
ins
.
end
())
{
return
false
;
}
const
auto
&
in
=
it
->
second
;
if
(
in
.
size
()
==
0
)
return
false
;
PADDLE_ENFORCE_EQ
(
in
.
size
(),
1UL
,
platform
::
errors
::
InvalidArgument
(
"Input %s should not contain more than one inputs."
,
name
));
return
in
[
0
]
!=
nullptr
;
}
bool
HasOutput
(
const
std
::
string
&
name
)
const
override
{
// has only one output
const
auto
&
outs
=
ctx_
.
outputs
;
auto
it
=
outs
.
find
(
name
);
if
(
it
==
outs
.
end
())
{
return
false
;
}
const
auto
&
out
=
it
->
second
;
if
(
out
.
size
()
==
0
)
{
return
false
;
}
PADDLE_ENFORCE_EQ
(
out
.
size
(),
1UL
,
platform
::
errors
::
InvalidArgument
(
"Output %s should not contain more than one outputs."
,
name
));
return
out
[
0
]
!=
nullptr
;
}
bool
HasInputs
(
const
std
::
string
&
name
)
const
override
{
const
auto
&
ins
=
ctx_
.
inputs
;
auto
it
=
ins
.
find
(
name
);
if
(
it
==
ins
.
end
()
||
it
->
second
.
empty
())
{
return
false
;
}
for
(
auto
&
input
:
it
->
second
)
{
if
(
input
==
nullptr
)
{
return
false
;
}
}
return
true
;
}
bool
HasOutputs
(
const
std
::
string
&
name
)
const
override
{
const
auto
&
outs
=
ctx_
.
outputs
;
auto
it
=
outs
.
find
(
name
);
if
(
it
==
outs
.
end
()
||
it
->
second
.
empty
())
{
return
false
;
}
for
(
auto
&
output
:
it
->
second
)
{
if
(
output
==
nullptr
)
{
return
false
;
}
}
return
true
;
}
AttrReader
Attrs
()
const
override
{
return
AttrReader
(
op_
.
Attrs
());
}
std
::
vector
<
std
::
string
>
Inputs
(
const
std
::
string
&
name
)
const
override
{
return
op_
.
Inputs
(
name
);
}
std
::
vector
<
std
::
string
>
Outputs
(
const
std
::
string
&
name
)
const
override
{
return
op_
.
Outputs
(
name
);
}
std
::
string
GetInputNameByIdx
(
size_t
idx
)
const
override
{
auto
&
op_proto
=
paddle
::
framework
::
OpInfoMap
::
Instance
().
Get
(
op_
.
Type
()).
proto_
;
PADDLE_ENFORCE_LT
(
idx
,
op_proto
->
inputs
().
size
(),
platform
::
errors
::
OutOfRange
(
"The index should be less than the size of inputs of "
"operator %s, but got index is %d and size is %d"
,
op_
.
Type
(),
idx
,
op_proto
->
inputs
().
size
()));
return
op_proto
->
inputs
()[
idx
].
name
();
}
std
::
string
GetOutputNameByIdx
(
size_t
idx
)
const
override
{
auto
&
op_proto
=
paddle
::
framework
::
OpInfoMap
::
Instance
().
Get
(
op_
.
Type
()).
proto_
;
PADDLE_ENFORCE_LT
(
idx
,
op_proto
->
outputs
().
size
(),
platform
::
errors
::
OutOfRange
(
"The index should be less than the size of outputs of "
"operator %s, but got index is %d and size is %d"
,
op_
.
Type
(),
idx
,
op_proto
->
outputs
().
size
()));
return
op_proto
->
outputs
()[
idx
].
name
();
}
void
ShareDim
(
const
std
::
string
&
in
,
const
std
::
string
&
out
,
size_t
i
=
0
,
size_t
j
=
0
)
override
{
auto
in_it
=
ctx_
.
inputs
.
find
(
in
);
auto
out_it
=
ctx_
.
outputs
.
find
(
out
);
PADDLE_ENFORCE_NE
(
in_it
,
ctx_
.
inputs
.
end
(),
platform
::
errors
::
NotFound
(
"Input %s does not exist."
,
in
));
PADDLE_ENFORCE_NE
(
out_it
,
ctx_
.
outputs
.
end
(),
platform
::
errors
::
NotFound
(
"Output %s does not exist."
,
out
));
PADDLE_ENFORCE_LT
(
i
,
in_it
->
second
.
size
(),
platform
::
errors
::
InvalidArgument
(
"The index of input dimension is out of range, "
"excepted index less than %zu, but received %zu."
,
in_it
->
second
.
size
(),
i
));
PADDLE_ENFORCE_LT
(
j
,
out_it
->
second
.
size
(),
platform
::
errors
::
InvalidArgument
(
"The index of output dimension is out of range, "
"excepted index less than %zu, but received %zu."
,
out_it
->
second
.
size
(),
j
));
Variable
*
in_var
=
in_it
->
second
[
i
];
Variable
*
out_var
=
out_it
->
second
[
j
];
PADDLE_ENFORCE_EQ
(
in_var
->
Type
(),
out_var
->
Type
(),
platform
::
errors
::
InvalidArgument
(
"The type of input (%s) and output (%s) are inconsistent."
,
in
,
out
));
if
(
in_var
->
IsType
<
framework
::
SelectedRows
>
())
{
auto
&
in_sele_rows
=
in_var
->
Get
<
framework
::
SelectedRows
>
();
auto
out_sele_rows
=
out_var
->
GetMutable
<
framework
::
SelectedRows
>
();
out_sele_rows
->
mutable_value
()
->
Resize
(
in_sele_rows
.
value
().
dims
());
out_sele_rows
->
set_rows
(
in_sele_rows
.
rows
());
out_sele_rows
->
set_height
(
in_sele_rows
.
height
());
}
else
if
(
in_var
->
IsType
<
framework
::
LoDTensor
>
())
{
auto
&
in_lod_tensor
=
in_var
->
Get
<
framework
::
LoDTensor
>
();
auto
*
out_lod_tensor
=
out_var
->
GetMutable
<
framework
::
LoDTensor
>
();
out_lod_tensor
->
Resize
(
in_lod_tensor
.
dims
());
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Currently, the input type of ShareDim only can be LoDTensor "
"or SelectedRows."
));
}
}
void
ShareAllLoD
(
const
std
::
string
&
in
,
const
std
::
string
&
out
)
const
override
{
auto
in_it
=
ctx_
.
inputs
.
find
(
in
);
auto
out_it
=
ctx_
.
outputs
.
find
(
out
);
PADDLE_ENFORCE_NE
(
in_it
,
ctx_
.
inputs
.
end
(),
platform
::
errors
::
NotFound
(
"Input [%s] found error in Op [%s]"
,
in
,
op_
.
Type
()));
PADDLE_ENFORCE_NE
(
out_it
,
ctx_
.
outputs
.
end
(),
platform
::
errors
::
NotFound
(
"Output [%s] found error in Op [%s]"
,
out
,
op_
.
Type
()));
auto
&
in_var_list
=
in_it
->
second
;
auto
&
out_var_list
=
out_it
->
second
;
PADDLE_ENFORCE_EQ
(
in_var_list
.
size
(),
out_var_list
.
size
(),
platform
::
errors
::
PreconditionNotMet
(
"Op [%s]: Input var size should be equal with output var size"
,
op_
.
Type
()));
auto
&
out_var_names
=
op_
.
Outputs
(
out
);
for
(
size_t
i
=
0
;
i
<
in_var_list
.
size
();
++
i
)
{
if
(
out_var_names
[
i
]
==
framework
::
kEmptyVarName
)
{
continue
;
}
Variable
*
in_var
=
in_var_list
[
i
];
if
(
!
in_var
->
IsType
<
LoDTensor
>
())
return
;
Variable
*
out_var
=
out_var_list
[
i
];
PADDLE_ENFORCE_EQ
(
out_var
->
IsType
<
LoDTensor
>
(),
true
,
platform
::
errors
::
PreconditionNotMet
(
"The %d-th output of Output(%s) must be LoDTensor."
,
i
,
out_var_names
[
i
]));
auto
&
in_tensor
=
in_var
->
Get
<
LoDTensor
>
();
auto
*
out_tensor
=
out_var
->
GetMutable
<
LoDTensor
>
();
out_tensor
->
set_lod
(
in_tensor
.
lod
());
#ifdef PADDLE_WITH_MKLDNN
if
(
in_tensor
.
layout
()
!=
DataLayout
::
kMKLDNN
)
#endif
out_tensor
->
set_layout
(
in_tensor
.
layout
());
}
}
void
ShareLoD
(
const
std
::
string
&
in
,
const
std
::
string
&
out
,
size_t
i
=
0
,
size_t
j
=
0
)
const
override
{
if
(
can_skip_lod_
)
{
return
;
}
auto
in_it
=
ctx_
.
inputs
.
find
(
in
);
auto
out_it
=
ctx_
.
outputs
.
find
(
out
);
PADDLE_ENFORCE_NE
(
in_it
,
ctx_
.
inputs
.
end
(),
platform
::
errors
::
NotFound
(
"Input %s does not exist."
,
in
));
PADDLE_ENFORCE_NE
(
out_it
,
ctx_
.
outputs
.
end
(),
platform
::
errors
::
NotFound
(
"Output %s does not exist."
,
out
));
PADDLE_ENFORCE_LT
(
i
,
in_it
->
second
.
size
(),
platform
::
errors
::
InvalidArgument
(
"The index of input dimension is out of range, "
"excepted index less than %zu, but received %zu."
,
in_it
->
second
.
size
(),
i
));
PADDLE_ENFORCE_LT
(
j
,
out_it
->
second
.
size
(),
platform
::
errors
::
InvalidArgument
(
"The index of output dimension is out of range, "
"excepted index less than %zu, but received %zu."
,
out_it
->
second
.
size
(),
j
));
Variable
*
in_var
=
in_it
->
second
.
at
(
i
);
if
(
!
in_var
->
IsType
<
LoDTensor
>
())
return
;
Variable
*
out_var
=
out_it
->
second
.
at
(
j
);
PADDLE_ENFORCE_EQ
(
out_var
->
IsType
<
LoDTensor
>
(),
true
,
platform
::
errors
::
InvalidArgument
(
"The %zu-th output of Output(%s) must be LoDTensor."
,
j
,
out
));
auto
&
in_tensor
=
in_var
->
Get
<
LoDTensor
>
();
auto
*
out_tensor
=
out_var
->
GetMutable
<
LoDTensor
>
();
out_tensor
->
set_lod
(
in_tensor
.
lod
());
// TODO(dzhwinter) : reuse ShareLoD in most operators.
// Need to call ShareLayout explicitly in sequence related ops.
// Shall we have a better method to shared info between in/out Tensor?
#ifdef PADDLE_WITH_MKLDNN
// Fix me: ugly workaround below
// Correct solution:
// set_layout() should NOT be called here (i.e. ShareLoD). Instead,
// layout of output tensor should be set "manually" in Compute()
// of each OPKernel. The reason layout should NOT be shared between
// input and output "automatically" (now by InferShape()->ShareLoD())
// is that layout transform may occur after InferShape().
// Workaround:
// Skip set_layout() when input layout is kMKLDNN
// This is to avoid kMKLDNN is populated wrongly into a non-MKLDNN
// OPKernel. In all MKLDNN OPkernel, set_layout(kMKLDNN) should be called
// in Compute()
if
(
in_tensor
.
layout
()
!=
DataLayout
::
kMKLDNN
)
#endif
out_tensor
->
set_layout
(
in_tensor
.
layout
());
}
int32_t
GetLoDLevel
(
const
std
::
string
&
in
,
size_t
i
=
0
)
const
override
{
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"GetLoDLevel is only used in compile time. The calculation of "
"output's actual lod is different among operators so that should be "
"set in the runtime kernel."
));
}
void
SetLoDLevel
(
const
std
::
string
&
out
,
int32_t
lod_level
,
size_t
j
=
0
)
const
override
{
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"SetLoDLevel is only used in compile time. The calculation of "
"output's actual lod is different among operators so that should be "
"set in the runtime kernel."
));
}
bool
IsRuntime
()
const
override
{
return
true
;
}
// TODO(paddle-dev): Can this be template?
std
::
vector
<
InferShapeVarPtr
>
GetInputVarPtrs
(
const
std
::
string
&
name
)
override
{
const
std
::
vector
<
Variable
*>&
vars
=
InputVars
(
name
);
std
::
vector
<
InferShapeVarPtr
>
res
;
res
.
reserve
(
vars
.
size
());
res
.
insert
(
res
.
begin
(),
vars
.
begin
(),
vars
.
end
());
return
res
;
}
std
::
vector
<
InferShapeVarPtr
>
GetOutputVarPtrs
(
const
std
::
string
&
name
)
override
{
const
std
::
vector
<
Variable
*>&
vars
=
OutputVars
(
name
);
std
::
vector
<
InferShapeVarPtr
>
res
;
res
.
reserve
(
vars
.
size
());
res
.
insert
(
res
.
begin
(),
vars
.
begin
(),
vars
.
end
());
return
res
;
}
DDim
GetInputDim
(
const
std
::
string
&
name
)
const
override
{
const
std
::
vector
<
Variable
*>&
vars
=
InputVars
(
name
);
PADDLE_ENFORCE_EQ
(
vars
.
size
(),
1UL
,
platform
::
errors
::
InvalidArgument
(
"Input(%s) should hold one element, but now it holds %zu elements."
,
name
,
vars
.
size
()));
return
this
->
GetDim
(
vars
[
0
]);
}
std
::
vector
<
DDim
>
GetInputsDim
(
const
std
::
string
&
name
)
const
override
{
const
std
::
vector
<
Variable
*>&
vars
=
InputVars
(
name
);
return
GetDims
(
vars
);
}
std
::
vector
<
proto
::
VarType
::
Type
>
GetInputsVarType
(
const
std
::
string
&
name
)
const
override
{
return
GetVarTypes
(
InputVars
(
name
));
}
std
::
vector
<
proto
::
VarType
::
Type
>
GetOutputsVarType
(
const
std
::
string
&
name
)
const
override
{
return
GetVarTypes
(
OutputVars
(
name
));
}
void
SetOutputDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
)
override
{
auto
&
vars
=
OutputVars
(
name
);
PADDLE_ENFORCE_EQ
(
vars
.
size
(),
1UL
,
platform
::
errors
::
InvalidArgument
(
"Output(%s) should hold one element, "
"but now it holds %zu elements."
,
name
,
vars
.
size
()));
SetDim
(
vars
[
0
],
dim
);
}
void
SetOutputsDim
(
const
std
::
string
&
name
,
const
std
::
vector
<
DDim
>&
dims
)
override
{
auto
&
vars
=
OutputVars
(
name
);
SetDims
(
vars
,
dims
);
}
void
SetSkipLoD
(
bool
skip
)
{
can_skip_lod_
=
skip
;
}
protected:
DDim
GetDim
(
Variable
*
var
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
var
,
platform
::
errors
::
InvalidArgument
(
"Input variable is nullptr."
));
if
(
var
->
IsType
<
LoDTensor
>
())
{
return
var
->
Get
<
LoDTensor
>
().
dims
();
}
else
if
(
var
->
IsType
<
SelectedRows
>
())
{
return
var
->
Get
<
SelectedRows
>
().
GetCompleteDims
();
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Only LoDTensor or SelectedRows support 'GetDim', but input "
"Variable's type is %s."
,
ToTypeName
(
var
->
Type
())));
}
}
std
::
vector
<
DDim
>
GetDims
(
const
std
::
vector
<
Variable
*>&
vars
)
const
{
std
::
vector
<
DDim
>
ret
;
ret
.
reserve
(
vars
.
size
());
std
::
transform
(
vars
.
begin
(),
vars
.
end
(),
std
::
back_inserter
(
ret
),
[
this
](
Variable
*
var
)
{
return
this
->
GetDim
(
var
);
});
return
ret
;
}
std
::
vector
<
DDim
>
GetRepeatedDims
(
const
std
::
string
&
name
)
const
override
{
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"GetRepeatedDims method only ban be used in compile time."
));
}
void
SetDim
(
Variable
*
var
,
const
DDim
&
dim
)
{
if
(
var
->
IsType
<
LoDTensor
>
())
{
var
->
GetMutable
<
LoDTensor
>
()
->
Resize
(
dim
);
}
else
if
(
var
->
IsType
<
SelectedRows
>
())
{
var
->
GetMutable
<
SelectedRows
>
()
->
set_height
(
dim
[
0
]);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Variable type error, expect LoDTensor or SelectedRows, but received "
"(%s)."
,
ToTypeName
(
var
->
Type
())));
}
}
void
SetDims
(
const
std
::
vector
<
Variable
*>&
vars
,
const
std
::
vector
<
DDim
>&
dims
)
{
size_t
length
=
vars
.
size
();
PADDLE_ENFORCE_EQ
(
length
,
dims
.
size
(),
platform
::
errors
::
InvalidArgument
(
"The number of input variables do not match the "
"number of input dimensions, the number of variables "
"is %zu, the number of dimensions is %zu."
,
length
,
dims
.
size
()));
for
(
size_t
i
=
0
;
i
<
length
;
++
i
)
{
if
(
vars
[
i
]
==
nullptr
)
{
continue
;
}
SetDim
(
vars
[
i
],
dims
[
i
]);
}
}
void
SetRepeatedDims
(
const
std
::
string
&
name
,
const
std
::
vector
<
DDim
>&
dims
)
override
{
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"SetRepeatedDims method only can be used in compile time."
));
}
std
::
vector
<
proto
::
VarType
::
Type
>
GetVarTypes
(
const
std
::
vector
<
Variable
*>&
vars
)
const
{
std
::
vector
<
proto
::
VarType
::
Type
>
retv
;
retv
.
resize
(
vars
.
size
());
std
::
transform
(
vars
.
begin
(),
vars
.
end
(),
retv
.
begin
(),
std
::
bind
(
std
::
mem_fn
(
&
InterpretercoreInferShapeContext
::
GetVarType
),
this
,
std
::
placeholders
::
_1
));
return
retv
;
}
proto
::
VarType
::
Type
GetVarType
(
Variable
*
var
)
const
{
return
ToVarType
(
var
->
Type
());
}
private:
const
std
::
vector
<
Variable
*>&
InputVars
(
const
std
::
string
&
name
)
const
{
auto
it
=
ctx_
.
inputs
.
find
(
name
);
PADDLE_ENFORCE_NE
(
it
,
ctx_
.
inputs
.
end
(),
platform
::
errors
::
NotFound
(
"Operator (%s) does not have the input (%s)."
,
op_
.
Type
(),
name
));
return
it
->
second
;
}
const
std
::
vector
<
Variable
*>&
OutputVars
(
const
std
::
string
&
name
)
const
{
auto
it
=
ctx_
.
outputs
.
find
(
name
);
PADDLE_ENFORCE_NE
(
it
,
ctx_
.
outputs
.
end
(),
platform
::
errors
::
NotFound
(
"Operator (%s) does not have the outputs (%s)."
,
op_
.
Type
(),
name
));
return
it
->
second
;
}
const
OperatorBase
&
op_
;
const
RuntimeContext
&
ctx_
;
bool
can_skip_lod_
;
};
struct
OpKernelFunc
{
struct
OpKernelFunc
{
OpKernelComputeFunc
compute_func_
;
OpKernelComputeFunc
compute_func_
;
OperatorBase
*
operator_base_
;
OperatorBase
*
operator_base_
;
...
@@ -75,12 +507,10 @@ struct InstructionInfo {
...
@@ -75,12 +507,10 @@ struct InstructionInfo {
std
::
vector
<
size_t
>
dependecy_count_
;
std
::
vector
<
size_t
>
dependecy_count_
;
};
};
class
RuntimeInferShapeContext
;
struct
Instruction
{
struct
Instruction
{
OpKernelFunc
kernel_func_
;
OpKernelFunc
kernel_func_
;
std
::
shared_ptr
<
RuntimeContext
>
runtime_ctx_
;
std
::
shared_ptr
<
RuntimeContext
>
runtime_ctx_
;
std
::
shared_ptr
<
Runtim
eInferShapeContext
>
infershape_ctx_
;
std
::
shared_ptr
<
Interpretercor
eInferShapeContext
>
infershape_ctx_
;
std
::
shared_ptr
<
ExecutionContext
>
execution_ctx_
;
std
::
shared_ptr
<
ExecutionContext
>
execution_ctx_
;
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
input_index_
;
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
input_index_
;
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
output_index_
;
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
output_index_
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录