未验证 提交 7ea999fd 编写于 作者: S ShenLiang 提交者: GitHub

[HybridParallel] Add ClipGradByGlobalNorm & check_finite_and_unscale in Dygraph (#32354)

* add clip/check

* add amp & clip grad in dygraph

* add logging
上级 b2ee8380
...@@ -74,3 +74,4 @@ state_dict = fleet.state_dict ...@@ -74,3 +74,4 @@ state_dict = fleet.state_dict
set_state_dict = fleet.set_state_dict set_state_dict = fleet.set_state_dict
shrink = fleet.shrink shrink = fleet.shrink
get_hybrid_communicate_group = fleet.get_hybrid_communicate_group get_hybrid_communicate_group = fleet.get_hybrid_communicate_group
distributed_scaler = fleet.distributed_scaler
...@@ -30,6 +30,7 @@ from . import topology as tp ...@@ -30,6 +30,7 @@ from . import topology as tp
from .topology import ParallelMode from .topology import ParallelMode
from ..meta_parallel import ModelParallel from ..meta_parallel import ModelParallel
from ..meta_optimizers import HybridParallelOptimizer from ..meta_optimizers import HybridParallelOptimizer
from ..meta_optimizers import HybridParallelGradScaler
def _inited_runtime_handler_(func): def _inited_runtime_handler_(func):
...@@ -1333,3 +1334,7 @@ class Fleet(object): ...@@ -1333,3 +1334,7 @@ class Fleet(object):
fleet.util._set_strategy(context["valid_strategy"]) fleet.util._set_strategy(context["valid_strategy"])
return optimize_ops, params_grads return optimize_ops, params_grads
@dygraph_only
def distributed_scaler(self, scaler):
return HybridParallelGradScaler(scaler, self._hcg)
...@@ -19,6 +19,8 @@ import collections ...@@ -19,6 +19,8 @@ import collections
import numpy as np import numpy as np
from itertools import product from itertools import product
from functools import reduce from functools import reduce
from ..utils.log_util import logger
__all__ = ['CommunicateTopology', 'HybridCommunicateGroup'] __all__ = ['CommunicateTopology', 'HybridCommunicateGroup']
_HYBRID_PARALLEL_GROUP = None _HYBRID_PARALLEL_GROUP = None
...@@ -129,12 +131,17 @@ class HybridCommunicateGroup(object): ...@@ -129,12 +131,17 @@ class HybridCommunicateGroup(object):
# create comm group for model parallel # create comm group for model parallel
self._mp_group, self._mp_comm_group = self._set_comm_group("model") self._mp_group, self._mp_comm_group = self._set_comm_group("model")
# create global group for check inf_nan / clip global norm
self._check_group, self._check_comm_group = self._set_check_group(
"data")
debug_str = "HybridParallelInfo: rank_id: %d, dp_degree: %d, " \ debug_str = "HybridParallelInfo: rank_id: %d, dp_degree: %d, " \
"mp_degree: %d, pp_degree: %d\n" % (self.global_rank, self._dp_degree, "mp_degree: %d, pp_degree: %d\n" % (self.global_rank, self._dp_degree,
self._mp_degree,self._pp_degree) self._mp_degree,self._pp_degree)
debug_str += "dp_group: %s, mp_group: %s" % (self._dp_group, debug_str += "dp_group: %s, mp_group: %s, check/clip group: %s" % (
self._mp_group) self._dp_group, self._mp_group, self._check_group)
print(debug_str, file=sys.stderr) logger.info(debug_str)
global _HYBRID_PARALLEL_GROUP global _HYBRID_PARALLEL_GROUP
_HYBRID_PARALLEL_GROUP = self _HYBRID_PARALLEL_GROUP = self
...@@ -168,6 +175,22 @@ class HybridCommunicateGroup(object): ...@@ -168,6 +175,22 @@ class HybridCommunicateGroup(object):
return parallel_group, parallel_comm_group return parallel_group, parallel_comm_group
def _set_check_group(self, parallel_method="data"):
parallel_group = []
parallel_comm_group = None
parallel_size = self._topo.get_dim(parallel_method)
for idx in range(parallel_size):
parallel_groups = self._topo.get_axis_list(parallel_method, idx)
comm_group = paddle.distributed.new_group(ranks=parallel_groups)
if self.global_rank in parallel_groups:
parallel_group = parallel_groups
parallel_comm_group = comm_group
assert len(parallel_group) > 0
assert parallel_comm_group is not None
return parallel_group, parallel_comm_group
def topology(self): def topology(self):
return self._topo return self._topo
...@@ -205,3 +228,7 @@ class HybridCommunicateGroup(object): ...@@ -205,3 +228,7 @@ class HybridCommunicateGroup(object):
def get_model_parallel_group_src_rank(self): def get_model_parallel_group_src_rank(self):
return self._mp_comm_group.ranks[0] return self._mp_comm_group.ranks[0]
# check parallel group
def get_check_parallel_group(self):
return self._check_comm_group
...@@ -26,3 +26,4 @@ from .lamb_optimizer import LambOptimizer ...@@ -26,3 +26,4 @@ from .lamb_optimizer import LambOptimizer
from .fp16_allreduce_optimizer import FP16AllReduceOptimizer from .fp16_allreduce_optimizer import FP16AllReduceOptimizer
from .sharding_optimizer import ShardingOptimizer from .sharding_optimizer import ShardingOptimizer
from .dygraph_optimizer import HybridParallelOptimizer from .dygraph_optimizer import HybridParallelOptimizer
from .dygraph_optimizer import HybridParallelGradScaler
...@@ -11,3 +11,4 @@ ...@@ -11,3 +11,4 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
from .hybrid_parallel_optimizer import HybridParallelOptimizer from .hybrid_parallel_optimizer import HybridParallelOptimizer
from .hybrid_parallel_gradscaler import HybridParallelGradScaler
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import sys
from paddle.optimizer import Optimizer
from ...base.topology import ParallelMode
from paddle.fluid.dygraph import base as imperative_base
from paddle.fluid import framework
from paddle.fluid.framework import Variable
import types
from paddle.fluid import core
import paddle
class HybridParallelGradScaler:
def __init__(self, scaler, hcg):
self._scaler = scaler
self._hcg = hcg
self._is_mp = (
self._hcg.get_parallel_mode() == ParallelMode.MODEL_PARALLEL)
def scale(self, var):
return self._scaler.scale(var)
def minimize(self, optimizer, *args, **kwargs):
if not self._enable:
return optimizer.minimize(*args, **kwargs)
# unscale the grad
self._unscale(optimizer)
optimize_ops, params_grads = (None, None)
if self._found_inf:
self._cache_founf_inf = True
else:
optimize_ops, params_grads = optimizer.minimize(*args, **kwargs)
self._cache_founf_inf = False
if self._use_dynamic_loss_scaling:
self._update()
return optimize_ops, params_grads
@imperative_base.no_grad
def _unscale(self, optimizer):
if not self._enable:
return
param_grads = [
param._grad_ivar() for param in optimizer._parameter_list
if param._grad_ivar() is not None
]
core.ops.check_finite_and_unscale(param_grads, self._scale, param_grads,
self._found_inf)
# allreduce_max found_inf in check_group
if self._is_mp:
self._found_inf = paddle.cast(self._found_inf, dtype="int64")
paddle.distributed.all_reduce(
self._found_inf,
op=paddle.distributed.ReduceOp.MAX,
group=self._hcg.get_check_parallel_group())
self._found_inf = paddle.cast(self._found_inf, dtype="bool")
def __getattr__(self, item):
return getattr(self._scaler, item)
...@@ -12,15 +12,77 @@ ...@@ -12,15 +12,77 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import print_function
import sys
from paddle.optimizer import Optimizer from paddle.optimizer import Optimizer
from paddle.fluid.clip import ClipGradByGlobalNorm
from ...utils.hybrid_parallel_util import fused_allreduce_gradients from ...utils.hybrid_parallel_util import fused_allreduce_gradients
from ...base.topology import ParallelMode from ...base.topology import ParallelMode
from paddle.fluid.dygraph import base as imperative_base from paddle.fluid.dygraph import base as imperative_base
from paddle.fluid import framework from paddle.fluid import framework
from paddle.fluid.framework import Variable from paddle.fluid.framework import Variable
from ...utils.log_util import logger
class HybridParallelClipGrad:
def __init__(self, clip, hcg):
self._clip = clip
self._hcg = hcg
@imperative_base.no_grad
def _dygraph_clip(self, params_grads):
params_and_grads = []
sum_square_list = []
for p, g in params_grads:
if g is None:
continue
if getattr(p, 'need_clip', True) is False:
continue
merge_grad = g
if g.type == core.VarDesc.VarType.SELECTED_ROWS:
merge_grad = layers.merge_selected_rows(g)
merge_grad = layers.get_tensor_from_selected_rows(merge_grad)
square = layers.square(merge_grad)
sum_square = layers.reduce_sum(square)
sum_square_list.append(sum_square)
# all parameters have been filterd out
if len(sum_square_list) == 0:
return params_grads
global_norm_var = layers.concat(sum_square_list)
global_norm_var = layers.reduce_sum(global_norm_var)
# add all reduce to get global norm in world size
paddle.distributed.all_reduce(global_norm_var,
self._hcg.get_check_parallel_group())
global_norm_var = layers.sqrt(global_norm_var)
max_global_norm = layers.fill_constant(
shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm)
clip_var = layers.elementwise_div(
x=max_global_norm,
y=layers.elementwise_max(
x=global_norm_var, y=max_global_norm))
for p, g in params_grads:
if g is None:
continue
if getattr(p, 'need_clip', True) is False:
params_and_grads.append((p, g))
continue
new_grad = layers.elementwise_mul(x=g, y=clip_var)
params_and_grads.append((p, new_grad))
return params_and_grads
def __getattr__(self, item):
return getattr(self._clip, item)
def __call__(self, params_grads):
return self._clip(params_grads)
class HybridParallelOptimizer: class HybridParallelOptimizer:
# adapter wrapper for optimizer
def __init__(self, optimizer, hcg, strategy): def __init__(self, optimizer, hcg, strategy):
self._inner_opt = optimizer self._inner_opt = optimizer
self._strategy = strategy self._strategy = strategy
...@@ -29,6 +91,13 @@ class HybridParallelOptimizer: ...@@ -29,6 +91,13 @@ class HybridParallelOptimizer:
self._hcg.get_parallel_mode() == ParallelMode.MODEL_PARALLEL) self._hcg.get_parallel_mode() == ParallelMode.MODEL_PARALLEL)
self._need_dp = (self._hcg.get_data_parallel_world_size() > 1) self._need_dp = (self._hcg.get_data_parallel_world_size() > 1)
if isinstance(self._inner_opt._grad_clip,
ClipGradByGlobalNorm) and self._is_mp:
logger.warning("using ClipGradByGlobalNorm in ModelParallel, the origin " \
"optmizer'grad clip will be changed.")
self._inner_opt._grad_clip = HybridParallelClipGrad(
self._inner_opt._grad_clip, hcg)
@imperative_base.no_grad @imperative_base.no_grad
@framework.dygraph_only @framework.dygraph_only
def step(self): def step(self):
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
from paddle.fluid.dygraph.layers import Layer from paddle.fluid.dygraph.layers import Layer
import logging
class MetaParallelBase(Layer): class MetaParallelBase(Layer):
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from paddle.fluid.dygraph.layers import Layer from paddle.fluid.dygraph.layers import Layer
from .meta_parallel_base import MetaParallelBase from .meta_parallel_base import MetaParallelBase
from ..utils.hybrid_parallel_util import * from ..utils.hybrid_parallel_util import *
from ..utils.log_util import logger
class ModelParallel(MetaParallelBase): class ModelParallel(MetaParallelBase):
...@@ -22,8 +23,14 @@ class ModelParallel(MetaParallelBase): ...@@ -22,8 +23,14 @@ class ModelParallel(MetaParallelBase):
super(ModelParallel, self).__init__(layers, hcg, **kwargs) super(ModelParallel, self).__init__(layers, hcg, **kwargs)
def _prepare_for_model(self): def _prepare_for_model(self):
logger.info("start broadcast mp parameters")
broadcast_mp_parameters(self._layers, self._hcg) broadcast_mp_parameters(self._layers, self._hcg)
logger.info("start broadcast mp parameters")
broadcast_dp_parameters(self._layers, self._hcg) broadcast_dp_parameters(self._layers, self._hcg)
logger.info("mp's parameters is ready")
def _pre_forward(self, *inputs, **kwargs): def _pre_forward(self, *inputs, **kwargs):
logger.debug("mp start broadcast input data")
return broadcast_input_data(self._hcg, *inputs, **kwargs) return broadcast_input_data(self._hcg, *inputs, **kwargs)
...@@ -19,8 +19,9 @@ import warnings ...@@ -19,8 +19,9 @@ import warnings
from paddle import framework from paddle import framework
import paddle import paddle
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.dygraph.parallel import _split_tensors, sync_params_buffers, construct_groups from paddle.fluid.dygraph.parallel import _split_tensors, sync_params_buffers, build_groups
from collections import OrderedDict from collections import OrderedDict
from .log_util import logger
def _apply_collective_grads(parameters, comm_group): def _apply_collective_grads(parameters, comm_group):
...@@ -37,7 +38,7 @@ def _apply_collective_grads(parameters, comm_group): ...@@ -37,7 +38,7 @@ def _apply_collective_grads(parameters, comm_group):
assert g_var not in grad_var_set assert g_var not in grad_var_set
grad_var_set.add(g_var) grad_var_set.add(g_var)
coalesced_grads_and_vars = construct_groups(grad_vars, 128 * 1024 * 1024) coalesced_grads_and_vars = build_groups(grad_vars, 128 * 1024 * 1024)
for coalesced_grad, _, _ in coalesced_grads_and_vars: for coalesced_grad, _, _ in coalesced_grads_and_vars:
# need to div nranks # need to div nranks
...@@ -60,7 +61,7 @@ def broadcast_input_data(hcg, *inputs, **kwargs): ...@@ -60,7 +61,7 @@ def broadcast_input_data(hcg, *inputs, **kwargs):
group=model_parallel_group, group=model_parallel_group,
use_calc_stream=True) use_calc_stream=True)
else: else:
print("it doesn't support data type {}".format(type(input_))) logger.error("it doesn't support data type {}".format(type(input_)))
for k, v in kwargs.items(): for k, v in kwargs.items():
if isinstance(v, core.VarBase): if isinstance(v, core.VarBase):
...@@ -72,7 +73,7 @@ def broadcast_input_data(hcg, *inputs, **kwargs): ...@@ -72,7 +73,7 @@ def broadcast_input_data(hcg, *inputs, **kwargs):
use_calc_stream=True) use_calc_stream=True)
kwargs[k] = v kwargs[k] = v
else: else:
print("it doesn't support data type {}".format(type(v))) logger.error("it doesn't support data type {}".format(type(v)))
return inputs, kwargs return inputs, kwargs
...@@ -92,5 +93,6 @@ def broadcast_dp_parameters(model, hcg): ...@@ -92,5 +93,6 @@ def broadcast_dp_parameters(model, hcg):
def fused_allreduce_gradients(parameter_list, hcg): def fused_allreduce_gradients(parameter_list, hcg):
data_parallel_group = hcg.get_data_parallel_group() data_parallel_group = hcg.get_data_parallel_group()
logger.debug("dp start fuse allreduce gradients")
with framework.no_grad(): with framework.no_grad():
_apply_collective_grads(parameter_list, data_parallel_group) _apply_collective_grads(parameter_list, data_parallel_group)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import sys
class LoggerFactory:
@staticmethod
def build_logger(name=None, level=logging.INFO):
assert name is not None, "name for logger should not be None"
formatter = logging.Formatter(
"%(asctime)s-%(levelname)s: "
"[%(filename)s:%(lineno)d:%(funcName)s] %(message)s")
_logger = logging.getLogger(name)
_logger.setLevel(level)
_logger.propagate = False
handler = logging.StreamHandler(stream=sys.stderr)
handler.setFormatter(formatter)
handler.setLevel(level)
_logger.addHandler(handler)
return _logger
logger = LoggerFactory.build_logger(name="HybridParallel", level=logging.INFO)
...@@ -323,7 +323,7 @@ def scale_loss(loss): ...@@ -323,7 +323,7 @@ def scale_loss(loss):
@imperative_base.no_grad @imperative_base.no_grad
@framework.dygraph_only @framework.dygraph_only
def construct_groups(vars, group_size): def build_groups(vars, group_size):
group_idx = 0 group_idx = 0
memory_counter = 0 memory_counter = 0
var_groups = OrderedDict() var_groups = OrderedDict()
...@@ -334,7 +334,7 @@ def construct_groups(vars, group_size): ...@@ -334,7 +334,7 @@ def construct_groups(vars, group_size):
if memory_counter < group_size and dtype == var.dtype: if memory_counter < group_size and dtype == var.dtype:
memory_counter += bytes memory_counter += bytes
else: else:
memory_counter = 0 memory_counter = bytes
dtype = var.dtype dtype = var.dtype
group_idx += 1 group_idx += 1
var_groups.setdefault(group_idx, []).append(var) var_groups.setdefault(group_idx, []).append(var)
...@@ -361,7 +361,7 @@ def sync_params_buffers(model, ...@@ -361,7 +361,7 @@ def sync_params_buffers(model,
return return
# group size is 128M # group size is 128M
coalesced_vars = construct_groups(model_vars, 128 * 1024 * 1024) coalesced_vars = build_groups(model_vars, 128 * 1024 * 1024)
for coalesced_var, _, _ in coalesced_vars: for coalesced_var, _, _ in coalesced_vars:
paddle.distributed.broadcast( paddle.distributed.broadcast(
......
...@@ -852,7 +852,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL) ...@@ -852,7 +852,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL)
set_tests_properties(test_parallel_dygraph_dataparallel PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_dataparallel PROPERTIES TIMEOUT 120)
set_tests_properties(test_parallel_dygraph_unused_variables PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_unused_variables PROPERTIES TIMEOUT 120)
set_tests_properties(test_parallel_dygraph_control_flow PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_control_flow PROPERTIES TIMEOUT 120)
set_tests_properties(test_parallel_dygraph_hybrid_parallel PROPERTIES TIMEOUT 120 LABELS "RUN_TYPE=DIST") set_tests_properties(test_parallel_dygraph_hybrid_parallel PROPERTIES TIMEOUT 200 LABELS "RUN_TYPE=DIST")
if(${NCCL_VERSION} VERSION_GREATER_EQUAL 2212) if(${NCCL_VERSION} VERSION_GREATER_EQUAL 2212)
set_tests_properties(test_parallel_dygraph_sparse_embedding PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_sparse_embedding PROPERTIES TIMEOUT 120)
set_tests_properties(test_parallel_dygraph_transformer PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_transformer PROPERTIES TIMEOUT 120)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import paddle
import numpy as np
from hybrid_parallel_mp_model import TestDistMPTraning
import paddle.distributed.fleet as fleet
import unittest
class TestMPClipGrad(TestDistMPTraning):
def build_optimizer(self, model):
grad_clip = paddle.nn.ClipGradByGlobalNorm(2.0)
scheduler = paddle.optimizer.lr.ExponentialDecay(
learning_rate=0.001, gamma=0.999, verbose=True)
optimizer = paddle.optimizer.SGD(scheduler,
grad_clip=grad_clip,
parameters=model.parameters())
return optimizer
def train_batch(self, batch, model, optimizer, is_mp):
scaler = paddle.amp.GradScaler(init_loss_scaling=5160)
if is_mp:
scaler = fleet.distributed_scaler(scaler)
with paddle.amp.auto_cast():
output = model(batch)
loss = output.mean()
scaled = scaler.scale(loss) # scale the loss
scaled.backward() # do backward
scaler.minimize(optimizer, scaled) # update parameters
optimizer.clear_grad()
return scaled
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import paddle
import numpy as np
from hybrid_parallel_mp_model import TestDistMPTraning
import unittest
import logging
#log = logging.getLogger("HybridParallel")
#log.setLevel(logging.WARNING)
class TestMPClipGrad(TestDistMPTraning):
def build_optimizer(self, model):
grad_clip = paddle.nn.ClipGradByGlobalNorm(2.0)
scheduler = paddle.optimizer.lr.ExponentialDecay(
learning_rate=0.001, gamma=0.999, verbose=True)
optimizer = paddle.optimizer.SGD(scheduler,
grad_clip=grad_clip,
parameters=model.parameters())
return optimizer
if __name__ == "__main__":
unittest.main()
...@@ -173,9 +173,9 @@ class TestDistTraning(unittest.TestCase): ...@@ -173,9 +173,9 @@ class TestDistTraning(unittest.TestCase):
self.word_size = self.hcg.get_model_parallel_world_size() self.word_size = self.hcg.get_model_parallel_world_size()
self.rank_id = self.hcg.get_model_parallel_rank() self.rank_id = self.hcg.get_model_parallel_rank()
input_size_per_card = 17 input_size_per_card = 11
input_size = input_size_per_card * self.model_parallel_size input_size = input_size_per_card * self.model_parallel_size
output_size_per_card = 13 output_size_per_card = 10
output_size = output_size_per_card * self.model_parallel_size output_size = output_size_per_card * self.model_parallel_size
batch_size = 4 batch_size = 4
......
...@@ -21,7 +21,6 @@ import random ...@@ -21,7 +21,6 @@ import random
import paddle.distributed as dist import paddle.distributed as dist
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.distributed.fleet as fleet import paddle.distributed.fleet as fleet
import paddle.fluid.generator as generator
from paddle.io import DataLoader, Dataset from paddle.io import DataLoader, Dataset
import unittest import unittest
...@@ -143,7 +142,7 @@ class TrainDataset(Dataset): ...@@ -143,7 +142,7 @@ class TrainDataset(Dataset):
return np_input_data return np_input_data
class TestDistTraning(unittest.TestCase): class TestDistMPTraning(unittest.TestCase):
def setUp(self): def setUp(self):
strategy = fleet.DistributedStrategy() strategy = fleet.DistributedStrategy()
self.model_parallel_size = 2 self.model_parallel_size = 2
...@@ -155,7 +154,20 @@ class TestDistTraning(unittest.TestCase): ...@@ -155,7 +154,20 @@ class TestDistTraning(unittest.TestCase):
} }
fleet.init(is_collective=True, strategy=strategy) fleet.init(is_collective=True, strategy=strategy)
def test_mp_model(self): def train_batch(self, batch, model, optimizer, is_mp):
output = model(batch)
loss = output.mean()
loss.backward() # do backward
optimizer.step() # update parameters
optimizer.clear_grad()
return loss
def build_optimizer(self, model):
optimizer = paddle.optimizer.SGD(learning_rate=0.001,
parameters=model.parameters())
return optimizer
def build_model_optimizer(self):
hcg = fleet.get_hybrid_communicate_group() hcg = fleet.get_hybrid_communicate_group()
word_size = hcg.get_model_parallel_world_size() word_size = hcg.get_model_parallel_world_size()
mp_id = hcg.get_model_parallel_rank() mp_id = hcg.get_model_parallel_rank()
...@@ -182,31 +194,29 @@ class TestDistTraning(unittest.TestCase): ...@@ -182,31 +194,29 @@ class TestDistTraning(unittest.TestCase):
model_a = SimpleMPNet(vocab_size, hidden_size, inner_size, output_size, model_a = SimpleMPNet(vocab_size, hidden_size, inner_size, output_size,
np_fc1, np_fc2, mp_id) np_fc1, np_fc2, mp_id)
optimizer_a = paddle.optimizer.SGD(learning_rate=0.001, optimizer_a = self.build_optimizer(model_a)
parameters=model_a.parameters())
model_a = fleet.distributed_model(model_a) model_a = fleet.distributed_model(model_a)
optimizer_a = fleet.distributed_optimizer(optimizer_a) optimizer_a = fleet.distributed_optimizer(optimizer_a)
model_b = SimpleDPNet(vocab_size, hidden_size, inner_size, output_size, model_b = SimpleDPNet(vocab_size, hidden_size, inner_size, output_size,
np_fc1, np_fc2) np_fc1, np_fc2)
optimizer_b = paddle.optimizer.SGD(learning_rate=0.001, optimizer_b = self.build_optimizer(model_b)
parameters=model_b.parameters())
return model_a, optimizer_a, model_b, optimizer_b, train_data_loader
def test_mp_model(self):
model_a, optimizer_a, model_b, optimizer_b, train_data_loader = self.build_model_optimizer(
)
for step, batch in enumerate(train_data_loader): for step, batch in enumerate(train_data_loader):
if step > 5: if step > 5:
return return
output_a = model_a(batch)
loss_a = output_a.mean() loss_a = self.train_batch(batch, model_a, optimizer_a, True)
loss_a.backward() loss_b = self.train_batch(batch, model_b, optimizer_b, False)
optimizer_a.step()
optimizer_a.clear_grad() np.testing.assert_allclose(
loss_a.numpy(), loss_b.numpy(), rtol=1e-5)
output_b = model_b(batch)
loss_b = output_b.mean()
loss_b.backward()
optimizer_b.step()
optimizer_b.clear_grad()
np.testing.assert_allclose(loss_a.numpy(), loss_b.numpy())
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -30,6 +30,12 @@ class TestHybridParallel(TestMultipleGpus): ...@@ -30,6 +30,12 @@ class TestHybridParallel(TestMultipleGpus):
def test_hybrid_parallel_mp_model(self): def test_hybrid_parallel_mp_model(self):
self.run_mnist_2gpu('hybrid_parallel_mp_model.py') self.run_mnist_2gpu('hybrid_parallel_mp_model.py')
def test_hybrid_parallel_mp_amp(self):
self.run_mnist_2gpu('hybrid_parallel_mp_amp.py')
def test_hybrid_parallel_mp_clip_grad(self):
self.run_mnist_2gpu('hybrid_parallel_mp_clip_grad.py')
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册