diff --git a/paddle/fluid/extension/include/ext_op_meta_info.h b/paddle/fluid/extension/include/ext_op_meta_info.h index bad1d6ad9f06a39552d4eaad96aad0e01c40352d..c400164c7543da9878d0fb51a6f239dfaff5beff 100644 --- a/paddle/fluid/extension/include/ext_op_meta_info.h +++ b/paddle/fluid/extension/include/ext_op_meta_info.h @@ -204,38 +204,68 @@ struct KernelFuncImpl { // Record Op infershape core function using InferShapeFunc = std::vector> (*)( const std::vector>& input_shapes, - const std::vector>>& vec_input_shapes); + const std::vector>>& vec_input_shapes, + const std::vector& attrs); -#define PD_SPECIALIZE_InferShapeCallHelper_FOR_SHAPE(input_type) \ - template \ - struct InferShapeCallHelper { \ - template \ - static Return InferShape( \ - const std::vector>& input_shapes, \ - const std::vector>>& \ - vec_input_shapes, \ - const PreviousArgs&... pargs) { \ - input_type arg = input_shapes[in_idx]; \ - return InferShapeCallHelper::template InferShape( \ - input_shapes, vec_input_shapes, pargs..., arg); \ - } \ +#define PD_SPECIALIZE_InferShapeCallHelper_FOR_SHAPE(input_type) \ + template \ + struct InferShapeCallHelper { \ + template \ + static Return InferShape( \ + const std::vector>& input_shapes, \ + const std::vector>>& \ + vec_input_shapes, \ + const std::vector& attrs, const PreviousArgs&... pargs) { \ + input_type arg = input_shapes[in_idx]; \ + return InferShapeCallHelper::template InferShape< \ + in_idx + 1, vec_in_idx, attr_idx>(input_shapes, vec_input_shapes, \ + attrs, pargs..., arg); \ + } \ } -#define PD_SPECIALIZE_InferShapeCallHelper_FOR_SHAPES(input_type) \ - template \ - struct InferShapeCallHelper { \ - template \ - static Return InferShape( \ - const std::vector>& input_shapes, \ - const std::vector>>& \ - vec_input_shapes, \ - const PreviousArgs&... pargs) { \ - input_type arg = vec_input_shapes[vec_in_idx]; \ - return InferShapeCallHelper::template InferShape< \ - in_idx, vec_in_idx + 1>(input_shapes, vec_input_shapes, pargs..., \ - arg); \ - } \ +#define PD_SPECIALIZE_InferShapeCallHelper_FOR_SHAPES(input_type) \ + template \ + struct InferShapeCallHelper { \ + template \ + static Return InferShape( \ + const std::vector>& input_shapes, \ + const std::vector>>& \ + vec_input_shapes, \ + const std::vector& attrs, const PreviousArgs&... pargs) { \ + input_type arg = vec_input_shapes[vec_in_idx]; \ + return InferShapeCallHelper::template InferShape< \ + in_idx, vec_in_idx + 1, attr_idx>(input_shapes, vec_input_shapes, \ + attrs, pargs..., arg); \ + } \ + } + +#define PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(attr_type) \ + template \ + struct InferShapeCallHelper { \ + template \ + static Return InferShape( \ + const std::vector>& input_shapes, \ + const std::vector>>& \ + vec_input_shapes, \ + const std::vector& attrs, const PreviousArgs&... pargs) { \ + try { \ + attr_type arg = boost::any_cast(attrs[attr_idx]); \ + return InferShapeCallHelper::template InferShape< \ + in_idx, vec_in_idx, attr_idx + 1>(input_shapes, vec_input_shapes, \ + attrs, pargs..., arg); \ + } catch (boost::bad_any_cast&) { \ + PD_THROW( \ + "Attribute cast error in custom operator InferShapeFn. " \ + "Expected " #attr_type \ + " value. InferShapeFn's attribute list must be exactly same as " \ + "Forward " \ + "KernelFn's attribute list except std::vector " \ + "attribute."); \ + } \ + } \ } template @@ -245,10 +275,10 @@ template struct InferShapeFuncImpl { static Return InferShape( const std::vector>& input_shapes, - const std::vector>>& vec_input_shapes) { - return InferShapeCallHelper>::template InferShape<0, - 0>( - input_shapes, vec_input_shapes); + const std::vector>>& vec_input_shapes, + const std::vector& attrs) { + return InferShapeCallHelper>::template InferShape< + 0, 0, 0>(input_shapes, vec_input_shapes, attrs); } private: @@ -265,14 +295,26 @@ struct InferShapeFuncImpl { PD_SPECIALIZE_InferShapeCallHelper_FOR_SHAPES( std::vector>); + PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(const bool&); + PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(const int&); + PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(const float&); + PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(const int64_t&); + PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(const std::string&); + PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(const std::vector&); + PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(const std::vector&); + PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(const std::vector&); + // NOTE(chenweihang): InferShape can't support std::vector attr type, + // because the input type is std::vector, only can use one rule to + // parse std::vector parameter + // end: base template template struct InferShapeCallHelper> { - template + template static Return InferShape( const std::vector>& input_shapes, const std::vector>>& vec_input_shapes, - const Args&... args) { + const std::vector& attrs, const Args&... args) { return impl_fn(args...); } }; diff --git a/paddle/fluid/framework/custom_operator.cc b/paddle/fluid/framework/custom_operator.cc index 69a9be603e677d9abc52bca58bfb357285e3613f..1ebb8998c854ebe662a51be7d833a3312c1d6876 100644 --- a/paddle/fluid/framework/custom_operator.cc +++ b/paddle/fluid/framework/custom_operator.cc @@ -178,7 +178,7 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx, "Unsupported `%s` type value as custom attribute now. " "Supported data types include `bool`, `int`, `float`, " "`int64_t`, `std::string`, `std::vector`, " - "`std::vector`, `std::vector, " + "`std::vector`, `std::vector`, " "`std::vector`, Please check whether " "the attribute data type and data type string are matched.", attr_type_str)); @@ -327,7 +327,7 @@ class CustomOpMaker : public OpProtoAndCheckerMaker { "Unsupported `%s` type value as custom attribute now. " "Supported data types include `bool`, `int`, `float`, " "`int64_t`, `std::string`, `std::vector`, " - "`std::vector`, `std::vector, " + "`std::vector`, `std::vector`, " "`std::vector`, Please check whether " "the attribute data type and data type string are matched.", attr_type_str)); @@ -581,7 +581,7 @@ void RegisterOperatorWithMetaInfo( ctx->ShareDim(op_inputs[0], op_outputs[0]); }; } else { - info.infer_shape_ = [op_inputs, op_outputs, + info.infer_shape_ = [op_inputs, op_outputs, op_attrs, infer_shape_func](InferShapeContext* ctx) { std::vector> input_shapes; std::vector>> vec_input_shapes; @@ -606,8 +606,50 @@ void RegisterOperatorWithMetaInfo( } } + std::vector custom_attrs; + for (auto& attr_str : op_attrs) { + auto attr_name_and_type = detail::ParseAttrStr(attr_str); + auto attr_name = attr_name_and_type[0]; + auto attr_type_str = attr_name_and_type[1]; + if (attr_type_str == "bool") { + custom_attrs.emplace_back(ctx->Attrs().Get(attr_name)); + } else if (attr_type_str == "int") { + custom_attrs.emplace_back(ctx->Attrs().Get(attr_name)); + } else if (attr_type_str == "float") { + custom_attrs.emplace_back(ctx->Attrs().Get(attr_name)); + } else if (attr_type_str == "int64_t") { + custom_attrs.emplace_back(ctx->Attrs().Get(attr_name)); + } else if (attr_type_str == "std::string") { + custom_attrs.emplace_back(ctx->Attrs().Get(attr_name)); + } else if (attr_type_str == "std::vector") { + custom_attrs.emplace_back( + ctx->Attrs().Get>(attr_name)); + } else if (attr_type_str == "std::vector") { + custom_attrs.emplace_back( + ctx->Attrs().Get>(attr_name)); + } else if (attr_type_str == "std::vector") { + // NOTE(chenweihang): InferShape can't support std::vector + // attr type, because the input type is std::vector, only + // can use one rule to parse std::vector parameter + continue; + } else if (attr_type_str == "std::vector") { + custom_attrs.emplace_back( + ctx->Attrs().Get>(attr_name)); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported `%s` type value as custom attribute now. " + "Supported data types include `bool`, `int`, `float`, " + "`int64_t`, `std::string`, `std::vector`, " + "`std::vector`, `std::vector`, " + "Please check whether the attribute data type and " + "data type string are matched.", + attr_type_str)); + } + } + VLOG(1) << "Custom Operator: InferShape - calc output ddim."; - auto output_shapes = infer_shape_func(input_shapes, vec_input_shapes); + auto output_shapes = + infer_shape_func(input_shapes, vec_input_shapes, custom_attrs); VLOG(1) << "Custom Operator: InferShape - set output ddim."; for (size_t i = 0; i < op_outputs.size(); ++i) { diff --git a/python/paddle/fluid/tests/custom_op/custom_concat_op.cc b/python/paddle/fluid/tests/custom_op/custom_concat_op.cc index 2d8d0ccb88f80e86323f058e80aa7a2417c05815..a01e01f2bc59239e5ce6aec4a1d9ea9a27bc00d1 100644 --- a/python/paddle/fluid/tests/custom_op/custom_concat_op.cc +++ b/python/paddle/fluid/tests/custom_op/custom_concat_op.cc @@ -144,3 +144,93 @@ PD_BUILD_GRAD_OP(custom_concat) .Inputs({paddle::Vec("X"), paddle::Grad("Out"), "Axis"}) .Outputs({paddle::Grad(paddle::Vec("X"))}) .SetKernelFn(PD_KERNEL(ConcatBackwardDynamicAxis)); + +std::vector ConcatForwardStaticAxis( + const std::vector& inputs, const int64_t& axis) { + // check inputs + PD_CHECK(inputs.size() >= 1, "No Tensor need to be concat."); + for (auto& t : inputs) { + CHECK_INPUT(t); + } + + // compute output shape + int64_t rank = static_cast(inputs[0].shape().size()); + auto final_axis = ComputeAxis(axis, rank); + std::vector> in_shapes; + for (auto& t : inputs) { + in_shapes.emplace_back(t.shape()); + } + auto out_shape = ComputeOutShape(in_shapes, final_axis); + + // create output + auto out = paddle::Tensor(paddle::PlaceType::kCPU); + out.reshape(out_shape); + + // calc + PD_DISPATCH_FLOATING_AND_INTEGRAL_TYPES( + inputs[0].type(), "ConcatCpuKernel", ([&] { + ConcatCpuKernel(inputs, &out, final_axis); + })); + + return {out}; +} + +std::vector ConcatBackwardStaticAxis( + const std::vector& inputs, + const paddle::Tensor& grad_out, + const int64_t& axis) { + // check input + PD_CHECK(inputs.size() >= 1, "No Tensor need to be concat."); + for (auto& t : inputs) { + CHECK_INPUT(t); + } + CHECK_INPUT(grad_out); + + // compate axis + int64_t rank = static_cast(inputs[0].shape().size()); + auto final_axis = ComputeAxis(axis, rank); + + // create outputs + std::vector grad_inputs; + for (auto& t : inputs) { + auto grad = paddle::Tensor(paddle::PlaceType::kCPU); + grad.reshape(t.shape()); + grad_inputs.emplace_back(grad); + } + + // calc + PD_DISPATCH_FLOATING_AND_INTEGRAL_TYPES( + grad_out.type(), "SplitCpuKernel", ([&] { + SplitCpuKernel(grad_out, inputs, &grad_inputs, final_axis); + })); + + return grad_inputs; +} + +std::vector> ConcatInferShapeStaticAxis( + const std::vector>& input_shapes, + const int64_t& axis) { + int64_t rank = static_cast(input_shapes[0].size()); + auto final_axis = ComputeAxis(axis, rank); + auto out_shape = ComputeOutShape(input_shapes, final_axis); + return {out_shape}; +} + +std::vector ConcatInferDtypeStaticAxis( + const std::vector& input_dtypes) { + return {input_dtypes[0]}; +} + +PD_BUILD_OP(custom_concat_with_attr) + .Inputs({paddle::Vec("X")}) + .Outputs({"Out"}) + .Attrs({"axis: int64_t"}) + .SetKernelFn(PD_KERNEL(ConcatForwardStaticAxis)) + .SetInferShapeFn(PD_INFER_SHAPE(ConcatInferShapeStaticAxis)) + .SetInferDtypeFn(PD_INFER_DTYPE(ConcatInferDtypeStaticAxis)); + +PD_BUILD_GRAD_OP(custom_concat_with_attr) + .Inputs({paddle::Vec("X"), paddle::Grad("Out")}) + .Outputs({paddle::Grad(paddle::Vec("X"))}) + .Attrs({"axis: int64_t"}) + .SetKernelFn(PD_KERNEL(ConcatBackwardStaticAxis)); diff --git a/python/paddle/fluid/tests/custom_op/test_custom_concat.py b/python/paddle/fluid/tests/custom_op/test_custom_concat.py index 4086224cd7b8d2d5d86f9e67e2fab093d666232a..ea41126c1c471d4026c4940eafeebf1141ce2b91 100644 --- a/python/paddle/fluid/tests/custom_op/test_custom_concat.py +++ b/python/paddle/fluid/tests/custom_op/test_custom_concat.py @@ -45,14 +45,16 @@ custom_ops = load( verbose=True) -def concat_dynamic(func, device, dtype, np_inputs, axis_v): - paddle.set_device(device) +def concat_dynamic(func, dtype, np_inputs, axis_v, with_attr=False): + paddle.set_device("cpu") inputs = [ paddle.to_tensor( - x, dtype=dtype, place=device, stop_gradient=False) - for x in np_inputs + x, dtype=dtype, stop_gradient=False) for x in np_inputs ] - axis = paddle.full(shape=[1], dtype='int64', fill_value=axis_v) + if with_attr: + axis = axis_v + else: + axis = paddle.full(shape=[1], dtype='int64', fill_value=axis_v) out = func(inputs, axis) out.stop_gradient = False out.backward() @@ -60,14 +62,17 @@ def concat_dynamic(func, device, dtype, np_inputs, axis_v): return out.numpy(), grad_inputs -def concat_static(func, device, dtype, np_inputs, axis_v): +def concat_static(func, dtype, np_inputs, axis_v, with_attr=False): paddle.enable_static() - paddle.set_device(device) + paddle.set_device("cpu") with static.scope_guard(static.Scope()): with static.program_guard(static.Program()): x1 = static.data(name="x1", shape=[2, 3], dtype=dtype) x2 = static.data(name="x2", shape=[2, 3], dtype=dtype) - axis = paddle.full(shape=[1], dtype='int64', fill_value=axis_v) + if with_attr: + axis = axis_v + else: + axis = paddle.full(shape=[1], dtype='int64', fill_value=axis_v) x1.stop_gradient = False x2.stop_gradient = False out = func([x1, x2], axis) @@ -78,13 +83,20 @@ def concat_static(func, device, dtype, np_inputs, axis_v): exe = static.Executor() exe.run(static.default_startup_program()) - out_v, x1_grad_v, x2_grad_v = exe.run( - static.default_main_program(), - feed={ + if with_attr: + feed_dict = { + "x1": np_inputs[0].astype(dtype), + "x2": np_inputs[1].astype(dtype) + } + else: + feed_dict = { "x1": np_inputs[0].astype(dtype), "x2": np_inputs[1].astype(dtype), "axis": axis - }, + } + out_v, x1_grad_v, x2_grad_v = exe.run( + static.default_main_program(), + feed=feed_dict, fetch_list=[out.name, x1.name + "@GRAD", x2.name + "@GRAD"]) paddle.disable_static() return out_v, x1_grad_v, x2_grad_v @@ -93,55 +105,67 @@ def concat_static(func, device, dtype, np_inputs, axis_v): class TestCustomConcatDynamicAxisJit(unittest.TestCase): def setUp(self): self.dtypes = ['float32', 'float64', 'int32', 'int64'] - self.devices = ['cpu'] self.np_inputs = [ np.array([[1, 2, 3], [4, 5, 6]]), np.array([[11, 12, 13], [14, 15, 16]]) ] self.axises = [0, 1] + def check_output(self, out, pd_out, name): + self.assertTrue( + np.array_equal(out, pd_out), + "custom op {}: {},\n paddle api {}: {}".format(name, out, name, + pd_out)) + def test_dynamic(self): - for device in self.devices: - for dtype in self.dtypes: - for axis in self.axises: - out, grad_inputs = concat_dynamic(custom_ops.custom_concat, - device, dtype, - self.np_inputs, axis) - pd_out, pd_grad_inputs = concat_dynamic( - paddle.concat, device, dtype, self.np_inputs, axis) - - self.assertTrue( - np.array_equal(out, pd_out), - "custom op out: {},\n paddle api out: {}".format( - out, pd_out)) - for x_grad, pd_x_grad in zip(grad_inputs, pd_grad_inputs): - self.assertTrue( - np.array_equal(x_grad, pd_x_grad), - "custom op x grad: {},\n paddle api x grad: {}". - format(x_grad, pd_x_grad)) + for dtype in self.dtypes: + for axis in self.axises: + out, grad_inputs = concat_dynamic(custom_ops.custom_concat, + dtype, self.np_inputs, axis) + pd_out, pd_grad_inputs = concat_dynamic(paddle.concat, dtype, + self.np_inputs, axis) + + self.check_output(out, pd_out, "out") + for x_grad, pd_x_grad in zip(grad_inputs, pd_grad_inputs): + self.check_output(x_grad, pd_x_grad, "x_grad") def test_static(self): - for device in self.devices: - for dtype in self.dtypes: - for axis in self.axises: - out, x1_grad, x2_grad = concat_static( - custom_ops.custom_concat, device, dtype, self.np_inputs, - axis) - pd_out, pd_x1_grad, pd_x2_grad = concat_static( - paddle.concat, device, dtype, self.np_inputs, axis) - - self.assertTrue( - np.array_equal(out, pd_out), - "custom op out: {},\n paddle api out: {}".format( - out, pd_out)) - self.assertTrue( - np.array_equal(x1_grad, pd_x1_grad), - "custom op x1_grad: {},\n paddle api x1_grad: {}". - format(x1_grad, pd_x1_grad)) - self.assertTrue( - np.array_equal(x2_grad, pd_x2_grad), - "custom op x2_grad: {},\n paddle api x2_grad: {}". - format(x2_grad, pd_x2_grad)) + for dtype in self.dtypes: + for axis in self.axises: + out, x1_grad, x2_grad = concat_static( + custom_ops.custom_concat, dtype, self.np_inputs, axis) + pd_out, pd_x1_grad, pd_x2_grad = concat_static( + paddle.concat, dtype, self.np_inputs, axis) + + self.check_output(out, pd_out, "out") + self.check_output(x1_grad, pd_x1_grad, "x1_grad") + self.check_output(x2_grad, pd_x2_grad, "x2_grad") + + def test_dynamic_with_attr(self): + for dtype in self.dtypes: + for axis in self.axises: + out, grad_inputs = concat_dynamic( + custom_ops.custom_concat_with_attr, dtype, self.np_inputs, + axis, True) + pd_out, pd_grad_inputs = concat_dynamic( + paddle.concat, dtype, self.np_inputs, axis, True) + + self.check_output(out, pd_out, "out") + for x_grad, pd_x_grad in zip(grad_inputs, pd_grad_inputs): + self.check_output(x_grad, pd_x_grad, "x_grad") + + def test_static_with_attr(self): + for dtype in self.dtypes: + for axis in self.axises: + out, x1_grad, x2_grad = concat_static( + custom_ops.custom_concat_with_attr, dtype, self.np_inputs, + axis, True) + pd_out, pd_x1_grad, pd_x2_grad = concat_static( + paddle.concat, dtype, self.np_inputs, axis, True) + + self.check_output(out, pd_out, "out") + self.check_output(x1_grad, pd_x1_grad, "x1_grad") + self.check_output(x2_grad, pd_x2_grad, "x2_grad") if __name__ == "__main__":