From 56dbe42628b6f0317c896327882fcb3fb78bd404 Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Mon, 16 Jan 2023 18:33:36 +0800 Subject: [PATCH] [PHI] channel_shuffle add yaml (#49808) --- paddle/phi/api/yaml/legacy_backward.yaml | 9 +++++++++ paddle/phi/api/yaml/legacy_ops.yaml | 9 +++++++++ .../paddle/fluid/tests/unittests/test_channel_shuffle.py | 5 +++-- python/paddle/nn/functional/vision.py | 4 +--- 4 files changed, 22 insertions(+), 5 deletions(-) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 50640d313ef..3ef30965673 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -194,6 +194,15 @@ invoke : cast (out_grad, x.dtype()) no_need_buffer : x +- backward_op : channel_shuffle_grad + forward : channel_shuffle (Tensor x, int groups, str data_format="NCHW") -> Tensor(out) + args : (Tensor out_grad, int groups, str data_format="NCHW") + output : Tensor(x_grad) + infer_meta : + func : ChannelShuffleGradInferMeta + kernel : + func : channel_shuffle_grad + - backward_op : concat_double_grad forward : concat_grad (Tensor[] x, Tensor grad_out, Scalar axis) -> Tensor[](grad_x) args : (Tensor[] grad_x_grad, Scalar axis = 0) diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index 3be9cdf371d..52db798aecc 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -338,6 +338,15 @@ data_type : x backward : cast_grad +- op : channel_shuffle + args : (Tensor x, int groups, str data_format="NCHW") + output : Tensor(out) + infer_meta : + func : ChannelShuffleInferMeta + kernel : + func : channel_shuffle + backward : channel_shuffle_grad + - op : check_finite_and_unscale_ args : (Tensor[] x, Tensor scale, Tensor input_found_infinite) output : Tensor[](out){x.size()}, Tensor(output_found_infinite) diff --git a/python/paddle/fluid/tests/unittests/test_channel_shuffle.py b/python/paddle/fluid/tests/unittests/test_channel_shuffle.py index bfad2bd94d3..99e52d18ca5 100644 --- a/python/paddle/fluid/tests/unittests/test_channel_shuffle.py +++ b/python/paddle/fluid/tests/unittests/test_channel_shuffle.py @@ -47,6 +47,7 @@ class TestChannelShuffleOp(OpTest): self.op_type = "channel_shuffle" self.init_data_format() n, c, h, w = 2, 9, 4, 4 + self.python_api = paddle.nn.functional.channel_shuffle if self.format == "NCHW": shape = [n, c, h, w] @@ -66,10 +67,10 @@ class TestChannelShuffleOp(OpTest): self.format = "NCHW" def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) class TestChannelLast(TestChannelShuffleOp): diff --git a/python/paddle/nn/functional/vision.py b/python/paddle/nn/functional/vision.py index fc9e030c3ea..4f164e991f3 100644 --- a/python/paddle/nn/functional/vision.py +++ b/python/paddle/nn/functional/vision.py @@ -516,9 +516,7 @@ def channel_shuffle(x, groups, data_format="NCHW", name=None): ) if in_dygraph_mode(): - return _legacy_C_ops.channel_shuffle( - x, "groups", groups, "data_format", data_format - ) + return _C_ops.channel_shuffle(x, groups, data_format) helper = LayerHelper("channel_shuffle", **locals()) check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'channel_shuffle') -- GitLab