From e3334f3e3035bb3494e1c573abbd73cef5f61214 Mon Sep 17 00:00:00 2001 From: mapingshuo Date: Wed, 23 Sep 2020 10:51:18 +0800 Subject: [PATCH] add zero --- .../framework/distributed_strategy.proto | 10 + .../collective/c_sync_comm_stream_op.cc | 6 +- .../fleet/base/distributed_strategy.py | 33 + .../distributed/fleet/base/fleet_base.py | 3 + .../fleet/meta_optimizers/__init__.py | 1 + .../fleet/meta_optimizers/zero_optimizer.py | 1245 +++++++++++++++++ python/paddle/fluid/clip.py | 2 +- .../contrib/mixed_precision/decorator.py | 32 +- python/paddle/fluid/framework.py | 13 +- 9 files changed, 1323 insertions(+), 22 deletions(-) create mode 100644 python/paddle/distributed/fleet/meta_optimizers/zero_optimizer.py diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index df482f4334..84fe15bd23 100644 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -24,6 +24,14 @@ enum Mode { message RecomputeConfig { repeated string checkpoints = 1; } +message ZeROConfig { + optional bool amp = 1 [ default = true ]; + optional int32 nrings = 2 [ default = 3 ]; + optional float fuse_broadcast_MB_bytes = 3 [ default = 64.0 ]; + repeated string checkpoints = 4; + optional bool allreduce = 5 [ default = false ]; +} + message AMPConfig { optional float init_loss_scaling = 1 [ default = 32768.0 ]; optional int32 incr_every_n_steps = 2 [ default = 1000 ]; @@ -127,6 +135,7 @@ message DistributedStrategy { optional int32 conv_workspace_size_limit = 22 [ default = 4000 ]; optional bool cudnn_batchnorm_spatial_persistent = 23 [ default = true ]; optional bool adaptive_localsgd = 24 [ default = false ]; + optional bool zero = 25 [ default = false ]; optional RecomputeConfig recompute_configs = 101; optional AMPConfig amp_configs = 102; @@ -138,6 +147,7 @@ message DistributedStrategy { optional LarsConfig lars_configs = 108; optional LambConfig lamb_configs = 109; optional AdaptiveLocalSGDConfig adaptive_localsgd_configs = 110; + optional ZeROConfig zero_configs = 111; optional BuildStrategy build_strategy = 201; optional ExecutionStrategy execution_strategy = 202; } diff --git a/paddle/fluid/operators/collective/c_sync_comm_stream_op.cc b/paddle/fluid/operators/collective/c_sync_comm_stream_op.cc index 5405ea70ef..e6fea26b6b 100644 --- a/paddle/fluid/operators/collective/c_sync_comm_stream_op.cc +++ b/paddle/fluid/operators/collective/c_sync_comm_stream_op.cc @@ -55,8 +55,10 @@ class CSyncCommStreamOp : public framework::OperatorBase { class CSyncCommStreamOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() { - AddInput("X", "(Tensor) Dependency of the variable need to sync"); - AddOutput("Out", "(Tensor) Dependency of the variable need to sync"); + AddInput("X", "(Tensor) Dependency of the variable need to sync") + .AsDuplicable(); + AddOutput("Out", "(Tensor) Dependency of the variable need to sync") + .AsDuplicable(); AddAttr("ring_id", "(int default 0) ring id.").SetDefault(0); AddComment(R"DOC( CSyncCommStream Operator diff --git a/python/paddle/distributed/fleet/base/distributed_strategy.py b/python/paddle/distributed/fleet/base/distributed_strategy.py index f1c836468d..6c287473b6 100755 --- a/python/paddle/distributed/fleet/base/distributed_strategy.py +++ b/python/paddle/distributed/fleet/base/distributed_strategy.py @@ -611,6 +611,39 @@ class DistributedStrategy(object): "checkpoint_configs") assign_configs_value(self.strategy.recompute_configs, configs) + @property + def zero(self): + """ + Indicating whether we are using Zero Redundancy Optimizer for memory + optimization + Default value: False + Examples: + .. code-block:: python + import paddle.fleet as fleet + strategy = fleet.DistributedStrategy() + strategy.zero = True + """ + return self.strategy.zero + + @zero.setter + def zero(self, flag): + if isinstance(flag, bool): + self.strategy.zero = flag + else: + print("WARNING: zero should have value of bool type") + + @property + def zero_configs(self): + """ + Set zero configurations. + """ + return get_msg_dict(self.strategy.zero_configs) + + @zero_configs.setter + def zero_configs(self, configs): + check_configs_key(self.strategy.zero_configs, configs, "zero_configs") + assign_configs_value(self.strategy.zero_configs, configs) + @property def pipeline(self): """ diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index d00faac838..0ee140941c 100644 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -1086,6 +1086,9 @@ class Fleet(object): context["program_optimize_ops"] = optimize_ops context["program_params_grads"] = params_grads + if self.user_defined_strategy.zero: + graph_optimizer = None + if graph_optimizer: optimize_ops, params_grads = graph_optimizer.minimize( loss, diff --git a/python/paddle/distributed/fleet/meta_optimizers/__init__.py b/python/paddle/distributed/fleet/meta_optimizers/__init__.py index a3a2dee703..21969709a3 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/__init__.py +++ b/python/paddle/distributed/fleet/meta_optimizers/__init__.py @@ -23,3 +23,4 @@ from .lars_optimizer import LarsOptimizer from .parameter_server_graph_optimizer import ParameterServerGraphOptimizer from .dgc_optimizer import DGCOptimizer from .lamb_optimizer import LambOptimizer +from .zero_optimizer import ZeroOptimizer diff --git a/python/paddle/distributed/fleet/meta_optimizers/zero_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/zero_optimizer.py new file mode 100644 index 0000000000..a5c865a3a5 --- /dev/null +++ b/python/paddle/distributed/fleet/meta_optimizers/zero_optimizer.py @@ -0,0 +1,1245 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY, CollectiveHelper +from .common import is_update_op, is_loss_grad_op, is_backward_op, is_optimizer_op +from .meta_optimizer_base import MetaOptimizerBase +from paddle.fluid import unique_name, core +from paddle.fluid.contrib.mixed_precision.decorator import OptimizerWithMixedPrecision +import paddle.fluid as fluid + +import math +import re + +__all__ = ["ZeroOptimizer"] + + +def _pretty_op_desc_(op_desc, prefix): + out_s = "%s\tname:[%s]\n%s \tinputs:[%s]\n%s \toutputs:[%s]" % \ + (prefix + "_op", str(op_desc.type()), prefix + "_input", " ".join(op_desc.input_arg_names()), + prefix + "_output", " ".join(op_desc.output_arg_names())) + return out_s + + +class SubProgram(object): + def __init__(self, block): + self._block = block + self._allreduce_vars = [] + # sub program start idx + self._start_idx = -1 + # sub program end idx + self._end_idx = -1 + # param name to broadcast name + self._param2broadcast = {} + self._broadcast_vars = [] + # cast op pairs, fp16 name (str) -> fp32 name (str) + self._cast_ops = {} + # fill constant vars + self._fill_constant_vars = [] + # parameter mems + self._param_mem = 0.0 + + +class ProgramDeps(object): + def __init__(self, block, start_vars, end_vars): + self._block = block + # vars where to start to build the deps + self._start_vars = start_vars + # vars where to stop to build the deps + self._end_vars = end_vars + # var name -> op idxs which depends on this var + self._var_deps = {} + # sub block deps which is a subset of this topo + self._sub_block_deps = {} + # var name -> op idxs which generate var + self._var_to_generate_op = {} + self._should_removed_var = set() + self._father_block_deps = None + self._build_deps() + + def get_sub_block_deps(self, idx): + if idx in self._sub_block_deps: + return self._sub_block_deps[idx] + else: + return None + + def get_var_deps(self, var_name): + if var_name in self._var_deps: + return self._var_deps[var_name] + else: + return None + + def _build_deps(self, ): + for var_name in self._start_vars: + self._var_deps[var_name] = [-1] + self._var_to_generate_op[var_name] = [-1] + + for idx, op in enumerate(self._block.ops): + if op.type in [ + "c_allreduce_sum", "c_sync_comm_stream", + "c_calc_comm_stream" + ]: + continue + input_vars = op.desc.input_arg_names() + output_vars = op.desc.output_arg_names() + deps_reduce = False + for input_name in input_vars: + if input_name in self._var_deps: + deps_reduce = True + if deps_reduce: + for input_name in input_vars: + if input_name in self._var_deps: + self._var_deps[input_name].append(idx) + for output_name in output_vars: + self._var_deps[output_name] = [] + if output_name not in self._var_to_generate_op: + self._var_to_generate_op[output_name] = [idx] + else: + self._var_to_generate_op[output_name].append(idx) + if op.type == "conditional_block": + # subblock + assert (op.desc.has_attr("sub_block")) + subblock_idx = op.desc.attr("sub_block").id + subblock_deps = ProgramDeps( + self._block.program.block(subblock_idx), + op.desc.input_arg_names(), op.desc.output_arg_names()) + self._sub_block_deps[subblock_idx] = subblock_deps + subblock_deps._father_block_deps = self + + def crop_input_var_from_op(self, op_idx, var_name): + if var_name in self._var_deps: + # update var -> dep_var_op + if self._var_deps[var_name] != []: + assert (op_idx in self._var_deps[var_name]) + self._var_deps[var_name].remove(op_idx) + # update _should_removed_var + if var_name in self._start_vars: + self._should_removed_var.discard(var_name) + elif self._var_deps[var_name] == []: # no more deps of this var + self._should_removed_var.add(var_name) + elif self._var_to_generate_op[var_name][-1] >= self._var_deps[ + var_name][-1]: + # there are circle in the graph + self._should_removed_var.add(var_name) + else: # input_name should not be deleted + self._should_removed_var.discard(var_name) + + def crop_output_var_from_op(self, op_idx, var_name): + if var_name in self._var_to_generate_op: + assert (op_idx in self._var_to_generate_op[var_name]) + self._var_to_generate_op[var_name].remove(op_idx) + if self._block.has_var(var_name) and self._var_to_generate_op[ + var_name] == []: + print("main_block remove var {}".format(var_name)) + self._block._remove_var(var_name) + + def remove_op(self, op_idx): + # update deps + op = self._block.ops[op_idx] + print("main_block remove op {}".format(op.type)) + for input_name in op.desc.input_arg_names(): + self.crop_input_var_from_op(op_idx, input_name) + for output_name in op.desc.output_arg_names(): + self.crop_output_var_from_op(op_idx, output_name) + self._block._remove_op(op_idx) + + def should_remove_op(self, op_idx): + op = self._block.ops[op_idx] + for output_name in op.desc.output_arg_names(): + if output_name not in self._should_removed_var: + return False + return True + + +class ZeroOptimizer(MetaOptimizerBase): + def __init__(self, optimizer): + super(ZeroOptimizer, self).__init__(optimizer) + self.inner_opt = optimizer + self._main_program = None + self._startup_program = None + # we do not allow meta optimizer to be inner optimizer currently + self.meta_optimizers_white_list = [] + # params and fp16 params is for broadcast + self._params = set([]) + self._fp16_params = set([]) + # fp16 to fp32 + self._fp16_to_params = {} + self._broadcast_vars = set([]) + # _param(str) -> device_id(int) + self._param2device = {} + # varname(str) -> param(Variable) + # reduced grads to param name + self._reduced_grads_to_param = {} + # self._nrings(int) is for nccl communicate + self._nrings = 3 + # self._sub_progs + self._sub_progs = [] + self._fuse_broadcast_MB_bytes = 64 + self._dtype_to_size = { + core.VarDesc.VarType.FP16: 2, + core.VarDesc.VarType.FP32: 4, + core.VarDesc.VarType.FP64: 8, + core.VarDesc.VarType.INT16: 2, + core.VarDesc.VarType.INT32: 4, + core.VarDesc.VarType.INT64: 8, + core.VarDesc.VarType.BOOL: 1, + core.VarDesc.VarType.UINT8: 1, + } + + def _get_var_size(self, param): + """ + input: + - param: var + return: + var size in Bytes + """ + assert -1 not in param.shape + return reduce( + lambda x, y: x * y, + param.shape) * self._dtype_to_size[param.dtype] / 1024.0 / 1024.0 + + def _can_apply(self): + return self.user_defined_strategy.zero + + def _disable_strategy(self, dist_strategy): + dist_strategy.zero = False + + def _is_fp16_cast_op(self, block, op): + if op.type != "cast": + return False + if is_optimizer_op(op): + return False + assert (len(op.desc.input_arg_names()) == 1) + assert (len(op.desc.output_arg_names()) == 1) + input_name, output_name = op.desc.input_arg_names()[ + 0], op.desc.output_arg_names()[0] + if input_name not in self._params: + return False + input_var = block.var(input_name) + output_var = block.var(output_name) + if input_var.dtype != core.VarDesc.VarType.FP32 or \ + output_var.dtype != core.VarDesc.VarType.FP16: + return False + return True + + def _split_params(self, params): + param2device = {} + total_param_mem = 0.0 + param2mem = [] + for param in params: + mem = self._get_var_size(param) + total_param_mem += mem + param2mem.append((param.name, mem)) + # print(param.name, mem) + # print("total_param_mem: ", total_param_mem) + device_num = self.role_maker.worker_num() + # print("device_num: ", device_num) + device2params = {x: [] for x in range(device_num)} + device_idx = 0 + mem_accu = 0.0 + for param_name, mem in param2mem: + if mem_accu > total_param_mem * 1.0 * (device_idx + 1) / device_num: + device_idx += 1 + device2params[device_idx].append(param_name) + param2device[param_name] = device_idx + mem_accu += mem + # for debug + print(device2params) + return param2device + + def _is_opti_var(self, var_name): + if var_name in self._params: + return True + for suffix in [ + "_moment1_0", "_moment2_0", "_beta1_pow_acc_0", + "_beta2_pow_acc_0" + ]: + base_name = re.sub(suffix, '', var_name) + if base_name in self._params: + return True + return False + + def _var_device_id(self, var_name): + if not self._is_opti_var(var_name): + return -1 + if var_name in self._param2device: + return self._param2device[var_name] + for suffix in [ + "_moment1_0", "_moment2_0", "_beta1_pow_acc_0", + "_beta2_pow_acc_0" + ]: + base_name = re.sub(suffix, '', var_name) + if base_name in self._param2device: + return self._param2device[base_name] + return -1 + + def _insert_scale_loss_grad_ops(self, block, scale=1.0): + ''' + In order to keep the learning rate consistent in different numbers of + training workers, we scale the loss grad by the number of workers + ''' + for idx, op in reversed(list(enumerate(block.ops))): + if is_loss_grad_op(op): + loss_grad_var = block.vars[op.output_arg_names[0]] + block._insert_op( + idx + 1, + type='scale', + inputs={'X': loss_grad_var}, + outputs={'Out': loss_grad_var}, + attrs={'scale': scale, + OP_ROLE_KEY: OpRole.Backward}) + + def _split_program(self, block): + for op_idx, op in reversed(list(enumerate(block.ops))): + if int(op.attr('op_role')) != int(OpRole.Optimize): + last_backward_op_idx = op_idx + 1 + break + sub_prog = SubProgram(block) + sub_prog._end_idx = last_backward_op_idx + for op_idx in reversed(range(last_backward_op_idx)): + op = block.ops[op_idx] + assert (int(op.attr('op_role')) != int(OpRole.Optimize)) + if sub_prog._param_mem >= self._fuse_broadcast_MB_bytes: + sub_prog._start_idx = op_idx + 1 + self._sub_progs.insert(0, sub_prog) + sub_prog = SubProgram(block) + sub_prog._end_idx = op_idx + 1 + + # find broadcast vars + for input_name in op.desc.input_arg_names(): + if input_name not in self._broadcast_vars: + continue + root_device = self._param2device[input_name] + if input_name in sub_prog._param2broadcast: + # skip broadcast because it reuse the old broadcast var + broadcast_name = sub_prog._param2broadcast[input_name] + if input_name != broadcast_name: + op._rename_input(input_name, broadcast_name) + continue + if root_device == self.role_maker.worker_index(): + broadcast_var_name = input_name + else: + broadcast_var_name = unique_name.generate(input_name + + "@BroadCast") + sub_prog._fill_constant_vars.append(broadcast_var_name) + sub_prog._param2broadcast[input_name] = broadcast_var_name + sub_prog._broadcast_vars.append( + (broadcast_var_name, self._param2device[input_name])) + sub_prog._param_mem += self._get_var_size( + self._main_program.global_block().var(input_name)) + + # find reduce vars + if is_backward_op(op) and \ + OP_ROLE_VAR_KEY in op.attr_names: + op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY] + if len(op_role_var) != 0: + assert len(op_role_var) % 2 == 0 + for i in range(0, len(op_role_var), 2): + param, reduced_grad = op_role_var[i], op_role_var[i + 1] + sub_prog._allreduce_vars.append(reduced_grad) + assert ( + reduced_grad not in self._reduced_grads_to_param) + self._reduced_grads_to_param[reduced_grad] = param + + # find cast op + if self._is_fp16_cast_op(block, op): + fp32_param = op.desc.input_arg_names()[0] + fp16_param = op.desc.output_arg_names()[0] + if self._param2device[ + fp32_param] == self.role_maker.worker_index(): + sub_prog._cast_ops[fp16_param] = fp32_param + + if sub_prog._param_mem > 0: + sub_prog._start_idx = 0 + self._sub_progs.insert(0, sub_prog) + return + + def _is_gradient_clip_sum_op(self, op): + return op.type == "sum" and op.desc.has_attr("op_namescope") \ + and op.desc.attr("op_namescope").startswith("/gradient_clip_@CLIP") + + def _is_amp_sum_op(self, op): + return op.type == "sum" and op.desc.has_attr("op_namescope") \ + and op.desc.attr("op_namescope").startswith("/mixed_precision") + + def _is_amp_subblock(self, op): + return op.type == "conditional_block" and op.desc.has_attr("op_namescope") \ + and op.desc.attr("op_namescope").startswith("/mixed_precision") + + def _prune_main_program(self, block): + """ + calculate deps from allredce op to optimize op, + remove ops and vars not needed in this worker + """ + # build prog deps + reduced_grads = [] + var_to_reduce_var = {} + for idx, op in enumerate(block.ops): + input_names = op.desc.input_arg_names() + output_names = op.desc.output_arg_names() + if op.type == "c_allreduce_sum": + assert (len(output_names) == 1) + output_name = output_names[0] + reduced_grads.append(output_name) + var_to_reduce_var[output_name] = output_name + else: + non_persistable_input = [ + x for x in input_names if not block.var(x).persistable + ] + if len(non_persistable_input) == 1 and len( + output_names) == 1 and non_persistable_input[ + 0] in var_to_reduce_var: + var_to_reduce_var[output_names[0]] = var_to_reduce_var[ + non_persistable_input[0]] + + params = [] + for var_name, _ in block.vars.items(): + if self._is_opti_var(var_name) and \ + self._var_device_id(var_name) != self.role_maker.worker_index(): + params.append(var_name) + program_deps = ProgramDeps(block, reduced_grads, params) + + # Init + for var_name in program_deps._end_vars: + program_deps._should_removed_var.add(var_name) + + # Prune + for idx, op in reversed(list(enumerate(block.ops))): + if op.type in [ + "c_allreduce_sum", "c_sync_comm_stream", + "c_calc_comm_stream", "c_gen_nccl_id", "c_comm_init" + ]: + pass + elif self._is_gradient_clip_sum_op(op) or self._is_amp_sum_op(op): + reversed_input_vars = [] + for input_name in op.desc.input_arg_names(): + assert (input_name in var_to_reduce_var) + reduce_var = var_to_reduce_var[input_name] + param_name = self._reduced_grads_to_param[reduce_var] + if self._param2device[ + param_name] != self.role_maker.worker_index(): + program_deps.crop_input_var_from_op(idx, input_name) + else: + reversed_input_vars.append(input_name) + op.desc.set_input("X", reversed_input_vars) + assert (len(op.desc.output_arg_names()) == 1) + sum_res = op.desc.output_arg_names()[0] + block._insert_op( + idx + 1, + type='c_sync_comm_stream', + inputs={'X': sum_res}, + outputs={'Out': sum_res}, + attrs={'ring_id': 0, + OP_ROLE_KEY: OpRole.Optimize}) + block._insert_op( + idx + 1, + type='c_allreduce_sum', + inputs={'X': sum_res}, + outputs={'Out': sum_res}, + attrs={'ring_id': 0, + OP_ROLE_KEY: OpRole.Optimize}) + block._insert_op( + idx + 1, + type='c_sync_calc_stream', + inputs={'X': sum_res}, + outputs={'Out': sum_res}, + attrs={OP_ROLE_KEY: OpRole.Optimize}) + elif op.type == "conditional_block": + assert (op.desc.has_attr("sub_block")) + subblock_idx = op.desc.attr("sub_block").id + subblock_deps = program_deps.get_sub_block_deps(subblock_idx) + # only prune amp subblock + if subblock_deps is None or not self._is_amp_subblock(op): + continue + # init + reversed_output_vars = [] + for output_name in op.desc.output("Out"): + if output_name in program_deps._should_removed_var: + subblock_deps._should_removed_var.add(output_name) + program_deps.crop_output_var_from_op(idx, output_name) + else: + reversed_output_vars.append(output_name) + # prune + for sub_op_idx, _ in reversed( + list(enumerate(subblock_deps._block.ops))): + if subblock_deps.should_remove_op(sub_op_idx): + subblock_deps.remove_op(sub_op_idx) + reversed_input_vars = [] + for input_name in op.desc.input('Input'): + if input_name not in subblock_deps._should_removed_var: + reversed_input_vars.append(input_name) + else: + program_deps.crop_input_var_from_op(idx, input_name) + op.desc.set_input('Input', reversed_input_vars) + op.desc.set_output('Out', reversed_output_vars) + else: + if program_deps.should_remove_op(idx): + program_deps.remove_op(idx) + + block._sync_with_cpp() + return + + def _remove_cast_op(self, block, sub_prog, offset): + inserted_op_num = 0 + for op_idx in reversed( + range(offset + sub_prog._start_idx, offset + + sub_prog._end_idx)): + op = block.ops[op_idx] + if self._is_fp16_cast_op(block, op): + block._remove_op(op_idx) + inserted_op_num -= 1 + block._sync_with_cpp() + return inserted_op_num + + def _insert_broadcast_ops(self, block, insert_idx, broadcast2root): + """ + _add_broadcast_ops + """ + ring_id = -1 + # TODO(mapingshuo): correct OP_ROLE_KEY + for broadcast_name, root_device in broadcast2root: + ring_id = (ring_id + 1) % self._nrings + block._insert_op( + insert_idx, + type='c_broadcast', + inputs={'X': broadcast_name}, + outputs={'Out': broadcast_name}, + attrs={ + 'ring_id': ring_id, + 'root': root_device, + OP_ROLE_KEY: OpRole.Forward + }) + return + + def _insert_allreduce_ops(self, block, insert_idx, allreduce_vars): + """ + _add_allreduce_ops + """ + ring_id = -1 + for var in allreduce_vars: + ring_id = (ring_id + 1) % self._nrings + block._insert_op( + insert_idx, + type='c_allreduce_sum', + inputs={'X': var}, + outputs={'Out': var}, + attrs={'ring_id': ring_id, + OP_ROLE_KEY: OpRole.Backward}) + return + + def _insert_cast_ops(self, block, insert_idx, cast_ops): + """ + _add_cast_ops + """ + for fp16_name, fp32_name in cast_ops.items(): + block._insert_op( + insert_idx, + type="cast", + inputs={"X": fp32_name}, + outputs={"Out": fp16_name}, + attrs={ + "in_dtype": core.VarDesc.VarType.FP32, + "out_dtype": core.VarDesc.VarType.FP16 + }) + return + + def _insert_fill_constant_ops(self, block, insert_idx, fill_constant_vars): + """ + _add_fill_constant_ops + """ + for broadcast_name in fill_constant_vars: + broadcast_var = block.var(broadcast_name) + block._insert_op( + insert_idx, + type="fill_constant", + outputs={"Out": broadcast_var.name}, + attrs={ + "shape": broadcast_var.shape, + "dtype": broadcast_var.dtype, + "value": 0.0, + }) + return + + def _insert_sync_comm_ops(self, block, insert_idx, comm_dep_vars): + """ + _insert_sync_comm_ops + """ + # TODO(mapingshuo) fix OP_ROLE_KEY + for i in range(self._nrings): + block._insert_op( + insert_idx, + type='c_sync_comm_stream', + inputs={'X': comm_dep_vars}, + outputs={'Out': comm_dep_vars}, + attrs={'ring_id': i, + OP_ROLE_KEY: OpRole.Forward}) + return + + def _insert_sync_calc_op(self, block, insert_idx, calc_dep_vars): + """ + _insert_sync_calc_op + """ + # TODO(mapingshuo) fix OP_ROLE_KEY + block._insert_op( + insert_idx, + type='c_sync_calc_stream', + inputs={'X': calc_dep_vars}, + outputs={'Out': calc_dep_vars}, + attrs={OP_ROLE_KEY: OpRole.Forward}) + return + + def _add_broadcast_allreduce_v2(self, block): + """ + _add_broadcast_allreduce_v2 + """ + ring_id = -1 + + if len(self._sub_progs) < 1: + return + + if self._sub_progs[-1]._allreduce_vars: + self._insert_sync_comm_ops(block, self._sub_progs[-1]._end_idx, + self._sub_progs[-1]._allreduce_vars) + self._insert_allreduce_ops(block, self._sub_progs[-1]._end_idx, + self._sub_progs[-1]._allreduce_vars) + + for idx, subprog in reversed(list(enumerate(self._sub_progs))): + print("subprog_{}: ({}-{})".format(idx, subprog._start_idx, + subprog._end_idx)) + + allreduce_vars = self._sub_progs[ + idx - 1]._allreduce_vars if idx > 0 else [] + broadcast_vars = self._sub_progs[idx + + 1]._broadcast_vars if idx < len( + self._sub_progs) - 1 else [] + fill_constant_vars = self._sub_progs[ + idx + 2]._fill_constant_vars if idx < len( + self._sub_progs) - 2 else [] + cast_ops = self._sub_progs[idx + 2]._cast_ops if idx < len( + self._sub_progs) - 2 else {} + + # for x in fill_constant_vars: + # print("fill_constant_vars: ", x) + + # step1: modify calculate ops + # for op_idx in reversed(range(subprog._start_idx, subprog._end_idx)): + # op = block.ops[op_idx] + # print(_pretty_op_desc_(op.desc, "subprog_op")) + + for op_idx in reversed(range(subprog._start_idx, subprog._end_idx)): + op = block.ops[op_idx] + for input_name in op.desc.input_arg_names(): + if input_name in subprog._param2broadcast and \ + input_name != subprog._param2broadcast[input_name]: + op._rename_input(input_name, + subprog._param2broadcast[input_name]) + + for param_name, broadcast_name in subprog._param2broadcast.items(): + if param_name != broadcast_name: + block.create_var( + name=broadcast_name, + shape=self._main_program.global_block().var( + param_name).shape, + dtype=self._main_program.global_block().var(param_name) + .dtype, + persistable=False) + + # step2: remove cast ops + block._sync_with_cpp() + subprog._end_idx += self._remove_cast_op(block, subprog, 0) + + # step3: add Sync ops + comm_dep_vars = allreduce_vars + [x[0] for x in broadcast_vars] + if len(comm_dep_vars) > 0: + self._insert_sync_comm_ops( + block, + subprog._end_idx, + comm_dep_vars, ) + calc_dep_vars = fill_constant_vars + [ + k for k, v in cast_ops.items() + ] + if len(calc_dep_vars) > 0: + self._insert_sync_calc_op(block, subprog._end_idx, + [calc_dep_vars[-1]]) + + # step4: insert `fill_constant` ops + self._insert_fill_constant_ops(block, subprog._end_idx, + fill_constant_vars) + + # step5: add `cast` ops + self._insert_cast_ops(block, subprog._end_idx, cast_ops) + + # step6: add broadcast ops + self._insert_broadcast_ops(block, subprog._start_idx, + broadcast_vars) + + # step7: add all_reduce ops + self._insert_allreduce_ops(block, subprog._start_idx, + allreduce_vars) + + block._sync_with_cpp() + + if self._sub_progs[0]._broadcast_vars: + self._insert_sync_comm_ops( + block, self._sub_progs[0]._start_idx, + [x[0] for x in self._sub_progs[0]._broadcast_vars]) + self._insert_broadcast_ops(block, self._sub_progs[0]._start_idx, + self._sub_progs[0]._broadcast_vars) + + fill_constant_vars = reduce( + lambda x, y: x._fill_constant_vars + y._fill_constant_vars, + self._sub_progs[:2]) + + # Join + cast_ops = {} + for x in self._sub_progs[:2]: + for k, v in x._cast_ops.items(): + cast_ops[k] = v + + calc_deps_vars = fill_constant_vars + [k for k, v in cast_ops.items()] + if fill_constant_vars or cast_ops: + self._insert_sync_calc_op(block, self._sub_progs[0]._start_idx, + [calc_deps_vars[-1]]) + + if fill_constant_vars: + self._insert_fill_constant_ops(block, self._sub_progs[0]._start_idx, + fill_constant_vars) + + if cast_ops: + self._insert_cast_ops(block, self._sub_progs[0]._start_idx, + cast_ops) + + return + + def _prune_startup_program(self, block): + for idx, op in reversed(list(enumerate(block.ops))): + for output_name in op.desc.output_arg_names(): + var_device_id = self._var_device_id(output_name) + if var_device_id == -1 or var_device_id == self.role_maker.worker_index( + ): + continue + print("%d: startup_block remove op %s" % + (self.role_maker.worker_index(), op.type)) + block._remove_op(idx) + break + for var_name, _ in block.vars.items(): + var_device_id = self._var_device_id(var_name) + if var_device_id == -1 or var_device_id == self.role_maker.worker_index( + ): + continue + print("%d: startup_block remove var %s" % + (self.role_maker.worker_index(), var_name)) + block._remove_var(var_name) + block._sync_with_cpp() + + def _find_broadcast_params(self, params, param2device): + broadcast_vars = set([]) + fp16_params = set([]) + fp16_to_fp32 = {} + main_block = self._main_program.global_block() + + param_usage = {x: 0 for x in params} + for op in main_block.ops: + if is_optimizer_op(op): + continue + for input_name in op.desc.input_arg_names(): + if input_name in params: + param_usage[input_name] += 1 + + for op in main_block.ops: + if not self._is_fp16_cast_op(main_block, op): + continue + input_name = op.input_arg_names[0] + output_name = op.output_arg_names[0] + broadcast_vars.add(output_name) + fp16_params.add(output_name) + fp16_to_fp32[output_name] = input_name + param_usage[input_name] -= 1 + param2device[output_name] = param2device[input_name] + + for param, usage in param_usage.items(): + if usage > 0: + broadcast_vars.add(param) + return fp16_params, broadcast_vars, fp16_to_fp32 + + def _set_up(self, params_grads): + # step 1: initialize nccl + # TODO(mapingshuo) fix get_trainer_endpoints + print("work idx: ", self.role_maker.worker_index()) + endpoints = self.role_maker.get_trainer_endpoints() + current_endpoint = endpoints[self.role_maker.worker_index()] + collective_helper = CollectiveHelper(self.role_maker, self._nrings) + for ring_id in range(self._nrings): + collective_helper._init_communicator( + self._startup_program, current_endpoint, endpoints, + self.role_maker.worker_index(), ring_id, '6174') + startup_block = self._startup_program.global_block() + startup_block._sync_with_cpp() + + # step 2: split params + self._params = set([x[0].name for x in params_grads]) + self._param2device = self._split_params([x[0] for x in params_grads]) + + # step 3: get broadcast vars + self._fp16_params, self._broadcast_vars, self._fp16_to_params = self._find_broadcast_params( + self._params, self._param2device) + + def minimize_impl(self, + loss, + startup_program=None, + parameter_list=None, + no_grad_set=None): + + if self.user_defined_strategy.zero_configs["allreduce"]: + return self.minimize_impl_allreduce(loss, startup_program, + parameter_list, no_grad_set) + + ckpts = list(self.user_defined_strategy.zero_configs["checkpoints"]) + optimizer = self.inner_opt + if len(ckpts) > 0: + print("add recompute") + print(ckpts) + optimizer = fluid.optimizer.RecomputeOptimizer(optimizer) + optimizer._set_checkpoints(ckpts) + + if self.user_defined_strategy.zero_configs["amp"]: + optimizer = fluid.contrib.mixed_precision.decorate( + optimizer, use_dynamic_loss_scaling=True) + + self._nrings = self.user_defined_strategy.zero_configs["nrings"] + self._fuse_broadcast_MB_bytes = self.user_defined_strategy.zero_configs[ + "fuse_broadcast_MB_bytes"] + + print("doing zero optimize...") + optimize_ops, params_grads = optimizer.minimize( + loss, startup_program, parameter_list, no_grad_set) + + if startup_program is None: + startup_program = default_startup_program() + main_block = loss.block + startup_block = startup_program.global_block() + self._main_program = main_block.program + self._startup_program = startup_program + + # step1: set_up + self._set_up(params_grads) + + # step2: split_program + self._split_program(main_block) + + # step3: add broadcast and reduce ops + print("insert broadcast and allreduce") + self._add_broadcast_allreduce_v2(main_block) + main_block._sync_with_cpp() + startup_block._sync_with_cpp() + + # step4: insert reduce_sum for grad + self._insert_scale_loss_grad_ops( + main_block, scale=1.0 / self.role_maker.worker_num()) + main_block._sync_with_cpp() + + # step5: remove unneeded ops and vars from block + print("main_block remove ops and vars") + self._prune_main_program(main_block) + print("startup_block remove ops and vars") + self._prune_startup_program(startup_block) + + # check op dependecy for broadcast + self._check_broadcast(main_block) + return optimize_ops, params_grads + + def _check_broadcast(self, block): + """ + if a var is broadcasted, it should have a sync_comm before + this var is used, if not, raise error. + if the broadcasted var has a fill_constant op, the fill_constant + op should stay forward before the broadcast op, and before a + sync_calc op. Otherwise, raise error. + """ + broadcast_vars = {} + for idx, op in enumerate(block.ops): + if op.type == "c_broadcast": + var_name = op.desc.input_arg_names()[0] + if "@BroadCast" in var_name: + if var_name in broadcast_vars: + print("error: var_name areadly exist: ", var_name) + print("the old pos is ", + broadcast_vars[var_name]["broadcast_pos"]) + print("the new pos is ", idx) + assert (var_name not in broadcast_vars) + broadcast_vars[var_name] = { + "fill_constant_pos": -1, + "broadcast_pos": idx, + } + + for idx, op in enumerate(block.ops): + if op.type == "fill_constant": + var_name = op.desc.output_arg_names()[0] + if var_name in broadcast_vars: + broadcast_vars[var_name]["fill_constant_pos"] = idx + continue + + last_sync_comm_op_idx = -1 + last_sync_calc_op_idx = -1 + for idx, op in enumerate(block.ops): + if op.type == "c_sync_comm_stream": + last_sync_comm_op_idx = idx + continue + if op.type == "c_sync_calc_stream": + last_sync_calc_op_idx = idx + continue + if op.type == "c_broadcast": + var_name = op.desc.input_arg_names()[0] + if "@BroadCast" in var_name: + if broadcast_vars[var_name]["fill_constant_pos"] != -1: + assert (last_sync_calc_op_idx != -1) + assert (broadcast_vars[var_name]["fill_constant_pos"] < + last_sync_calc_op_idx) + assert (last_sync_calc_op_idx < idx) + continue + for input_name in op.desc.input_arg_names(): + if input_name in broadcast_vars: + assert (broadcast_vars[input_name]["broadcast_pos"] != -1) + assert (broadcast_vars[input_name]["broadcast_pos"] < + last_sync_comm_op_idx) + assert (last_sync_comm_op_idx < idx) + print("check done") + return + + def _add_broadcast_allreduce(self, block, sub_prog, offset): + """ + add broadcast and allreduce + """ + # insert reduce ops + inserted_op_num = 0 + ring_id = -1 + + if len(sub_prog._allreduce_vars) > 0: + for i in range(self._nrings): + block._insert_op( + offset + sub_prog._end_idx, + type='c_sync_comm_stream', + inputs={'X': sub_prog._allreduce_vars}, + outputs={'Out': sub_prog._allreduce_vars}, + attrs={'ring_id': i, + OP_ROLE_KEY: OpRole.Forward}) + inserted_op_num += self._nrings + + for var in sub_prog._allreduce_vars: + ring_id = (ring_id + 1) % self._nrings + block._insert_op( + offset + sub_prog._end_idx, + type='c_allreduce_sum', + inputs={'X': var}, + outputs={'Out': var}, + attrs={'ring_id': ring_id, + OP_ROLE_KEY: OpRole.Backward}) + inserted_op_num += 1 + + block._insert_op( + offset + sub_prog._end_idx, + type='c_sync_calc_stream', + inputs={'X': sub_prog._allreduce_vars[-1]}, + outputs={'Out': sub_prog._allreduce_vars[-1]}, + attrs={OP_ROLE_KEY: OpRole.Forward}) + inserted_op_num += 1 + + block._sync_with_cpp() + # insert broadcast ops + for op_idx in reversed( + range(offset + sub_prog._start_idx, offset + + sub_prog._end_idx)): + op = block.ops[op_idx] + for input_name in op.desc.input_arg_names(): + if input_name in sub_prog._param2broadcast and \ + input_name != sub_prog._param2broadcast[input_name]: + op._rename_input(input_name, + sub_prog._param2broadcast[input_name]) + + for param_name, broadcast_name in sub_prog._param2broadcast.items(): + if param_name != broadcast_name: + block.create_var( + name=broadcast_name, + shape=self._main_program.global_block().var( + param_name).shape, + dtype=self._main_program.global_block().var(param_name) + .dtype, + persistable=False) + + comm_dep_vars = [v for k, v in sub_prog._param2broadcast.items()] + for i in range(self._nrings): + block._insert_op( + offset + sub_prog._start_idx, + type='c_sync_comm_stream', + inputs={'X': comm_dep_vars}, + outputs={'Out': comm_dep_vars}, + attrs={'ring_id': i, + OP_ROLE_KEY: OpRole.Forward}) + inserted_op_num += self._nrings + + for param_name, broadcast_name in sub_prog._param2broadcast.items(): + broadcast_var = block.var(broadcast_name) + root_device = self._param2device[param_name] + ring_id = (ring_id + 1) % self._nrings + block._insert_op( + offset + sub_prog._start_idx, + type='c_broadcast', + inputs={'X': broadcast_var.name}, + outputs={'Out': broadcast_var.name}, + attrs={ + 'ring_id': ring_id, + 'root': root_device, + OP_ROLE_KEY: OpRole.Forward + }) + inserted_op_num += 1 + + comm_dep_vars = [ + v for k, v in sub_prog._param2broadcast.items() if k != v + ] + if comm_dep_vars != []: + block._insert_op( + offset + sub_prog._start_idx, + type='c_sync_calc_stream', + inputs={'X': comm_dep_vars[-1]}, + outputs={'Out': comm_dep_vars[-1]}, + attrs={OP_ROLE_KEY: OpRole.Forward}) + inserted_op_num += 1 + + for param_name, broadcast_name in sub_prog._param2broadcast.items(): + if param_name != broadcast_name: + broadcast_var = block.var(broadcast_name) + block._insert_op( + offset + sub_prog._start_idx, + type="fill_constant", + outputs={"Out": broadcast_var.name}, + attrs={ + "shape": broadcast_var.shape, + "dtype": broadcast_var.dtype, + "value": 0.0, + }) + inserted_op_num += 1 + + for fp16_name, fp32_name in sub_prog._cast_ops.items(): + block._insert_op( + offset + sub_prog._start_idx, + type="cast", + inputs={"X": fp32_name}, + outputs={"Out": fp16_name}, + attrs={ + "in_dtype": core.VarDesc.VarType.FP32, + "out_dtype": core.VarDesc.VarType.FP16 + }) + inserted_op_num += 1 + + block._sync_with_cpp() + return inserted_op_num + + def _broadcast_params(self, block): + ring_id = -1 + for param in block.iter_parameters(): + if param.is_distributed: + continue + ring_id = (ring_id + 1) % self._nrings + block.append_op( + type='c_broadcast', + inputs={'X': param}, + outputs={'Out': param}, + attrs={ + 'ring_id': ring_id, + 'root': 0, + OP_ROLE_KEY: OpRole.Forward + }) + for ring_id in range(self._nrings): + block.append_op( + type='c_sync_comm_stream', + inputs={'X': param}, + outputs={'Out': param}, + attrs={'ring_id': ring_id, + OP_ROLE_KEY: OpRole.Forward}) + + # def _insert_broadcast_ops(self, block, fuse_broadcast=False): + # def _insert_cache(cache, + # prepend_comm_sync=False, + # append_comm_sync=False): + # insert_idx = cache["insert_idx"] + # dummy_var_name = cache["dummy_var_name"] + # assert (len(cache["broadcast_ops"]) > 0) + + # if prepend_comm_sync: + # insert_idx += self._insert_comm_sync(block, insert_idx, + # [dummy_var_name]) + + # if len(cache["fill_constant_ops"]) > 0: + # insert_idx += self._insert_fill_constant( + # block, insert_idx, cache["fill_constant_ops"], + # [dummy_var_name]) + + # insert_idx += self._insert_broadcast_inner(block, insert_idx, + # cache["broadcast_ops"]) + + # if append_comm_sync: + # insert_idx += self._insert_comm_sync(block, insert_idx, + # [dummy_var_name]) + + # return insert_idx - cache["insert_idx"] + + # print("insert_idx: ", [x["insert_idx"] for x in self._sub_progs]) + # move_ahead = 1 + # for idx, cache in reversed(list(enumerate(self._sub_progs))): + # if idx < move_ahead: + # cache["insert_idx"] = 0 + # else: + # cache["insert_idx"] = self._sub_progs[idx - move_ahead][ + # "insert_idx"] + # print("insert_idx: ", [x["insert_idx"] for x in self._sub_progs]) + + # inserted_op_num = 0 + # for idx, cache in enumerate(self._sub_progs): + # prepend_comm_sync = True + # append_comm_sync = True + # cache["insert_idx"] += inserted_op_num + # inserted_op_num += _insert_cache( + # cache, + # prepend_comm_sync=prepend_comm_sync, + # append_comm_sync=append_comm_sync) + # return + + def _insert_allreduce_ops_tmp(self, block): + ring_id = -1 + grad = None + for idx, op in reversed(list(enumerate(block.ops))): + if is_backward_op(op) and \ + OP_ROLE_VAR_KEY in op.attr_names: + op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY] + + if len(op_role_var) == 0: + continue + assert len(op_role_var) % 2 == 0 + + offset = idx + for i in range(0, len(op_role_var), 2): + # param = block.vars[op_role_var[i]] + grad = block.vars[op_role_var[i + 1]] + # TODO(mapingshuo): what is is_distributed + # if param.is_distributed: + # continue + + if offset == idx: + offset += 1 + block._insert_op( + offset, + type='c_sync_calc_stream', + inputs={'X': grad}, + outputs={'Out': grad}, + attrs={OP_ROLE_KEY: OpRole.Backward}) + offset += 1 + # As we search ops reversedly, we should insert c_allreduce_sum + # op in the same way to keep the ring_id alternate + print("add allreduce op for {}".format(grad.name)) + ring_id = (ring_id + 1) % self._nrings + block._insert_op( + offset, + type='c_allreduce_sum', + inputs={'X': grad}, + outputs={'Out': grad}, + attrs={ + 'ring_id': ring_id, + OP_ROLE_KEY: OpRole.Backward + }) + + if grad is None: + return + + for idx, op in enumerate(block.ops): + if is_optimizer_op(op): + for ring_id in range(self._nrings): + block._insert_op( + idx + ring_id, + type='c_sync_comm_stream', + inputs={'X': grad}, + outputs={'Out': grad}, + attrs={ + 'ring_id': ring_id, + OP_ROLE_KEY: OpRole.Backward + }) + break + + def minimize_impl_allreduce(self, + loss, + startup_program=None, + parameter_list=None, + no_grad_set=None): + + self._nrings = self.user_defined_strategy.zero_configs["nrings"] + + optimizer = self.inner_opt + if self.user_defined_strategy.zero_configs["amp"]: + optimizer = fluid.contrib.mixed_precision.decorate( + optimizer, use_dynamic_loss_scaling=True) + + optimize_ops, params_grads = optimizer.minimize( + loss, startup_program, parameter_list, no_grad_set) + + if startup_program is None: + startup_program = default_startup_program() + + print("work idx: ", self.role_maker.worker_index()) + endpoints = self.role_maker.get_trainer_endpoints() + current_endpoint = endpoints[self.role_maker.worker_index()] + + collective_helper = CollectiveHelper(self.role_maker, self._nrings) + for ring_id in range(self._nrings): + collective_helper._init_communicator( + startup_program, current_endpoint, endpoints, + self.role_maker.worker_index(), ring_id, '6174') + main_block = loss.block + startup_block = startup_program.global_block() + self._broadcast_params(startup_block) + + self._insert_scale_loss_grad_ops( + main_block, scale=1.0 / self.role_maker.worker_num()) + self._insert_allreduce_ops_tmp(main_block) + print("insert allreduce done") + return optimize_ops, params_grads + + # def _insert_comm_sync(self, block, insert_idx, var_names): + # for r in range(self._nrings): + # block._insert_op( + # insert_idx, + # type='c_sync_comm_stream', + # inputs={'X': var_names}, + # outputs={'Out': var_names}, + # attrs={'ring_id': r, + # OP_ROLE_KEY: OpRole.Backward}) + # insert_idx += 1 + # return self._nrings + + # def _insert_broadcast_inner(self, block, insert_idx, broadcast_attrs): + # for attr in broadcast_attrs: + # block._insert_op(insert_idx, **attr) + # insert_idx += 1 + # return len(broadcast_attrs) + + # def _insert_fill_constant(self, block, insert_idx, fill_constant_attrs, + # var_names): + # for attr in fill_constant_attrs: + # block._insert_op(insert_idx, **attr) + # insert_idx += 1 + # block._insert_op( + # insert_idx, + # type='c_sync_calc_stream', + # inputs={'X': var_names}, + # outputs={'Out': var_names}, + # attrs={OP_ROLE_KEY: OpRole.Backward}) + # return len(fill_constant_attrs) + 1 diff --git a/python/paddle/fluid/clip.py b/python/paddle/fluid/clip.py index 7b301ac19d..0ed3fb46a2 100644 --- a/python/paddle/fluid/clip.py +++ b/python/paddle/fluid/clip.py @@ -847,7 +847,7 @@ def append_gradient_clip_ops(param_grads): if g is None: continue with p.block.program._optimized_guard( - [p, g]), framework.name_scope('graident_clip_@CLIP'): + [p, g]), framework.name_scope('gradient_clip_@CLIP'): param, new_grad = clip_attr._create_operators(param=p, grad=g) param_new_grad_name_dict[param.name] = new_grad.name res.append([param, new_grad]) diff --git a/python/paddle/fluid/contrib/mixed_precision/decorator.py b/python/paddle/fluid/contrib/mixed_precision/decorator.py index c9112ac849..b8baabaf74 100644 --- a/python/paddle/fluid/contrib/mixed_precision/decorator.py +++ b/python/paddle/fluid/contrib/mixed_precision/decorator.py @@ -16,6 +16,7 @@ from ... import default_main_program from ... import default_startup_program from ... import layers from ... import unique_name +from ... import framework from . import fp16_utils from .fp16_utils import rewrite_program from .fp16_utils import update_role_var_grad @@ -132,7 +133,8 @@ class OptimizerWithMixedPrecision(object): gradient respectively, and the scaled loss. """ rewrite_program(self._train_program, self._amp_lists) - self._scaled_loss = loss * self._loss_scaling + with framework.name_scope('mixed_precision'): + self._scaled_loss = loss * self._loss_scaling self._params_grads = self._optimizer.backward( self._scaled_loss, startup_program, parameter_list, no_grad_set, callbacks) @@ -156,22 +158,24 @@ class OptimizerWithMixedPrecision(object): grads = [g for _, g in params_grads] with self._train_program._optimized_guard(grads): - grads, found_inf = check_finite_and_unscale( - grads, self._loss_scaling, name="find_infinite_scale") + with framework.name_scope('mixed_precision'): + grads, found_inf = check_finite_and_unscale( + grads, self._loss_scaling, name="find_infinite_scale") if self._use_dynamic_loss_scaling: with self._train_program._optimized_guard(grads): - grads = update_loss_scaling( - grads, - found_inf, - self._loss_scaling, - self._num_good_steps, - self._num_bad_steps, - self._incr_every_n_steps, - self._decr_every_n_nan_or_inf, - self._incr_ratio, - self._decr_ratio, - name="update_loss_scaling") + with framework.name_scope('mixed_precision'): + grads = update_loss_scaling( + grads, + found_inf, + self._loss_scaling, + self._num_good_steps, + self._num_bad_steps, + self._incr_every_n_steps, + self._decr_every_n_nan_or_inf, + self._incr_ratio, + self._decr_ratio, + name="update_loss_scaling") params_unscaled_grads = [] for pg, new_g in zip(params_grads, grads): diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 797b32f5d4..5b29cf873d 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -2063,10 +2063,16 @@ class Operator(object): % (out_proto.name, len(out_args))) out_arg_names = [] for arg in out_args: - out_arg_names.append(cpt.to_text(arg.name)) + if isinstance(arg, six.string_types): + out_arg_names.append(arg) + else: + out_arg_names.append(cpt.to_text(arg.name)) # TODO(minqiyang): could we remove variable's op in static mode? if not in_dygraph_mode(): - arg.op = self + if isinstance(arg, six.string_types): + block.var(arg).op = self + else: + arg.op = self self.desc.set_output(out_proto.name, out_arg_names) if op_attrs is not None: @@ -2801,7 +2807,6 @@ class Block(object): return var def _remove_var(self, name): - self._sync_with_cpp() self.desc._remove_var(cpt.to_bytes(name)) del self.vars[name] @@ -2893,7 +2898,6 @@ class Block(object): Returns: Operator: the insert Operator. """ - self._sync_with_cpp() op_desc = self.desc._insert_op(index) op = Operator(block=self, desc=op_desc, *args, **kwargs) self.ops.insert(index, op) @@ -2909,7 +2913,6 @@ class Block(object): Returns: None """ - self._sync_with_cpp() self.desc._remove_op(index, index + 1) del self.ops[index] -- GitLab