diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index 0d0fd74c17aa786c98e63c889fc2b65abf551ef8..3e2e0f70a9260ff6430d65ea1d43daa1dc2a11e2 100644 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -2154,6 +2154,16 @@ func : tanh_shrink backward : tanh_shrink_grad +# temporal_shift +- api : temporal_shift + args : (Tensor x, int seg_num, float shift_ratio, str data_format_str) + output : Tensor + infer_meta : + func : TemporalShiftInferMeta + kernel : + func : temporal_shift + backward : temporal_shift_grad + # thresholded_relu - api : thresholded_relu args : (Tensor x, float threshold) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 6df48831456202c696ede4fc29b66834b70148cb..73749b94870a634ef76ae8aa20ccdbd3f1cfaf21 100644 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -2173,6 +2173,16 @@ func : tanh_triple_grad inplace : (grad_x_grad_forward -> grad_out_forward_grad) +- backward_api : temporal_shift_grad + forward : temporal_shift(Tensor x, int seg_num, float shift_ratio, str data_format_str) -> Tensor(out) + args : (Tensor out_grad, int seg_num, float shift_ratio, str data_format_str) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [out_grad] + kernel : + func : temporal_shift_grad + - backward_api : thresholded_relu_grad forward : thresholded_relu (Tensor x, float threshold) -> Tensor(out) args : (Tensor x, Tensor out_grad, float threshold) diff --git a/python/paddle/fluid/tests/unittests/test_temporal_shift_op.py b/python/paddle/fluid/tests/unittests/test_temporal_shift_op.py index e9561b3e0a5d46def8fff1b2ca6170402bb0260a..0a9a6804b36d90ac4c3bc025755d4fca196da00f 100644 --- a/python/paddle/fluid/tests/unittests/test_temporal_shift_op.py +++ b/python/paddle/fluid/tests/unittests/test_temporal_shift_op.py @@ -46,6 +46,7 @@ class TestTemporalShift(OpTest): def setUp(self): self.initTestCase() self.op_type = 'temporal_shift' + self.python_api = paddle.nn.functional.temporal_shift x = np.random.random(self.x_shape).astype(self.dtype) self.attrs = { @@ -61,12 +62,13 @@ class TestTemporalShift(OpTest): output = temporal_shift(x, self.seg_num, self.shift_ratio, self.data_format) self.outputs = {"Out": output} + self.python_out_sig = ["Out"] def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad_ignore_uv(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) def initTestCase(self): self.x_shape = (6, 4, 4, 4) diff --git a/python/paddle/nn/functional/extension.py b/python/paddle/nn/functional/extension.py index 1bfa7f148838a7c7b3348d4bc404be69c3c238a0..6191d015e2a201a6c09f3c1af2baa1e1391309eb 100644 --- a/python/paddle/nn/functional/extension.py +++ b/python/paddle/nn/functional/extension.py @@ -366,6 +366,9 @@ def temporal_shift(x, seg_num, shift_ratio=0.25, name=None, data_format="NCHW"): if data_format not in ["NCHW", "NHWC"]: raise ValueError("Attr(data_format) should be 'NCHW' or 'NHWC'. " "Received Attr(data_format): {}.".format(data_format)) + if in_dygraph_mode(): + return _C_ops.final_state_temporal_shift(x, seg_num, shift_ratio, + data_format) if _non_static_mode(): return _C_ops.temporal_shift(x, 'seg_num', seg_num, 'shift_ratio', shift_ratio, 'data_format', data_format)