未验证 提交 bb6bd223 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] support ClipGradByGlobalNorm (#45205)

* add clip_grad

* fix comments

* add unittest

* update logger
上级 d257acc6
...@@ -19,7 +19,7 @@ import time ...@@ -19,7 +19,7 @@ import time
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid import framework from paddle.fluid import framework
from .utils import print_program_with_dist_attr from .utils import print_program_with_dist_attr, _is_gradient_clip_op
from .operators import find_compatible_distributed_operator_impls from .operators import find_compatible_distributed_operator_impls
from .dist_context import get_default_distributed_context, _node_id from .dist_context import get_default_distributed_context, _node_id
from .dist_tensor import DistributedTensor from .dist_tensor import DistributedTensor
...@@ -1319,26 +1319,70 @@ class Completer: ...@@ -1319,26 +1319,70 @@ class Completer:
# TODO to add attribute for moment var # TODO to add attribute for moment var
op = ops[idx] op = ops[idx]
if int(op.attr('op_role')) == int(OpRole.Optimize): if int(op.attr('op_role')) == int(OpRole.Optimize):
if op.type == "clip_by_norm": # TODO:
param_grad = vars[op.input("X")[0]] # 1. move `generate_optimizer` before `partitioner`
param_grad_dist_attr = self._dist_context.get_tensor_dist_attr_for_program( # 2. implement grad_clip completion by `dist_op`
param_grad) # 3. allreduce dist_gloabl_norm (mp-group) and no_dist_global_norm (pp-group, sharding-group)
assert param_grad_dist_attr is not None if _is_gradient_clip_op(op):
ref_process_mesh = param_grad_dist_attr.process_mesh if op.type in [
ref_dims_mapping = param_grad_dist_attr.dims_mapping "sum", "sqrt", "fill_constant", "elementwise_max",
"elementwise_div"
out = vars[op.output("Out")[0]] ]:
out_dist_attr = TensorDistributedAttribute() op_dist_attr = OperatorDistributedAttribute()
out_dist_attr.process_mesh = ref_process_mesh op_dist_attr.process_mesh = world_ranks
out_dist_attr.dims_mapping = ref_dims_mapping for in_name in op.input_arg_names:
self._dist_context.set_tensor_dist_attr_for_program( in_var = vars[in_name]
out, out_dist_attr) in_dist_attr = self._dist_context.get_tensor_dist_attr_for_program(
in_var)
op_dist_attr.set_input_dist_attr(
in_name, in_dist_attr)
for out_name in op.output_arg_names:
out_var = vars[out_name]
out_dist_attr = TensorDistributedAttribute()
out_dist_attr.process_mesh = world_ranks
out_dist_attr.dims_mapping = [
-1 for _ in range(len(out_var.shape))
]
self._dist_context.set_tensor_dist_attr_for_program(
out_var, out_dist_attr)
op_dist_attr.set_output_dist_attr(
out_name, out_dist_attr)
remove_no_need_in_op(op, self._dist_context)
else:
in_var = vars[op.input("X")[0]]
in_dist_attr = self._dist_context.get_tensor_dist_attr_for_program(
in_var)
assert in_dist_attr is not None
ref_process_mesh = in_dist_attr.process_mesh
ref_dims_mapping = in_dist_attr.dims_mapping
if op.type == "cast" and ops[
idx + 1].type == "elementwise_mul":
ref_var = vars[ops[idx + 1].input("X")[0]]
ref_dist_attr = self._dist_context.get_tensor_dist_attr_for_program(
ref_var)
assert ref_dist_attr is not None
ref_process_mesh = ref_dist_attr.process_mesh
out_var = vars[op.output("Out")[0]]
out_dist_attr = TensorDistributedAttribute()
out_dist_attr.process_mesh = ref_process_mesh
if out_var.shape == in_var.shape:
out_dist_attr.dims_mapping = ref_dims_mapping
else:
assert len(
out_var.shape) == 1 and out_var.shape[0] == 1
out_dist_attr.dims_mapping = [-1]
self._dist_context.set_tensor_dist_attr_for_program(
out_var, out_dist_attr)
op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.process_mesh = ref_process_mesh
op_dist_attr.set_input_dist_attr(
in_var.name, in_dist_attr)
op_dist_attr.set_output_dist_attr(
out_var.name, out_dist_attr)
op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.process_mesh = ref_process_mesh
op_dist_attr.set_input_dist_attr(param_grad.name,
param_grad_dist_attr)
op_dist_attr.set_output_dist_attr(out.name, out_dist_attr)
self._dist_context.set_op_dist_attr_for_program( self._dist_context.set_op_dist_attr_for_program(
op, op_dist_attr) op, op_dist_attr)
...@@ -1383,11 +1427,17 @@ class Completer: ...@@ -1383,11 +1427,17 @@ class Completer:
for input_name in op.desc.input_names(): for input_name in op.desc.input_names():
if input_name in [ if input_name in [
'Param', 'Grad', 'LearningRate', "SkipUpdate", 'Param',
"Beta1Tensor", "Beta2Tensor", "EpsilonTensor", 'Grad',
"MasterParam" 'LearningRate',
"SkipUpdate",
"Beta1Tensor",
"Beta2Tensor",
"EpsilonTensor",
]: ]:
continue continue
if len(op.desc.input(input_name)) == 0:
continue
assert len(op.desc.input(input_name)) == 1 assert len(op.desc.input(input_name)) == 1
input_var = vars[op.desc.input(input_name)[0]] input_var = vars[op.desc.input(input_name)[0]]
...@@ -1400,7 +1450,6 @@ class Completer: ...@@ -1400,7 +1450,6 @@ class Completer:
op_dist_attr.set_output_dims_mapping( op_dist_attr.set_output_dims_mapping(
input_var.name, [-1]) input_var.name, [-1])
else: else:
assert "Moment" in input_name or "Velocity" in input_name
input_var_attr.dims_mapping = ref_dims_mapping input_var_attr.dims_mapping = ref_dims_mapping
op_dist_attr.set_input_dims_mapping( op_dist_attr.set_input_dims_mapping(
input_var.name, ref_dims_mapping) input_var.name, ref_dims_mapping)
...@@ -1481,3 +1530,20 @@ class Completer: ...@@ -1481,3 +1530,20 @@ class Completer:
break break
else: else:
dist_op.dist_attr = backup_op_dist_attr dist_op.dist_attr = backup_op_dist_attr
def remove_no_need_in_op(op, dist_context):
if op.type == "fill_constant":
return
filter_vars = []
main_block = op.block
rank_id = dist_context.dist_op_context.rank_id
for varname in op.input("X"):
if rank_id in dist_context.get_tensor_dist_attr_for_program(
main_block.var(varname)).process_mesh.processes:
filter_vars.append(varname)
if not filter_vars:
return
op.desc.set_input('X', filter_vars)
...@@ -68,7 +68,6 @@ class DistributedContext: ...@@ -68,7 +68,6 @@ class DistributedContext:
self._original_serial_loss = serial_loss self._original_serial_loss = serial_loss
self._original_serial_feed_vars = feed_vars self._original_serial_feed_vars = feed_vars
self._original_serial_fetch_vars = fetch_vars self._original_serial_fetch_vars = fetch_vars
self._original_serial_optimizer = serial_optimizer
# Data members related to programs (changed) # Data members related to programs (changed)
self._serial_main_program = None self._serial_main_program = None
...@@ -77,6 +76,7 @@ class DistributedContext: ...@@ -77,6 +76,7 @@ class DistributedContext:
self._serial_optimizer = None self._serial_optimizer = None
self._serial_feed_vars = {} self._serial_feed_vars = {}
self._serial_fetch_vars = {} self._serial_fetch_vars = {}
self._lr_optimizer = None # record the optimzier holding lr_scheduler
# Data members related to the program # Data members related to the program
self._dist_tensors_for_program = {} self._dist_tensors_for_program = {}
...@@ -126,7 +126,7 @@ class DistributedContext: ...@@ -126,7 +126,7 @@ class DistributedContext:
self._data_parallel = False self._data_parallel = False
# flag whether using `to_static` # flag whether using `to_static`
self._dygraph_mode = True self._dygraph_mode = False
@property @property
def serial_main_program(self): def serial_main_program(self):
...@@ -235,31 +235,20 @@ class DistributedContext: ...@@ -235,31 +235,20 @@ class DistributedContext:
if dist: if dist:
self._backup_dist_info(dist_mode) self._backup_dist_info(dist_mode)
def _restore_serial_info(self, mode="to_backup"): def _restore_serial_loss(self):
if mode == "to_backup":
self._serial_main_program = self._backup_serial_main_program_stack.pop(
)
self._serial_startup_program = self._backup_serial_startup_program_stack.pop(
)
elif mode == "to_original":
assert self._original_serial_main_program is not None
assert self._original_serial_startup_program is not None
self._serial_main_program = self._original_serial_main_program.clone(
)
self._serial_startup_program = self._original_serial_startup_program.clone(
)
self._serial_optimizer = self._original_serial_optimizer
if self._original_serial_loss: if self._original_serial_loss:
if isinstance(self._original_serial_loss, list): if isinstance(self._original_serial_loss, list):
assert len(self._original_serial_loss) == 1 if len(self._original_serial_loss) == 1:
loss = self._original_serial_loss[0] loss = self._original_serial_loss[0]
block_idx = loss.block.idx block_idx = loss.block.idx
var_name = loss.name var_name = loss.name
var = self._serial_main_program.blocks[ var = self._serial_main_program.blocks[
block_idx]._var_recursive(var_name) block_idx]._var_recursive(var_name)
self._serial_loss = var self._serial_loss = var
elif len(self._original_serial_loss) == 0:
self._serial_loss = []
else:
raise ValueError("multi loss vars are not supported.")
else: else:
block_idx = self._original_serial_loss.block.idx block_idx = self._original_serial_loss.block.idx
var_name = self._original_serial_loss.name var_name = self._original_serial_loss.name
...@@ -267,6 +256,7 @@ class DistributedContext: ...@@ -267,6 +256,7 @@ class DistributedContext:
block_idx]._var_recursive(var_name) block_idx]._var_recursive(var_name)
self._serial_loss = var self._serial_loss = var
def _restore_serial_feed_vars(self):
for key, var_list in self._original_serial_feed_vars.items(): for key, var_list in self._original_serial_feed_vars.items():
new_var_list = [] new_var_list = []
for var in var_list: for var in var_list:
...@@ -277,6 +267,7 @@ class DistributedContext: ...@@ -277,6 +267,7 @@ class DistributedContext:
new_var_list.append(var) new_var_list.append(var)
self._serial_feed_vars[key] = new_var_list self._serial_feed_vars[key] = new_var_list
def _restore_serial_fetch_vars(self):
for key, var_list in self._original_serial_fetch_vars.items(): for key, var_list in self._original_serial_fetch_vars.items():
new_var_list = [] new_var_list = []
for var in var_list: for var in var_list:
...@@ -287,6 +278,24 @@ class DistributedContext: ...@@ -287,6 +278,24 @@ class DistributedContext:
new_var_list.append(var) new_var_list.append(var)
self._serial_fetch_vars[key] = new_var_list self._serial_fetch_vars[key] = new_var_list
def _restore_serial_info(self, mode="to_backup"):
if mode == "to_backup":
self._serial_main_program = self._backup_serial_main_program_stack.pop(
)
self._serial_startup_program = self._backup_serial_startup_program_stack.pop(
)
elif mode == "to_original":
assert self._original_serial_main_program is not None
assert self._original_serial_startup_program is not None
self._serial_main_program = self._original_serial_main_program.clone(
)
self._serial_startup_program = self._original_serial_startup_program.clone(
)
self._restore_serial_loss()
self._restore_serial_feed_vars()
self._restore_serial_fetch_vars()
self._serial_optimizer = self._original_serial_optimizer
self._pass_context = self._backup_pass_context_stack.pop() self._pass_context = self._backup_pass_context_stack.pop()
self._block_state = self._backup_block_state_stack.pop() self._block_state = self._backup_block_state_stack.pop()
...@@ -353,25 +362,21 @@ class DistributedContext: ...@@ -353,25 +362,21 @@ class DistributedContext:
def initialize(self, with_graph=True): def initialize(self, with_graph=True):
if not self._is_initialized: if not self._is_initialized:
if not self._serial_main_program: if not self._serial_main_program:
self._serial_main_program = self._original_serial_main_program if self._original_serial_main_program:
self._serial_main_program = self._original_serial_main_program.clone(
)
if not self._serial_startup_program: if not self._serial_startup_program:
self._serial_startup_program = self._original_serial_startup_program if self._original_serial_startup_program:
self._serial_startup_program = self._original_serial_startup_program.clone(
)
if not self._serial_loss: if not self._serial_loss:
if isinstance(self._original_serial_loss, list): self._restore_serial_loss()
if len(self._original_serial_loss) == 1:
self._serial_loss = self._original_serial_loss[0]
elif len(self._original_serial_loss) == 0:
self._serial_loss = self._original_serial_loss
else:
raise ValueError("multi loss vars are not supported.")
else:
self._serial_loss = self._original_serial_loss
if not self._serial_optimizer: if not self._serial_optimizer:
self._serial_optimizer = self._original_serial_optimizer self._serial_optimizer = self._original_serial_optimizer
if not self._serial_feed_vars: if not self._serial_feed_vars:
self._serial_feed_vars = self._original_serial_feed_vars self._restore_serial_feed_vars()
if not self._serial_fetch_vars: if not self._serial_fetch_vars:
self._serial_fetch_vars = self._original_serial_fetch_vars self._restore_serial_fetch_vars()
self._init_dist_attr_for_program() self._init_dist_attr_for_program()
# Backup the original distributed information for later restore # Backup the original distributed information for later restore
...@@ -856,7 +861,11 @@ class DistributedContext: ...@@ -856,7 +861,11 @@ class DistributedContext:
"_serial_main_program", "_serial_startup_program", "_serial_graph", \ "_serial_main_program", "_serial_startup_program", "_serial_graph", \
"_dist_main_programs", "_dist_startup_programs", \ "_dist_main_programs", "_dist_startup_programs", \
"_serial_ordered_nodes", "_serial_ordered_tensor_nodes", \ "_serial_ordered_nodes", "_serial_ordered_tensor_nodes", \
"_serial_ordered_op_nodes"]: "_serial_ordered_op_nodes", "_original_serial_loss", \
"_original_serial_feed_vars", "_original_serial_fetch_vars", \
"_serial_loss", "_serial_feed_vars", "_serial_fetch_vars", "_lr_optimizer", \
"_backup_serial_main_program_stack", "_backup_serial_startup_program_stack", \
"_pass_context"]:
setattr(result, k, v) setattr(result, k, v)
else: else:
setattr(result, k, copy.deepcopy(v, memo)) setattr(result, k, copy.deepcopy(v, memo))
......
...@@ -16,7 +16,6 @@ import time ...@@ -16,7 +16,6 @@ import time
import copy import copy
import logging import logging
from collections import defaultdict from collections import defaultdict
import socket
import paddle import paddle
import paddle.utils as utils import paddle.utils as utils
...@@ -35,7 +34,6 @@ from paddle.fluid.framework import Operator, Parameter, _non_static_mode ...@@ -35,7 +34,6 @@ from paddle.fluid.framework import Operator, Parameter, _non_static_mode
from paddle.fluid.framework import _current_expected_place as _get_device from paddle.fluid.framework import _current_expected_place as _get_device
from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.utils import get_logger
from paddle.distributed.passes import new_pass, PassContext from paddle.distributed.passes import new_pass, PassContext
from .hepler import ProgramHelper from .hepler import ProgramHelper
...@@ -76,7 +74,18 @@ class Engine: ...@@ -76,7 +74,18 @@ class Engine:
self._cur_rank = paddle.distributed.get_rank() self._cur_rank = paddle.distributed.get_rank()
self._nranks = paddle.distributed.get_world_size() self._nranks = paddle.distributed.get_world_size()
self._saver = DistributedSaver() self._saver = DistributedSaver()
self._logger = get_logger(logging.INFO)
# TODO: add logger module
self._logger = logging.getLogger()
self._logger.propagate = False
if not self._logger.handlers:
self._logger.setLevel(logging.INFO)
log_handler = logging.StreamHandler()
log_format = logging.Formatter(
'[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s'
)
log_handler.setFormatter(log_format)
self._logger.addHandler(log_handler)
self._orig_main_prog = static.default_main_program() self._orig_main_prog = static.default_main_program()
self._orig_startup_prog = static.default_startup_program() self._orig_startup_prog = static.default_startup_program()
...@@ -307,7 +316,7 @@ class Engine: ...@@ -307,7 +316,7 @@ class Engine:
mode].dist_startup_programs mode].dist_startup_programs
self._feed_vars[mode] = self._dist_contexts[mode].serial_feed_vars self._feed_vars[mode] = self._dist_contexts[mode].serial_feed_vars
self._fetch_vars[mode] = self._dist_contexts[mode].serial_fetch_vars self._fetch_vars[mode] = self._dist_contexts[mode].serial_fetch_vars
self._optimizer = self._dist_contexts[mode].serial_optimizer self._lr_optimizer = self._dist_contexts[mode]._lr_optimizer
if self._nranks > 1: if self._nranks > 1:
# Traverse different rank programs and traverse each op of them, # Traverse different rank programs and traverse each op of them,
...@@ -429,25 +438,27 @@ class Engine: ...@@ -429,25 +438,27 @@ class Engine:
lr_scheduler = self.get_lr_scheduler(self.main_program) lr_scheduler = self.get_lr_scheduler(self.main_program)
for epoch in range(epochs): for epoch in range(epochs):
train_logs = {"epoch": epoch} train_logs = {"epoch: {:d} ": epoch}
for step, _ in enumerate(train_dataloader): for step, _ in enumerate(train_dataloader):
outs = self._executor.run(self.main_program, outs = self._executor.run(self.main_program,
fetch_list=fetch_list, fetch_list=fetch_list,
use_program_cache=use_cache, use_program_cache=use_cache,
return_numpy=return_numpy) return_numpy=return_numpy)
train_logs["step: {:d} "] = step
if lr_scheduler is not None: if lr_scheduler is not None:
lr_scheduler.step() lr_scheduler.step()
train_logs["lr"] = self._optimizer.get_lr() train_logs["lr: {:5e} "] = self._lr_optimizer.get_lr()
train_logs["step"] = step
# inner fetches # inner fetches
if fetch_loss: if fetch_loss:
train_logs["train_loss"] = outs[0][0] train_logs["loss: {:9f} "] = outs[0][0]
# user fetches # user fetches
user_outs = outs[len(fetch_loss):] user_outs = outs[len(fetch_loss):]
user_fetch_list = fetch_list[len(fetch_loss):] user_fetch_list = fetch_list[len(fetch_loss):]
for i, out in enumerate(user_outs): for i, out in enumerate(user_outs):
train_logs["train_" + fetch_map[user_fetch_list[i]]] = out train_logs[fetch_map[user_fetch_list[i]] + ": {}"] = out
self._logger.info(train_logs) # logger
string = '[train] ' + ''.join(list(train_logs.keys()))
self._logger.info(string.format(*list(train_logs.values())))
def evaluate(self, def evaluate(self,
eval_data, eval_data,
...@@ -473,14 +484,14 @@ class Engine: ...@@ -473,14 +484,14 @@ class Engine:
fetch_list, fetch_map = self._fetch_map(inner_fetch, usr_fetch) fetch_list, fetch_map = self._fetch_map(inner_fetch, usr_fetch)
for step, _ in enumerate(eval_dataloader): for step, _ in enumerate(eval_dataloader):
eval_logs = {"step": step} eval_logs = {"step: {:d} ": step}
outs = self._executor.run(self.main_program, outs = self._executor.run(self.main_program,
fetch_list=fetch_list, fetch_list=fetch_list,
use_program_cache=use_cache, use_program_cache=use_cache,
return_numpy=return_numpy) return_numpy=return_numpy)
# inner fetches # inner fetches
if fetch_loss: if fetch_loss:
eval_logs["eval_loss"] = outs[0][0] eval_logs["loss: {:9f} "] = outs[0][0]
# Metric # Metric
if fetch_metrics: if fetch_metrics:
metric_out = outs[len(fetch_loss):len(inner_fetch)] metric_out = outs[len(fetch_loss):len(inner_fetch)]
...@@ -488,14 +499,15 @@ class Engine: ...@@ -488,14 +499,15 @@ class Engine:
metric.update(*metric_out) metric.update(*metric_out)
results = metric.accumulate() results = metric.accumulate()
for i, res in enumerate(to_list(results)): for i, res in enumerate(to_list(results)):
eval_logs["eval_" + metric.name()[i]] = res eval_logs[metric.name()[i] + ": {:9f} "] = res
# usr fetches # usr fetches
usr_outs = outs[len(inner_fetch):] usr_outs = outs[len(inner_fetch):]
usr_fetch_list = fetch_list[len(inner_fetch):] usr_fetch_list = fetch_list[len(inner_fetch):]
for i, out in enumerate(usr_outs): for i, out in enumerate(usr_outs):
eval_logs["eval_" + fetch_map[usr_fetch_list[i]]] = out eval_logs[fetch_map[usr_fetch_list[i]] + ": {}"] = out
# logger # logger
self._logger.info(eval_logs) string = '[eval] ' + ''.join(list(eval_logs.keys()))
self._logger.info(string.format(*list(eval_logs.values())))
def predict(self, def predict(self,
test_data, test_data,
...@@ -520,15 +532,17 @@ class Engine: ...@@ -520,15 +532,17 @@ class Engine:
outputs = [] outputs = []
for step, _ in enumerate(test_dataloader): for step, _ in enumerate(test_dataloader):
predict_logs = {"step": step} predict_logs = {"step: {:d} ": step}
outs = self._executor.run(self.main_program, outs = self._executor.run(self.main_program,
fetch_list=fetch_list, fetch_list=fetch_list,
use_program_cache=use_cache, use_program_cache=use_cache,
return_numpy=return_numpy) return_numpy=return_numpy)
outputs.append(outs[:len(fetch_outputs)]) outputs.append(outs[:len(fetch_outputs)])
for i, out in enumerate(outs): for i, out in enumerate(outs):
predict_logs["pred_" + fetch_map[fetch_list[i]]] = out predict_logs[fetch_map[fetch_list[i]] + ": {}"] = out
self._logger.info(predict_logs) # logger
string = '[pred] ' + ''.join(list(predict_logs.keys()))
self._logger.info(string.format(*list(predict_logs.values())))
return outputs return outputs
......
...@@ -20,7 +20,7 @@ from collections import defaultdict ...@@ -20,7 +20,7 @@ from collections import defaultdict
import paddle import paddle
from paddle.fluid import program_guard from paddle.fluid import program_guard
from paddle.fluid.backward import append_backward from paddle.fluid.backward import append_backward
from paddle.fluid.framework import _non_static_mode from paddle.fluid.framework import _non_static_mode, unique_name
from paddle.distributed.passes import new_pass from paddle.distributed.passes import new_pass
from paddle.distributed.utils import get_logger from paddle.distributed.utils import get_logger
...@@ -143,15 +143,18 @@ class Parallelizer: ...@@ -143,15 +143,18 @@ class Parallelizer:
def _generate_optimizer(self, main_program, startup_program, optimizer, def _generate_optimizer(self, main_program, startup_program, optimizer,
params_grads): params_grads):
# NOTE: `apply_gradients` will add an Accumulator for a parameter only once,
# but optimizer will be called repeatedly in re-launch, so optimizer need to be copied.
if self._dist_context._dygraph_mode: if self._dist_context._dygraph_mode:
paddle.disable_static() paddle.disable_static()
optimizer = copy.deepcopy(optimizer) optimizer = copy.deepcopy(optimizer)
paddle.enable_static() paddle.enable_static()
else: else:
optimizer = copy.deepcopy(optimizer) optimizer = copy.deepcopy(optimizer)
self._dist_context._serial_optimizer = optimizer self._dist_context._lr_optimizer = optimizer
with program_guard(main_program, startup_program): with program_guard(main_program, startup_program):
optimizer_ops = optimizer.apply_gradients(params_grads) with unique_name.guard("opt_"):
optimizer_ops = optimizer.apply_gradients(params_grads)
self._completer.complete_update_annotation(main_program) self._completer.complete_update_annotation(main_program)
return optimizer_ops return optimizer_ops
......
...@@ -30,10 +30,13 @@ from .cost import build_comm_desc, CommContext ...@@ -30,10 +30,13 @@ from .cost import build_comm_desc, CommContext
from .cost import AllgatherOpCost, SendOpCost from .cost import AllgatherOpCost, SendOpCost
from .cost import SliceOpCost, SplitOpCost, ConcatOpCost from .cost import SliceOpCost, SplitOpCost, ConcatOpCost
from .cluster import Cluster from .cluster import Cluster
from .utils import print_program_with_dist_attr from .utils import print_program_with_dist_attr, _is_gradient_clip_op
# NOTE: If op in _g_special_ops, it will not be resharded. # NOTE: If op in _g_special_ops or _g_gradient_clip_ops, it will not be resharded.
_g_special_ops = ['check_finite_and_unscale', 'update_loss_scaling'] _g_special_ops = ['check_finite_and_unscale', 'update_loss_scaling']
_g_gradient_clip_ops = [
"sum", "sqrt", "fill_constant", "elementwise_max", "elementwise_div"
]
def get_var_with_recursion(var_name, block, program): def get_var_with_recursion(var_name, block, program):
...@@ -1076,9 +1079,11 @@ class Resharder: ...@@ -1076,9 +1079,11 @@ class Resharder:
return True return True
def is_special_op(self, op): def is_special_op(self, op):
global _g_special_ops global _g_special_ops, _g_gradient_clip_ops
if op.type in _g_special_ops: if op.type in _g_special_ops:
return True return True
if _is_gradient_clip_op(op) and op.type in _g_gradient_clip_ops:
return True
return False return False
def is_condition_replicative(self, op): def is_condition_replicative(self, op):
......
...@@ -1131,6 +1131,11 @@ def is_loss_grad_op(op): ...@@ -1131,6 +1131,11 @@ def is_loss_grad_op(op):
return op_role & int(OpRole.Backward) and op_role & int(OpRole.Loss) return op_role & int(OpRole.Backward) and op_role & int(OpRole.Loss)
def _is_gradient_clip_op(op):
return op.desc.has_attr("op_namescope") \
and op.desc.attr("op_namescope").startswith("/gradient_clip")
def is_prim_op(op): def is_prim_op(op):
return op.type.endswith("_p") return op.type.endswith("_p")
......
...@@ -64,4 +64,5 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -64,4 +64,5 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_cluster_v2 MODULES test_cluster_v2) py_test_modules(test_cluster_v2 MODULES test_cluster_v2)
py_test_modules(test_process_mesh_v2 MODULES test_process_mesh_v2) py_test_modules(test_process_mesh_v2 MODULES test_process_mesh_v2)
py_test_modules(test_dist_attr_v2 MODULES test_dist_attr_v2) py_test_modules(test_dist_attr_v2 MODULES test_dist_attr_v2)
py_test_modules(test_lr_grad_clip MODULES test_lr_grad_clip)
endif() endif()
...@@ -108,9 +108,7 @@ def train(fetch): ...@@ -108,9 +108,7 @@ def train(fetch):
dropout_ratio=0.1, dropout_ratio=0.1,
initializer_range=0.02) initializer_range=0.02)
loss = paddle.nn.CrossEntropyLoss() loss = paddle.nn.CrossEntropyLoss()
scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=0.00001, optimizer = paddle.optimizer.Adam(learning_rate=0.00001,
T_max=10)
optimizer = paddle.optimizer.Adam(learning_rate=scheduler,
beta1=0.9, beta1=0.9,
beta2=0.999, beta2=0.999,
epsilon=1e-08, epsilon=1e-08,
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import unittest import unittest
import os import os
import json import json
import copy
import paddle import paddle
import numpy as np import numpy as np
...@@ -194,6 +195,32 @@ class TestDistributedContext(unittest.TestCase): ...@@ -194,6 +195,32 @@ class TestDistributedContext(unittest.TestCase):
dist_context._backup(serial=True, dist=True) dist_context._backup(serial=True, dist=True)
dist_context._restore(serial=True, dist=True, dist_mode="to_nothing") dist_context._restore(serial=True, dist=True, dist_mode="to_nothing")
def test_deepcopy(self):
train_program, start_program, dataloader, loss, optimizer, feed_vars, fetch_vars = get_program(
)
dist_context = DistributedContext(train_program, start_program,
optimizer, loss, feed_vars,
fetch_vars)
dist_context.initialize()
copy_dist_context = copy.deepcopy(dist_context)
copy_list = [
"_original_serial_main_program", "_original_serial_startup_program", \
"_serial_main_program", "_serial_startup_program", "_serial_graph", \
"_dist_main_programs", "_dist_startup_programs", \
"_serial_ordered_nodes", "_serial_ordered_tensor_nodes", \
"_serial_ordered_op_nodes", "_original_serial_loss", \
"_original_serial_feed_vars", "_original_serial_fetch_vars", \
"_serial_loss", "_serial_feed_vars", "_serial_fetch_vars", "_lr_optimizer", \
"_backup_serial_main_program_stack", "_backup_serial_startup_program_stack", \
"_pass_context"]
for i in range(len(copy_list)):
copy_obj = "copy_dist_context." + copy_list[i]
obj = "dist_context." + copy_list[i]
assert id(eval(copy_obj)) == id(eval(obj))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import os
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import paddle.distributed.auto_parallel as auto
import paddle.distributed.fleet as fleet
from paddle.io import Dataset
from paddle.static import InputSpec
from paddle.fluid.framework import _non_static_mode
from paddle.distributed.auto_parallel.engine import Engine
from paddle.distributed.auto_parallel.hepler import ProgramHelper
from test_to_static import MLPLayer, MyDataset
paddle.enable_static()
class TestEngineBase(unittest.TestCase):
def setUp(self):
self.batch_size = 4
self.batch_num = 5
self.hidden_size = 1024
self.init_model()
self.init_optimizer()
self.init_dataset()
self.init_engine()
def init_model(self):
self.mlp = MLPLayer(hidden_size=self.hidden_size,
intermediate_size=4 * self.hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
self.loss = paddle.nn.CrossEntropyLoss()
def init_optimizer(self):
self.optimizer = paddle.optimizer.SGD(learning_rate=0.00001,
parameters=self.mlp.parameters())
def init_dataset(self):
self.dataset = MyDataset(self.batch_num * self.batch_size)
def init_engine(self):
inputs = InputSpec([self.batch_size, self.hidden_size], 'float32', 'x')
labels = InputSpec([self.batch_size], 'int64', 'label')
self.engine = Engine(model=self.mlp,
inputs_spec=inputs,
labels_spec=labels)
self.engine.prepare(optimizer=self.optimizer,
loss=self.loss,
metrics=paddle.metric.Accuracy())
class TestLRScheduler(TestEngineBase):
def init_optimizer(self):
scheduler = paddle.optimizer.lr.CosineAnnealingDecay(
learning_rate=0.00001, T_max=10)
self.optimizer = paddle.optimizer.SGD(learning_rate=scheduler)
def test_lr_scheduler(self):
self.init_engine()
lr = self.engine._optimizer._learning_rate
assert isinstance(lr, paddle.optimizer.lr.LRScheduler)
self.engine.fit(self.dataset, batch_size=self.batch_size)
class TestGradClip(TestEngineBase):
def init_optimizer(self):
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
self.optimizer = paddle.optimizer.SGD(learning_rate=0.00001,
grad_clip=clip)
def test_grad_clip(self):
clip = self.engine._optimizer._grad_clip
assert isinstance(clip, paddle.nn.ClipGradByGlobalNorm)
self.engine.fit(self.dataset, batch_size=self.batch_size)
self.check_program()
def check_program(self):
ops = self.engine.main_program.global_block().ops
has_grad_clip = False
for op in ops:
if op.desc.has_attr("op_namescope") \
and op.desc.attr("op_namescope").startswith("/gradient_clip"):
has_grad_clip = True
break
assert has_grad_clip is True
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册