未验证 提交 e429deb0 编写于 作者: C Chen Weihang 提交者: GitHub

[CustomOp] Support attribute in infershape function (#31713)

* support attribute in infershape

* polish details
上级 a4a2b77d
...@@ -204,38 +204,68 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> { ...@@ -204,38 +204,68 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> {
// Record Op infershape core function // Record Op infershape core function
using InferShapeFunc = std::vector<std::vector<int64_t>> (*)( using InferShapeFunc = std::vector<std::vector<int64_t>> (*)(
const std::vector<std::vector<int64_t>>& input_shapes, const std::vector<std::vector<int64_t>>& input_shapes,
const std::vector<std::vector<std::vector<int64_t>>>& vec_input_shapes); const std::vector<std::vector<std::vector<int64_t>>>& vec_input_shapes,
const std::vector<boost::any>& attrs);
#define PD_SPECIALIZE_InferShapeCallHelper_FOR_SHAPE(input_type) \ #define PD_SPECIALIZE_InferShapeCallHelper_FOR_SHAPE(input_type) \
template <typename... Tail> \ template <typename... Tail> \
struct InferShapeCallHelper<input_type, Tail...> { \ struct InferShapeCallHelper<input_type, Tail...> { \
template <int in_idx, int vec_in_idx, typename... PreviousArgs> \ template <int in_idx, int vec_in_idx, int attr_idx, \
static Return InferShape( \ typename... PreviousArgs> \
const std::vector<std::vector<int64_t>>& input_shapes, \ static Return InferShape( \
const std::vector<std::vector<std::vector<int64_t>>>& \ const std::vector<std::vector<int64_t>>& input_shapes, \
vec_input_shapes, \ const std::vector<std::vector<std::vector<int64_t>>>& \
const PreviousArgs&... pargs) { \ vec_input_shapes, \
input_type arg = input_shapes[in_idx]; \ const std::vector<boost::any>& attrs, const PreviousArgs&... pargs) { \
return InferShapeCallHelper<Tail...>::template InferShape<in_idx + 1, \ input_type arg = input_shapes[in_idx]; \
vec_in_idx>( \ return InferShapeCallHelper<Tail...>::template InferShape< \
input_shapes, vec_input_shapes, pargs..., arg); \ in_idx + 1, vec_in_idx, attr_idx>(input_shapes, vec_input_shapes, \
} \ attrs, pargs..., arg); \
} \
} }
#define PD_SPECIALIZE_InferShapeCallHelper_FOR_SHAPES(input_type) \ #define PD_SPECIALIZE_InferShapeCallHelper_FOR_SHAPES(input_type) \
template <typename... Tail> \ template <typename... Tail> \
struct InferShapeCallHelper<input_type, Tail...> { \ struct InferShapeCallHelper<input_type, Tail...> { \
template <int in_idx, int vec_in_idx, typename... PreviousArgs> \ template <int in_idx, int vec_in_idx, int attr_idx, \
static Return InferShape( \ typename... PreviousArgs> \
const std::vector<std::vector<int64_t>>& input_shapes, \ static Return InferShape( \
const std::vector<std::vector<std::vector<int64_t>>>& \ const std::vector<std::vector<int64_t>>& input_shapes, \
vec_input_shapes, \ const std::vector<std::vector<std::vector<int64_t>>>& \
const PreviousArgs&... pargs) { \ vec_input_shapes, \
input_type arg = vec_input_shapes[vec_in_idx]; \ const std::vector<boost::any>& attrs, const PreviousArgs&... pargs) { \
return InferShapeCallHelper<Tail...>::template InferShape< \ input_type arg = vec_input_shapes[vec_in_idx]; \
in_idx, vec_in_idx + 1>(input_shapes, vec_input_shapes, pargs..., \ return InferShapeCallHelper<Tail...>::template InferShape< \
arg); \ in_idx, vec_in_idx + 1, attr_idx>(input_shapes, vec_input_shapes, \
} \ attrs, pargs..., arg); \
} \
}
#define PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(attr_type) \
template <typename... Tail> \
struct InferShapeCallHelper<attr_type, Tail...> { \
template <int in_idx, int vec_in_idx, int attr_idx, \
typename... PreviousArgs> \
static Return InferShape( \
const std::vector<std::vector<int64_t>>& input_shapes, \
const std::vector<std::vector<std::vector<int64_t>>>& \
vec_input_shapes, \
const std::vector<boost::any>& attrs, const PreviousArgs&... pargs) { \
try { \
attr_type arg = boost::any_cast<attr_type>(attrs[attr_idx]); \
return InferShapeCallHelper<Tail...>::template InferShape< \
in_idx, vec_in_idx, attr_idx + 1>(input_shapes, vec_input_shapes, \
attrs, pargs..., arg); \
} catch (boost::bad_any_cast&) { \
PD_THROW( \
"Attribute cast error in custom operator InferShapeFn. " \
"Expected " #attr_type \
" value. InferShapeFn's attribute list must be exactly same as " \
"Forward " \
"KernelFn's attribute list except std::vector<int64_t> " \
"attribute."); \
} \
} \
} }
template <typename F, F f> template <typename F, F f>
...@@ -245,10 +275,10 @@ template <typename Return, typename... Args, Return (*impl_fn)(Args...)> ...@@ -245,10 +275,10 @@ template <typename Return, typename... Args, Return (*impl_fn)(Args...)>
struct InferShapeFuncImpl<Return (*)(Args...), impl_fn> { struct InferShapeFuncImpl<Return (*)(Args...), impl_fn> {
static Return InferShape( static Return InferShape(
const std::vector<std::vector<int64_t>>& input_shapes, const std::vector<std::vector<int64_t>>& input_shapes,
const std::vector<std::vector<std::vector<int64_t>>>& vec_input_shapes) { const std::vector<std::vector<std::vector<int64_t>>>& vec_input_shapes,
return InferShapeCallHelper<Args..., TypeTag<int>>::template InferShape<0, const std::vector<boost::any>& attrs) {
0>( return InferShapeCallHelper<Args..., TypeTag<int>>::template InferShape<
input_shapes, vec_input_shapes); 0, 0, 0>(input_shapes, vec_input_shapes, attrs);
} }
private: private:
...@@ -265,14 +295,26 @@ struct InferShapeFuncImpl<Return (*)(Args...), impl_fn> { ...@@ -265,14 +295,26 @@ struct InferShapeFuncImpl<Return (*)(Args...), impl_fn> {
PD_SPECIALIZE_InferShapeCallHelper_FOR_SHAPES( PD_SPECIALIZE_InferShapeCallHelper_FOR_SHAPES(
std::vector<std::vector<int64_t>>); std::vector<std::vector<int64_t>>);
PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(const bool&);
PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(const int&);
PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(const float&);
PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(const int64_t&);
PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(const std::string&);
PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(const std::vector<int>&);
PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(const std::vector<float>&);
PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(const std::vector<std::string>&);
// NOTE(chenweihang): InferShape can't support std::vector<int64_t> attr type,
// because the input type is std::vector<int64_t>, only can use one rule to
// parse std::vector<int64_t> parameter
// end: base template // end: base template
template <typename T> template <typename T>
struct InferShapeCallHelper<TypeTag<T>> { struct InferShapeCallHelper<TypeTag<T>> {
template <int in_idx, int vec_in_idx> template <int in_idx, int vec_in_idx, int attr_idx>
static Return InferShape( static Return InferShape(
const std::vector<std::vector<int64_t>>& input_shapes, const std::vector<std::vector<int64_t>>& input_shapes,
const std::vector<std::vector<std::vector<int64_t>>>& vec_input_shapes, const std::vector<std::vector<std::vector<int64_t>>>& vec_input_shapes,
const Args&... args) { const std::vector<boost::any>& attrs, const Args&... args) {
return impl_fn(args...); return impl_fn(args...);
} }
}; };
......
...@@ -178,7 +178,7 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx, ...@@ -178,7 +178,7 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx,
"Unsupported `%s` type value as custom attribute now. " "Unsupported `%s` type value as custom attribute now. "
"Supported data types include `bool`, `int`, `float`, " "Supported data types include `bool`, `int`, `float`, "
"`int64_t`, `std::string`, `std::vector<int>`, " "`int64_t`, `std::string`, `std::vector<int>`, "
"`std::vector<float>`, `std::vector<int64_t>, " "`std::vector<float>`, `std::vector<int64_t>`, "
"`std::vector<std::string>`, Please check whether " "`std::vector<std::string>`, Please check whether "
"the attribute data type and data type string are matched.", "the attribute data type and data type string are matched.",
attr_type_str)); attr_type_str));
...@@ -327,7 +327,7 @@ class CustomOpMaker : public OpProtoAndCheckerMaker { ...@@ -327,7 +327,7 @@ class CustomOpMaker : public OpProtoAndCheckerMaker {
"Unsupported `%s` type value as custom attribute now. " "Unsupported `%s` type value as custom attribute now. "
"Supported data types include `bool`, `int`, `float`, " "Supported data types include `bool`, `int`, `float`, "
"`int64_t`, `std::string`, `std::vector<int>`, " "`int64_t`, `std::string`, `std::vector<int>`, "
"`std::vector<float>`, `std::vector<int64_t>, " "`std::vector<float>`, `std::vector<int64_t>`, "
"`std::vector<std::string>`, Please check whether " "`std::vector<std::string>`, Please check whether "
"the attribute data type and data type string are matched.", "the attribute data type and data type string are matched.",
attr_type_str)); attr_type_str));
...@@ -581,7 +581,7 @@ void RegisterOperatorWithMetaInfo( ...@@ -581,7 +581,7 @@ void RegisterOperatorWithMetaInfo(
ctx->ShareDim(op_inputs[0], op_outputs[0]); ctx->ShareDim(op_inputs[0], op_outputs[0]);
}; };
} else { } else {
info.infer_shape_ = [op_inputs, op_outputs, info.infer_shape_ = [op_inputs, op_outputs, op_attrs,
infer_shape_func](InferShapeContext* ctx) { infer_shape_func](InferShapeContext* ctx) {
std::vector<std::vector<int64_t>> input_shapes; std::vector<std::vector<int64_t>> input_shapes;
std::vector<std::vector<std::vector<int64_t>>> vec_input_shapes; std::vector<std::vector<std::vector<int64_t>>> vec_input_shapes;
...@@ -606,8 +606,50 @@ void RegisterOperatorWithMetaInfo( ...@@ -606,8 +606,50 @@ void RegisterOperatorWithMetaInfo(
} }
} }
std::vector<boost::any> custom_attrs;
for (auto& attr_str : op_attrs) {
auto attr_name_and_type = detail::ParseAttrStr(attr_str);
auto attr_name = attr_name_and_type[0];
auto attr_type_str = attr_name_and_type[1];
if (attr_type_str == "bool") {
custom_attrs.emplace_back(ctx->Attrs().Get<bool>(attr_name));
} else if (attr_type_str == "int") {
custom_attrs.emplace_back(ctx->Attrs().Get<int>(attr_name));
} else if (attr_type_str == "float") {
custom_attrs.emplace_back(ctx->Attrs().Get<float>(attr_name));
} else if (attr_type_str == "int64_t") {
custom_attrs.emplace_back(ctx->Attrs().Get<int64_t>(attr_name));
} else if (attr_type_str == "std::string") {
custom_attrs.emplace_back(ctx->Attrs().Get<std::string>(attr_name));
} else if (attr_type_str == "std::vector<int>") {
custom_attrs.emplace_back(
ctx->Attrs().Get<std::vector<int>>(attr_name));
} else if (attr_type_str == "std::vector<float>") {
custom_attrs.emplace_back(
ctx->Attrs().Get<std::vector<float>>(attr_name));
} else if (attr_type_str == "std::vector<int64_t>") {
// NOTE(chenweihang): InferShape can't support std::vector<int64_t>
// attr type, because the input type is std::vector<int64_t>, only
// can use one rule to parse std::vector<int64_t> parameter
continue;
} else if (attr_type_str == "std::vector<std::string>") {
custom_attrs.emplace_back(
ctx->Attrs().Get<std::vector<std::string>>(attr_name));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported `%s` type value as custom attribute now. "
"Supported data types include `bool`, `int`, `float`, "
"`int64_t`, `std::string`, `std::vector<int>`, "
"`std::vector<float>`, `std::vector<std::string>`, "
"Please check whether the attribute data type and "
"data type string are matched.",
attr_type_str));
}
}
VLOG(1) << "Custom Operator: InferShape - calc output ddim."; VLOG(1) << "Custom Operator: InferShape - calc output ddim.";
auto output_shapes = infer_shape_func(input_shapes, vec_input_shapes); auto output_shapes =
infer_shape_func(input_shapes, vec_input_shapes, custom_attrs);
VLOG(1) << "Custom Operator: InferShape - set output ddim."; VLOG(1) << "Custom Operator: InferShape - set output ddim.";
for (size_t i = 0; i < op_outputs.size(); ++i) { for (size_t i = 0; i < op_outputs.size(); ++i) {
......
...@@ -144,3 +144,93 @@ PD_BUILD_GRAD_OP(custom_concat) ...@@ -144,3 +144,93 @@ PD_BUILD_GRAD_OP(custom_concat)
.Inputs({paddle::Vec("X"), paddle::Grad("Out"), "Axis"}) .Inputs({paddle::Vec("X"), paddle::Grad("Out"), "Axis"})
.Outputs({paddle::Grad(paddle::Vec("X"))}) .Outputs({paddle::Grad(paddle::Vec("X"))})
.SetKernelFn(PD_KERNEL(ConcatBackwardDynamicAxis)); .SetKernelFn(PD_KERNEL(ConcatBackwardDynamicAxis));
std::vector<paddle::Tensor> ConcatForwardStaticAxis(
const std::vector<paddle::Tensor>& inputs, const int64_t& axis) {
// check inputs
PD_CHECK(inputs.size() >= 1, "No Tensor need to be concat.");
for (auto& t : inputs) {
CHECK_INPUT(t);
}
// compute output shape
int64_t rank = static_cast<int64_t>(inputs[0].shape().size());
auto final_axis = ComputeAxis(axis, rank);
std::vector<std::vector<int64_t>> in_shapes;
for (auto& t : inputs) {
in_shapes.emplace_back(t.shape());
}
auto out_shape = ComputeOutShape(in_shapes, final_axis);
// create output
auto out = paddle::Tensor(paddle::PlaceType::kCPU);
out.reshape(out_shape);
// calc
PD_DISPATCH_FLOATING_AND_INTEGRAL_TYPES(
inputs[0].type(), "ConcatCpuKernel", ([&] {
ConcatCpuKernel<data_t>(inputs, &out, final_axis);
}));
return {out};
}
std::vector<paddle::Tensor> ConcatBackwardStaticAxis(
const std::vector<paddle::Tensor>& inputs,
const paddle::Tensor& grad_out,
const int64_t& axis) {
// check input
PD_CHECK(inputs.size() >= 1, "No Tensor need to be concat.");
for (auto& t : inputs) {
CHECK_INPUT(t);
}
CHECK_INPUT(grad_out);
// compate axis
int64_t rank = static_cast<int64_t>(inputs[0].shape().size());
auto final_axis = ComputeAxis(axis, rank);
// create outputs
std::vector<paddle::Tensor> grad_inputs;
for (auto& t : inputs) {
auto grad = paddle::Tensor(paddle::PlaceType::kCPU);
grad.reshape(t.shape());
grad_inputs.emplace_back(grad);
}
// calc
PD_DISPATCH_FLOATING_AND_INTEGRAL_TYPES(
grad_out.type(), "SplitCpuKernel", ([&] {
SplitCpuKernel<data_t>(grad_out, inputs, &grad_inputs, final_axis);
}));
return grad_inputs;
}
std::vector<std::vector<int64_t>> ConcatInferShapeStaticAxis(
const std::vector<std::vector<int64_t>>& input_shapes,
const int64_t& axis) {
int64_t rank = static_cast<int64_t>(input_shapes[0].size());
auto final_axis = ComputeAxis(axis, rank);
auto out_shape = ComputeOutShape(input_shapes, final_axis);
return {out_shape};
}
std::vector<paddle::DataType> ConcatInferDtypeStaticAxis(
const std::vector<paddle::DataType>& input_dtypes) {
return {input_dtypes[0]};
}
PD_BUILD_OP(custom_concat_with_attr)
.Inputs({paddle::Vec("X")})
.Outputs({"Out"})
.Attrs({"axis: int64_t"})
.SetKernelFn(PD_KERNEL(ConcatForwardStaticAxis))
.SetInferShapeFn(PD_INFER_SHAPE(ConcatInferShapeStaticAxis))
.SetInferDtypeFn(PD_INFER_DTYPE(ConcatInferDtypeStaticAxis));
PD_BUILD_GRAD_OP(custom_concat_with_attr)
.Inputs({paddle::Vec("X"), paddle::Grad("Out")})
.Outputs({paddle::Grad(paddle::Vec("X"))})
.Attrs({"axis: int64_t"})
.SetKernelFn(PD_KERNEL(ConcatBackwardStaticAxis));
...@@ -45,14 +45,16 @@ custom_ops = load( ...@@ -45,14 +45,16 @@ custom_ops = load(
verbose=True) verbose=True)
def concat_dynamic(func, device, dtype, np_inputs, axis_v): def concat_dynamic(func, dtype, np_inputs, axis_v, with_attr=False):
paddle.set_device(device) paddle.set_device("cpu")
inputs = [ inputs = [
paddle.to_tensor( paddle.to_tensor(
x, dtype=dtype, place=device, stop_gradient=False) x, dtype=dtype, stop_gradient=False) for x in np_inputs
for x in np_inputs
] ]
axis = paddle.full(shape=[1], dtype='int64', fill_value=axis_v) if with_attr:
axis = axis_v
else:
axis = paddle.full(shape=[1], dtype='int64', fill_value=axis_v)
out = func(inputs, axis) out = func(inputs, axis)
out.stop_gradient = False out.stop_gradient = False
out.backward() out.backward()
...@@ -60,14 +62,17 @@ def concat_dynamic(func, device, dtype, np_inputs, axis_v): ...@@ -60,14 +62,17 @@ def concat_dynamic(func, device, dtype, np_inputs, axis_v):
return out.numpy(), grad_inputs return out.numpy(), grad_inputs
def concat_static(func, device, dtype, np_inputs, axis_v): def concat_static(func, dtype, np_inputs, axis_v, with_attr=False):
paddle.enable_static() paddle.enable_static()
paddle.set_device(device) paddle.set_device("cpu")
with static.scope_guard(static.Scope()): with static.scope_guard(static.Scope()):
with static.program_guard(static.Program()): with static.program_guard(static.Program()):
x1 = static.data(name="x1", shape=[2, 3], dtype=dtype) x1 = static.data(name="x1", shape=[2, 3], dtype=dtype)
x2 = static.data(name="x2", shape=[2, 3], dtype=dtype) x2 = static.data(name="x2", shape=[2, 3], dtype=dtype)
axis = paddle.full(shape=[1], dtype='int64', fill_value=axis_v) if with_attr:
axis = axis_v
else:
axis = paddle.full(shape=[1], dtype='int64', fill_value=axis_v)
x1.stop_gradient = False x1.stop_gradient = False
x2.stop_gradient = False x2.stop_gradient = False
out = func([x1, x2], axis) out = func([x1, x2], axis)
...@@ -78,13 +83,20 @@ def concat_static(func, device, dtype, np_inputs, axis_v): ...@@ -78,13 +83,20 @@ def concat_static(func, device, dtype, np_inputs, axis_v):
exe = static.Executor() exe = static.Executor()
exe.run(static.default_startup_program()) exe.run(static.default_startup_program())
out_v, x1_grad_v, x2_grad_v = exe.run( if with_attr:
static.default_main_program(), feed_dict = {
feed={ "x1": np_inputs[0].astype(dtype),
"x2": np_inputs[1].astype(dtype)
}
else:
feed_dict = {
"x1": np_inputs[0].astype(dtype), "x1": np_inputs[0].astype(dtype),
"x2": np_inputs[1].astype(dtype), "x2": np_inputs[1].astype(dtype),
"axis": axis "axis": axis
}, }
out_v, x1_grad_v, x2_grad_v = exe.run(
static.default_main_program(),
feed=feed_dict,
fetch_list=[out.name, x1.name + "@GRAD", x2.name + "@GRAD"]) fetch_list=[out.name, x1.name + "@GRAD", x2.name + "@GRAD"])
paddle.disable_static() paddle.disable_static()
return out_v, x1_grad_v, x2_grad_v return out_v, x1_grad_v, x2_grad_v
...@@ -93,55 +105,67 @@ def concat_static(func, device, dtype, np_inputs, axis_v): ...@@ -93,55 +105,67 @@ def concat_static(func, device, dtype, np_inputs, axis_v):
class TestCustomConcatDynamicAxisJit(unittest.TestCase): class TestCustomConcatDynamicAxisJit(unittest.TestCase):
def setUp(self): def setUp(self):
self.dtypes = ['float32', 'float64', 'int32', 'int64'] self.dtypes = ['float32', 'float64', 'int32', 'int64']
self.devices = ['cpu']
self.np_inputs = [ self.np_inputs = [
np.array([[1, 2, 3], [4, 5, 6]]), np.array([[1, 2, 3], [4, 5, 6]]),
np.array([[11, 12, 13], [14, 15, 16]]) np.array([[11, 12, 13], [14, 15, 16]])
] ]
self.axises = [0, 1] self.axises = [0, 1]
def check_output(self, out, pd_out, name):
self.assertTrue(
np.array_equal(out, pd_out),
"custom op {}: {},\n paddle api {}: {}".format(name, out, name,
pd_out))
def test_dynamic(self): def test_dynamic(self):
for device in self.devices: for dtype in self.dtypes:
for dtype in self.dtypes: for axis in self.axises:
for axis in self.axises: out, grad_inputs = concat_dynamic(custom_ops.custom_concat,
out, grad_inputs = concat_dynamic(custom_ops.custom_concat, dtype, self.np_inputs, axis)
device, dtype, pd_out, pd_grad_inputs = concat_dynamic(paddle.concat, dtype,
self.np_inputs, axis) self.np_inputs, axis)
pd_out, pd_grad_inputs = concat_dynamic(
paddle.concat, device, dtype, self.np_inputs, axis) self.check_output(out, pd_out, "out")
for x_grad, pd_x_grad in zip(grad_inputs, pd_grad_inputs):
self.assertTrue( self.check_output(x_grad, pd_x_grad, "x_grad")
np.array_equal(out, pd_out),
"custom op out: {},\n paddle api out: {}".format(
out, pd_out))
for x_grad, pd_x_grad in zip(grad_inputs, pd_grad_inputs):
self.assertTrue(
np.array_equal(x_grad, pd_x_grad),
"custom op x grad: {},\n paddle api x grad: {}".
format(x_grad, pd_x_grad))
def test_static(self): def test_static(self):
for device in self.devices: for dtype in self.dtypes:
for dtype in self.dtypes: for axis in self.axises:
for axis in self.axises: out, x1_grad, x2_grad = concat_static(
out, x1_grad, x2_grad = concat_static( custom_ops.custom_concat, dtype, self.np_inputs, axis)
custom_ops.custom_concat, device, dtype, self.np_inputs, pd_out, pd_x1_grad, pd_x2_grad = concat_static(
axis) paddle.concat, dtype, self.np_inputs, axis)
pd_out, pd_x1_grad, pd_x2_grad = concat_static(
paddle.concat, device, dtype, self.np_inputs, axis) self.check_output(out, pd_out, "out")
self.check_output(x1_grad, pd_x1_grad, "x1_grad")
self.assertTrue( self.check_output(x2_grad, pd_x2_grad, "x2_grad")
np.array_equal(out, pd_out),
"custom op out: {},\n paddle api out: {}".format( def test_dynamic_with_attr(self):
out, pd_out)) for dtype in self.dtypes:
self.assertTrue( for axis in self.axises:
np.array_equal(x1_grad, pd_x1_grad), out, grad_inputs = concat_dynamic(
"custom op x1_grad: {},\n paddle api x1_grad: {}". custom_ops.custom_concat_with_attr, dtype, self.np_inputs,
format(x1_grad, pd_x1_grad)) axis, True)
self.assertTrue( pd_out, pd_grad_inputs = concat_dynamic(
np.array_equal(x2_grad, pd_x2_grad), paddle.concat, dtype, self.np_inputs, axis, True)
"custom op x2_grad: {},\n paddle api x2_grad: {}".
format(x2_grad, pd_x2_grad)) self.check_output(out, pd_out, "out")
for x_grad, pd_x_grad in zip(grad_inputs, pd_grad_inputs):
self.check_output(x_grad, pd_x_grad, "x_grad")
def test_static_with_attr(self):
for dtype in self.dtypes:
for axis in self.axises:
out, x1_grad, x2_grad = concat_static(
custom_ops.custom_concat_with_attr, dtype, self.np_inputs,
axis, True)
pd_out, pd_x1_grad, pd_x2_grad = concat_static(
paddle.concat, dtype, self.np_inputs, axis, True)
self.check_output(out, pd_out, "out")
self.check_output(x1_grad, pd_x1_grad, "x1_grad")
self.check_output(x2_grad, pd_x2_grad, "x2_grad")
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册