Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
4b6d2f5f
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4b6d2f5f
编写于
7月 13, 2023
作者:
H
hong
提交者:
GitHub
7月 13, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[NewIR]new ir support builtin slice op (#55381)
* new ir support builtin slice op * fix phi kernel adaptor bug
上级
0dad9458
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
225 addition
and
31 deletion
+225
-31
paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc
...id/framework/new_executor/interpreter/interpreter_util.cc
+3
-1
paddle/fluid/framework/tensor_ref_array.h
paddle/fluid/framework/tensor_ref_array.h
+3
-3
paddle/fluid/framework/type_info.cc
paddle/fluid/framework/type_info.cc
+1
-1
paddle/fluid/framework/var_type_traits.h
paddle/fluid/framework/var_type_traits.h
+1
-1
paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_adaptor.h
paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_adaptor.h
+2
-0
paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc
paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc
+56
-12
paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h
paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h
+27
-12
paddle/fluid/ir_adaptor/translator/op_translator.cc
paddle/fluid/ir_adaptor/translator/op_translator.cc
+99
-0
paddle/phi/api/yaml/op_compat.yaml
paddle/phi/api/yaml/op_compat.yaml
+11
-1
test/ir/new_ir/test_standalone_new_ir.py
test/ir/new_ir/test_standalone_new_ir.py
+22
-0
未找到文件。
paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc
浏览文件 @
4b6d2f5f
...
...
@@ -957,7 +957,7 @@ void BuildOpFuncList(
if
(
op_name
==
"builtin.combine"
||
op_name
==
"pd.feed"
||
op_name
==
"builtin.set_parameter"
||
op_name
==
"builtin.get_parameter"
)
{
op_name
==
"builtin.get_parameter"
||
op_name
==
"builtin.slice"
)
{
VLOG
(
6
)
<<
"skip process "
<<
op_name
;
continue
;
}
...
...
@@ -977,6 +977,7 @@ void BuildOpFuncList(
phi
::
MetaTensor
,
phi
::
MetaTensor
,
paddle
::
small_vector
<
phi
::
MetaTensor
,
phi
::
kInputSmallVectorSize
>
,
paddle
::
small_vector
<
phi
::
MetaTensor
,
phi
::
kInputSmallVectorSize
>
,
false
>
((
*
it
),
value_2_name_map
,
scope
,
...
...
@@ -1003,6 +1004,7 @@ void BuildOpFuncList(
const
phi
::
TensorBase
*
,
phi
::
TensorBase
*
,
paddle
::
small_vector
<
const
phi
::
TensorBase
*>
,
paddle
::
small_vector
<
phi
::
TensorBase
*>
,
true
>
((
*
it
),
value_2_name_map
,
scope
,
...
...
paddle/fluid/framework/tensor_ref_array.h
浏览文件 @
4b6d2f5f
...
...
@@ -20,11 +20,11 @@ namespace paddle {
namespace
framework
{
template
<
>
struct
PhiVectorType
<
const
phi
::
DenseTensor
*>
{
const
char
*
type_name
=
"
PhiTensor
RefArray"
;
struct
PhiVectorType
<
const
framework
::
Variable
*>
{
const
char
*
type_name
=
"
Variable
RefArray"
;
};
using
TensorRefArray
=
PhiVector
<
const
phi
::
DenseTensor
*>
;
using
VariableRefArray
=
PhiVector
<
const
framework
::
Variable
*>
;
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/type_info.cc
浏览文件 @
4b6d2f5f
...
...
@@ -41,6 +41,6 @@ template class TypeInfoTraits<phi::TensorBase, paddle::framework::FeedList>;
template
class
TypeInfoTraits
<
phi
::
TensorBase
,
egr
::
VariableCompatTensor
>;
template
class
TypeInfoTraits
<
phi
::
TensorBase
,
paddle
::
prim
::
DescTensor
>;
template
class
TypeInfoTraits
<
phi
::
TensorBase
,
paddle
::
framework
::
Tensor
RefArray
>;
paddle
::
framework
::
Variable
RefArray
>;
}
// namespace phi
paddle/fluid/framework/var_type_traits.h
浏览文件 @
4b6d2f5f
...
...
@@ -212,7 +212,7 @@ using VarTypeRegistry = detail::VarTypeRegistryImpl<
std
::
vector
<
float
>
,
std
::
vector
<
std
::
string
>
,
RawTensor
,
Tensor
RefArray
>
;
Variable
RefArray
>
;
template
<
typename
T
>
struct
VarTypeTrait
{
static_assert
(
VarTypeRegistry
::
IsRegistered
<
T
>
(),
"Must be registered type"
);
...
...
paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_adaptor.h
浏览文件 @
4b6d2f5f
...
...
@@ -87,6 +87,7 @@ class PhiKernelAdaptor {
phi
::
MetaTensor
,
phi
::
MetaTensor
,
paddle
::
small_vector
<
phi
::
MetaTensor
,
phi
::
kInputSmallVectorSize
>
,
paddle
::
small_vector
<
phi
::
MetaTensor
,
phi
::
kInputSmallVectorSize
>
,
false
>
((
*
it
),
name_map
,
scope_
,
nullptr
,
op_yaml_info_parser
,
&
ctx
);
infer_meta_impl
->
infer_meta_
(
&
ctx
);
...
...
@@ -106,6 +107,7 @@ class PhiKernelAdaptor {
const
phi
::
TensorBase
*
,
phi
::
TensorBase
*
,
paddle
::
small_vector
<
const
phi
::
TensorBase
*>
,
paddle
::
small_vector
<
phi
::
TensorBase
*>
,
true
>
(
(
*
it
),
name_map
,
scope_
,
nullptr
,
op_yaml_info_parser
,
&
kernel_ctx
);
kernel_fn
(
&
kernel_ctx
);
...
...
paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc
浏览文件 @
4b6d2f5f
...
...
@@ -43,6 +43,9 @@
namespace
ir
{
using
VariableNameMap
=
std
::
unordered_map
<
const
paddle
::
framework
::
Variable
*
,
std
::
string
>
;
paddle
::
framework
::
Variable
*
CreateVar
(
ir
::
Value
value
,
const
std
::
string
&
name
,
paddle
::
framework
::
Scope
*
scope
,
...
...
@@ -89,6 +92,7 @@ void BuildValue(ir::Value value,
paddle
::
framework
::
Scope
*
scope
,
paddle
::
framework
::
Scope
*
local_scope
,
std
::
unordered_map
<
ir
::
Value
,
std
::
string
>*
name_map
,
VariableNameMap
*
variable_name_map
,
int
&
count
)
{
// NOLINT
auto
inner_local_scope
=
local_scope
!=
nullptr
?
local_scope
:
scope
;
std
::
string
name
;
...
...
@@ -107,7 +111,7 @@ void BuildValue(ir::Value value,
}
else
if
(
value
.
type
().
isa
<
paddle
::
dialect
::
AllocatedSelectedRowsType
>
())
{
var
->
GetMutable
<
phi
::
SelectedRows
>
();
}
else
if
(
value
.
type
().
isa
<
ir
::
VectorType
>
())
{
auto
tensor_array
=
var
->
GetMutable
<
paddle
::
framework
::
Tensor
RefArray
>
();
auto
tensor_array
=
var
->
GetMutable
<
paddle
::
framework
::
Variable
RefArray
>
();
for
(
size_t
i
=
0
;
i
<
value
.
type
().
dyn_cast
<
ir
::
VectorType
>
().
size
();
i
++
)
{
PADDLE_ENFORCE
(
value
.
type
()
...
...
@@ -118,7 +122,9 @@ void BuildValue(ir::Value value,
"DenseTensorType"
));
std
::
string
name_i
=
"inner_var_"
+
std
::
to_string
(
count
++
);
auto
var_i
=
CreateVar
(
value
,
name_i
,
scope
,
inner_local_scope
);
tensor_array
->
emplace_back
(
var_i
->
GetMutable
<
phi
::
DenseTensor
>
());
var_i
->
GetMutable
<
phi
::
DenseTensor
>
();
tensor_array
->
emplace_back
(
var_i
);
variable_name_map
->
emplace
(
var_i
,
name_i
);
}
}
else
{
PADDLE_THROW
(
phi
::
errors
::
PreconditionNotMet
(
...
...
@@ -127,6 +133,7 @@ void BuildValue(ir::Value value,
}
void
HandleForSpecialOp
(
ir
::
Operation
*
op
,
const
VariableNameMap
&
variable_name_map
,
paddle
::
framework
::
Scope
*
scope
,
paddle
::
framework
::
Scope
*
local_scope
,
std
::
unordered_map
<
ir
::
Value
,
std
::
string
>*
name_map
,
...
...
@@ -180,7 +187,7 @@ void HandleForSpecialOp(ir::Operation* op,
}
auto
var
=
CreateVar
(
out_value
,
name
,
scope
,
local_scope
);
auto
tensor_array
=
var
->
GetMutable
<
paddle
::
framework
::
Tensor
RefArray
>
();
auto
tensor_array
=
var
->
GetMutable
<
paddle
::
framework
::
Variable
RefArray
>
();
// clear tensor array
tensor_array
->
clear
();
...
...
@@ -192,8 +199,7 @@ void HandleForSpecialOp(ir::Operation* op,
true
,
phi
::
errors
::
PreconditionNotMet
(
"can not found input of combine op"
));
tensor_array
->
emplace_back
(
&
(
CreateVar
(
value
,
name_map
->
at
(
value
),
scope
,
local_scope
)
->
Get
<
phi
::
DenseTensor
>
()));
CreateVar
(
value
,
name_map
->
at
(
value
),
scope
,
local_scope
));
}
}
...
...
@@ -223,6 +229,34 @@ void HandleForSpecialOp(ir::Operation* op,
auto
out_ptr
=
op
->
result
(
0
);
name_map
->
emplace
(
out_ptr
,
param_name
);
}
if
(
op_name
==
"builtin.slice"
)
{
VLOG
(
6
)
<<
"Handle for builtin.slice"
;
auto
out_value
=
op
->
result
(
0
);
auto
in_value
=
op
->
operand
(
0
);
PADDLE_ENFORCE_EQ
(
name_map
->
count
(
in_value
),
true
,
phi
::
errors
::
PreconditionNotMet
(
"input of buildin slice not in name map"
));
int
index
=
op
->
attributes
().
at
(
"index"
).
dyn_cast
<
ir
::
Int32Attribute
>
().
data
();
auto
in_var
=
scope
->
FindVar
(
name_map
->
at
(
in_value
));
auto
variable_array
=
in_var
->
Get
<
paddle
::
framework
::
VariableRefArray
>
();
PADDLE_ENFORCE_EQ
(
variable_name_map
.
count
(
variable_array
[
index
]),
true
,
phi
::
errors
::
PreconditionNotMet
(
"[%d] the variable in build slice "
"input MUST in variable name map"
,
index
));
std
::
string
var_name
=
variable_name_map
.
at
(
variable_array
[
index
]);
name_map
->
emplace
(
out_value
,
var_name
);
}
}
void
HandleForInplaceOp
(
ir
::
Operation
*
op
,
...
...
@@ -242,7 +276,7 @@ void HandleForInplaceOp(ir::Operation* op,
paddle
::
dialect
::
OpYamlInfoParser
yaml_parser
(
op_info
.
GetInterfaceImpl
<
paddle
::
dialect
::
OpYamlInfoInterface
>
()
->
get_op_info_
());
VariableNameMap
variable_name_map
;
for
(
size_t
i
=
0
;
i
<
op
->
num_results
();
++
i
)
{
ir
::
Value
value
=
op
->
result
(
i
);
std
::
string
value_name
=
yaml_parser
.
OutputNames
()[
i
];
...
...
@@ -255,7 +289,8 @@ void HandleForInplaceOp(ir::Operation* op,
<<
" (var: "
<<
var_name
<<
")"
;
name_map
->
emplace
(
value
,
var_name
);
}
else
{
BuildValue
(
value
,
scope
,
local_scope
,
name_map
,
count
);
BuildValue
(
value
,
scope
,
local_scope
,
name_map
,
&
variable_name_map
,
count
);
}
}
}
...
...
@@ -273,8 +308,11 @@ void BuildScope(const ir::Block& block,
VLOG
(
6
)
<<
"Build: scope ["
<<
scope
<<
"] inner_local_scope ["
<<
inner_local_scope
<<
"]"
;
std
::
unordered_map
<
const
paddle
::
framework
::
Variable
*
,
std
::
string
>
variable_name_map
;
// int count = name_map->size();
int
count
=
inner_local_scope
->
S
ize
();
int
count
=
name_map
->
s
ize
();
for
(
auto
it
=
block
.
begin
();
it
!=
block
.
end
();
++
it
)
{
ir
::
Operation
*
op
=
*
it
;
...
...
@@ -288,9 +326,10 @@ void BuildScope(const ir::Block& block,
if
(
op_name
==
"pd.feed"
||
op_name
==
"pd.fetch"
||
op_name
==
"builtin.combine"
||
op_name
==
"builtin.set_parameter"
||
op_name
==
"builtin.get_parameter"
)
{
VLOG
(
4
)
<<
"HandleForSpecialOp: "
<<
op_name
;
HandleForSpecialOp
(
op
,
scope
,
inner_local_scope
,
name_map
,
count
);
op_name
==
"builtin.get_parameter"
||
op_name
==
"builtin.slice"
)
{
VLOG
(
6
)
<<
"HandleForSpecialOp: "
<<
op_name
;
HandleForSpecialOp
(
op
,
variable_name_map
,
scope
,
inner_local_scope
,
name_map
,
count
);
continue
;
}
...
...
@@ -306,7 +345,12 @@ void BuildScope(const ir::Block& block,
continue
;
}
else
{
for
(
size_t
i
=
0
;
i
<
op
->
num_results
();
++
i
)
{
BuildValue
(
op
->
result
(
i
),
scope
,
local_scope
,
name_map
,
count
);
BuildValue
(
op
->
result
(
i
),
scope
,
local_scope
,
name_map
,
&
variable_name_map
,
count
);
}
}
}
...
...
paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h
浏览文件 @
4b6d2f5f
...
...
@@ -75,7 +75,8 @@ void BuildScope(const ir::Block& block,
template
<
typename
Context
,
typename
InType
,
typename
OutType
,
typename
ListType
,
typename
InListType
,
typename
OutListType
,
bool
is_kernel
>
void
BuildPhiContext
(
ir
::
Operation
*
op
,
...
...
@@ -121,11 +122,12 @@ void BuildPhiContext(
if
(
var
->
IsType
<
phi
::
DenseTensor
>
())
{
const
phi
::
TensorBase
*
tensor_in
=
&
(
var
->
Get
<
phi
::
DenseTensor
>
());
ctx
->
EmplaceBackInput
(
InType
(
tensor_in
));
}
else
if
(
var
->
IsType
<
paddle
::
framework
::
TensorRefArray
>
())
{
ListType
inputs
;
auto
&
tensor_array
=
var
->
Get
<
paddle
::
framework
::
TensorRefArray
>
();
for
(
size_t
i
=
0
;
i
<
tensor_array
.
size
();
++
i
)
{
inputs
.
emplace_back
(
InType
(
tensor_array
[
i
]));
}
else
if
(
var
->
IsType
<
paddle
::
framework
::
VariableRefArray
>
())
{
InListType
inputs
;
auto
&
variable_array
=
var
->
Get
<
paddle
::
framework
::
VariableRefArray
>
();
for
(
size_t
i
=
0
;
i
<
variable_array
.
size
();
++
i
)
{
inputs
.
emplace_back
(
InType
(
const_cast
<
phi
::
DenseTensor
*>
(
&
(
variable_array
[
i
]
->
Get
<
phi
::
DenseTensor
>
()))));
}
ctx
->
EmplaceBackInputs
(
inputs
);
}
else
{
...
...
@@ -157,18 +159,21 @@ void BuildPhiContext(
VLOG
(
6
)
<<
"ctx->EmplaceBack mutable attr: "
<<
t
<<
"
\t
"
<<
in_var_name
;
if
(
tensor_attr_type
==
"paddle::dialect::IntArrayAttribute"
)
{
if
(
ptr
.
type
().
isa
<
paddle
::
dialect
::
AllocatedDenseTensorType
>
())
{
phi
::
Attribute
r1
=
phi
::
TensorRef
(
phi
::
Attribute
attr
=
phi
::
TensorRef
(
&
(
inner_scope
->
FindVar
(
in_var_name
)
->
Get
<
phi
::
DenseTensor
>
()));
ctx
->
EmplaceBackAttr
(
r1
);
ctx
->
EmplaceBackAttr
(
attr
);
}
else
if
(
ptr
.
type
().
isa
<
ir
::
VectorType
>
())
{
auto
&
tensor_array
=
inner_scope
->
FindVar
(
in_var_name
)
->
Get
<
paddle
::
framework
::
Tensor
RefArray
>
();
->
Get
<
paddle
::
framework
::
Variable
RefArray
>
();
if
(
tensor_array
.
size
()
==
1
)
{
ctx
->
EmplaceBackAttr
(
phi
::
TensorRef
(
tensor_array
[
0
]));
phi
::
Attribute
attr
=
phi
::
TensorRef
(
&
(
tensor_array
[
0
]
->
Get
<
phi
::
DenseTensor
>
()));
ctx
->
EmplaceBackAttr
(
attr
);
}
else
{
std
::
vector
<
phi
::
TensorRef
>
vec_ref
;
for
(
size_t
i
=
0
;
i
<
tensor_array
.
size
();
++
i
)
{
vec_ref
.
emplace_back
(
phi
::
TensorRef
(
tensor_array
[
i
]));
vec_ref
.
emplace_back
(
phi
::
TensorRef
(
&
(
tensor_array
[
i
]
->
Get
<
phi
::
DenseTensor
>
())));
}
ctx
->
EmplaceBackAttr
(
vec_ref
);
}
...
...
@@ -328,8 +333,18 @@ void BuildPhiContext(
}
else
if
(
out_type
.
isa
<
paddle
::
dialect
::
AllocatedSelectedRowsType
>
())
{
ctx
->
EmplaceBackOutput
(
OutType
(
const_cast
<
phi
::
SelectedRows
*>
(
&
(
scope
->
Var
(
name
)
->
Get
<
phi
::
SelectedRows
>
()))));
}
else
if
(
out_type
.
isa
<
ir
::
VectorType
>
())
{
OutListType
outputs
;
auto
&
variable_array
=
scope
->
Var
(
name
)
->
Get
<
paddle
::
framework
::
VariableRefArray
>
();
for
(
size_t
i
=
0
;
i
<
variable_array
.
size
();
++
i
)
{
outputs
.
emplace_back
(
OutType
(
const_cast
<
phi
::
DenseTensor
*>
(
&
(
variable_array
[
i
]
->
Get
<
phi
::
DenseTensor
>
()))));
}
ctx
->
EmplaceBackOutputs
(
outputs
);
}
else
{
PADDLE_THROW
(
"not support type"
);
PADDLE_THROW
(
phi
::
errors
::
Unimplemented
(
"only support DenseTensor and vector "
));
}
if
(
output_map
!=
nullptr
)
{
...
...
paddle/fluid/ir_adaptor/translator/op_translator.cc
浏览文件 @
4b6d2f5f
...
...
@@ -955,6 +955,104 @@ struct FeedOpTranscriber : public OpTranscriber {
}
};
struct
SplitOpTranscriber
:
public
OpTranscriber
{
std
::
vector
<
ir
::
OpResult
>
GenerateOperationInput
(
ir
::
IrContext
*
ctx
,
TranslationContext
*
param_map
,
const
OpDesc
&
op_desc
,
const
std
::
string
&
normalized_op_name
,
const
OpInputInfoList
&
input_infos
,
ir
::
Program
*
program
)
override
{
// input of pslit is [Tensor x, IntArray sections, Scalar(int) axis)]
VLOG
(
10
)
<<
"[op:split][input] start"
;
std
::
vector
<
ir
::
OpResult
>
op_inputs
;
// process first input
auto
x_input_vars
=
op_desc
.
Input
(
"X"
);
IR_ENFORCE
(
x_input_vars
.
size
()
==
1
,
"x input of split MUST be a tensor"
);
auto
x_defining_info
=
(
*
param_map
)[
x_input_vars
[
0
]];
op_inputs
.
push_back
(
x_defining_info
.
value
);
// process sections
int
num
=
paddle
::
get
<
int
>
(
op_desc
.
GetAttr
(
"num"
));
if
(
num
<=
0
)
{
if
(
op_desc
.
HasInput
(
"SectionsTensorList"
))
{
// get SectionsTensorList from input
auto
sec_tensor_list
=
op_desc
.
Input
(
"SectionsTensorList"
);
auto
*
combine_op
=
InsertCombineOperationForTarget
(
ctx
,
param_map
,
program
,
sec_tensor_list
);
op_inputs
.
push_back
(
combine_op
->
result
(
0
));
}
else
{
auto
&
attribute_translator
=
AttributeTranslator
::
instance
();
ir
::
Attribute
new_attr
=
attribute_translator
(
"paddle::dialect::IntArrayAttribute"
,
op_desc
.
GetAttr
(
"sections"
));
auto
sec_defin_op
=
InsertFullOperationForAttributeInput
(
ctx
,
program
,
new_attr
);
op_inputs
.
push_back
(
sec_defin_op
->
result
(
0
));
}
}
// process axis
if
(
op_desc
.
HasInput
(
"AxisTensor"
)
&&
op_desc
.
Input
(
"AxisTensor"
).
size
()
>
0
)
{
// get axis from input
auto
axis_var_list
=
op_desc
.
Input
(
"AxisTensor"
);
IR_ENFORCE
(
axis_var_list
.
size
()
==
1
,
"axis tensor input of split MUST be a tensor"
);
auto
axis_defining_info
=
(
*
param_map
)[
axis_var_list
[
0
]];
op_inputs
.
push_back
(
axis_defining_info
.
value
);
}
else
{
auto
&
attribute_translator
=
AttributeTranslator
::
instance
();
ir
::
Attribute
new_attr
=
attribute_translator
(
"ir::Int32Attribute"
,
op_desc
.
GetAttr
(
"axis"
));
auto
sec_defin_op
=
InsertFullOperationForAttributeInput
(
ctx
,
program
,
new_attr
);
op_inputs
.
push_back
(
sec_defin_op
->
result
(
0
));
}
return
op_inputs
;
}
ir
::
AttributeMap
TranslateOpAttribute
(
ir
::
IrContext
*
ctx
,
const
std
::
string
&
normalized_op_name
,
const
OpAttributeInfoList
&
op_attr_infos
,
const
OpDesc
&
op_desc
)
override
{
int
num
=
paddle
::
get
<
int
>
(
op_desc
.
GetAttr
(
"num"
));
if
(
num
>
0
)
{
ir
::
AttributeMap
attribute_map
=
{
{
"num"
,
ir
::
Int32Attribute
::
get
(
ctx
,
op_desc
.
GetAttrIfExists
<
int
>
(
"num"
))},
};
return
attribute_map
;
}
return
{};
}
ir
::
OpInfo
LoopkUpOpInfo
(
ir
::
IrContext
*
ctx
,
const
OpDesc
&
op_desc
)
override
{
int
num
=
paddle
::
get
<
int
>
(
op_desc
.
GetAttr
(
"num"
));
std
::
string
target_op_name
;
if
(
num
>
0
)
{
target_op_name
=
"pd.split_with_num"
;
}
else
{
target_op_name
=
"pd.split"
;
}
const
auto
&
op_info
=
ctx
->
GetRegisteredOpInfo
(
target_op_name
);
if
(
!
op_info
)
{
IR_THROW
(
"Op assign_value should have corresponding OpInfo pd.split"
);
}
return
op_info
;
}
};
struct
FetchOpTranscriber
:
public
OpTranscriber
{
ir
::
Operation
*
operator
()(
ir
::
IrContext
*
ctx
,
TranslationContext
*
param_map
,
...
...
@@ -994,6 +1092,7 @@ OpTranslator::OpTranslator() {
special_handlers
[
"feed"
]
=
FeedOpTranscriber
();
special_handlers
[
"fetch_v2"
]
=
FetchOpTranscriber
();
special_handlers
[
"cast"
]
=
CastOpTranscriber
();
special_handlers
[
"split"
]
=
SplitOpTranscriber
();
special_handlers
[
"lookup_table_v2"
]
=
EmbeddingOpTranscriber
();
special_handlers
[
"lookup_table_v2_grad"
]
=
EmbeddingGradOpTranscriber
();
special_handlers
[
"assign_value"
]
=
AssignValueOpTranscriber
();
...
...
paddle/phi/api/yaml/op_compat.yaml
浏览文件 @
4b6d2f5f
...
...
@@ -2532,7 +2532,17 @@
int_array
:
sections
:
data_type
:
int
tensor_name
:
AxesTensor
scalar
:
axis
:
data_type
:
int
support_tensor
:
true
-
op
:
split_with_num
scalar
:
axis
:
data_type
:
int
support_tensor
:
true
tensor_name
:
AxisTensor
-
op
:
sqrt
backward
:
sqrt_grad, sqrt_double_grad (sqrt_grad_grad)
...
...
test/ir/new_ir/test_standalone_new_ir.py
浏览文件 @
4b6d2f5f
...
...
@@ -141,5 +141,27 @@ class TestAddGradOp(unittest.TestCase):
np
.
testing
.
assert_array_equal
(
out
[
0
],
gold_res
)
class
TestSplitOp
(
unittest
.
TestCase
):
def
test_with_new_ir
(
self
):
place
=
paddle
.
CPUPlace
()
exe
=
paddle
.
static
.
Executor
(
place
)
main_program
=
paddle
.
static
.
Program
()
new_scope
=
paddle
.
static
.
Scope
()
with
paddle
.
static
.
scope_guard
(
new_scope
):
with
paddle
.
static
.
program_guard
(
main_program
):
x
=
paddle
.
static
.
data
(
"x"
,
[
6
,
2
],
dtype
=
"float32"
)
out0
,
out1
,
out2
=
paddle
.
split
(
x
,
num_or_sections
=
3
,
axis
=
0
)
np_a
=
np
.
random
.
rand
(
6
,
2
).
astype
(
"float32"
)
out
=
exe
.
run
(
main_program
,
feed
=
{
"x"
:
np_a
},
fetch_list
=
[
out0
.
name
],
)
np
.
testing
.
assert_array_equal
(
out
[
0
],
np_a
[
0
:
2
])
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录