diff --git a/paddle/pten/api/ext/op_meta_info.h b/paddle/pten/api/ext/op_meta_info.h index 351e88b57bd8b8dd767ec266018604e568c8d5ba..66336c7466bd7194974d037adcc1e82779e5d68a 100644 --- a/paddle/pten/api/ext/op_meta_info.h +++ b/paddle/pten/api/ext/op_meta_info.h @@ -164,10 +164,10 @@ struct KernelFuncImpl { } }; - PD_SPECIALIZE_ComputeCallHelper(const bool&); - PD_SPECIALIZE_ComputeCallHelper(const int&); - PD_SPECIALIZE_ComputeCallHelper(const float&); - PD_SPECIALIZE_ComputeCallHelper(const int64_t&); + PD_SPECIALIZE_ComputeCallHelper(bool); + PD_SPECIALIZE_ComputeCallHelper(int); + PD_SPECIALIZE_ComputeCallHelper(float); + PD_SPECIALIZE_ComputeCallHelper(int64_t); PD_SPECIALIZE_ComputeCallHelper(const std::string&); PD_SPECIALIZE_ComputeCallHelper(const std::vector&); PD_SPECIALIZE_ComputeCallHelper(const std::vector&); @@ -181,15 +181,10 @@ struct KernelFuncImpl { // NOTE(chenweihang): Used to be compatible with the 2.0.1 released // interface, and will be deprecated in the future - PD_SPECIALIZE_ComputeCallHelper(bool); - PD_SPECIALIZE_ComputeCallHelper(int); - PD_SPECIALIZE_ComputeCallHelper(float); - PD_SPECIALIZE_ComputeCallHelper(int64_t); - PD_SPECIALIZE_ComputeCallHelper(std::string); - PD_SPECIALIZE_ComputeCallHelper(std::vector); - PD_SPECIALIZE_ComputeCallHelper(std::vector); - PD_SPECIALIZE_ComputeCallHelper(std::vector); - PD_SPECIALIZE_ComputeCallHelper(std::vector); + PD_SPECIALIZE_ComputeCallHelper(const bool&); + PD_SPECIALIZE_ComputeCallHelper(const int&); + PD_SPECIALIZE_ComputeCallHelper(const float&); + PD_SPECIALIZE_ComputeCallHelper(const int64_t&); // end: base template template @@ -315,10 +310,10 @@ struct InferShapeFuncImpl { PD_SPECIALIZE_InferShapeCallHelper_FOR_SHAPES( std::vector>); - PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(const bool&); - PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(const int&); - PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(const float&); - PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(const int64_t&); + PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(bool); + PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(int); + PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(float); + PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(int64_t); PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(const std::string&); PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(const std::vector&); PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(const std::vector&); @@ -327,6 +322,13 @@ struct InferShapeFuncImpl { // because the input type is std::vector, only can use one rule to // parse std::vector parameter + // NOTE(chenweihang): Used to be compatible with the 2.0.1 released + // interface, and will be deprecated in the future + PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(const bool&); + PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(const int&); + PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(const float&); + PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(const int64_t&); + // end: base template template struct InferShapeCallHelper> { diff --git a/python/paddle/fluid/tests/custom_op/attr_test_op.cc b/python/paddle/fluid/tests/custom_op/attr_test_op.cc index 297acd602086541cbdfc274a9b248b08b5c94758..0b9f4284aa70be7362f39296db7735c98c117f24 100644 --- a/python/paddle/fluid/tests/custom_op/attr_test_op.cc +++ b/python/paddle/fluid/tests/custom_op/attr_test_op.cc @@ -127,11 +127,11 @@ std::vector AttrTestForward( int int_attr, float float_attr, int64_t int64_attr, - std::string str_attr, - std::vector int_vec_attr, - std::vector float_vec_attr, - std::vector int64_vec_attr, - std::vector str_vec_attr) { + const std::string& str_attr, + const std::vector& int_vec_attr, + const std::vector& float_vec_attr, + const std::vector& int64_vec_attr, + const std::vector& str_vec_attr) { auto out = paddle::Tensor(paddle::PlaceType::kCPU, x.shape()); PD_DISPATCH_FLOATING_TYPES( @@ -154,12 +154,25 @@ std::vector AttrTestForward( return {out}; } +std::vector> AttrTestInferShape( + const std::vector& x_shape, + bool bool_attr, + int int_attr, + float float_attr, + int64_t int64_attr, + const std::string& str_attr, + const std::vector& int_vec_attr, + const std::vector& float_vec_attr, + const std::vector& str_vec_attr) { + return {x_shape}; +} + // The attrs of backward op must be the subset of attrs of forward op std::vector AttrTestBackward( const paddle::Tensor& grad_out, int int_attr, - std::vector float_vec_attr, - std::vector str_vec_attr) { + const std::vector& float_vec_attr, + const std::vector& str_vec_attr) { auto grad_x = paddle::Tensor(paddle::PlaceType::kCPU, grad_out.shape()); PD_DISPATCH_FLOATING_TYPES(grad_out.type(), "assign_cpu_kernel", ([&] { @@ -207,6 +220,19 @@ std::vector ConstAttrTestForward( return {out}; } +std::vector> ConstAttrTestInferShape( + const std::vector& x_shape, + const bool& bool_attr, + const int& int_attr, + const float& float_attr, + const int64_t& int64_attr, + const std::string& str_attr, + const std::vector& int_vec_attr, + const std::vector& float_vec_attr, + const std::vector& str_vec_attr) { + return {x_shape}; +} + // The attrs of backward op must be the subset of attrs of forward op std::vector ConstAttrTestBackward( const paddle::Tensor& grad_out, @@ -239,7 +265,8 @@ PD_BUILD_OP(attr_test) "float_vec_attr: std::vector", "int64_vec_attr: std::vector", "str_vec_attr: std::vector"}) - .SetKernelFn(PD_KERNEL(AttrTestForward)); + .SetKernelFn(PD_KERNEL(AttrTestForward)) + .SetInferShapeFn(PD_INFER_SHAPE(AttrTestInferShape)); PD_BUILD_GRAD_OP(attr_test) .Inputs({paddle::Grad("Out")}) @@ -261,7 +288,8 @@ PD_BUILD_OP(const_attr_test) "float_vec_attr: std::vector", "int64_vec_attr: std::vector", "str_vec_attr: std::vector"}) - .SetKernelFn(PD_KERNEL(AttrTestForward)); + .SetKernelFn(PD_KERNEL(ConstAttrTestForward)) + .SetInferShapeFn(PD_INFER_SHAPE(ConstAttrTestInferShape)); PD_BUILD_GRAD_OP(const_attr_test) .Inputs({paddle::Grad("Out")}) @@ -269,4 +297,4 @@ PD_BUILD_GRAD_OP(const_attr_test) .Attrs({"int_attr: int", "float_vec_attr: std::vector", "str_vec_attr: std::vector"}) - .SetKernelFn(PD_KERNEL(AttrTestBackward)); + .SetKernelFn(PD_KERNEL(ConstAttrTestBackward));