未验证 提交 3b8f8b6c 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] Remove redundancy code, fix fp16 case (#42169)

上级 4a16d5c6
......@@ -353,7 +353,6 @@ class NormalInitializer(Initializer):
out_var = _C_ops.final_state_gaussian_random(
var.shape, self._mean, self._std_dev, self._seed, out_dtype,
place)
out_var._share_underline_tensor_to(var)
if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]:
var_tmp = _C_ops.final_state_cast(out_var, var.dtype)
......
......@@ -19,6 +19,7 @@ import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear
from paddle.fluid.framework import _test_eager_guard
class SimpleImgConvPool(fluid.dygraph.Layer):
......@@ -117,7 +118,7 @@ class MNIST(fluid.dygraph.Layer):
class TestMnist(unittest.TestCase):
def test_mnist_fp16(self):
def func_mnist_fp16(self):
if not fluid.is_compiled_with_cuda():
return
x = np.random.randn(1, 3, 224, 224).astype("float16")
......@@ -129,6 +130,11 @@ class TestMnist(unittest.TestCase):
loss = model(x, y)
print(loss.numpy())
def test_mnist_fp16(self):
with _test_eager_guard():
self.func_mnist_fp16()
self.func_mnist_fp16()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册