未验证 提交 fbb40281 编写于 作者: Z zhangbo9674 提交者: GitHub

[AMP] Check call order of paddle.amp.decorate and paddle.DataParallel (#38785)

* check amp.decorate and DataParallel

* refine coverage

* fix layer dtype

* refine code
上级 9f34a070
......@@ -145,6 +145,10 @@ def check_models(models):
raise RuntimeError(
"Current train mode is pure fp16, models should be paddle.nn.Layer, but receive {}.".
format(type(model)))
if isinstance(model, paddle.DataParallel):
raise RuntimeError(
"For distributed AMP training, you should first use paddle.amp.decorate() to decotate origin model, and then call paddle.DataParallel get distributed model."
)
def check_optimizers(optimizers):
......
......@@ -1569,6 +1569,8 @@ class Layer(object):
for key, buf in self._buffers.items():
self._buffers[key] = func(buf, device, dtype, blocking)
self._dtype = dtype
def _to_impl(self,
device=None,
dtype=None,
......
......@@ -536,6 +536,14 @@ class TestAmpDecorator(unittest.TestCase):
self.assertRaises(TypeError, test_error_model)
def test_error_distributed_model():
model = fluid.dygraph.Conv2D(3, 2, 3, bias_attr=False, act=None)
model = paddle.DataParallel(model)
with fluid.dygraph.guard():
model = paddle.amp.decorate(models=model, level='O2')
self.assertRaises(RuntimeError, test_error_distributed_model)
def test_error_optimizer():
class MyOptimizer(object):
def __init__(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册