未验证 提交 11e62d68 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] add grad_clip pass (#45513)

* add grad_clip pass

* add unittest

* add notes

* update func

* add dist_attr for new op
上级 9cbae54c
...@@ -19,7 +19,7 @@ import time ...@@ -19,7 +19,7 @@ import time
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid import framework from paddle.fluid import framework
from .utils import print_program_with_dist_attr, _is_gradient_clip_op from .utils import print_program_with_dist_attr, is_gradient_clip_op
from .operators import find_compatible_distributed_operator_impls from .operators import find_compatible_distributed_operator_impls
from .dist_context import get_default_distributed_context, _node_id from .dist_context import get_default_distributed_context, _node_id
from .dist_tensor import DistributedTensor from .dist_tensor import DistributedTensor
...@@ -1325,11 +1325,7 @@ class Completer: ...@@ -1325,11 +1325,7 @@ class Completer:
# TODO to add attribute for moment var # TODO to add attribute for moment var
op = ops[idx] op = ops[idx]
if int(op.attr('op_role')) == int(OpRole.Optimize): if int(op.attr('op_role')) == int(OpRole.Optimize):
# TODO: if is_gradient_clip_op(op):
# 1. move `generate_optimizer` before `partitioner`
# 2. implement grad_clip completion by `dist_op`
# 3. allreduce dist_gloabl_norm (mp-group) and no_dist_global_norm (pp-group, sharding-group)
if _is_gradient_clip_op(op):
if op.type in [ if op.type in [
"sum", "sqrt", "fill_constant", "elementwise_max", "sum", "sqrt", "fill_constant", "elementwise_max",
"elementwise_div" "elementwise_div"
...@@ -1353,7 +1349,6 @@ class Completer: ...@@ -1353,7 +1349,6 @@ class Completer:
out_var, out_dist_attr) out_var, out_dist_attr)
op_dist_attr.set_output_dist_attr( op_dist_attr.set_output_dist_attr(
out_name, out_dist_attr) out_name, out_dist_attr)
remove_no_need_in_op(op, self._dist_context)
else: else:
in_var = vars[op.input("X")[0]] in_var = vars[op.input("X")[0]]
in_dist_attr = self._dist_context.get_tensor_dist_attr_for_program( in_dist_attr = self._dist_context.get_tensor_dist_attr_for_program(
...@@ -1362,8 +1357,8 @@ class Completer: ...@@ -1362,8 +1357,8 @@ class Completer:
ref_process_mesh = in_dist_attr.process_mesh ref_process_mesh = in_dist_attr.process_mesh
ref_dims_mapping = in_dist_attr.dims_mapping ref_dims_mapping = in_dist_attr.dims_mapping
if op.type == "cast" and ops[ if op.type == "cast" and \
idx + 1].type == "elementwise_mul": ops[idx + 1].type == "elementwise_mul":
ref_var = vars[ops[idx + 1].input("X")[0]] ref_var = vars[ops[idx + 1].input("X")[0]]
ref_dist_attr = self._dist_context.get_tensor_dist_attr_for_program( ref_dist_attr = self._dist_context.get_tensor_dist_attr_for_program(
ref_var) ref_var)
...@@ -1536,20 +1531,3 @@ class Completer: ...@@ -1536,20 +1531,3 @@ class Completer:
break break
else: else:
dist_op.dist_attr = backup_op_dist_attr dist_op.dist_attr = backup_op_dist_attr
def remove_no_need_in_op(op, dist_context):
if op.type == "fill_constant":
return
filter_vars = []
main_block = op.block
rank_id = dist_context.dist_op_context.rank_id
for varname in op.input("X"):
if rank_id in dist_context.get_tensor_dist_attr_for_program(
main_block.var(varname)).process_mesh.processes:
filter_vars.append(varname)
if not filter_vars:
return
op.desc.set_input('X', filter_vars)
...@@ -235,7 +235,19 @@ class Parallelizer: ...@@ -235,7 +235,19 @@ class Parallelizer:
auto_parallel_sharding_pass.apply([main_program], [startup_program], auto_parallel_sharding_pass.apply([main_program], [startup_program],
self._pass_context) self._pass_context)
# recompute is then train-only optimization # GradClip is train-only optimization
if self._mode == "train":
config = copy.deepcopy(self._strategy.sharding_configs)
config["dist_context"] = self._dist_context
config["params_grads"] = params_grads
config["rank_id"] = rank
auto_parallel_clip_pass = new_pass("auto_parallel_grad_clip",
config)
auto_parallel_clip_pass.apply([main_program], [startup_program],
self._pass_context)
# gradient_merge is then train-only optimization
if self._mode == "train" and self._strategy.gradient_merge: if self._mode == "train" and self._strategy.gradient_merge:
config = copy.deepcopy(self._strategy.gradient_merge_configs) config = copy.deepcopy(self._strategy.gradient_merge_configs)
config["dist_context"] = self._dist_context config["dist_context"] = self._dist_context
......
...@@ -30,7 +30,7 @@ from .cost import build_comm_desc, CommContext ...@@ -30,7 +30,7 @@ from .cost import build_comm_desc, CommContext
from .cost import AllgatherOpCost, SendOpCost from .cost import AllgatherOpCost, SendOpCost
from .cost import SliceOpCost, SplitOpCost, ConcatOpCost from .cost import SliceOpCost, SplitOpCost, ConcatOpCost
from .cluster import Cluster from .cluster import Cluster
from .utils import print_program_with_dist_attr, _is_gradient_clip_op from .utils import print_program_with_dist_attr, is_gradient_clip_op
# NOTE: If op in _g_special_ops or _g_gradient_clip_ops, it will not be resharded. # NOTE: If op in _g_special_ops or _g_gradient_clip_ops, it will not be resharded.
_g_special_ops = ['check_finite_and_unscale', 'update_loss_scaling'] _g_special_ops = ['check_finite_and_unscale', 'update_loss_scaling']
...@@ -1088,7 +1088,7 @@ class Resharder: ...@@ -1088,7 +1088,7 @@ class Resharder:
global _g_special_ops, _g_gradient_clip_ops global _g_special_ops, _g_gradient_clip_ops
if op.type in _g_special_ops: if op.type in _g_special_ops:
return True return True
if _is_gradient_clip_op(op) and op.type in _g_gradient_clip_ops: if is_gradient_clip_op(op) and op.type in _g_gradient_clip_ops:
return True return True
return False return False
......
...@@ -1131,7 +1131,7 @@ def is_loss_grad_op(op): ...@@ -1131,7 +1131,7 @@ def is_loss_grad_op(op):
return op_role & int(OpRole.Backward) and op_role & int(OpRole.Loss) return op_role & int(OpRole.Backward) and op_role & int(OpRole.Loss)
def _is_gradient_clip_op(op): def is_gradient_clip_op(op):
return op.desc.has_attr("op_namescope") \ return op.desc.has_attr("op_namescope") \
and op.desc.attr("op_namescope").startswith("/gradient_clip") and op.desc.attr("op_namescope").startswith("/gradient_clip")
......
...@@ -21,6 +21,7 @@ from .auto_parallel_fp16 import * ...@@ -21,6 +21,7 @@ from .auto_parallel_fp16 import *
from .auto_parallel_recompute import * from .auto_parallel_recompute import *
from .auto_parallel_quantization import * from .auto_parallel_quantization import *
from .auto_parallel_data_parallel_optimization import * from .auto_parallel_data_parallel_optimization import *
from .auto_parallel_grad_clip import *
from .cpp_pass import * from .cpp_pass import *
import os import os
from .ps_trainer_pass import * from .ps_trainer_pass import *
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from functools import reduce
import paddle
from paddle.fluid import core
from .pass_base import PassBase, register_pass
from ..auto_parallel.reshard import Resharder
from ..auto_parallel.process_group import get_world_process_group
from ..auto_parallel.utils import is_gradient_clip_op, is_optimize_op, OP_ROLE_KEY, OpRole, _get_comm_group
from ..auto_parallel.dist_attribute import TensorDistributedAttribute, OperatorDistributedAttribute
def _get_params_grads(block):
params_grads = []
for op in reversed(block.ops):
if not is_optimize_op(op):
break
if "Param" in op.input_names and "Grad" in op.input_names:
param_name = op.input("Param")[0]
grad_name = op.input("Grad")[0]
param = block.var(param_name)
grad = block.var(grad_name)
params_grads.append((param, grad))
return params_grads
def _get_dpmp_topology(origin_topology, sharding_group):
"""
Get dpmp topology from origin_topology
Example:
the parallel strategy: dp4-mp2-sharding2
the complete process_mesh:
topology: [4, 2]
processes: [0, 1, 2, 3, 4, 5, 6, 7]
the dpmp topology: [2, 2]
the sharding axis: 1
"""
sharding_axis = 1
dp_sharding_topology = [
origin_topology[0] // sharding_group.nranks, sharding_group.nranks
]
if dp_sharding_topology[0] == 1:
sharding_axis = 0
dp_sharding_topology = dp_sharding_topology[1:]
product_dp_sharding = reduce(lambda x, y: x * y, dp_sharding_topology)
product_topology = reduce(lambda x, y: x * y, origin_topology)
if product_topology == product_dp_sharding:
dpmp_topology = dp_sharding_topology
else:
assert product_topology % product_dp_sharding == 0
mp_degree = product_topology // product_dp_sharding
dpmp_topology = dp_sharding_topology + [mp_degree]
return dpmp_topology, sharding_axis
def _get_dpmp_process_mesh(rank_id, topology, processes, sharding_group):
"""
Get dpmp process_mesh from the complete process_mesh which apply sharding.
Example:
the parallel strategy: dp4-mp2-sharding2
the complete process_mesh:
topology: [4, 2]
processes: [0, 1, 2, 3, 4, 5, 6, 7]
the dpmp process_mesh is:
1) topology: [2, 2], processes: [0, 1, 4, 5]
2) topology: [2, 2], processes: [2, 3, 6, 7]
"""
if sharding_group is None:
return topology, processes
# get dpmp_topology
dpmp_topology, sharding_axis = _get_dpmp_topology(topology, sharding_group)
# get all sharding_groups of ranks
sharding_groups = []
for rank in processes:
group = _get_comm_group(processes, dpmp_topology, sharding_axis, rank)
if group not in sharding_groups:
sharding_groups.append(group)
# get dpmp_processes
sharding_groups = np.array(sharding_groups)
dpmp_processes_in_sharding = None
for i in range(sharding_groups.shape[-1]):
if rank_id in sharding_groups[:, i]:
dpmp_processes_in_sharding = sharding_groups[:, i]
assert dpmp_processes_in_sharding is not None
return dpmp_topology, list(dpmp_processes_in_sharding)
def _is_about_global_norm(rank_id, tensor_shape, topology, processes,
dims_mapping, sharding_group):
# get current process_mesh where the parameter exist.
dpmp_topology, dpmp_processes = _get_dpmp_process_mesh(
rank_id, topology, processes, sharding_group)
complete_shape = Resharder.compute_complete_shape(tensor_shape,
dpmp_topology,
dims_mapping)
complete_partitions = []
complete_param_ranks = []
for process in dpmp_processes:
partition_index = Resharder.compute_partition_index(
process, complete_shape, dims_mapping, dpmp_topology,
dpmp_processes)
if partition_index not in complete_partitions:
complete_partitions.append(partition_index)
complete_param_ranks.append(process)
return rank_id in complete_param_ranks
class ClipHelper(object):
def __init__(self, params_grads, rank_id, block, dist_context):
params, _ = zip(*params_grads)
self.params = list(params)
self.params_name = [p.name for p in self.params]
self.rank_id = rank_id
self.block = block
self.dist_context = dist_context
self.sharding_group = None
self.world_ranks = get_world_process_group().ranks
if hasattr(dist_context, '_sharding_group'):
self.sharding_group = dist_context._sharding_group
def _is_calcuate_norm(self, name):
if not self._is_local_param(name):
return False, []
param = self.params[self.params_name.index(name)]
dist_attr = self._get_dist_attr(name)
topology = dist_attr.process_mesh.topology
processes = dist_attr.process_mesh.processes
dims_mapping = dist_attr.dims_mapping
return _is_about_global_norm(self.rank_id, param.shape, topology,
processes, dims_mapping,
self.sharding_group)
def _get_dist_attr(self, name):
var = self.block.vars[name]
return self.dist_context.get_tensor_dist_attr_for_program(var)
def _is_local_param(self, name):
if name not in self.params_name:
return False
return True
def _is_local_var(self, name):
dist_attr = self._get_dist_attr(name)
assert dist_attr is not None
return self.rank_id in dist_attr.process_mesh.processes
def _init_dist_attr(self, op):
op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.process_mesh = self.world_ranks
for in_name in op.input_arg_names:
in_var = self.block.vars[in_name]
in_dist_attr = TensorDistributedAttribute()
in_dist_attr.process_mesh = self.world_ranks
in_dist_attr.dims_mapping = [-1]
self.dist_context.set_tensor_dist_attr_for_program(
in_var, in_dist_attr)
op_dist_attr.set_input_dist_attr(in_name, in_dist_attr)
for out_name in op.output_arg_names:
out_var = self.block.vars[out_name]
out_dist_attr = TensorDistributedAttribute()
out_dist_attr.process_mesh = self.world_ranks
out_dist_attr.dims_mapping = [-1]
self.dist_context.set_tensor_dist_attr_for_program(
out_var, out_dist_attr)
op_dist_attr.set_output_dist_attr(out_name, out_dist_attr)
self.dist_context.set_op_dist_attr_for_program(op, op_dist_attr)
@register_pass("auto_parallel_grad_clip")
class ClipGradByGloblNormPass(PassBase):
"""
1. Remove norm-compute op and grad-scale op when the grad is not in current rank
or is independent of the calculation of norm.
2. Each rank computes its own norm value, then gets global_norm by allreduce_sum only once.
"""
def __init__(self):
super(ClipGradByGloblNormPass, self).__init__()
self.set_attr("rank_id", None)
self.set_attr("dist_context", None)
def _check_self(self):
if self.get_attr("dist_context") is None:
return False
dist_context = self.get_attr("dist_context")
if dist_context._lr_optimizer._grad_clip is None:
return False
return True
def _check_conflict(self, other_pass):
return True
def _apply_single_impl(self, main_program, startup_program, context):
dist_context = self.get_attr("dist_context", None)
rank_id = self.get_attr("rank_id", None)
block = main_program.global_block()
dist_params_grads = _get_params_grads(block)
self.clip_helper = ClipHelper(dist_params_grads, rank_id, block,
dist_context)
self._remove_no_need_ops_vars(block)
def _remove_no_need_ops_vars(self, block):
removed_op_out_type = [
'clip_by_norm', 'squared_l2_norm', 'square', 'reduce_sum'
]
removed_op_idx = set()
removed_tmp_var = set()
for idx, op in enumerate(block.ops):
if not is_gradient_clip_op(op):
continue
if op.type in removed_op_out_type:
input_name = op.input("X")[0]
if input_name.find("@GRAD") != -1:
#'clip_by_norm', 'squared_l2_norm', 'square'
param_name = input_name[:input_name.find("@GRAD")]
is_local = self.clip_helper._is_local_param(param_name)
is_calculate = self.clip_helper._is_calcuate_norm(
param_name)
if not is_local or (not is_calculate
and op.type != 'clip_by_norm'):
removed_op_idx.add(idx)
removed_tmp_var.update(set(op.output_arg_names))
else:
# 'reduce_sum'
if idx - 1 in removed_op_idx:
removed_op_idx.add(idx)
removed_tmp_var.update(set(op.output_arg_names))
elif op.type == 'elementwise_mul':
input_name = op.input("X")[0]
if input_name.find("@GRAD") != -1:
param_name = input_name[:input_name.find("@GRAD")]
is_local = self.clip_helper._is_local_param(param_name)
if not is_local:
removed_op_idx.add(idx)
if block.ops[idx - 1].type == 'cast':
removed_op_idx.add(idx - 1)
removed_tmp_var.update(
set(block.ops[idx - 1].output_arg_names))
elif op.type == 'sum':
reserved_vars = []
for input_name in op.input_arg_names:
if input_name not in removed_tmp_var and \
self.clip_helper._is_local_var(input_name):
reserved_vars.append(input_name)
if not reserved_vars:
removed_op_idx.add(idx)
removed_tmp_var.update(set(op.output_arg_names))
if block.ops[idx + 1].type == 'cast':
removed_op_idx.add(idx + 1)
removed_tmp_var.update(
set(block.ops[idx + 1].output_arg_names))
else:
op.desc.set_input("X", reserved_vars)
for idx, op in reversed(list(enumerate(block.ops))):
if not is_optimize_op(op):
break
if not is_gradient_clip_op(op):
continue
if idx in removed_op_idx:
block._remove_op(idx, sync=False)
for idx, op in reversed(list(enumerate(block.ops))):
if not is_optimize_op(op):
break
if not is_gradient_clip_op(op):
continue
if op.type == 'sqrt':
input_name = op.input("X")[0]
input_var = block.vars[input_name]
if paddle.distributed.get_world_size() > 1:
offset = 0
if input_name in removed_tmp_var:
removed_tmp_var.remove(input_name)
fill_constant_op = block._insert_op(
idx,
type='fill_constant',
inputs={},
outputs={'Out': [input_var]},
attrs={
'shape': [1],
'dtype': input_var.dtype,
'value': 0,
'force_cpu': False,
OP_ROLE_KEY: OpRole.Optimize
})
fill_constant_op._set_attr('op_namescope',
"/gradient_clip_pass")
offset += 1
self.clip_helper._init_dist_attr(fill_constant_op)
allreduce_op = block._insert_op(
idx + offset,
type='c_allreduce_sum',
inputs={'X': [input_var]},
outputs={'Out': [input_var]},
attrs={
'ring_id': 0,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize,
})
allreduce_op._set_attr('op_namescope',
"/gradient_clip_pass")
self.clip_helper._init_dist_attr(allreduce_op)
for varname in removed_tmp_var:
block._remove_var(varname, sync=False)
block._sync_with_cpp()
...@@ -29,7 +29,7 @@ OpRole = core.op_proto_and_checker_maker.OpRole ...@@ -29,7 +29,7 @@ OpRole = core.op_proto_and_checker_maker.OpRole
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName() OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
_skip_ops = [ _skip_ops = [
'create_py_reader', 'create_double_buffer_reader', 'read', 'slice', 'split', 'create_py_reader', 'create_double_buffer_reader', 'read', 'slice', 'split',
'assign' 'assign', "send_v2"
] ]
# update here to support new optimizers # update here to support new optimizers
_supported_optimizer_type = [ _supported_optimizer_type = [
...@@ -140,6 +140,7 @@ class ShardingPass(PassBase): ...@@ -140,6 +140,7 @@ class ShardingPass(PassBase):
else: else:
sharding_group = dp_group sharding_group = dp_group
self._dist_context._sharding_group = sharding_group
# TODO(JZ-LIANG) when support multiple dp groups in future, should group param and bind them to corresponding dp group # TODO(JZ-LIANG) when support multiple dp groups in future, should group param and bind them to corresponding dp group
params_in_group = [p for p, g in params_grads] params_in_group = [p for p, g in params_grads]
assert len(params_in_group) == len( assert len(params_in_group) == len(
...@@ -160,7 +161,7 @@ class ShardingPass(PassBase): ...@@ -160,7 +161,7 @@ class ShardingPass(PassBase):
""" """
self._shard_amp_related_op_and_vars(main_block, pass_context) self._shard_amp_related_op_and_vars(main_block, pass_context)
self._shard_weight_decay(main_block) self._shard_weight_decay(main_block)
self._shard_gradient_clip(main_block) # self._shard_gradient_clip(main_block)
self._shard_optimizer_ops_and_states(main_block, startup_block) self._shard_optimizer_ops_and_states(main_block, startup_block)
self._insert_optimizer_broadcasts(main_block, startup_block) self._insert_optimizer_broadcasts(main_block, startup_block)
......
...@@ -37,6 +37,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -37,6 +37,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
${dist_ENVS}) ${dist_ENVS})
set_tests_properties(test_high_order_grad set_tests_properties(test_high_order_grad
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50) PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_grad_clip MODULES test_grad_clip ENVS ${dist_ENVS})
set_tests_properties(test_grad_clip PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE"
TIMEOUT 50)
py_test_modules(test_while_op_completion MODULES test_while_op_completion py_test_modules(test_while_op_completion MODULES test_while_op_completion
ENVS ${dist_ENVS}) ENVS ${dist_ENVS})
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import sys
import random
import numpy as np
import paddle
import paddle.distributed.fleet as fleet
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.engine import Engine
from get_gpt_model import generate_model, create_data_holder, FakeDataset
paddle.enable_static()
def apply_pass(use_sharding=False):
strategy = fleet.DistributedStrategy()
strategy.semi_auto = True
if use_sharding:
strategy.sharding = True
strategy.sharding_configs = {
"sharding_degree": 2,
"stage": 2,
}
return strategy
def get_parameter_value(program):
from paddle.fluid.framework import Parameter
def is_parameter(var):
return isinstance(var, Parameter)
def get_tensor(var):
t = paddle.fluid.global_scope().find_var(var.name).get_tensor()
return np.array(t)
def get_name(var):
return len(var.name)
parameters_list = list(filter(is_parameter, program.list_vars()))
parameters_value = []
for p in sorted(parameters_list, key=get_name):
parameters_value.append(get_tensor(p))
return parameters_value
def reset_prog():
paddle.fluid.framework.switch_main_program(paddle.static.Program())
paddle.fluid.framework.switch_startup_program(paddle.static.Program())
class TestGradientClipByGlobalNorm(unittest.TestCase):
def setUp(self):
self.batch_size = 2
self.batch_num = 1
self.clip_norm = 0.2
self.dataset = FakeDataset(self.batch_size * self.batch_num)
def init(self, engine):
paddle.seed(2022)
np.random.seed(2022)
random.seed(2022)
engine.mode = "train"
engine._executor.run(engine.startup_program)
def get_dp2_engine(self):
reset_prog()
strategy = apply_pass()
clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model("dp")
inputs_spec, labels_spec = create_data_holder(self.batch_size)
engine = Engine(model, inputs_spec, labels_spec, strategy=strategy)
engine.prepare(optimizer=opt, loss=loss)
self.init(engine)
return engine
def get_dp2sharding2_engine(self):
reset_prog()
strategy = apply_pass(True)
clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model("dp")
inputs_spec, labels_spec = create_data_holder(self.batch_size)
engine = Engine(model, inputs_spec, labels_spec, strategy=strategy)
engine.prepare(optimizer=opt, loss=loss)
self.init(engine)
return engine
def check_result(self, dp_params, sharding_params):
assert len(dp_params) == len(sharding_params)
for dp_p, sharding_p in zip(dp_params, sharding_params):
np.testing.assert_allclose(
dp_p,
sharding_p,
rtol=1e-05,
atol=1e-08,
err_msg=
'gradient clip by global norm has wrong results!, \nu={}\nv={}\ndiff={}'
.format(dp_p, sharding_p, dp_p - sharding_p))
def test_grad_clip(self):
# dp2 training
dp_engine = self.get_dp2_engine()
dp_engine.fit(self.dataset, batch_size=self.batch_size, use_cache=True)
dp_param_values = get_parameter_value(dp_engine.main_program)
# dp2sharding2 training
sharding_engine = self.get_dp2sharding2_engine()
sharding_engine.fit(self.dataset,
batch_size=self.batch_size,
use_cache=True)
sharding_param_values = get_parameter_value(
sharding_engine.main_program)
self.check_result(dp_param_values, sharding_param_values)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import numpy as np
import paddle
sys.path.append("..")
import auto_parallel_gpt_model as modeling
from auto_parallel_gpt_model import GPTModel, GPTForPretraining, GPTPretrainingCriterion
sequence_len = 512
vocab_size = 1000
class FakeDataset:
def __init__(self, num_samples):
self.num_samples = num_samples
self.sequence_len = sequence_len
self.vocab_size = vocab_size
def __getitem__(self, idx):
tokens = np.random.randint(self.vocab_size, size=self.sequence_len)
position_ids = np.arange(self.sequence_len)
attention_mask = np.tril(np.ones(self.sequence_len)).reshape(
(1, self.sequence_len, self.sequence_len)).astype(np.float32)
labels = np.random.randint(self.vocab_size, size=self.sequence_len)
loss_mask = np.ones(self.sequence_len).astype(np.float32)
return tokens, position_ids, attention_mask, labels, loss_mask
def __len__(self):
return self.num_samples
def create_data_holder(batch_size):
tokens = paddle.static.InputSpec(name="tokens",
shape=[batch_size, sequence_len],
dtype='int64')
position_ids = paddle.static.InputSpec(name="position_ids",
shape=[batch_size, sequence_len],
dtype='int64')
attention_mask = paddle.static.InputSpec(
name="attention_mask",
shape=[batch_size, 1, sequence_len, sequence_len],
dtype='float32')
labels = paddle.static.InputSpec(name="labels",
shape=[batch_size, sequence_len],
dtype='int64')
loss_mask = paddle.static.InputSpec(name="loss_mask",
shape=[batch_size, sequence_len],
dtype='float32')
return [tokens, position_ids, attention_mask], [labels, loss_mask]
def generate_model(strategy):
modeling.init_global()
if strategy == "serial":
modeling._global_parallel_strategy = "serial"
modeling._global_process_mesh = [0]
elif strategy == "mp":
modeling._global_parallel_strategy = "mp"
modeling._global_process_mesh = [0, 1]
elif strategy == "dp":
modeling._global_parallel_strategy = "dp"
modeling._global_process_mesh = [0, 1]
else:
raise ValueError("Only support serial, mp2 and dp2.")
gpt = GPTModel(vocab_size=1000,
hidden_size=64,
num_hidden_layers=2,
num_attention_heads=8,
intermediate_size=256,
hidden_act="gelu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
max_position_embeddings=1024,
type_vocab_size=1,
initializer_range=0.02,
pad_token_id=0,
eos_token_id=7,
bos_token_id=0,
eol_token_id=3)
model = GPTForPretraining(gpt,
vocab_size=1000,
hidden_size=64,
initializer_range=0.02)
criterion = GPTPretrainingCriterion()
return model, criterion
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
import os
import sys
import shutil
import subprocess
from paddle.distributed.fleet.launch_utils import run_with_coverage
class TestGradientClip(unittest.TestCase):
def test_dp2(self):
file_dir = os.path.dirname(os.path.abspath(__file__))
launch_model_path = os.path.join(file_dir,
"clip_grad_by_global_norm.py")
if os.environ.get("WITH_COVERAGE", "OFF") == "ON":
coverage_args = ["-m", "coverage", "run", "--branch", "-p"]
else:
coverage_args = []
tmp_dir = tempfile.TemporaryDirectory()
cmd = [sys.executable, "-u"] + coverage_args + [
"-m", "paddle.distributed.launch", "--devices", "0,1", "--log_dir",
tmp_dir.name, launch_model_path
]
process = subprocess.Popen(cmd)
process.wait()
self.assertEqual(process.returncode, 0)
tmp_dir.cleanup()
if __name__ == "__main__":
unittest.main()
...@@ -86,7 +86,7 @@ class TestLRScheduler(TestEngineBase): ...@@ -86,7 +86,7 @@ class TestLRScheduler(TestEngineBase):
self.engine.fit(self.dataset, batch_size=self.batch_size) self.engine.fit(self.dataset, batch_size=self.batch_size)
class TestGradClip(TestEngineBase): class TestGradClipByGlobalNorm(TestEngineBase):
def init_optimizer(self): def init_optimizer(self):
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0) clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
...@@ -96,7 +96,6 @@ class TestGradClip(TestEngineBase): ...@@ -96,7 +96,6 @@ class TestGradClip(TestEngineBase):
def test_grad_clip(self): def test_grad_clip(self):
clip = self.engine._optimizer._grad_clip clip = self.engine._optimizer._grad_clip
assert isinstance(clip, paddle.nn.ClipGradByGlobalNorm)
self.engine.fit(self.dataset, batch_size=self.batch_size) self.engine.fit(self.dataset, batch_size=self.batch_size)
self.check_program() self.check_program()
...@@ -112,5 +111,13 @@ class TestGradClip(TestEngineBase): ...@@ -112,5 +111,13 @@ class TestGradClip(TestEngineBase):
assert has_grad_clip is True assert has_grad_clip is True
class TestGradClipByNorm(TestGradClipByGlobalNorm):
def init_optimizer(self):
clip = paddle.nn.ClipGradByNorm(clip_norm=1.0)
self.optimizer = paddle.optimizer.SGD(learning_rate=0.00001,
grad_clip=clip)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册