未验证 提交 56dbe426 编写于 作者: W Weilong Wu 提交者: GitHub

[PHI] channel_shuffle add yaml (#49808)

上级 65b0181e
...@@ -194,6 +194,15 @@ ...@@ -194,6 +194,15 @@
invoke : cast (out_grad, x.dtype()) invoke : cast (out_grad, x.dtype())
no_need_buffer : x no_need_buffer : x
- backward_op : channel_shuffle_grad
forward : channel_shuffle (Tensor x, int groups, str data_format="NCHW") -> Tensor(out)
args : (Tensor out_grad, int groups, str data_format="NCHW")
output : Tensor(x_grad)
infer_meta :
func : ChannelShuffleGradInferMeta
kernel :
func : channel_shuffle_grad
- backward_op : concat_double_grad - backward_op : concat_double_grad
forward : concat_grad (Tensor[] x, Tensor grad_out, Scalar axis) -> Tensor[](grad_x) forward : concat_grad (Tensor[] x, Tensor grad_out, Scalar axis) -> Tensor[](grad_x)
args : (Tensor[] grad_x_grad, Scalar axis = 0) args : (Tensor[] grad_x_grad, Scalar axis = 0)
......
...@@ -338,6 +338,15 @@ ...@@ -338,6 +338,15 @@
data_type : x data_type : x
backward : cast_grad backward : cast_grad
- op : channel_shuffle
args : (Tensor x, int groups, str data_format="NCHW")
output : Tensor(out)
infer_meta :
func : ChannelShuffleInferMeta
kernel :
func : channel_shuffle
backward : channel_shuffle_grad
- op : check_finite_and_unscale_ - op : check_finite_and_unscale_
args : (Tensor[] x, Tensor scale, Tensor input_found_infinite) args : (Tensor[] x, Tensor scale, Tensor input_found_infinite)
output : Tensor[](out){x.size()}, Tensor(output_found_infinite) output : Tensor[](out){x.size()}, Tensor(output_found_infinite)
......
...@@ -47,6 +47,7 @@ class TestChannelShuffleOp(OpTest): ...@@ -47,6 +47,7 @@ class TestChannelShuffleOp(OpTest):
self.op_type = "channel_shuffle" self.op_type = "channel_shuffle"
self.init_data_format() self.init_data_format()
n, c, h, w = 2, 9, 4, 4 n, c, h, w = 2, 9, 4, 4
self.python_api = paddle.nn.functional.channel_shuffle
if self.format == "NCHW": if self.format == "NCHW":
shape = [n, c, h, w] shape = [n, c, h, w]
...@@ -66,10 +67,10 @@ class TestChannelShuffleOp(OpTest): ...@@ -66,10 +67,10 @@ class TestChannelShuffleOp(OpTest):
self.format = "NCHW" self.format = "NCHW"
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_eager=True)
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out', check_eager=True)
class TestChannelLast(TestChannelShuffleOp): class TestChannelLast(TestChannelShuffleOp):
......
...@@ -516,9 +516,7 @@ def channel_shuffle(x, groups, data_format="NCHW", name=None): ...@@ -516,9 +516,7 @@ def channel_shuffle(x, groups, data_format="NCHW", name=None):
) )
if in_dygraph_mode(): if in_dygraph_mode():
return _legacy_C_ops.channel_shuffle( return _C_ops.channel_shuffle(x, groups, data_format)
x, "groups", groups, "data_format", data_format
)
helper = LayerHelper("channel_shuffle", **locals()) helper = LayerHelper("channel_shuffle", **locals())
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'channel_shuffle') check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'channel_shuffle')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册