未验证 提交 0df9e4ce 编写于 作者: C Charles-hit 提交者: GitHub

[AMP Prim OP]support dropout prim ops bfloat16 dtype (#54175)

* fix dropout api and support bf16 for prim

* fix code style

* fix dropout test

* fix dropout p = 0 test
上级 d4451cb0
...@@ -311,10 +311,6 @@ class TestFP16DropoutOp(OpTest): ...@@ -311,10 +311,6 @@ class TestFP16DropoutOp(OpTest):
'is_test': True, 'is_test': True,
} }
self.outputs = {'Out': out} self.outputs = {'Out': out}
# Because prim op compare res with dygraph
# when p = 0 dropout api return x,in dygraph mode x_grad = out_grad,
# but in static mode x_grad = []
self.enable_check_static_comp = False
def init_test_case(self): def init_test_case(self):
self.input_size = [32, 64] self.input_size = [32, 64]
...@@ -362,22 +358,10 @@ class TestBF16DropoutOp(OpTest): ...@@ -362,22 +358,10 @@ class TestBF16DropoutOp(OpTest):
} }
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_prim=True)
def test_check_grad_normal(self): def test_check_grad_normal(self):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out', check_prim=True)
def test_check_output_for_prim(self):
# greater_equal does't support bfloat16 in cpu
if core.is_compiled_with_cuda():
self.check_output_with_place(core.CUDAPlace(0))
def test_check_grad_for_prim(self):
# greater_equal does't support bfloat16 in cpu
if core.is_compiled_with_cuda():
self.check_grad_with_place(
core.CUDAPlace(0), ['X'], 'Out', only_check_prim=True
)
class TestDropoutOpWithSeedOnCPUPlace(unittest.TestCase): class TestDropoutOpWithSeedOnCPUPlace(unittest.TestCase):
...@@ -1451,8 +1435,9 @@ class PrimNet(paddle.nn.Layer): ...@@ -1451,8 +1435,9 @@ class PrimNet(paddle.nn.Layer):
training=True, training=True,
mode="upscale_in_train", mode="upscale_in_train",
): ):
y = paddle.assign(x)
out = paddle.nn.functional.dropout( out = paddle.nn.functional.dropout(
x=x, p=p, axis=axis, training=training, mode=mode x=y, p=p, axis=axis, training=training, mode=mode
) )
return out return out
...@@ -1476,6 +1461,16 @@ def apply_to_static(net, use_cinn): ...@@ -1476,6 +1461,16 @@ def apply_to_static(net, use_cinn):
'float32', 'float32',
places, places,
), ),
(
'bfp16',
np.random.rand(100000),
0.3,
False,
'upscale_in_train',
1002,
'bfloat16',
places,
),
( (
'fp64', 'fp64',
np.random.rand(100000), np.random.rand(100000),
...@@ -1506,6 +1501,16 @@ def apply_to_static(net, use_cinn): ...@@ -1506,6 +1501,16 @@ def apply_to_static(net, use_cinn):
'float32', 'float32',
places, places,
), ),
(
'p=1.0,dtype=bfp16',
np.random.rand(100000),
1.0,
True,
'upscale_in_train',
1002,
'bfloat16',
places,
),
( (
'p=1.0,test=False', 'p=1.0,test=False',
np.random.rand(100000), np.random.rand(100000),
...@@ -1517,15 +1522,35 @@ def apply_to_static(net, use_cinn): ...@@ -1517,15 +1522,35 @@ def apply_to_static(net, use_cinn):
places, places,
), ),
( (
'p=0.0', 'p=1.0,test=False,dtype=bfp16',
np.random.rand(100000), np.random.rand(100000),
1.0, 1.0,
False,
'upscale_in_train',
1002,
'bfloat16',
places,
),
(
'p=0.0',
np.random.rand(100000),
0,
True, True,
'upscale_in_train', 'upscale_in_train',
1002, 1002,
'float32', 'float32',
places, places,
), ),
(
'p=0.0,dtype=bfp16',
np.random.rand(100000),
0,
True,
'upscale_in_train',
1002,
'bfloat16',
places,
),
( (
'downgrade_train', 'downgrade_train',
np.random.rand(100000), np.random.rand(100000),
...@@ -1536,6 +1561,16 @@ def apply_to_static(net, use_cinn): ...@@ -1536,6 +1561,16 @@ def apply_to_static(net, use_cinn):
'float32', 'float32',
places, places,
), ),
(
'downgrade_train,dtype=bfp16',
np.random.rand(100000),
0.5,
False,
'downscale_in_infer',
1002,
'bfloat16',
places,
),
( (
'fp32_cpu', 'fp32_cpu',
np.random.rand(100000), np.random.rand(100000),
...@@ -1571,7 +1606,11 @@ def apply_to_static(net, use_cinn): ...@@ -1571,7 +1606,11 @@ def apply_to_static(net, use_cinn):
class TestCompositeDropout(unittest.TestCase): class TestCompositeDropout(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.x = cls.x.astype(cls.dtype) cls.x = (
cls.x.astype(cls.dtype)
if cls.dtype != "bfloat16"
else cls.x.astype("float32")
)
core._set_prim_all_enabled(True) core._set_prim_all_enabled(True)
@classmethod @classmethod
...@@ -1596,12 +1635,18 @@ class TestCompositeDropout(unittest.TestCase): ...@@ -1596,12 +1635,18 @@ class TestCompositeDropout(unittest.TestCase):
paddle.set_device("gpu") paddle.set_device("gpu")
core.set_prim_eager_enabled(False) core.set_prim_eager_enabled(False)
input_ = paddle.to_tensor( input_ = paddle.to_tensor(
data=self.x, dtype=self.dtype, place=place, stop_gradient=False data=self.x,
dtype=self.dtype if self.dtype != "bfloat16" else "float32",
place=place,
stop_gradient=False,
) )
output = paddle.nn.functional.dropout( output = paddle.nn.functional.dropout(
input_, self.p, training=(not self.is_test), mode=self.mode input_, self.p, training=(not self.is_test), mode=self.mode
) )
grad = paddle.grad(output, input_) grad = paddle.grad(output, input_)
if self.dtype == "bfloat16":
output = paddle.cast(output, "float32")
grad[0] = paddle.cast(grad[0], "float32")
return output, grad[0] return output, grad[0]
def test_static_comp(self): def test_static_comp(self):
...@@ -1614,11 +1659,16 @@ class TestCompositeDropout(unittest.TestCase): ...@@ -1614,11 +1659,16 @@ class TestCompositeDropout(unittest.TestCase):
mp, sp = paddle.static.Program(), paddle.static.Program() mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp): with paddle.static.program_guard(mp, sp):
input_ = paddle.static.data( input_ = paddle.static.data(
'x', shape=self.x.shape, dtype=self.x.dtype 'x',
shape=self.x.shape,
dtype=self.x.dtype
if self.dtype != "bfloat16"
else "float32",
) )
input_.stop_gradient = False input_.stop_gradient = False
y = paddle.assign(input_)
output = paddle.nn.functional.dropout( output = paddle.nn.functional.dropout(
input_, y,
self.p, self.p,
training=(not self.is_test), training=(not self.is_test),
mode=self.mode, mode=self.mode,
...@@ -1626,6 +1676,9 @@ class TestCompositeDropout(unittest.TestCase): ...@@ -1626,6 +1676,9 @@ class TestCompositeDropout(unittest.TestCase):
if core._is_fwd_prim_enabled(): if core._is_fwd_prim_enabled():
primapi.to_prim(mp.blocks) primapi.to_prim(mp.blocks)
grad = paddle.static.gradients(output, input_)[0] grad = paddle.static.gradients(output, input_)[0]
if self.dtype == "bfloat16":
output = paddle.cast(output, "float32")
grad = paddle.cast(grad, "float32")
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
exe.run(sp) exe.run(sp)
fwd, rev = exe.run( fwd, rev = exe.run(
...@@ -1662,7 +1715,10 @@ class TestCompositeDropout(unittest.TestCase): ...@@ -1662,7 +1715,10 @@ class TestCompositeDropout(unittest.TestCase):
paddle.set_device("gpu") paddle.set_device("gpu")
paddle.seed(self.seed) paddle.seed(self.seed)
input_ = paddle.to_tensor( input_ = paddle.to_tensor(
data=self.x, dtype=self.dtype, place=place, stop_gradient=False data=self.x,
dtype=self.dtype if self.dtype != "bfloat16" else "float32",
place=place,
stop_gradient=False,
) )
net = PrimNet() net = PrimNet()
net = apply_to_static(net, False) net = apply_to_static(net, False)
...@@ -1670,6 +1726,9 @@ class TestCompositeDropout(unittest.TestCase): ...@@ -1670,6 +1726,9 @@ class TestCompositeDropout(unittest.TestCase):
input_, self.p, training=(not self.is_test), mode=self.mode input_, self.p, training=(not self.is_test), mode=self.mode
) )
grad = paddle.grad(output, input_) grad = paddle.grad(output, input_)
if self.dtype == "bfloat16":
output = paddle.cast(output, "float32")
grad[0] = paddle.cast(grad[0], "float32")
fwd_actual.append(output.numpy()) fwd_actual.append(output.numpy())
rev_actual.append(grad[0].numpy()) rev_actual.append(grad[0].numpy())
for i in range(len(self.places)): for i in range(len(self.places)):
...@@ -1696,7 +1755,10 @@ class TestCompositeDropout(unittest.TestCase): ...@@ -1696,7 +1755,10 @@ class TestCompositeDropout(unittest.TestCase):
paddle.set_device("gpu") paddle.set_device("gpu")
paddle.seed(self.seed) paddle.seed(self.seed)
input_ = paddle.to_tensor( input_ = paddle.to_tensor(
data=self.x, dtype=self.dtype, place=place, stop_gradient=False data=self.x,
dtype=self.dtype if self.dtype != "bfloat16" else "float32",
place=place,
stop_gradient=False,
) )
net = PrimNet() net = PrimNet()
net = apply_to_static(net, True) net = apply_to_static(net, True)
...@@ -1704,6 +1766,9 @@ class TestCompositeDropout(unittest.TestCase): ...@@ -1704,6 +1766,9 @@ class TestCompositeDropout(unittest.TestCase):
input_, self.p, training=(not self.is_test), mode=self.mode input_, self.p, training=(not self.is_test), mode=self.mode
) )
grad = paddle.grad(output, input_) grad = paddle.grad(output, input_)
if self.dtype == "bfloat16":
output = paddle.cast(output, "float32")
grad[0] = paddle.cast(grad[0], "float32")
fwd_actual.append(output.numpy()) fwd_actual.append(output.numpy())
rev_actual.append(grad[0].numpy()) rev_actual.append(grad[0].numpy())
i = 0 i = 0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册