Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
e429deb0
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e429deb0
编写于
3月 19, 2021
作者:
C
Chen Weihang
提交者:
GitHub
3月 19, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[CustomOp] Support attribute in infershape function (#31713)
* support attribute in infershape * polish details
上级
a4a2b77d
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
289 addition
and
91 deletion
+289
-91
paddle/fluid/extension/include/ext_op_meta_info.h
paddle/fluid/extension/include/ext_op_meta_info.h
+77
-35
paddle/fluid/framework/custom_operator.cc
paddle/fluid/framework/custom_operator.cc
+46
-4
python/paddle/fluid/tests/custom_op/custom_concat_op.cc
python/paddle/fluid/tests/custom_op/custom_concat_op.cc
+90
-0
python/paddle/fluid/tests/custom_op/test_custom_concat.py
python/paddle/fluid/tests/custom_op/test_custom_concat.py
+76
-52
未找到文件。
paddle/fluid/extension/include/ext_op_meta_info.h
浏览文件 @
e429deb0
...
@@ -204,37 +204,67 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> {
...
@@ -204,37 +204,67 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> {
// Record Op infershape core function
// Record Op infershape core function
using
InferShapeFunc
=
std
::
vector
<
std
::
vector
<
int64_t
>>
(
*
)(
using
InferShapeFunc
=
std
::
vector
<
std
::
vector
<
int64_t
>>
(
*
)(
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
input_shapes
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
input_shapes
,
const
std
::
vector
<
std
::
vector
<
std
::
vector
<
int64_t
>>>&
vec_input_shapes
);
const
std
::
vector
<
std
::
vector
<
std
::
vector
<
int64_t
>>>&
vec_input_shapes
,
const
std
::
vector
<
boost
::
any
>&
attrs
);
#define PD_SPECIALIZE_InferShapeCallHelper_FOR_SHAPE(input_type) \
#define PD_SPECIALIZE_InferShapeCallHelper_FOR_SHAPE(input_type) \
template <typename... Tail> \
template <typename... Tail> \
struct InferShapeCallHelper<input_type, Tail...> { \
struct InferShapeCallHelper<input_type, Tail...> { \
template <int in_idx, int vec_in_idx, typename... PreviousArgs> \
template <int in_idx, int vec_in_idx, int attr_idx, \
typename... PreviousArgs> \
static Return InferShape( \
static Return InferShape( \
const std::vector<std::vector<int64_t>>& input_shapes, \
const std::vector<std::vector<int64_t>>& input_shapes, \
const std::vector<std::vector<std::vector<int64_t>>>& \
const std::vector<std::vector<std::vector<int64_t>>>& \
vec_input_shapes, \
vec_input_shapes, \
const
PreviousArgs&... pargs) {
\
const
std::vector<boost::any>& attrs, const PreviousArgs&... pargs) {
\
input_type arg = input_shapes[in_idx]; \
input_type arg = input_shapes[in_idx]; \
return InferShapeCallHelper<Tail...>::template InferShape<
in_idx + 1,
\
return InferShapeCallHelper<Tail...>::template InferShape<
\
vec_in_idx>(
\
in_idx + 1, vec_in_idx, attr_idx>(input_shapes, vec_input_shapes,
\
input_shapes, vec_input_shapes, pargs..., arg);
\
attrs, pargs..., arg);
\
} \
} \
}
}
#define PD_SPECIALIZE_InferShapeCallHelper_FOR_SHAPES(input_type) \
#define PD_SPECIALIZE_InferShapeCallHelper_FOR_SHAPES(input_type) \
template <typename... Tail> \
template <typename... Tail> \
struct InferShapeCallHelper<input_type, Tail...> { \
struct InferShapeCallHelper<input_type, Tail...> { \
template <int in_idx, int vec_in_idx, typename... PreviousArgs> \
template <int in_idx, int vec_in_idx, int attr_idx, \
typename... PreviousArgs> \
static Return InferShape( \
static Return InferShape( \
const std::vector<std::vector<int64_t>>& input_shapes, \
const std::vector<std::vector<int64_t>>& input_shapes, \
const std::vector<std::vector<std::vector<int64_t>>>& \
const std::vector<std::vector<std::vector<int64_t>>>& \
vec_input_shapes, \
vec_input_shapes, \
const
PreviousArgs&... pargs) {
\
const
std::vector<boost::any>& attrs, const PreviousArgs&... pargs) {
\
input_type arg = vec_input_shapes[vec_in_idx]; \
input_type arg = vec_input_shapes[vec_in_idx]; \
return InferShapeCallHelper<Tail...>::template InferShape< \
return InferShapeCallHelper<Tail...>::template InferShape< \
in_idx, vec_in_idx + 1>(input_shapes, vec_input_shapes, pargs..., \
in_idx, vec_in_idx + 1, attr_idx>(input_shapes, vec_input_shapes, \
arg); \
attrs, pargs..., arg); \
} \
}
#define PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(attr_type) \
template <typename... Tail> \
struct InferShapeCallHelper<attr_type, Tail...> { \
template <int in_idx, int vec_in_idx, int attr_idx, \
typename... PreviousArgs> \
static Return InferShape( \
const std::vector<std::vector<int64_t>>& input_shapes, \
const std::vector<std::vector<std::vector<int64_t>>>& \
vec_input_shapes, \
const std::vector<boost::any>& attrs, const PreviousArgs&... pargs) { \
try { \
attr_type arg = boost::any_cast<attr_type>(attrs[attr_idx]); \
return InferShapeCallHelper<Tail...>::template InferShape< \
in_idx, vec_in_idx, attr_idx + 1>(input_shapes, vec_input_shapes, \
attrs, pargs..., arg); \
} catch (boost::bad_any_cast&) { \
PD_THROW( \
"Attribute cast error in custom operator InferShapeFn. " \
"Expected " #attr_type \
" value. InferShapeFn's attribute list must be exactly same as " \
"Forward " \
"KernelFn's attribute list except std::vector<int64_t> " \
"attribute."); \
} \
} \
} \
}
}
...
@@ -245,10 +275,10 @@ template <typename Return, typename... Args, Return (*impl_fn)(Args...)>
...
@@ -245,10 +275,10 @@ template <typename Return, typename... Args, Return (*impl_fn)(Args...)>
struct
InferShapeFuncImpl
<
Return
(
*
)(
Args
...),
impl_fn
>
{
struct
InferShapeFuncImpl
<
Return
(
*
)(
Args
...),
impl_fn
>
{
static
Return
InferShape
(
static
Return
InferShape
(
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
input_shapes
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
input_shapes
,
const
std
::
vector
<
std
::
vector
<
std
::
vector
<
int64_t
>>>&
vec_input_shapes
)
{
const
std
::
vector
<
std
::
vector
<
std
::
vector
<
int64_t
>>>&
vec_input_shapes
,
return
InferShapeCallHelper
<
Args
...,
TypeTag
<
int
>>::
template
InferShape
<
0
,
const
std
::
vector
<
boost
::
any
>&
attrs
)
{
0
>(
return
InferShapeCallHelper
<
Args
...,
TypeTag
<
int
>>::
template
InferShape
<
input_shapes
,
vec_input_shape
s
);
0
,
0
,
0
>(
input_shapes
,
vec_input_shapes
,
attr
s
);
}
}
private:
private:
...
@@ -265,14 +295,26 @@ struct InferShapeFuncImpl<Return (*)(Args...), impl_fn> {
...
@@ -265,14 +295,26 @@ struct InferShapeFuncImpl<Return (*)(Args...), impl_fn> {
PD_SPECIALIZE_InferShapeCallHelper_FOR_SHAPES
(
PD_SPECIALIZE_InferShapeCallHelper_FOR_SHAPES
(
std
::
vector
<
std
::
vector
<
int64_t
>>
);
std
::
vector
<
std
::
vector
<
int64_t
>>
);
PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR
(
const
bool
&
);
PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR
(
const
int
&
);
PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR
(
const
float
&
);
PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR
(
const
int64_t
&
);
PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR
(
const
std
::
string
&
);
PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR
(
const
std
::
vector
<
int
>&
);
PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR
(
const
std
::
vector
<
float
>&
);
PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR
(
const
std
::
vector
<
std
::
string
>&
);
// NOTE(chenweihang): InferShape can't support std::vector<int64_t> attr type,
// because the input type is std::vector<int64_t>, only can use one rule to
// parse std::vector<int64_t> parameter
// end: base template
// end: base template
template
<
typename
T
>
template
<
typename
T
>
struct
InferShapeCallHelper
<
TypeTag
<
T
>>
{
struct
InferShapeCallHelper
<
TypeTag
<
T
>>
{
template
<
int
in_idx
,
int
vec_in_idx
>
template
<
int
in_idx
,
int
vec_in_idx
,
int
attr_idx
>
static
Return
InferShape
(
static
Return
InferShape
(
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
input_shapes
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
input_shapes
,
const
std
::
vector
<
std
::
vector
<
std
::
vector
<
int64_t
>>>&
vec_input_shapes
,
const
std
::
vector
<
std
::
vector
<
std
::
vector
<
int64_t
>>>&
vec_input_shapes
,
const
Args
&
...
args
)
{
const
std
::
vector
<
boost
::
any
>&
attrs
,
const
Args
&
...
args
)
{
return
impl_fn
(
args
...);
return
impl_fn
(
args
...);
}
}
};
};
...
...
paddle/fluid/framework/custom_operator.cc
浏览文件 @
e429deb0
...
@@ -178,7 +178,7 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx,
...
@@ -178,7 +178,7 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx,
"Unsupported `%s` type value as custom attribute now. "
"Unsupported `%s` type value as custom attribute now. "
"Supported data types include `bool`, `int`, `float`, "
"Supported data types include `bool`, `int`, `float`, "
"`int64_t`, `std::string`, `std::vector<int>`, "
"`int64_t`, `std::string`, `std::vector<int>`, "
"`std::vector<float>`, `std::vector<int64_t>, "
"`std::vector<float>`, `std::vector<int64_t>
`
, "
"`std::vector<std::string>`, Please check whether "
"`std::vector<std::string>`, Please check whether "
"the attribute data type and data type string are matched."
,
"the attribute data type and data type string are matched."
,
attr_type_str
));
attr_type_str
));
...
@@ -327,7 +327,7 @@ class CustomOpMaker : public OpProtoAndCheckerMaker {
...
@@ -327,7 +327,7 @@ class CustomOpMaker : public OpProtoAndCheckerMaker {
"Unsupported `%s` type value as custom attribute now. "
"Unsupported `%s` type value as custom attribute now. "
"Supported data types include `bool`, `int`, `float`, "
"Supported data types include `bool`, `int`, `float`, "
"`int64_t`, `std::string`, `std::vector<int>`, "
"`int64_t`, `std::string`, `std::vector<int>`, "
"`std::vector<float>`, `std::vector<int64_t>, "
"`std::vector<float>`, `std::vector<int64_t>
`
, "
"`std::vector<std::string>`, Please check whether "
"`std::vector<std::string>`, Please check whether "
"the attribute data type and data type string are matched."
,
"the attribute data type and data type string are matched."
,
attr_type_str
));
attr_type_str
));
...
@@ -581,7 +581,7 @@ void RegisterOperatorWithMetaInfo(
...
@@ -581,7 +581,7 @@ void RegisterOperatorWithMetaInfo(
ctx
->
ShareDim
(
op_inputs
[
0
],
op_outputs
[
0
]);
ctx
->
ShareDim
(
op_inputs
[
0
],
op_outputs
[
0
]);
};
};
}
else
{
}
else
{
info
.
infer_shape_
=
[
op_inputs
,
op_outputs
,
info
.
infer_shape_
=
[
op_inputs
,
op_outputs
,
op_attrs
,
infer_shape_func
](
InferShapeContext
*
ctx
)
{
infer_shape_func
](
InferShapeContext
*
ctx
)
{
std
::
vector
<
std
::
vector
<
int64_t
>>
input_shapes
;
std
::
vector
<
std
::
vector
<
int64_t
>>
input_shapes
;
std
::
vector
<
std
::
vector
<
std
::
vector
<
int64_t
>>>
vec_input_shapes
;
std
::
vector
<
std
::
vector
<
std
::
vector
<
int64_t
>>>
vec_input_shapes
;
...
@@ -606,8 +606,50 @@ void RegisterOperatorWithMetaInfo(
...
@@ -606,8 +606,50 @@ void RegisterOperatorWithMetaInfo(
}
}
}
}
std
::
vector
<
boost
::
any
>
custom_attrs
;
for
(
auto
&
attr_str
:
op_attrs
)
{
auto
attr_name_and_type
=
detail
::
ParseAttrStr
(
attr_str
);
auto
attr_name
=
attr_name_and_type
[
0
];
auto
attr_type_str
=
attr_name_and_type
[
1
];
if
(
attr_type_str
==
"bool"
)
{
custom_attrs
.
emplace_back
(
ctx
->
Attrs
().
Get
<
bool
>
(
attr_name
));
}
else
if
(
attr_type_str
==
"int"
)
{
custom_attrs
.
emplace_back
(
ctx
->
Attrs
().
Get
<
int
>
(
attr_name
));
}
else
if
(
attr_type_str
==
"float"
)
{
custom_attrs
.
emplace_back
(
ctx
->
Attrs
().
Get
<
float
>
(
attr_name
));
}
else
if
(
attr_type_str
==
"int64_t"
)
{
custom_attrs
.
emplace_back
(
ctx
->
Attrs
().
Get
<
int64_t
>
(
attr_name
));
}
else
if
(
attr_type_str
==
"std::string"
)
{
custom_attrs
.
emplace_back
(
ctx
->
Attrs
().
Get
<
std
::
string
>
(
attr_name
));
}
else
if
(
attr_type_str
==
"std::vector<int>"
)
{
custom_attrs
.
emplace_back
(
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
attr_name
));
}
else
if
(
attr_type_str
==
"std::vector<float>"
)
{
custom_attrs
.
emplace_back
(
ctx
->
Attrs
().
Get
<
std
::
vector
<
float
>>
(
attr_name
));
}
else
if
(
attr_type_str
==
"std::vector<int64_t>"
)
{
// NOTE(chenweihang): InferShape can't support std::vector<int64_t>
// attr type, because the input type is std::vector<int64_t>, only
// can use one rule to parse std::vector<int64_t> parameter
continue
;
}
else
if
(
attr_type_str
==
"std::vector<std::string>"
)
{
custom_attrs
.
emplace_back
(
ctx
->
Attrs
().
Get
<
std
::
vector
<
std
::
string
>>
(
attr_name
));
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unsupported `%s` type value as custom attribute now. "
"Supported data types include `bool`, `int`, `float`, "
"`int64_t`, `std::string`, `std::vector<int>`, "
"`std::vector<float>`, `std::vector<std::string>`, "
"Please check whether the attribute data type and "
"data type string are matched."
,
attr_type_str
));
}
}
VLOG
(
1
)
<<
"Custom Operator: InferShape - calc output ddim."
;
VLOG
(
1
)
<<
"Custom Operator: InferShape - calc output ddim."
;
auto
output_shapes
=
infer_shape_func
(
input_shapes
,
vec_input_shapes
);
auto
output_shapes
=
infer_shape_func
(
input_shapes
,
vec_input_shapes
,
custom_attrs
);
VLOG
(
1
)
<<
"Custom Operator: InferShape - set output ddim."
;
VLOG
(
1
)
<<
"Custom Operator: InferShape - set output ddim."
;
for
(
size_t
i
=
0
;
i
<
op_outputs
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
op_outputs
.
size
();
++
i
)
{
...
...
python/paddle/fluid/tests/custom_op/custom_concat_op.cc
浏览文件 @
e429deb0
...
@@ -144,3 +144,93 @@ PD_BUILD_GRAD_OP(custom_concat)
...
@@ -144,3 +144,93 @@ PD_BUILD_GRAD_OP(custom_concat)
.
Inputs
({
paddle
::
Vec
(
"X"
),
paddle
::
Grad
(
"Out"
),
"Axis"
})
.
Inputs
({
paddle
::
Vec
(
"X"
),
paddle
::
Grad
(
"Out"
),
"Axis"
})
.
Outputs
({
paddle
::
Grad
(
paddle
::
Vec
(
"X"
))})
.
Outputs
({
paddle
::
Grad
(
paddle
::
Vec
(
"X"
))})
.
SetKernelFn
(
PD_KERNEL
(
ConcatBackwardDynamicAxis
));
.
SetKernelFn
(
PD_KERNEL
(
ConcatBackwardDynamicAxis
));
std
::
vector
<
paddle
::
Tensor
>
ConcatForwardStaticAxis
(
const
std
::
vector
<
paddle
::
Tensor
>&
inputs
,
const
int64_t
&
axis
)
{
// check inputs
PD_CHECK
(
inputs
.
size
()
>=
1
,
"No Tensor need to be concat."
);
for
(
auto
&
t
:
inputs
)
{
CHECK_INPUT
(
t
);
}
// compute output shape
int64_t
rank
=
static_cast
<
int64_t
>
(
inputs
[
0
].
shape
().
size
());
auto
final_axis
=
ComputeAxis
(
axis
,
rank
);
std
::
vector
<
std
::
vector
<
int64_t
>>
in_shapes
;
for
(
auto
&
t
:
inputs
)
{
in_shapes
.
emplace_back
(
t
.
shape
());
}
auto
out_shape
=
ComputeOutShape
(
in_shapes
,
final_axis
);
// create output
auto
out
=
paddle
::
Tensor
(
paddle
::
PlaceType
::
kCPU
);
out
.
reshape
(
out_shape
);
// calc
PD_DISPATCH_FLOATING_AND_INTEGRAL_TYPES
(
inputs
[
0
].
type
(),
"ConcatCpuKernel"
,
([
&
]
{
ConcatCpuKernel
<
data_t
>
(
inputs
,
&
out
,
final_axis
);
}));
return
{
out
};
}
std
::
vector
<
paddle
::
Tensor
>
ConcatBackwardStaticAxis
(
const
std
::
vector
<
paddle
::
Tensor
>&
inputs
,
const
paddle
::
Tensor
&
grad_out
,
const
int64_t
&
axis
)
{
// check input
PD_CHECK
(
inputs
.
size
()
>=
1
,
"No Tensor need to be concat."
);
for
(
auto
&
t
:
inputs
)
{
CHECK_INPUT
(
t
);
}
CHECK_INPUT
(
grad_out
);
// compate axis
int64_t
rank
=
static_cast
<
int64_t
>
(
inputs
[
0
].
shape
().
size
());
auto
final_axis
=
ComputeAxis
(
axis
,
rank
);
// create outputs
std
::
vector
<
paddle
::
Tensor
>
grad_inputs
;
for
(
auto
&
t
:
inputs
)
{
auto
grad
=
paddle
::
Tensor
(
paddle
::
PlaceType
::
kCPU
);
grad
.
reshape
(
t
.
shape
());
grad_inputs
.
emplace_back
(
grad
);
}
// calc
PD_DISPATCH_FLOATING_AND_INTEGRAL_TYPES
(
grad_out
.
type
(),
"SplitCpuKernel"
,
([
&
]
{
SplitCpuKernel
<
data_t
>
(
grad_out
,
inputs
,
&
grad_inputs
,
final_axis
);
}));
return
grad_inputs
;
}
std
::
vector
<
std
::
vector
<
int64_t
>>
ConcatInferShapeStaticAxis
(
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
input_shapes
,
const
int64_t
&
axis
)
{
int64_t
rank
=
static_cast
<
int64_t
>
(
input_shapes
[
0
].
size
());
auto
final_axis
=
ComputeAxis
(
axis
,
rank
);
auto
out_shape
=
ComputeOutShape
(
input_shapes
,
final_axis
);
return
{
out_shape
};
}
std
::
vector
<
paddle
::
DataType
>
ConcatInferDtypeStaticAxis
(
const
std
::
vector
<
paddle
::
DataType
>&
input_dtypes
)
{
return
{
input_dtypes
[
0
]};
}
PD_BUILD_OP
(
custom_concat_with_attr
)
.
Inputs
({
paddle
::
Vec
(
"X"
)})
.
Outputs
({
"Out"
})
.
Attrs
({
"axis: int64_t"
})
.
SetKernelFn
(
PD_KERNEL
(
ConcatForwardStaticAxis
))
.
SetInferShapeFn
(
PD_INFER_SHAPE
(
ConcatInferShapeStaticAxis
))
.
SetInferDtypeFn
(
PD_INFER_DTYPE
(
ConcatInferDtypeStaticAxis
));
PD_BUILD_GRAD_OP
(
custom_concat_with_attr
)
.
Inputs
({
paddle
::
Vec
(
"X"
),
paddle
::
Grad
(
"Out"
)})
.
Outputs
({
paddle
::
Grad
(
paddle
::
Vec
(
"X"
))})
.
Attrs
({
"axis: int64_t"
})
.
SetKernelFn
(
PD_KERNEL
(
ConcatBackwardStaticAxis
));
python/paddle/fluid/tests/custom_op/test_custom_concat.py
浏览文件 @
e429deb0
...
@@ -45,13 +45,15 @@ custom_ops = load(
...
@@ -45,13 +45,15 @@ custom_ops = load(
verbose
=
True
)
verbose
=
True
)
def
concat_dynamic
(
func
,
d
evice
,
dtype
,
np_inputs
,
axis_v
):
def
concat_dynamic
(
func
,
d
type
,
np_inputs
,
axis_v
,
with_attr
=
False
):
paddle
.
set_device
(
device
)
paddle
.
set_device
(
"cpu"
)
inputs
=
[
inputs
=
[
paddle
.
to_tensor
(
paddle
.
to_tensor
(
x
,
dtype
=
dtype
,
place
=
device
,
stop_gradient
=
False
)
x
,
dtype
=
dtype
,
stop_gradient
=
False
)
for
x
in
np_inputs
for
x
in
np_inputs
]
]
if
with_attr
:
axis
=
axis_v
else
:
axis
=
paddle
.
full
(
shape
=
[
1
],
dtype
=
'int64'
,
fill_value
=
axis_v
)
axis
=
paddle
.
full
(
shape
=
[
1
],
dtype
=
'int64'
,
fill_value
=
axis_v
)
out
=
func
(
inputs
,
axis
)
out
=
func
(
inputs
,
axis
)
out
.
stop_gradient
=
False
out
.
stop_gradient
=
False
...
@@ -60,13 +62,16 @@ def concat_dynamic(func, device, dtype, np_inputs, axis_v):
...
@@ -60,13 +62,16 @@ def concat_dynamic(func, device, dtype, np_inputs, axis_v):
return
out
.
numpy
(),
grad_inputs
return
out
.
numpy
(),
grad_inputs
def
concat_static
(
func
,
d
evice
,
dtype
,
np_inputs
,
axis_v
):
def
concat_static
(
func
,
d
type
,
np_inputs
,
axis_v
,
with_attr
=
False
):
paddle
.
enable_static
()
paddle
.
enable_static
()
paddle
.
set_device
(
device
)
paddle
.
set_device
(
"cpu"
)
with
static
.
scope_guard
(
static
.
Scope
()):
with
static
.
scope_guard
(
static
.
Scope
()):
with
static
.
program_guard
(
static
.
Program
()):
with
static
.
program_guard
(
static
.
Program
()):
x1
=
static
.
data
(
name
=
"x1"
,
shape
=
[
2
,
3
],
dtype
=
dtype
)
x1
=
static
.
data
(
name
=
"x1"
,
shape
=
[
2
,
3
],
dtype
=
dtype
)
x2
=
static
.
data
(
name
=
"x2"
,
shape
=
[
2
,
3
],
dtype
=
dtype
)
x2
=
static
.
data
(
name
=
"x2"
,
shape
=
[
2
,
3
],
dtype
=
dtype
)
if
with_attr
:
axis
=
axis_v
else
:
axis
=
paddle
.
full
(
shape
=
[
1
],
dtype
=
'int64'
,
fill_value
=
axis_v
)
axis
=
paddle
.
full
(
shape
=
[
1
],
dtype
=
'int64'
,
fill_value
=
axis_v
)
x1
.
stop_gradient
=
False
x1
.
stop_gradient
=
False
x2
.
stop_gradient
=
False
x2
.
stop_gradient
=
False
...
@@ -78,13 +83,20 @@ def concat_static(func, device, dtype, np_inputs, axis_v):
...
@@ -78,13 +83,20 @@ def concat_static(func, device, dtype, np_inputs, axis_v):
exe
=
static
.
Executor
()
exe
=
static
.
Executor
()
exe
.
run
(
static
.
default_startup_program
())
exe
.
run
(
static
.
default_startup_program
())
out_v
,
x1_grad_v
,
x2_grad_v
=
exe
.
run
(
if
with_attr
:
static
.
default_main_program
(),
feed_dict
=
{
feed
=
{
"x1"
:
np_inputs
[
0
].
astype
(
dtype
),
"x2"
:
np_inputs
[
1
].
astype
(
dtype
)
}
else
:
feed_dict
=
{
"x1"
:
np_inputs
[
0
].
astype
(
dtype
),
"x1"
:
np_inputs
[
0
].
astype
(
dtype
),
"x2"
:
np_inputs
[
1
].
astype
(
dtype
),
"x2"
:
np_inputs
[
1
].
astype
(
dtype
),
"axis"
:
axis
"axis"
:
axis
},
}
out_v
,
x1_grad_v
,
x2_grad_v
=
exe
.
run
(
static
.
default_main_program
(),
feed
=
feed_dict
,
fetch_list
=
[
out
.
name
,
x1
.
name
+
"@GRAD"
,
x2
.
name
+
"@GRAD"
])
fetch_list
=
[
out
.
name
,
x1
.
name
+
"@GRAD"
,
x2
.
name
+
"@GRAD"
])
paddle
.
disable_static
()
paddle
.
disable_static
()
return
out_v
,
x1_grad_v
,
x2_grad_v
return
out_v
,
x1_grad_v
,
x2_grad_v
...
@@ -93,55 +105,67 @@ def concat_static(func, device, dtype, np_inputs, axis_v):
...
@@ -93,55 +105,67 @@ def concat_static(func, device, dtype, np_inputs, axis_v):
class
TestCustomConcatDynamicAxisJit
(
unittest
.
TestCase
):
class
TestCustomConcatDynamicAxisJit
(
unittest
.
TestCase
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
dtypes
=
[
'float32'
,
'float64'
,
'int32'
,
'int64'
]
self
.
dtypes
=
[
'float32'
,
'float64'
,
'int32'
,
'int64'
]
self
.
devices
=
[
'cpu'
]
self
.
np_inputs
=
[
self
.
np_inputs
=
[
np
.
array
([[
1
,
2
,
3
],
[
4
,
5
,
6
]]),
np
.
array
([[
1
,
2
,
3
],
[
4
,
5
,
6
]]),
np
.
array
([[
11
,
12
,
13
],
[
14
,
15
,
16
]])
np
.
array
([[
11
,
12
,
13
],
[
14
,
15
,
16
]])
]
]
self
.
axises
=
[
0
,
1
]
self
.
axises
=
[
0
,
1
]
def
check_output
(
self
,
out
,
pd_out
,
name
):
self
.
assertTrue
(
np
.
array_equal
(
out
,
pd_out
),
"custom op {}: {},
\n
paddle api {}: {}"
.
format
(
name
,
out
,
name
,
pd_out
))
def
test_dynamic
(
self
):
def
test_dynamic
(
self
):
for
device
in
self
.
devices
:
for
dtype
in
self
.
dtypes
:
for
dtype
in
self
.
dtypes
:
for
axis
in
self
.
axises
:
for
axis
in
self
.
axises
:
out
,
grad_inputs
=
concat_dynamic
(
custom_ops
.
custom_concat
,
out
,
grad_inputs
=
concat_dynamic
(
custom_ops
.
custom_concat
,
device
,
dtype
,
dtype
,
self
.
np_inputs
,
axis
)
pd_out
,
pd_grad_inputs
=
concat_dynamic
(
paddle
.
concat
,
dtype
,
self
.
np_inputs
,
axis
)
self
.
np_inputs
,
axis
)
pd_out
,
pd_grad_inputs
=
concat_dynamic
(
paddle
.
concat
,
device
,
dtype
,
self
.
np_inputs
,
axis
)
self
.
assertTrue
(
self
.
check_output
(
out
,
pd_out
,
"out"
)
np
.
array_equal
(
out
,
pd_out
),
"custom op out: {},
\n
paddle api out: {}"
.
format
(
out
,
pd_out
))
for
x_grad
,
pd_x_grad
in
zip
(
grad_inputs
,
pd_grad_inputs
):
for
x_grad
,
pd_x_grad
in
zip
(
grad_inputs
,
pd_grad_inputs
):
self
.
assertTrue
(
self
.
check_output
(
x_grad
,
pd_x_grad
,
"x_grad"
)
np
.
array_equal
(
x_grad
,
pd_x_grad
),
"custom op x grad: {},
\n
paddle api x grad: {}"
.
format
(
x_grad
,
pd_x_grad
))
def
test_static
(
self
):
def
test_static
(
self
):
for
device
in
self
.
devices
:
for
dtype
in
self
.
dtypes
:
for
dtype
in
self
.
dtypes
:
for
axis
in
self
.
axises
:
for
axis
in
self
.
axises
:
out
,
x1_grad
,
x2_grad
=
concat_static
(
out
,
x1_grad
,
x2_grad
=
concat_static
(
custom_ops
.
custom_concat
,
device
,
dtype
,
self
.
np_inputs
,
custom_ops
.
custom_concat
,
dtype
,
self
.
np_inputs
,
axis
)
axis
)
pd_out
,
pd_x1_grad
,
pd_x2_grad
=
concat_static
(
pd_out
,
pd_x1_grad
,
pd_x2_grad
=
concat_static
(
paddle
.
concat
,
device
,
dtype
,
self
.
np_inputs
,
axis
)
paddle
.
concat
,
dtype
,
self
.
np_inputs
,
axis
)
self
.
assertTrue
(
self
.
check_output
(
out
,
pd_out
,
"out"
)
np
.
array_equal
(
out
,
pd_out
),
self
.
check_output
(
x1_grad
,
pd_x1_grad
,
"x1_grad"
)
"custom op out: {},
\n
paddle api out: {}"
.
format
(
self
.
check_output
(
x2_grad
,
pd_x2_grad
,
"x2_grad"
)
out
,
pd_out
))
self
.
assertTrue
(
def
test_dynamic_with_attr
(
self
):
np
.
array_equal
(
x1_grad
,
pd_x1_grad
),
for
dtype
in
self
.
dtypes
:
"custom op x1_grad: {},
\n
paddle api x1_grad: {}"
.
for
axis
in
self
.
axises
:
format
(
x1_grad
,
pd_x1_grad
))
out
,
grad_inputs
=
concat_dynamic
(
self
.
assertTrue
(
custom_ops
.
custom_concat_with_attr
,
dtype
,
self
.
np_inputs
,
np
.
array_equal
(
x2_grad
,
pd_x2_grad
),
axis
,
True
)
"custom op x2_grad: {},
\n
paddle api x2_grad: {}"
.
pd_out
,
pd_grad_inputs
=
concat_dynamic
(
format
(
x2_grad
,
pd_x2_grad
))
paddle
.
concat
,
dtype
,
self
.
np_inputs
,
axis
,
True
)
self
.
check_output
(
out
,
pd_out
,
"out"
)
for
x_grad
,
pd_x_grad
in
zip
(
grad_inputs
,
pd_grad_inputs
):
self
.
check_output
(
x_grad
,
pd_x_grad
,
"x_grad"
)
def
test_static_with_attr
(
self
):
for
dtype
in
self
.
dtypes
:
for
axis
in
self
.
axises
:
out
,
x1_grad
,
x2_grad
=
concat_static
(
custom_ops
.
custom_concat_with_attr
,
dtype
,
self
.
np_inputs
,
axis
,
True
)
pd_out
,
pd_x1_grad
,
pd_x2_grad
=
concat_static
(
paddle
.
concat
,
dtype
,
self
.
np_inputs
,
axis
,
True
)
self
.
check_output
(
out
,
pd_out
,
"out"
)
self
.
check_output
(
x1_grad
,
pd_x1_grad
,
"x1_grad"
)
self
.
check_output
(
x2_grad
,
pd_x2_grad
,
"x2_grad"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录