“74836ec7b7406bd26e2f52daa31f23478d265307”上不存在“...paddle/fluid/tests/unittests/test_adaptive_max_pool3d.py”
未验证 提交 3b8f8b6c 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] Remove redundancy code, fix fp16 case (#42169)

上级 4a16d5c6
...@@ -353,7 +353,6 @@ class NormalInitializer(Initializer): ...@@ -353,7 +353,6 @@ class NormalInitializer(Initializer):
out_var = _C_ops.final_state_gaussian_random( out_var = _C_ops.final_state_gaussian_random(
var.shape, self._mean, self._std_dev, self._seed, out_dtype, var.shape, self._mean, self._std_dev, self._seed, out_dtype,
place) place)
out_var._share_underline_tensor_to(var)
if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]: if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]:
var_tmp = _C_ops.final_state_cast(out_var, var.dtype) var_tmp = _C_ops.final_state_cast(out_var, var.dtype)
......
...@@ -19,6 +19,7 @@ import numpy as np ...@@ -19,6 +19,7 @@ import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear
from paddle.fluid.framework import _test_eager_guard
class SimpleImgConvPool(fluid.dygraph.Layer): class SimpleImgConvPool(fluid.dygraph.Layer):
...@@ -117,7 +118,7 @@ class MNIST(fluid.dygraph.Layer): ...@@ -117,7 +118,7 @@ class MNIST(fluid.dygraph.Layer):
class TestMnist(unittest.TestCase): class TestMnist(unittest.TestCase):
def test_mnist_fp16(self): def func_mnist_fp16(self):
if not fluid.is_compiled_with_cuda(): if not fluid.is_compiled_with_cuda():
return return
x = np.random.randn(1, 3, 224, 224).astype("float16") x = np.random.randn(1, 3, 224, 224).astype("float16")
...@@ -129,6 +130,11 @@ class TestMnist(unittest.TestCase): ...@@ -129,6 +130,11 @@ class TestMnist(unittest.TestCase):
loss = model(x, y) loss = model(x, y)
print(loss.numpy()) print(loss.numpy())
def test_mnist_fp16(self):
with _test_eager_guard():
self.func_mnist_fp16()
self.func_mnist_fp16()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册