未验证 提交 3a339cc0 编写于 作者: C Chen Weihang 提交者: GitHub

fix custom op infershape error (#38045)

上级 d3569c7e
......@@ -164,10 +164,10 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> {
}
};
PD_SPECIALIZE_ComputeCallHelper(const bool&);
PD_SPECIALIZE_ComputeCallHelper(const int&);
PD_SPECIALIZE_ComputeCallHelper(const float&);
PD_SPECIALIZE_ComputeCallHelper(const int64_t&);
PD_SPECIALIZE_ComputeCallHelper(bool);
PD_SPECIALIZE_ComputeCallHelper(int);
PD_SPECIALIZE_ComputeCallHelper(float);
PD_SPECIALIZE_ComputeCallHelper(int64_t);
PD_SPECIALIZE_ComputeCallHelper(const std::string&);
PD_SPECIALIZE_ComputeCallHelper(const std::vector<int>&);
PD_SPECIALIZE_ComputeCallHelper(const std::vector<float>&);
......@@ -181,15 +181,10 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> {
// NOTE(chenweihang): Used to be compatible with the 2.0.1 released
// interface, and will be deprecated in the future
PD_SPECIALIZE_ComputeCallHelper(bool);
PD_SPECIALIZE_ComputeCallHelper(int);
PD_SPECIALIZE_ComputeCallHelper(float);
PD_SPECIALIZE_ComputeCallHelper(int64_t);
PD_SPECIALIZE_ComputeCallHelper(std::string);
PD_SPECIALIZE_ComputeCallHelper(std::vector<int>);
PD_SPECIALIZE_ComputeCallHelper(std::vector<float>);
PD_SPECIALIZE_ComputeCallHelper(std::vector<int64_t>);
PD_SPECIALIZE_ComputeCallHelper(std::vector<std::string>);
PD_SPECIALIZE_ComputeCallHelper(const bool&);
PD_SPECIALIZE_ComputeCallHelper(const int&);
PD_SPECIALIZE_ComputeCallHelper(const float&);
PD_SPECIALIZE_ComputeCallHelper(const int64_t&);
// end: base template
template <typename T>
......@@ -315,10 +310,10 @@ struct InferShapeFuncImpl<Return (*)(Args...), impl_fn> {
PD_SPECIALIZE_InferShapeCallHelper_FOR_SHAPES(
std::vector<std::vector<int64_t>>);
PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(const bool&);
PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(const int&);
PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(const float&);
PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(const int64_t&);
PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(bool);
PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(int);
PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(float);
PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(int64_t);
PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(const std::string&);
PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(const std::vector<int>&);
PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(const std::vector<float>&);
......@@ -327,6 +322,13 @@ struct InferShapeFuncImpl<Return (*)(Args...), impl_fn> {
// because the input type is std::vector<int64_t>, only can use one rule to
// parse std::vector<int64_t> parameter
// NOTE(chenweihang): Used to be compatible with the 2.0.1 released
// interface, and will be deprecated in the future
PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(const bool&);
PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(const int&);
PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(const float&);
PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(const int64_t&);
// end: base template
template <typename T>
struct InferShapeCallHelper<TypeTag<T>> {
......
......@@ -127,11 +127,11 @@ std::vector<paddle::Tensor> AttrTestForward(
int int_attr,
float float_attr,
int64_t int64_attr,
std::string str_attr,
std::vector<int> int_vec_attr,
std::vector<float> float_vec_attr,
std::vector<int64_t> int64_vec_attr,
std::vector<std::string> str_vec_attr) {
const std::string& str_attr,
const std::vector<int>& int_vec_attr,
const std::vector<float>& float_vec_attr,
const std::vector<int64_t>& int64_vec_attr,
const std::vector<std::string>& str_vec_attr) {
auto out = paddle::Tensor(paddle::PlaceType::kCPU, x.shape());
PD_DISPATCH_FLOATING_TYPES(
......@@ -154,12 +154,25 @@ std::vector<paddle::Tensor> AttrTestForward(
return {out};
}
std::vector<std::vector<int64_t>> AttrTestInferShape(
const std::vector<int64_t>& x_shape,
bool bool_attr,
int int_attr,
float float_attr,
int64_t int64_attr,
const std::string& str_attr,
const std::vector<int>& int_vec_attr,
const std::vector<float>& float_vec_attr,
const std::vector<std::string>& str_vec_attr) {
return {x_shape};
}
// The attrs of backward op must be the subset of attrs of forward op
std::vector<paddle::Tensor> AttrTestBackward(
const paddle::Tensor& grad_out,
int int_attr,
std::vector<float> float_vec_attr,
std::vector<std::string> str_vec_attr) {
const std::vector<float>& float_vec_attr,
const std::vector<std::string>& str_vec_attr) {
auto grad_x = paddle::Tensor(paddle::PlaceType::kCPU, grad_out.shape());
PD_DISPATCH_FLOATING_TYPES(grad_out.type(), "assign_cpu_kernel", ([&] {
......@@ -207,6 +220,19 @@ std::vector<paddle::Tensor> ConstAttrTestForward(
return {out};
}
std::vector<std::vector<int64_t>> ConstAttrTestInferShape(
const std::vector<int64_t>& x_shape,
const bool& bool_attr,
const int& int_attr,
const float& float_attr,
const int64_t& int64_attr,
const std::string& str_attr,
const std::vector<int>& int_vec_attr,
const std::vector<float>& float_vec_attr,
const std::vector<std::string>& str_vec_attr) {
return {x_shape};
}
// The attrs of backward op must be the subset of attrs of forward op
std::vector<paddle::Tensor> ConstAttrTestBackward(
const paddle::Tensor& grad_out,
......@@ -239,7 +265,8 @@ PD_BUILD_OP(attr_test)
"float_vec_attr: std::vector<float>",
"int64_vec_attr: std::vector<int64_t>",
"str_vec_attr: std::vector<std::string>"})
.SetKernelFn(PD_KERNEL(AttrTestForward));
.SetKernelFn(PD_KERNEL(AttrTestForward))
.SetInferShapeFn(PD_INFER_SHAPE(AttrTestInferShape));
PD_BUILD_GRAD_OP(attr_test)
.Inputs({paddle::Grad("Out")})
......@@ -261,7 +288,8 @@ PD_BUILD_OP(const_attr_test)
"float_vec_attr: std::vector<float>",
"int64_vec_attr: std::vector<int64_t>",
"str_vec_attr: std::vector<std::string>"})
.SetKernelFn(PD_KERNEL(AttrTestForward));
.SetKernelFn(PD_KERNEL(ConstAttrTestForward))
.SetInferShapeFn(PD_INFER_SHAPE(ConstAttrTestInferShape));
PD_BUILD_GRAD_OP(const_attr_test)
.Inputs({paddle::Grad("Out")})
......@@ -269,4 +297,4 @@ PD_BUILD_GRAD_OP(const_attr_test)
.Attrs({"int_attr: int",
"float_vec_attr: std::vector<float>",
"str_vec_attr: std::vector<std::string>"})
.SetKernelFn(PD_KERNEL(AttrTestBackward));
.SetKernelFn(PD_KERNEL(ConstAttrTestBackward));
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册