From 0243c6cad55d33d945a5500eff523760604cda72 Mon Sep 17 00:00:00 2001 From: ccrrong <101700995+ccrrong@users.noreply.github.com> Date: Thu, 21 Jul 2022 13:23:40 +0800 Subject: [PATCH] [Phi] add temporal_shift yaml (#44409) * add temporal_shift yaml and unittest --- paddle/phi/api/yaml/legacy_api.yaml | 10 ++++++++++ paddle/phi/api/yaml/legacy_backward.yaml | 10 ++++++++++ .../fluid/tests/unittests/test_temporal_shift_op.py | 6 ++++-- python/paddle/nn/functional/extension.py | 3 +++ 4 files changed, 27 insertions(+), 2 deletions(-) diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index 0d0fd74c17..3e2e0f70a9 100644 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -2154,6 +2154,16 @@ func : tanh_shrink backward : tanh_shrink_grad +# temporal_shift +- api : temporal_shift + args : (Tensor x, int seg_num, float shift_ratio, str data_format_str) + output : Tensor + infer_meta : + func : TemporalShiftInferMeta + kernel : + func : temporal_shift + backward : temporal_shift_grad + # thresholded_relu - api : thresholded_relu args : (Tensor x, float threshold) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 6df4883145..73749b9487 100644 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -2173,6 +2173,16 @@ func : tanh_triple_grad inplace : (grad_x_grad_forward -> grad_out_forward_grad) +- backward_api : temporal_shift_grad + forward : temporal_shift(Tensor x, int seg_num, float shift_ratio, str data_format_str) -> Tensor(out) + args : (Tensor out_grad, int seg_num, float shift_ratio, str data_format_str) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [out_grad] + kernel : + func : temporal_shift_grad + - backward_api : thresholded_relu_grad forward : thresholded_relu (Tensor x, float threshold) -> Tensor(out) args : (Tensor x, Tensor out_grad, float threshold) diff --git a/python/paddle/fluid/tests/unittests/test_temporal_shift_op.py b/python/paddle/fluid/tests/unittests/test_temporal_shift_op.py index e9561b3e0a..0a9a6804b3 100644 --- a/python/paddle/fluid/tests/unittests/test_temporal_shift_op.py +++ b/python/paddle/fluid/tests/unittests/test_temporal_shift_op.py @@ -46,6 +46,7 @@ class TestTemporalShift(OpTest): def setUp(self): self.initTestCase() self.op_type = 'temporal_shift' + self.python_api = paddle.nn.functional.temporal_shift x = np.random.random(self.x_shape).astype(self.dtype) self.attrs = { @@ -61,12 +62,13 @@ class TestTemporalShift(OpTest): output = temporal_shift(x, self.seg_num, self.shift_ratio, self.data_format) self.outputs = {"Out": output} + self.python_out_sig = ["Out"] def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad_ignore_uv(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) def initTestCase(self): self.x_shape = (6, 4, 4, 4) diff --git a/python/paddle/nn/functional/extension.py b/python/paddle/nn/functional/extension.py index 1bfa7f1488..6191d015e2 100644 --- a/python/paddle/nn/functional/extension.py +++ b/python/paddle/nn/functional/extension.py @@ -366,6 +366,9 @@ def temporal_shift(x, seg_num, shift_ratio=0.25, name=None, data_format="NCHW"): if data_format not in ["NCHW", "NHWC"]: raise ValueError("Attr(data_format) should be 'NCHW' or 'NHWC'. " "Received Attr(data_format): {}.".format(data_format)) + if in_dygraph_mode(): + return _C_ops.final_state_temporal_shift(x, seg_num, shift_ratio, + data_format) if _non_static_mode(): return _C_ops.temporal_shift(x, 'seg_num', seg_num, 'shift_ratio', shift_ratio, 'data_format', data_format) -- GitLab