未验证 提交 0243c6ca 编写于 作者: C ccrrong 提交者: GitHub

[Phi] add temporal_shift yaml (#44409)

* add temporal_shift yaml and unittest
上级 438ca7f6
......@@ -2154,6 +2154,16 @@
func : tanh_shrink
backward : tanh_shrink_grad
# temporal_shift
- api : temporal_shift
args : (Tensor x, int seg_num, float shift_ratio, str data_format_str)
output : Tensor
infer_meta :
func : TemporalShiftInferMeta
kernel :
func : temporal_shift
backward : temporal_shift_grad
# thresholded_relu
- api : thresholded_relu
args : (Tensor x, float threshold)
......
......@@ -2173,6 +2173,16 @@
func : tanh_triple_grad
inplace : (grad_x_grad_forward -> grad_out_forward_grad)
- backward_api : temporal_shift_grad
forward : temporal_shift(Tensor x, int seg_num, float shift_ratio, str data_format_str) -> Tensor(out)
args : (Tensor out_grad, int seg_num, float shift_ratio, str data_format_str)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [out_grad]
kernel :
func : temporal_shift_grad
- backward_api : thresholded_relu_grad
forward : thresholded_relu (Tensor x, float threshold) -> Tensor(out)
args : (Tensor x, Tensor out_grad, float threshold)
......
......@@ -46,6 +46,7 @@ class TestTemporalShift(OpTest):
def setUp(self):
self.initTestCase()
self.op_type = 'temporal_shift'
self.python_api = paddle.nn.functional.temporal_shift
x = np.random.random(self.x_shape).astype(self.dtype)
self.attrs = {
......@@ -61,12 +62,13 @@ class TestTemporalShift(OpTest):
output = temporal_shift(x, self.seg_num, self.shift_ratio,
self.data_format)
self.outputs = {"Out": output}
self.python_out_sig = ["Out"]
def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)
def test_check_grad_ignore_uv(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_eager=True)
def initTestCase(self):
self.x_shape = (6, 4, 4, 4)
......
......@@ -366,6 +366,9 @@ def temporal_shift(x, seg_num, shift_ratio=0.25, name=None, data_format="NCHW"):
if data_format not in ["NCHW", "NHWC"]:
raise ValueError("Attr(data_format) should be 'NCHW' or 'NHWC'. "
"Received Attr(data_format): {}.".format(data_format))
if in_dygraph_mode():
return _C_ops.final_state_temporal_shift(x, seg_num, shift_ratio,
data_format)
if _non_static_mode():
return _C_ops.temporal_shift(x, 'seg_num', seg_num, 'shift_ratio',
shift_ratio, 'data_format', data_format)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册