未验证 提交 64f769d4 编写于 作者: H Haohongxiang 提交者: GitHub

[Dygraph] Remove unrequired UT cases of DP in eager mode (#41413)

* remove unrequired ut cases

* update

* fix bugs

* update
上级 6f4bd0ea
...@@ -20,6 +20,7 @@ from paddle import framework ...@@ -20,6 +20,7 @@ from paddle import framework
import paddle import paddle
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.dygraph.parallel import _split_tensors, sync_params_buffers, build_groups from paddle.fluid.dygraph.parallel import _split_tensors, sync_params_buffers, build_groups
from paddle.fluid.framework import in_dygraph_mode, _in_legacy_dygraph
from collections import OrderedDict from collections import OrderedDict
from .log_util import logger from .log_util import logger
...@@ -58,6 +59,30 @@ def _apply_collective_grads(parameters, comm_group): ...@@ -58,6 +59,30 @@ def _apply_collective_grads(parameters, comm_group):
_split_tensors(coalesced_grads_and_vars) _split_tensors(coalesced_grads_and_vars)
def _apply_collective_grads_eager(parameters, comm_group):
grad_var_set = set()
grad_vars = []
for param in parameters:
if param.trainable and (param._grad_ivar() is not None):
g_var = param._grad_ivar()
assert not g_var.is_sparse(
), "Now, it doesn't support sparse parameters"
grad_vars.append(g_var)
assert g_var not in grad_var_set
grad_var_set.add(g_var)
coalesced_grads_and_vars = build_groups(grad_vars, 128 * 1024 * 1024)
div_factor = 1.0 / comm_group.nranks
for coalesced_grad, _, _ in coalesced_grads_and_vars:
# need to div nranks
coalesced_grad.scale_(div_factor)
paddle.distributed.all_reduce(coalesced_grad, group=comm_group)
_split_tensors(coalesced_grads_and_vars)
def _broadcast_data_help(data, shape, dtype, hcg): def _broadcast_data_help(data, shape, dtype, hcg):
model_parallel_group = hcg.get_model_parallel_group() model_parallel_group = hcg.get_model_parallel_group()
src_rank = hcg.get_model_parallel_group_src_rank() src_rank = hcg.get_model_parallel_group_src_rank()
...@@ -115,10 +140,17 @@ def broadcast_dp_parameters(model, hcg): ...@@ -115,10 +140,17 @@ def broadcast_dp_parameters(model, hcg):
def fused_allreduce_gradients(parameter_list, hcg): def fused_allreduce_gradients(parameter_list, hcg):
data_parallel_group = None if hcg is None else hcg.get_data_parallel_group() if _in_legacy_dygraph():
data_parallel_group = None if hcg is None else hcg.get_data_parallel_group(
)
logger.debug("dp start fuse allreduce gradients") logger.debug("dp start fuse allreduce gradients")
with framework.no_grad(): with framework.no_grad():
_apply_collective_grads(parameter_list, data_parallel_group) _apply_collective_grads(parameter_list, data_parallel_group)
elif in_dygraph_mode():
assert hcg is None, "It's not support to use hcg in EagerDygraph now."
data_parallel_group = paddle.distributed.collective._get_default_group()
with framework.no_grad():
_apply_collective_grads_eager(parameter_list, data_parallel_group)
def sharding_reduce_gradients(parameter_list, hcg): def sharding_reduce_gradients(parameter_list, hcg):
......
...@@ -22,6 +22,7 @@ import warnings ...@@ -22,6 +22,7 @@ import warnings
from contextlib import contextmanager from contextlib import contextmanager
import paddle import paddle
from paddle import _C_ops
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid import framework from paddle.fluid import framework
from paddle.fluid.dygraph import layers from paddle.fluid.dygraph import layers
...@@ -307,6 +308,7 @@ def _reshape_inplace(x, shape): ...@@ -307,6 +308,7 @@ def _reshape_inplace(x, shape):
@framework.dygraph_only @framework.dygraph_only
def _split_tensors(coalesced_grads_and_grad_vars): def _split_tensors(coalesced_grads_and_grad_vars):
if _in_legacy_dygraph():
for coalesced_grad, origin_grad_vars, grad_shapes in coalesced_grads_and_grad_vars: for coalesced_grad, origin_grad_vars, grad_shapes in coalesced_grads_and_grad_vars:
grad_var_len = [np.prod(g_shape) for g_shape in grad_shapes] grad_var_len = [np.prod(g_shape) for g_shape in grad_shapes]
framework._dygraph_tracer().trace_op( framework._dygraph_tracer().trace_op(
...@@ -318,6 +320,16 @@ def _split_tensors(coalesced_grads_and_grad_vars): ...@@ -318,6 +320,16 @@ def _split_tensors(coalesced_grads_and_grad_vars):
for g_var, g_shape in zip(origin_grad_vars, grad_shapes): for g_var, g_shape in zip(origin_grad_vars, grad_shapes):
_reshape_inplace(x=g_var, shape=g_shape) _reshape_inplace(x=g_var, shape=g_shape)
assert g_var.shape == g_shape assert g_var.shape == g_shape
elif in_dygraph_mode():
for coalesced_grad, origin_grad_vars, grad_shapes in coalesced_grads_and_grad_vars:
grad_var_len = [np.prod(g_shape) for g_shape in grad_shapes]
attrs = ()
attrs += ('sections', grad_var_len)
attrs += ('axis', 0)
_C_ops.split(coalesced_grad, origin_grad_vars, *attrs)
for g_var, g_shape in zip(origin_grad_vars, grad_shapes):
g_var.reshape_(shape=g_shape)
assert g_var.shape == g_shape
def scale_loss(loss): def scale_loss(loss):
......
...@@ -21,7 +21,8 @@ import paddle ...@@ -21,7 +21,8 @@ import paddle
import numpy as np import numpy as np
import paddle.distributed as dist import paddle.distributed as dist
from paddle.fluid.dygraph.nn import Linear from paddle.fluid.dygraph.nn import Linear
from paddle.autograd import PyLayer from paddle.autograd import PyLayer, EagerPyLayer
from paddle.fluid.framework import in_dygraph_mode, _in_legacy_dygraph
from paddle.distributed.fleet.utils.hybrid_parallel_util import fused_allreduce_gradients from paddle.distributed.fleet.utils.hybrid_parallel_util import fused_allreduce_gradients
batch = 5 batch = 5
...@@ -43,6 +44,20 @@ class cus_tanh(PyLayer): ...@@ -43,6 +44,20 @@ class cus_tanh(PyLayer):
return grad return grad
class cus_tanh_eager(EagerPyLayer):
@staticmethod
def forward(ctx, x):
y = paddle.tanh(x)
ctx.save_for_backward(y)
return y
@staticmethod
def backward(ctx, dy):
y, = ctx.saved_tensor()
grad = dy * (1 - paddle.square(y))
return grad
class SimpleNet(paddle.nn.Layer): class SimpleNet(paddle.nn.Layer):
def __init__(self, train_id, model_id): def __init__(self, train_id, model_id):
super(SimpleNet, self).__init__() super(SimpleNet, self).__init__()
...@@ -55,6 +70,9 @@ class SimpleNet(paddle.nn.Layer): ...@@ -55,6 +70,9 @@ class SimpleNet(paddle.nn.Layer):
def forward(self, inputs): def forward(self, inputs):
if self.model_id == 0: if self.model_id == 0:
if in_dygraph_mode():
inputs = cus_tanh_eager.apply(inputs)
elif _in_legacy_dygraph():
inputs = cus_tanh.apply(inputs) inputs = cus_tanh.apply(inputs)
else: else:
inputs = self.tanh(inputs) inputs = self.tanh(inputs)
......
...@@ -23,6 +23,7 @@ import os ...@@ -23,6 +23,7 @@ import os
import subprocess import subprocess
from paddle.distributed.utils import find_free_ports, watch_local_trainers, get_cluster, TrainerProc from paddle.distributed.utils import find_free_ports, watch_local_trainers, get_cluster, TrainerProc
from paddle.fluid.framework import _test_eager_guard
def get_cluster_from_args(selected_gpus): def get_cluster_from_args(selected_gpus):
...@@ -205,6 +206,8 @@ class TestDataParallelGradientCheck(TestMultipleGpus): ...@@ -205,6 +206,8 @@ class TestDataParallelGradientCheck(TestMultipleGpus):
class TestDataParallelWithPyLayer(TestMultipleGpus): class TestDataParallelWithPyLayer(TestMultipleGpus):
def test_parallel_dygraph_dataparallel_with_pylayer(self): def test_parallel_dygraph_dataparallel_with_pylayer(self):
with _test_eager_guard():
self.run_mnist_2gpu('parallel_dygraph_dataparallel_with_pylayer.py')
self.run_mnist_2gpu('parallel_dygraph_dataparallel_with_pylayer.py') self.run_mnist_2gpu('parallel_dygraph_dataparallel_with_pylayer.py')
......
...@@ -55,35 +55,5 @@ class TestParallelDygraphSparseEmdeddingFP64_GLOO(TestDistBase): ...@@ -55,35 +55,5 @@ class TestParallelDygraphSparseEmdeddingFP64_GLOO(TestDistBase):
log_name=flag_name) log_name=flag_name)
class TestParallelDygraphSparseEmdeddingEager_GLOO(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._eager_mode = True
self._gloo_mode = True
self._dygraph = True
def test_sparse_embedding(self):
self.check_with_place(
"parallel_dygraph_sparse_embedding.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name)
class TestParallelDygraphSparseEmdeddingEagerFP64_GLOO(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._eager_mode = True
self._gloo_mode = True
self._dygraph = True
def test_sparse_embedding_fp64(self):
self.check_with_place(
"parallel_dygraph_sparse_embedding_fp64.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -40,20 +40,5 @@ class TestParallelDygraphSparseEmdeddingOverHeight_GLOO(TestDistBase): ...@@ -40,20 +40,5 @@ class TestParallelDygraphSparseEmdeddingOverHeight_GLOO(TestDistBase):
log_name=flag_name) log_name=flag_name)
class TestParallelDygraphSparseEmdeddingOverHeightEager_GLOO(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._eager_mode = True
self._gloo_mode = True
self._dygraph = True
def test_sparse_embedding(self):
self.check_with_place(
"parallel_dygraph_sparse_embedding_over_height.py",
delta=1e-7,
check_error_log=True,
log_name=flag_name)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -57,20 +57,5 @@ class TestParallelDygraphTransformerAccGrad_GLOO(TestDistBase): ...@@ -57,20 +57,5 @@ class TestParallelDygraphTransformerAccGrad_GLOO(TestDistBase):
log_name=flag_name) log_name=flag_name)
class TestParallelDygraphTransformerEager_GLOO(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._eager_mode = True
self._gloo_mode = True
self._dygraph = True
def test_transformer(self):
self.check_with_place(
"parallel_dygraph_transformer.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册