未验证 提交 3f991128 编写于 作者: A Ainavo 提交者: GitHub

[bug fix] fix fp16 dtype checking for diff op (#51736)

* add_fp16_for_diff

* fix doc_for_fp16
上级 e0007f31
...@@ -228,6 +228,36 @@ class TestDiffOpPreAppendAxis(TestDiffOp): ...@@ -228,6 +228,36 @@ class TestDiffOpPreAppendAxis(TestDiffOp):
self.append = np.array([[2, 3, 4, 7], [1, 3, 5, 6]]).astype('float32') self.append = np.array([[2, 3, 4, 7], [1, 3, 5, 6]]).astype('float32')
class TestDiffOpFp16(TestDiffOp):
def test_fp16_with_gpu(self):
paddle.enable_static()
if paddle.fluid.core.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
input = np.random.random([4, 4]).astype("float16")
x = paddle.static.data(
name="input", shape=[4, 4], dtype="float16"
)
exe = paddle.static.Executor(place)
out = paddle.diff(
x,
n=self.n,
axis=self.axis,
prepend=self.prepend,
append=self.append,
)
fetches = exe.run(
paddle.static.default_main_program(),
feed={
"input": input,
},
fetch_list=[out],
)
paddle.disable_static()
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static() paddle.enable_static()
unittest.main() unittest.main()
...@@ -4611,7 +4611,7 @@ def diff(x, n=1, axis=-1, prepend=None, append=None, name=None): ...@@ -4611,7 +4611,7 @@ def diff(x, n=1, axis=-1, prepend=None, append=None, name=None):
Only n=1 is currently supported. Only n=1 is currently supported.
Args: Args:
x (Tensor): The input tensor to compute the forward difference on x (Tensor): The input tensor to compute the forward difference on, the data type is float16(GPU), float32, float64, bool, int32, int64.
n (int, optional): The number of times to recursively compute the difference. n (int, optional): The number of times to recursively compute the difference.
Only support n=1. Default:1 Only support n=1. Default:1
axis (int, optional): The axis to compute the difference along. Default:-1 axis (int, optional): The axis to compute the difference along. Default:-1
...@@ -4706,7 +4706,10 @@ def diff(x, n=1, axis=-1, prepend=None, append=None, name=None): ...@@ -4706,7 +4706,10 @@ def diff(x, n=1, axis=-1, prepend=None, append=None, name=None):
return _C_ops.subtract(input_back, input_front) return _C_ops.subtract(input_back, input_front)
else: else:
check_variable_and_dtype( check_variable_and_dtype(
x, 'x', ['float32', 'float64', 'bool', 'int32', 'int64'], 'diff' x,
'x',
['float16', 'float32', 'float64', 'bool', 'int32', 'int64'],
'diff',
) )
check_type(axis, 'axis', (int), 'diff') check_type(axis, 'axis', (int), 'diff')
helper = LayerHelper('diff', **locals()) helper = LayerHelper('diff', **locals())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册