From 11e62d681a882a7fc4abb7ac515dc12a24ad6b81 Mon Sep 17 00:00:00 2001 From: zhaoyingli <86812880+zhaoyinglia@users.noreply.github.com> Date: Wed, 31 Aug 2022 16:36:03 +0800 Subject: [PATCH] [AutoParallel] add grad_clip pass (#45513) * add grad_clip pass * add unittest * add notes * update func * add dist_attr for new op --- .../distributed/auto_parallel/completion.py | 30 +- .../auto_parallel/parallelizer_v2.py | 14 +- .../distributed/auto_parallel/reshard.py | 4 +- .../paddle/distributed/auto_parallel/utils.py | 2 +- python/paddle/distributed/passes/__init__.py | 1 + .../passes/auto_parallel_grad_clip.py | 344 ++++++++++++++++++ .../passes/auto_parallel_sharding.py | 5 +- .../unittests/auto_parallel/CMakeLists.txt | 3 + .../auto_parallel/clip_grad_by_global_norm.py | 140 +++++++ .../unittests/auto_parallel/get_gpt_model.py | 102 ++++++ .../unittests/auto_parallel/test_grad_clip.py | 50 +++ .../auto_parallel/test_lr_grad_clip.py | 11 +- 12 files changed, 672 insertions(+), 34 deletions(-) create mode 100644 python/paddle/distributed/passes/auto_parallel_grad_clip.py create mode 100644 python/paddle/fluid/tests/unittests/auto_parallel/clip_grad_by_global_norm.py create mode 100644 python/paddle/fluid/tests/unittests/auto_parallel/get_gpt_model.py create mode 100644 python/paddle/fluid/tests/unittests/auto_parallel/test_grad_clip.py diff --git a/python/paddle/distributed/auto_parallel/completion.py b/python/paddle/distributed/auto_parallel/completion.py index 341f4baf57..1775a823c5 100644 --- a/python/paddle/distributed/auto_parallel/completion.py +++ b/python/paddle/distributed/auto_parallel/completion.py @@ -19,7 +19,7 @@ import time from paddle.fluid import core from paddle.fluid import framework -from .utils import print_program_with_dist_attr, _is_gradient_clip_op +from .utils import print_program_with_dist_attr, is_gradient_clip_op from .operators import find_compatible_distributed_operator_impls from .dist_context import get_default_distributed_context, _node_id from .dist_tensor import DistributedTensor @@ -1325,11 +1325,7 @@ class Completer: # TODO to add attribute for moment var op = ops[idx] if int(op.attr('op_role')) == int(OpRole.Optimize): - # TODO: - # 1. move `generate_optimizer` before `partitioner` - # 2. implement grad_clip completion by `dist_op` - # 3. allreduce dist_gloabl_norm (mp-group) and no_dist_global_norm (pp-group, sharding-group) - if _is_gradient_clip_op(op): + if is_gradient_clip_op(op): if op.type in [ "sum", "sqrt", "fill_constant", "elementwise_max", "elementwise_div" @@ -1353,7 +1349,6 @@ class Completer: out_var, out_dist_attr) op_dist_attr.set_output_dist_attr( out_name, out_dist_attr) - remove_no_need_in_op(op, self._dist_context) else: in_var = vars[op.input("X")[0]] in_dist_attr = self._dist_context.get_tensor_dist_attr_for_program( @@ -1362,8 +1357,8 @@ class Completer: ref_process_mesh = in_dist_attr.process_mesh ref_dims_mapping = in_dist_attr.dims_mapping - if op.type == "cast" and ops[ - idx + 1].type == "elementwise_mul": + if op.type == "cast" and \ + ops[idx + 1].type == "elementwise_mul": ref_var = vars[ops[idx + 1].input("X")[0]] ref_dist_attr = self._dist_context.get_tensor_dist_attr_for_program( ref_var) @@ -1536,20 +1531,3 @@ class Completer: break else: dist_op.dist_attr = backup_op_dist_attr - - -def remove_no_need_in_op(op, dist_context): - if op.type == "fill_constant": - return - - filter_vars = [] - main_block = op.block - rank_id = dist_context.dist_op_context.rank_id - for varname in op.input("X"): - if rank_id in dist_context.get_tensor_dist_attr_for_program( - main_block.var(varname)).process_mesh.processes: - filter_vars.append(varname) - - if not filter_vars: - return - op.desc.set_input('X', filter_vars) diff --git a/python/paddle/distributed/auto_parallel/parallelizer_v2.py b/python/paddle/distributed/auto_parallel/parallelizer_v2.py index 93c684eecc..51eede5763 100644 --- a/python/paddle/distributed/auto_parallel/parallelizer_v2.py +++ b/python/paddle/distributed/auto_parallel/parallelizer_v2.py @@ -235,7 +235,19 @@ class Parallelizer: auto_parallel_sharding_pass.apply([main_program], [startup_program], self._pass_context) - # recompute is then train-only optimization + # GradClip is train-only optimization + + if self._mode == "train": + config = copy.deepcopy(self._strategy.sharding_configs) + config["dist_context"] = self._dist_context + config["params_grads"] = params_grads + config["rank_id"] = rank + auto_parallel_clip_pass = new_pass("auto_parallel_grad_clip", + config) + auto_parallel_clip_pass.apply([main_program], [startup_program], + self._pass_context) + + # gradient_merge is then train-only optimization if self._mode == "train" and self._strategy.gradient_merge: config = copy.deepcopy(self._strategy.gradient_merge_configs) config["dist_context"] = self._dist_context diff --git a/python/paddle/distributed/auto_parallel/reshard.py b/python/paddle/distributed/auto_parallel/reshard.py index 4f1f02f815..6da39b063e 100644 --- a/python/paddle/distributed/auto_parallel/reshard.py +++ b/python/paddle/distributed/auto_parallel/reshard.py @@ -30,7 +30,7 @@ from .cost import build_comm_desc, CommContext from .cost import AllgatherOpCost, SendOpCost from .cost import SliceOpCost, SplitOpCost, ConcatOpCost from .cluster import Cluster -from .utils import print_program_with_dist_attr, _is_gradient_clip_op +from .utils import print_program_with_dist_attr, is_gradient_clip_op # NOTE: If op in _g_special_ops or _g_gradient_clip_ops, it will not be resharded. _g_special_ops = ['check_finite_and_unscale', 'update_loss_scaling'] @@ -1088,7 +1088,7 @@ class Resharder: global _g_special_ops, _g_gradient_clip_ops if op.type in _g_special_ops: return True - if _is_gradient_clip_op(op) and op.type in _g_gradient_clip_ops: + if is_gradient_clip_op(op) and op.type in _g_gradient_clip_ops: return True return False diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index c5897d83f1..d276df6ddb 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -1131,7 +1131,7 @@ def is_loss_grad_op(op): return op_role & int(OpRole.Backward) and op_role & int(OpRole.Loss) -def _is_gradient_clip_op(op): +def is_gradient_clip_op(op): return op.desc.has_attr("op_namescope") \ and op.desc.attr("op_namescope").startswith("/gradient_clip") diff --git a/python/paddle/distributed/passes/__init__.py b/python/paddle/distributed/passes/__init__.py index 03dd31fb9b..5f721a1df5 100644 --- a/python/paddle/distributed/passes/__init__.py +++ b/python/paddle/distributed/passes/__init__.py @@ -21,6 +21,7 @@ from .auto_parallel_fp16 import * from .auto_parallel_recompute import * from .auto_parallel_quantization import * from .auto_parallel_data_parallel_optimization import * +from .auto_parallel_grad_clip import * from .cpp_pass import * import os from .ps_trainer_pass import * diff --git a/python/paddle/distributed/passes/auto_parallel_grad_clip.py b/python/paddle/distributed/passes/auto_parallel_grad_clip.py new file mode 100644 index 0000000000..6fba98ce75 --- /dev/null +++ b/python/paddle/distributed/passes/auto_parallel_grad_clip.py @@ -0,0 +1,344 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from functools import reduce + +import paddle + +from paddle.fluid import core +from .pass_base import PassBase, register_pass +from ..auto_parallel.reshard import Resharder +from ..auto_parallel.process_group import get_world_process_group +from ..auto_parallel.utils import is_gradient_clip_op, is_optimize_op, OP_ROLE_KEY, OpRole, _get_comm_group +from ..auto_parallel.dist_attribute import TensorDistributedAttribute, OperatorDistributedAttribute + + +def _get_params_grads(block): + params_grads = [] + for op in reversed(block.ops): + if not is_optimize_op(op): + break + if "Param" in op.input_names and "Grad" in op.input_names: + param_name = op.input("Param")[0] + grad_name = op.input("Grad")[0] + param = block.var(param_name) + grad = block.var(grad_name) + params_grads.append((param, grad)) + return params_grads + + +def _get_dpmp_topology(origin_topology, sharding_group): + """ + Get dpmp topology from origin_topology + + Example: + the parallel strategy: dp4-mp2-sharding2 + the complete process_mesh: + topology: [4, 2] + processes: [0, 1, 2, 3, 4, 5, 6, 7] + the dpmp topology: [2, 2] + the sharding axis: 1 + """ + sharding_axis = 1 + dp_sharding_topology = [ + origin_topology[0] // sharding_group.nranks, sharding_group.nranks + ] + if dp_sharding_topology[0] == 1: + sharding_axis = 0 + dp_sharding_topology = dp_sharding_topology[1:] + + product_dp_sharding = reduce(lambda x, y: x * y, dp_sharding_topology) + product_topology = reduce(lambda x, y: x * y, origin_topology) + + if product_topology == product_dp_sharding: + dpmp_topology = dp_sharding_topology + else: + assert product_topology % product_dp_sharding == 0 + mp_degree = product_topology // product_dp_sharding + dpmp_topology = dp_sharding_topology + [mp_degree] + + return dpmp_topology, sharding_axis + + +def _get_dpmp_process_mesh(rank_id, topology, processes, sharding_group): + """ + Get dpmp process_mesh from the complete process_mesh which apply sharding. + + Example: + the parallel strategy: dp4-mp2-sharding2 + the complete process_mesh: + topology: [4, 2] + processes: [0, 1, 2, 3, 4, 5, 6, 7] + the dpmp process_mesh is: + 1) topology: [2, 2], processes: [0, 1, 4, 5] + 2) topology: [2, 2], processes: [2, 3, 6, 7] + """ + if sharding_group is None: + return topology, processes + + # get dpmp_topology + dpmp_topology, sharding_axis = _get_dpmp_topology(topology, sharding_group) + + # get all sharding_groups of ranks + sharding_groups = [] + for rank in processes: + group = _get_comm_group(processes, dpmp_topology, sharding_axis, rank) + if group not in sharding_groups: + sharding_groups.append(group) + + # get dpmp_processes + sharding_groups = np.array(sharding_groups) + dpmp_processes_in_sharding = None + for i in range(sharding_groups.shape[-1]): + if rank_id in sharding_groups[:, i]: + dpmp_processes_in_sharding = sharding_groups[:, i] + + assert dpmp_processes_in_sharding is not None + return dpmp_topology, list(dpmp_processes_in_sharding) + + +def _is_about_global_norm(rank_id, tensor_shape, topology, processes, + dims_mapping, sharding_group): + # get current process_mesh where the parameter exist. + dpmp_topology, dpmp_processes = _get_dpmp_process_mesh( + rank_id, topology, processes, sharding_group) + + complete_shape = Resharder.compute_complete_shape(tensor_shape, + dpmp_topology, + dims_mapping) + + complete_partitions = [] + complete_param_ranks = [] + for process in dpmp_processes: + partition_index = Resharder.compute_partition_index( + process, complete_shape, dims_mapping, dpmp_topology, + dpmp_processes) + if partition_index not in complete_partitions: + complete_partitions.append(partition_index) + complete_param_ranks.append(process) + + return rank_id in complete_param_ranks + + +class ClipHelper(object): + + def __init__(self, params_grads, rank_id, block, dist_context): + params, _ = zip(*params_grads) + self.params = list(params) + self.params_name = [p.name for p in self.params] + self.rank_id = rank_id + self.block = block + self.dist_context = dist_context + self.sharding_group = None + self.world_ranks = get_world_process_group().ranks + if hasattr(dist_context, '_sharding_group'): + self.sharding_group = dist_context._sharding_group + + def _is_calcuate_norm(self, name): + if not self._is_local_param(name): + return False, [] + + param = self.params[self.params_name.index(name)] + dist_attr = self._get_dist_attr(name) + topology = dist_attr.process_mesh.topology + processes = dist_attr.process_mesh.processes + dims_mapping = dist_attr.dims_mapping + return _is_about_global_norm(self.rank_id, param.shape, topology, + processes, dims_mapping, + self.sharding_group) + + def _get_dist_attr(self, name): + var = self.block.vars[name] + return self.dist_context.get_tensor_dist_attr_for_program(var) + + def _is_local_param(self, name): + if name not in self.params_name: + return False + return True + + def _is_local_var(self, name): + dist_attr = self._get_dist_attr(name) + assert dist_attr is not None + return self.rank_id in dist_attr.process_mesh.processes + + def _init_dist_attr(self, op): + op_dist_attr = OperatorDistributedAttribute() + op_dist_attr.process_mesh = self.world_ranks + for in_name in op.input_arg_names: + in_var = self.block.vars[in_name] + in_dist_attr = TensorDistributedAttribute() + in_dist_attr.process_mesh = self.world_ranks + in_dist_attr.dims_mapping = [-1] + self.dist_context.set_tensor_dist_attr_for_program( + in_var, in_dist_attr) + op_dist_attr.set_input_dist_attr(in_name, in_dist_attr) + for out_name in op.output_arg_names: + out_var = self.block.vars[out_name] + out_dist_attr = TensorDistributedAttribute() + out_dist_attr.process_mesh = self.world_ranks + out_dist_attr.dims_mapping = [-1] + self.dist_context.set_tensor_dist_attr_for_program( + out_var, out_dist_attr) + op_dist_attr.set_output_dist_attr(out_name, out_dist_attr) + self.dist_context.set_op_dist_attr_for_program(op, op_dist_attr) + + +@register_pass("auto_parallel_grad_clip") +class ClipGradByGloblNormPass(PassBase): + """ + 1. Remove norm-compute op and grad-scale op when the grad is not in current rank + or is independent of the calculation of norm. + 2. Each rank computes its own norm value, then gets global_norm by allreduce_sum only once. + """ + + def __init__(self): + super(ClipGradByGloblNormPass, self).__init__() + self.set_attr("rank_id", None) + self.set_attr("dist_context", None) + + def _check_self(self): + if self.get_attr("dist_context") is None: + return False + dist_context = self.get_attr("dist_context") + if dist_context._lr_optimizer._grad_clip is None: + return False + return True + + def _check_conflict(self, other_pass): + return True + + def _apply_single_impl(self, main_program, startup_program, context): + dist_context = self.get_attr("dist_context", None) + rank_id = self.get_attr("rank_id", None) + block = main_program.global_block() + dist_params_grads = _get_params_grads(block) + + self.clip_helper = ClipHelper(dist_params_grads, rank_id, block, + dist_context) + self._remove_no_need_ops_vars(block) + + def _remove_no_need_ops_vars(self, block): + + removed_op_out_type = [ + 'clip_by_norm', 'squared_l2_norm', 'square', 'reduce_sum' + ] + + removed_op_idx = set() + removed_tmp_var = set() + for idx, op in enumerate(block.ops): + if not is_gradient_clip_op(op): + continue + + if op.type in removed_op_out_type: + input_name = op.input("X")[0] + if input_name.find("@GRAD") != -1: + #'clip_by_norm', 'squared_l2_norm', 'square' + param_name = input_name[:input_name.find("@GRAD")] + is_local = self.clip_helper._is_local_param(param_name) + is_calculate = self.clip_helper._is_calcuate_norm( + param_name) + if not is_local or (not is_calculate + and op.type != 'clip_by_norm'): + removed_op_idx.add(idx) + removed_tmp_var.update(set(op.output_arg_names)) + else: + # 'reduce_sum' + if idx - 1 in removed_op_idx: + removed_op_idx.add(idx) + removed_tmp_var.update(set(op.output_arg_names)) + + elif op.type == 'elementwise_mul': + input_name = op.input("X")[0] + if input_name.find("@GRAD") != -1: + param_name = input_name[:input_name.find("@GRAD")] + is_local = self.clip_helper._is_local_param(param_name) + if not is_local: + removed_op_idx.add(idx) + if block.ops[idx - 1].type == 'cast': + removed_op_idx.add(idx - 1) + removed_tmp_var.update( + set(block.ops[idx - 1].output_arg_names)) + + elif op.type == 'sum': + reserved_vars = [] + for input_name in op.input_arg_names: + if input_name not in removed_tmp_var and \ + self.clip_helper._is_local_var(input_name): + reserved_vars.append(input_name) + if not reserved_vars: + removed_op_idx.add(idx) + removed_tmp_var.update(set(op.output_arg_names)) + if block.ops[idx + 1].type == 'cast': + removed_op_idx.add(idx + 1) + removed_tmp_var.update( + set(block.ops[idx + 1].output_arg_names)) + else: + op.desc.set_input("X", reserved_vars) + + for idx, op in reversed(list(enumerate(block.ops))): + if not is_optimize_op(op): + break + if not is_gradient_clip_op(op): + continue + if idx in removed_op_idx: + block._remove_op(idx, sync=False) + + for idx, op in reversed(list(enumerate(block.ops))): + if not is_optimize_op(op): + break + if not is_gradient_clip_op(op): + continue + if op.type == 'sqrt': + input_name = op.input("X")[0] + input_var = block.vars[input_name] + if paddle.distributed.get_world_size() > 1: + offset = 0 + if input_name in removed_tmp_var: + removed_tmp_var.remove(input_name) + fill_constant_op = block._insert_op( + idx, + type='fill_constant', + inputs={}, + outputs={'Out': [input_var]}, + attrs={ + 'shape': [1], + 'dtype': input_var.dtype, + 'value': 0, + 'force_cpu': False, + OP_ROLE_KEY: OpRole.Optimize + }) + fill_constant_op._set_attr('op_namescope', + "/gradient_clip_pass") + offset += 1 + self.clip_helper._init_dist_attr(fill_constant_op) + + allreduce_op = block._insert_op( + idx + offset, + type='c_allreduce_sum', + inputs={'X': [input_var]}, + outputs={'Out': [input_var]}, + attrs={ + 'ring_id': 0, + 'use_calc_stream': True, + OP_ROLE_KEY: OpRole.Optimize, + }) + allreduce_op._set_attr('op_namescope', + "/gradient_clip_pass") + self.clip_helper._init_dist_attr(allreduce_op) + + for varname in removed_tmp_var: + block._remove_var(varname, sync=False) + + block._sync_with_cpp() diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index 6e07e16e97..e414a235b5 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -29,7 +29,7 @@ OpRole = core.op_proto_and_checker_maker.OpRole OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName() _skip_ops = [ 'create_py_reader', 'create_double_buffer_reader', 'read', 'slice', 'split', - 'assign' + 'assign', "send_v2" ] # update here to support new optimizers _supported_optimizer_type = [ @@ -140,6 +140,7 @@ class ShardingPass(PassBase): else: sharding_group = dp_group + self._dist_context._sharding_group = sharding_group # TODO(JZ-LIANG) when support multiple dp groups in future, should group param and bind them to corresponding dp group params_in_group = [p for p, g in params_grads] assert len(params_in_group) == len( @@ -160,7 +161,7 @@ class ShardingPass(PassBase): """ self._shard_amp_related_op_and_vars(main_block, pass_context) self._shard_weight_decay(main_block) - self._shard_gradient_clip(main_block) + # self._shard_gradient_clip(main_block) self._shard_optimizer_ops_and_states(main_block, startup_block) self._insert_optimizer_broadcasts(main_block, startup_block) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt index 8566186d76..beb1c722dd 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt @@ -37,6 +37,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ${dist_ENVS}) set_tests_properties(test_high_order_grad PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50) + py_test_modules(test_grad_clip MODULES test_grad_clip ENVS ${dist_ENVS}) + set_tests_properties(test_grad_clip PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" + TIMEOUT 50) py_test_modules(test_while_op_completion MODULES test_while_op_completion ENVS ${dist_ENVS}) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/clip_grad_by_global_norm.py b/python/paddle/fluid/tests/unittests/auto_parallel/clip_grad_by_global_norm.py new file mode 100644 index 0000000000..60a915c53c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/clip_grad_by_global_norm.py @@ -0,0 +1,140 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import sys +import random +import numpy as np +import paddle + +import paddle.distributed.fleet as fleet +import paddle.distributed.auto_parallel as auto + +from paddle.distributed.auto_parallel.engine import Engine +from get_gpt_model import generate_model, create_data_holder, FakeDataset + +paddle.enable_static() + + +def apply_pass(use_sharding=False): + strategy = fleet.DistributedStrategy() + strategy.semi_auto = True + if use_sharding: + strategy.sharding = True + strategy.sharding_configs = { + "sharding_degree": 2, + "stage": 2, + } + return strategy + + +def get_parameter_value(program): + from paddle.fluid.framework import Parameter + + def is_parameter(var): + return isinstance(var, Parameter) + + def get_tensor(var): + t = paddle.fluid.global_scope().find_var(var.name).get_tensor() + return np.array(t) + + def get_name(var): + return len(var.name) + + parameters_list = list(filter(is_parameter, program.list_vars())) + parameters_value = [] + for p in sorted(parameters_list, key=get_name): + parameters_value.append(get_tensor(p)) + return parameters_value + + +def reset_prog(): + paddle.fluid.framework.switch_main_program(paddle.static.Program()) + paddle.fluid.framework.switch_startup_program(paddle.static.Program()) + + +class TestGradientClipByGlobalNorm(unittest.TestCase): + + def setUp(self): + self.batch_size = 2 + self.batch_num = 1 + self.clip_norm = 0.2 + self.dataset = FakeDataset(self.batch_size * self.batch_num) + + def init(self, engine): + paddle.seed(2022) + np.random.seed(2022) + random.seed(2022) + engine.mode = "train" + engine._executor.run(engine.startup_program) + + def get_dp2_engine(self): + reset_prog() + + strategy = apply_pass() + clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm) + opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip) + model, loss = generate_model("dp") + inputs_spec, labels_spec = create_data_holder(self.batch_size) + + engine = Engine(model, inputs_spec, labels_spec, strategy=strategy) + engine.prepare(optimizer=opt, loss=loss) + self.init(engine) + return engine + + def get_dp2sharding2_engine(self): + reset_prog() + + strategy = apply_pass(True) + clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm) + opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip) + model, loss = generate_model("dp") + inputs_spec, labels_spec = create_data_holder(self.batch_size) + + engine = Engine(model, inputs_spec, labels_spec, strategy=strategy) + engine.prepare(optimizer=opt, loss=loss) + self.init(engine) + return engine + + def check_result(self, dp_params, sharding_params): + assert len(dp_params) == len(sharding_params) + for dp_p, sharding_p in zip(dp_params, sharding_params): + np.testing.assert_allclose( + dp_p, + sharding_p, + rtol=1e-05, + atol=1e-08, + err_msg= + 'gradient clip by global norm has wrong results!, \nu={}\nv={}\ndiff={}' + .format(dp_p, sharding_p, dp_p - sharding_p)) + + def test_grad_clip(self): + # dp2 training + dp_engine = self.get_dp2_engine() + dp_engine.fit(self.dataset, batch_size=self.batch_size, use_cache=True) + dp_param_values = get_parameter_value(dp_engine.main_program) + + # dp2sharding2 training + sharding_engine = self.get_dp2sharding2_engine() + sharding_engine.fit(self.dataset, + batch_size=self.batch_size, + use_cache=True) + sharding_param_values = get_parameter_value( + sharding_engine.main_program) + + self.check_result(dp_param_values, sharding_param_values) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/get_gpt_model.py b/python/paddle/fluid/tests/unittests/auto_parallel/get_gpt_model.py new file mode 100644 index 0000000000..0e5c6b387f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/get_gpt_model.py @@ -0,0 +1,102 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import numpy as np + +import paddle + +sys.path.append("..") +import auto_parallel_gpt_model as modeling +from auto_parallel_gpt_model import GPTModel, GPTForPretraining, GPTPretrainingCriterion + +sequence_len = 512 +vocab_size = 1000 + + +class FakeDataset: + + def __init__(self, num_samples): + self.num_samples = num_samples + self.sequence_len = sequence_len + self.vocab_size = vocab_size + + def __getitem__(self, idx): + tokens = np.random.randint(self.vocab_size, size=self.sequence_len) + position_ids = np.arange(self.sequence_len) + attention_mask = np.tril(np.ones(self.sequence_len)).reshape( + (1, self.sequence_len, self.sequence_len)).astype(np.float32) + labels = np.random.randint(self.vocab_size, size=self.sequence_len) + loss_mask = np.ones(self.sequence_len).astype(np.float32) + return tokens, position_ids, attention_mask, labels, loss_mask + + def __len__(self): + return self.num_samples + + +def create_data_holder(batch_size): + tokens = paddle.static.InputSpec(name="tokens", + shape=[batch_size, sequence_len], + dtype='int64') + position_ids = paddle.static.InputSpec(name="position_ids", + shape=[batch_size, sequence_len], + dtype='int64') + attention_mask = paddle.static.InputSpec( + name="attention_mask", + shape=[batch_size, 1, sequence_len, sequence_len], + dtype='float32') + labels = paddle.static.InputSpec(name="labels", + shape=[batch_size, sequence_len], + dtype='int64') + loss_mask = paddle.static.InputSpec(name="loss_mask", + shape=[batch_size, sequence_len], + dtype='float32') + return [tokens, position_ids, attention_mask], [labels, loss_mask] + + +def generate_model(strategy): + modeling.init_global() + if strategy == "serial": + modeling._global_parallel_strategy = "serial" + modeling._global_process_mesh = [0] + elif strategy == "mp": + modeling._global_parallel_strategy = "mp" + modeling._global_process_mesh = [0, 1] + elif strategy == "dp": + modeling._global_parallel_strategy = "dp" + modeling._global_process_mesh = [0, 1] + else: + raise ValueError("Only support serial, mp2 and dp2.") + + gpt = GPTModel(vocab_size=1000, + hidden_size=64, + num_hidden_layers=2, + num_attention_heads=8, + intermediate_size=256, + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + max_position_embeddings=1024, + type_vocab_size=1, + initializer_range=0.02, + pad_token_id=0, + eos_token_id=7, + bos_token_id=0, + eol_token_id=3) + model = GPTForPretraining(gpt, + vocab_size=1000, + hidden_size=64, + initializer_range=0.02) + criterion = GPTPretrainingCriterion() + return model, criterion diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_grad_clip.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_grad_clip.py new file mode 100644 index 0000000000..3527589db6 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_grad_clip.py @@ -0,0 +1,50 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest +import os +import sys +import shutil +import subprocess +from paddle.distributed.fleet.launch_utils import run_with_coverage + + +class TestGradientClip(unittest.TestCase): + + def test_dp2(self): + file_dir = os.path.dirname(os.path.abspath(__file__)) + launch_model_path = os.path.join(file_dir, + "clip_grad_by_global_norm.py") + + if os.environ.get("WITH_COVERAGE", "OFF") == "ON": + coverage_args = ["-m", "coverage", "run", "--branch", "-p"] + else: + coverage_args = [] + + tmp_dir = tempfile.TemporaryDirectory() + cmd = [sys.executable, "-u"] + coverage_args + [ + "-m", "paddle.distributed.launch", "--devices", "0,1", "--log_dir", + tmp_dir.name, launch_model_path + ] + + process = subprocess.Popen(cmd) + process.wait() + self.assertEqual(process.returncode, 0) + + tmp_dir.cleanup() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_lr_grad_clip.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_lr_grad_clip.py index ab11886dd1..e7d73921eb 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_lr_grad_clip.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_lr_grad_clip.py @@ -86,7 +86,7 @@ class TestLRScheduler(TestEngineBase): self.engine.fit(self.dataset, batch_size=self.batch_size) -class TestGradClip(TestEngineBase): +class TestGradClipByGlobalNorm(TestEngineBase): def init_optimizer(self): clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0) @@ -96,7 +96,6 @@ class TestGradClip(TestEngineBase): def test_grad_clip(self): clip = self.engine._optimizer._grad_clip - assert isinstance(clip, paddle.nn.ClipGradByGlobalNorm) self.engine.fit(self.dataset, batch_size=self.batch_size) self.check_program() @@ -112,5 +111,13 @@ class TestGradClip(TestEngineBase): assert has_grad_clip is True +class TestGradClipByNorm(TestGradClipByGlobalNorm): + + def init_optimizer(self): + clip = paddle.nn.ClipGradByNorm(clip_norm=1.0) + self.optimizer = paddle.optimizer.SGD(learning_rate=0.00001, + grad_clip=clip) + + if __name__ == "__main__": unittest.main() -- GitLab