diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py index f9c3a613c4053a79cb467d752b20f6f4ed3ea4ec..67e83a2ec4617c0c59bdb1f92c983e3b5ae471a3 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py @@ -123,7 +123,8 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): outputs={"Out": out_var}, attrs={ "in_dtype": in_var.dtype, - "out_dtype": out_var.dtype + "out_dtype": out_var.dtype, + "op_device": op.attr("op_device") }) num_cast_ops += 1 _rename_arg(op, in_var.name, out_var.name) @@ -171,8 +172,11 @@ def _insert_cast_post_op(block, op, idx, src_dtype, dest_dtype, target_name, type="cast", inputs={"X": target_var}, outputs={"Out": cast_var}, - attrs={"in_dtype": target_var.dtype, - "out_dtype": cast_var.dtype}) + attrs={ + "in_dtype": target_var.dtype, + "out_dtype": cast_var.dtype, + "op_device": op.attr("op_device") + }) num_cast_ops += 1 op_var_rename_map[block.idx][target_var.name] = cast_var.name diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 3c560689e1210fcb312a2311da72c720afb2fe0a..9612c87d870752086ec900632a2604243145b008 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -19,6 +19,7 @@ import six import os import logging from collections import defaultdict +import time import paddle from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table @@ -3759,15 +3760,21 @@ class PipelineOptimizer(object): def __init__(self, optimizer, num_microbatches=1, start_cpu_core_id=0): if framework.in_dygraph_mode(): raise Exception("In dygraph, don't support PipelineOptimizer.") - if not isinstance(optimizer, Optimizer) and not isinstance( - optimizer, paddle.optimizer.Optimizer) and not isinstance( - optimizer, paddle.fluid.contrib.mixed_precision.decorator. - OptimizerWithMixedPrecision): + supported_opt_types = (Optimizer, paddle.fluid.contrib.mixed_precision. + decorator.OptimizerWithMixedPrecision) + if not isinstance(optimizer, supported_opt_types): raise ValueError("The 'optimizer' parameter for " - "PipelineOptimizer must be an instance of " - "Optimizer, but the given type is {}.".format( - type(optimizer))) + "PipelineOptimizer must be an instance of one of " + "{}, but the type is {}.".format( + supported_opt_types, type(optimizer))) + self._optimizer = optimizer + + # Get the original optimizer defined by users, such as SGD + self._origin_optimizer = self._optimizer + while hasattr(self._origin_optimizer, "inner_opt"): + self._origin_optimizer = self._origin_optimizer.inner_opt + assert num_microbatches >= 1, ( "num_microbatches must be a positive value.") self._num_microbatches = num_microbatches @@ -3783,50 +3790,141 @@ class PipelineOptimizer(object): self._param_device_map = None def _create_vars(self, block, ori_block): - # Create vars for block, copied from main_program's global block + # Create vars for block, copied from ori_block used_var_set = set() for op_idx in range(block.desc.op_size()): - op_desc = block.desc.op(op_idx) - vars = op_desc.input_arg_names() + op_desc.output_arg_names() + # Whether to insert allreduce_sum or allreduce_max op? + # For amp and global gradient clip strategies, we should + # get the global infomation, so allreduce op is needed. + should_insert = False + + op = block.ops[op_idx] + # For op process vars on all devices, remove its input + # vars not in this block + reserved_x = [] + + if op.type == 'reduce_any' and self._is_optimize_op(op): + should_insert = True + if op.type == 'concat' and self._is_optimize_op(op): + for input_name in op.desc.input("X"): + if block._find_var_recursive(input_name): + reserved_x.append(input_name) + op.desc.set_input('X', reserved_x) + print('reserved_x:', reserved_x) + if op.type == 'update_loss_scaling': + for input_name in op.desc.input("X"): + if block._find_var_recursive(input_name): + reserved_x.append(input_name) + op.desc.set_input('X', reserved_x) + op.desc.set_output('Out', reserved_x) + if op.type == 'sum' and self._is_gradient_clip_op(op): + for input_name in op.desc.input("X"): + if block._find_var_recursive(input_name): + reserved_x.append(input_name) + op.desc.set_input('X', reserved_x) + should_insert = True + vars = op.desc.input_arg_names() + op.desc.output_arg_names() for var in vars: # a var whose name contains "blocking_queue" # only exists in startup program - if var in used_var_set or "_blocking_queue" in var: - continue + if var in used_var_set or "_blocking_queue" in var: continue used_var_set.add(var) if block._find_var_recursive(str(var)): continue source_var = ori_block._var_recursive(str(var)) if source_var.type == core.VarDesc.VarType.READER: - block.create_var( + dest_var = block.create_var( name=var, type=core.VarDesc.VarType.READER, persistable=source_var.persistable) else: - block._clone_variable(source_var, False) + dest_var = block._clone_variable(source_var, False) + dest_var.stop_gradient = source_var.stop_gradient + + if not should_insert: continue + out_name = op.desc.output_arg_names()[0] + out_var = block.var(out_name) + offset = 0 + if op.type == "reduce_any": + # cast the bool var to int32 to use allreduce op + temp_var_name = unique_name.generate(out_name + "_cast_int32") + temp_var = block.create_var( + name=temp_var_name, shape=[1], dtype="int32") + block._insert_op( + op_idx + 1 + offset, + type='cast', + inputs={'X': out_var}, + outputs={'Out': temp_var}, + attrs={ + 'in_dtype': out_var.dtype, + 'out_dtype': temp_var.dtype, + self._op_role_key: + core.op_proto_and_checker_maker.OpRole.Optimize + }) + offset += 1 + # block._insert_op( + # op_idx + 1 + offset, + # type='c_sync_calc_stream', + # inputs={'X': temp_var if op.type == "reduce_any" else out_var}, + # outputs={ + # 'Out': temp_var if op.type == "reduce_any" else out_var + # }, + # attrs={ + # OP_ROLE_KEY: + # core.op_proto_and_checker_maker.OpRole.Optimize, + # }) + # offset += 1 + block._insert_op( + op_idx + 1 + offset, + type='c_allreduce_max' + if op.type == "reduce_any" else 'c_allreduce_sum', + inputs={'X': temp_var if op.type == "reduce_any" else out_var}, + outputs={ + 'Out': temp_var if op.type == "reduce_any" else out_var + }, + attrs={ + 'ring_id': self.ring_id, + self._op_role_key: + core.op_proto_and_checker_maker.OpRole.Optimize, + 'use_calc_stream': True + }) + offset += 1 + # block._insert_op( + # # op_idx + 1 + extra_index, + # op_idx + 1 + offset, + # type='c_sync_comm_stream', + # inputs={'X': temp_var if op.type == "reduce_any" else out_var}, + # outputs={ + # 'Out': temp_var if op.type == "reduce_any" else out_var + # }, + # attrs={ + # 'ring_id': self.ring_id, + # OP_ROLE_KEY: + # core.op_proto_and_checker_maker.OpRole.Optimize, + # }) + # offset += 1 + if op.type == "reduce_any": + block._insert_op( + op_idx + 1 + offset, + type='cast', + inputs={'X': temp_var}, + outputs={'Out': out_var}, + attrs={ + 'in_dtype': temp_var.dtype, + 'out_dtype': out_var.dtype, + self._op_role_key: + core.op_proto_and_checker_maker.OpRole.Optimize + }) def _is_loss_grad_op(self, op): - if self._op_role_key not in op.attr_names: - return False - op_role = int(op.all_attrs()[self._op_role_key]) + assert self._op_role_key in op.attr_names + op_role = int(op.attr(self._op_role_key)) return op_role & int(self._op_role.Backward) and op_role & int( self._op_role.Loss) - def _is_backward_op(self, op): - return self._op_role_key in op.attr_names and int(op.all_attrs()[ - self._op_role_key]) & int(self._op_role.Backward) - - def _is_optimize_op(self, op): - return self._op_role_key in op.attr_names and int(op.all_attrs()[ - self._op_role_key]) & int(self._op_role.Optimize) - - def _is_update_op(self, op): - return 'Param' in op.input_names and 'Grad' in op.input_names and ( - "LearningRate" in op.input_names) - def _split_program(self, main_program, devices): """ Split a program into sections according to devices that ops run on. - The ops of the role LRSched are copied to all sections. + The op whose op_device attr is "gpu:all" is copied to all sections. Args: main_program (Program): the main program @@ -3842,27 +3940,20 @@ class PipelineOptimizer(object): block = main_program.block(0) for op in block.ops: device = op.attr(self._op_device_key) - op_role = op.attr(self._op_role_key) - if int(op_role) & int(self._op_role.LRSched): - # Copy ops of the role LRSched to all sections. - for device in device_program_map.keys(): - program = device_program_map[device] - op_desc = op.desc - ap_op = program["program"].block(0).desc.append_op() - ap_op.copy_from(op_desc) - # ap_op._set_attr(self._op_device_key, "") - elif op.type == "create_py_reader" or op.type == "read" or op.type == "create_double_buffer_reader": - # Copy read related ops to all section to make them exit after each epoch. + # Copy ops whose op_device set to "gpu:all" to all sections. + if device == "gpu:all": for device in device_program_map.keys(): program = device_program_map[device] op_desc = op.desc ap_op = program["program"].block(0).desc.append_op() ap_op.copy_from(op_desc) + ap_op._set_attr(self._op_device_key, "") else: program = device_program_map[device] op_desc = op.desc ap_op = program["program"].block(0).desc.append_op() ap_op.copy_from(op_desc) + ap_op._set_attr(self._op_device_key, "") for key in devices: program = device_program_map[key] @@ -3921,6 +4012,11 @@ class PipelineOptimizer(object): var_name as output. var_name (string): Variable name. """ + # To skip the cast op added by amp which has no op_device set + if '.cast_fp32' in var_name: + var_name = var_name.replace('.cast_fp32', '') + if '.cast_fp16' in var_name: + var_name = var_name.replace('.cast_fp16', '') post_op = [] before = True for op in ops: @@ -3982,9 +4078,10 @@ class PipelineOptimizer(object): dtype=ref_var.dtype, type=ref_var.type, lod_level=ref_var.lod_level, - persistable=False, - is_data=False, + persistable=ref_var.persistable, + is_data=ref_var.is_data, need_check_feed=ref_var.desc.need_check_feed()) + new_var.stop_gradient = ref_var.stop_gradient return new_var def _get_data_var_info(self, block): @@ -4046,6 +4143,7 @@ class PipelineOptimizer(object): self._op_role_key: self._op_role.Forward, 'use_calc_stream': True, 'peer': dev_index, + 'ring_id': self.ring_id, }) # Get the device that that data on assert device in devices @@ -4070,6 +4168,7 @@ class PipelineOptimizer(object): self._op_role_key: self._op_role.Forward, 'peer': first_dev_index, 'use_calc_stream': True, + 'ring_id': self.ring_id, }) def _strip_grad_suffix(self, name): @@ -4085,79 +4184,178 @@ class PipelineOptimizer(object): """ return name + core.grad_var_suffix() - def _add_opdevice_attr_for_regularization_clip(self, block): + def _is_forward_op(self, op): """ - Add op_device attribute for regulization and clip ops. + Is the op_role attribute of a op is Forward. """ - for op in block.ops: - # role for regularization and clip ops is optimize - if int(op.attr(self._op_role_key)) != int(self._op_role.Optimize): - continue - if op.has_attr(self._op_device_key) and ( - op.attr(self._op_device_key) != ""): - continue - assert self._op_role_var_key in op.attr_names - op_role_var = op.all_attrs()[self._op_role_var_key] - assert len(op_role_var) == 2 - param_name = op_role_var[0] - device = self._param_device_map[param_name] - op._set_attr(self._op_device_key, device) + assert self._op_role_key in op.attr_names + return int(op.attr(self._op_role_key)) == int(self._op_role.Forward) - def _add_default_opdevice_attr(self, block): + def _is_backward_op(self, op): """ - 1. Add default op_device attribute for lr-related ops. - The default value is the one that of the first place. - 2. Add default op_device attribute for sum ops added during - backward. For these ops, we set the op_device attribute - as the one of its post op, i.e, which op has the output of the - sum op as an input. + Is the op_role attribute of a op is Backward. """ - first_devcie = "" + assert self._op_role_key in op.attr_names + return int(op.attr(self._op_role_key)) == int(self._op_role.Backward) - # Get the device spec of the first place. - # device_spec: 'cpu' for cpu device and 'gpu:id' for gpu device, - # e.g. 'gpu:0', 'gpu:1', etc. - for op in block.ops: - if op.has_attr(self._op_device_key) and ( - op.attr(self._op_device_key) != ""): - first_device = op.attr(self._op_device_key) - break - assert first_device - first_device_type = first_device.split(":")[0] - assert first_device_type == "gpu" + def _is_loss_op(self, op): + """ + Is the op_role attribute of a op is Loss. + """ + assert self._op_role_key in op.attr_names + return int(op.attr(self._op_role_key)) == int(self._op_role.Loss) - # set op_device attr for lr-related ops - lrsched_role = int(self._op_role.LRSched) - for op in block.ops: - if not op.has_attr(self._op_device_key) or ( - op.attr(self._op_device_key) == ""): - if op.type == "sum": - # For sum ops that compute the sum of @RENAMED@ vars - for name in op.desc.input_arg_names(): - assert '@RENAME@' in name - assert len(op.desc.output_arg_names()) == 1 - out_name = op.desc.output_arg_names()[0] - post_op = self._find_post_op(block.ops, op, out_name) - device = post_op.attr(self._op_device_key) - assert device - op._set_attr(self._op_device_key, device) - continue + def _is_optimize_op(self, op): + """ + Is the op_role attribute of a op is Optimize. + """ + assert self._op_role_key in op.attr_names + return int(op.attr(self._op_role_key)) == int(self._op_role.Optimize) + + def _is_lrsched_op(self, op): + """ + Is the op_role attribute of a op is LRSched. + """ + assert self._op_role_key in op.attr_names + return int(op.attr(self._op_role_key)) == int(self._op_role.LRSched) + + def _is_update_op(self, op): + """ + Is the op updates the parameter using gradient. + """ + return 'Param' in op.input_names and 'Grad' in op.input_names and ( + "LearningRate" in op.input_names) + + def _get_op_device_attr(self, op): + """ + Get the op_device attribute of a op. + """ + device = op.attr(self._op_device_key) \ + if op.has_attr(self._op_device_key) else None + if device: + assert device[0:3] == 'gpu', "Now, only gpu devices are " \ + "supported in pipeline parallemism." + return device + + def _add_op_device_attr_for_op(self, op, idx, block): + """ + Add op_device attrribute for ops that have not that attribute set. - assert op.attr(self._op_role_key) == lrsched_role, ( - "Op whose op_device attr has not been set for pipeline" - " must be of the role LRSched.") - op._set_attr(self._op_device_key, first_device) + We use "gpu:all" to represent the op should be put on all + sub-programs, such as lr-related ops. Note that: "gpu:all" + is only used by pipeline as an indicator. + """ + lrsched_role = int(self._op_role.LRSched) + if op.attr(self._op_role_key) == lrsched_role: + # For LRSched ops, we should put them on all sub-programs to + # make sure each sub-program update the lr correctly + op._set_attr(self._op_device_key, "gpu:all") + elif op.type == "sum" and self._is_backward_op(op): + # For sum ops that compute the sum of @RENAMED@ vars + for name in op.desc.input_arg_names(): + assert '@RENAME@' in name, \ + "The op must be sum used to accumulate renamed vars." + assert len(op.desc.output_arg_names()) == 1 + out_name = op.desc.output_arg_names()[0] + post_op = self._find_post_op(block.ops, op, out_name) + assert post_op.has_attr( + 'op_device'), "{} has no op_device attr for var {}".format( + post_op.type, out_name) + device = post_op.attr(self._op_device_key) + assert device, "The post op must have op_device set." + op._set_attr(self._op_device_key, device) + elif (op.type == "cast" or + op.type == "scale") and self._is_backward_op(op): + prev_op = self._find_real_prev_op(block.ops, op, + op.desc.input("X")[0]) + op._set_attr('op_device', prev_op.attr('op_device')) + elif self._is_loss_op(op): + # For loss * loss_scaling op added by AMP + offset = 1 + while (not block.ops[idx + offset].has_attr(self._op_device_key) or + not block.ops[idx + offset].attr(self._op_device_key)): + offset += 1 + # assert block.ops[idx + 1].type == "fill_constant" + # assert block.ops[idx + 2].type == "elementwise_mul_grad" + # assert block.ops[idx + 3].type == "elementwise_add_grad" + # assert block.ops[idx + 4].type == "mean_grad" + # device = block.ops[idx + 4].attr(self._op_device_key) + device = block.ops[idx + offset].attr(self._op_device_key) + assert device, "Please put you program within device_guard scope." + # op._set_attr(self._op_device_key, device) + # block.ops[idx + 1]._set_attr(self._op_device_key, device) + # block.ops[idx + 2]._set_attr(self._op_device_key, device) + # block.ops[idx + 2]._set_attr(self._op_device_key, device) + for i in range(offset): + block.ops[idx + i]._set_attr(self._op_device_key, device) + elif self._is_optimize_op(op) and op.type == "check_finite_and_unscale": + #op._set_attr(self._op_device_key, "gpu:all") + op_role_var = op.attr(self._op_role_var_key) + param_name = op_role_var[0] + device = self._param_device_map[param_name] + op._set_attr(self._op_device_key, device) + elif self._is_optimize_op(op) and op.type == "cast": + # For fp16-->fp32 cast added by AMP + grad_name = op.output('Out') + assert len(grad_name) == 1 + param_name = grad_name[0].strip(core.grad_var_suffix()) + device = self._param_device_map[param_name] + op._set_attr(self._op_device_key, device) + elif self._is_gradient_clip_op(op) or self._is_regularization_op(op): + # For gradient clip and regularization ops, we set their op_device + # attribute to the device where their corresponding parameters on. + assert self._op_role_var_key in op.attr_names, "gradient_clip " \ + "and regularization ops must have op_role_var attribute." + op_role_var = op.attr(self._op_role_var_key) + assert len(op_role_var) == 2, "op_role_var for gradient_clip " \ + "regularization ops must have two elements." + param_name = op_role_var[0] + device = self._param_device_map[param_name] + # For sum op added by global gradient clip, it must be + # put on all devices + if (op.type == 'sum' or op.type == 'sqrt' or + op.type == 'fill_constant' or + op.type == 'elementwise_max' or + op.type == 'elementwise_div'): + device = "gpu:all" + op._set_attr(self._op_device_key, device) + else: + other_known_ops = [ + 'update_loss_scaling', 'reduce_any', 'concat', 'sum' + ] + assert op.type in other_known_ops, "For other ops without " \ + "op_device set, they must be one of {}, but it " \ + "is {}".format(other_known_ops, op.type) + assert self._is_optimize_op(op) + op._set_attr(self._op_device_key, "gpu:all") + + def _add_op_device_attr(self, block): + """ + Add op_device attrribute for ops in block that have + not that attribute set. + """ + for idx, op in enumerate(list(block.ops)): + if (op.type == "create_py_reader" or op.type == "read" or + op.type == "create_double_buffer_reader"): + # Copy read related ops to all section to make them exit + # after each epoch. + # We use "gpu:all" to represent the op should be put on all + # sub-programs, such as lr-related ops. Note that: "gpu:all" + # is only used by pipeline as an indicator. + op._set_attr(self._op_device_key, "gpu:all") + continue + # op_device attribute has been set + if self._get_op_device_attr(op): continue + self._add_op_device_attr_for_op(op, idx, block) def _check_validation(self, block): """ - Check whether ops in a block are all validate (i.e., the - op_device attribute has been set). - Then, return all device specifications in order. + Check whether ops in a block have the op_device attribute set. + Then, return all devices in order. """ - device_specs = [] + device_list = [] for op in block.ops: - type = op.type - if not op._has_kernel(type): + if not op._has_kernel(op.type): assert op.type == "conditional_block" and ( op.attr(self._op_role_key) == int(self._op_role.LRSched)), ( "Now, the only supported op without kernel is " @@ -4165,15 +4363,16 @@ class PipelineOptimizer(object): assert op.has_attr(self._op_device_key), ( "op ({}) has no {} attribute.".format(op.type, self._op_device_key)) - dev_spec = op.attr(self._op_device_key) - assert dev_spec, ("op_device attribute for op " - "{} has not been set.".format(op.type)) - dev_type = dev_spec.split(':')[0] + device = op.attr(self._op_device_key) + assert device, ("op_device attribute for op " + "{} has not been set.".format(op.type)) + if device == "gpu:all": continue + dev_type = device.split(':')[0] assert dev_type == "gpu", ("Now only gpu devices are supported " "for pipeline parallelism.") - if not dev_spec in device_specs: - device_specs.append(dev_spec) - return device_specs + if not device in device_list: + device_list.append(device) + return device_list def _insert_sendrecv_ops_for_boundaries(self, block): """ @@ -4182,49 +4381,44 @@ class PipelineOptimizer(object): """ extra_index = 0 - # A map from var to device spec where op takes it as input, + # A map from var to device where op takes it as input, # avoiding multiple send and recv ops. - var_devspec = dict() + var_dev_map = dict() for index, op in enumerate(list(block.ops)): - # skips lr-related ops and vars, as we will process them later. - if int(op.attr(self._op_role_key)) & int(self._op_role.LRSched): - continue - # skips update ops and vars, as we will process them later. - if self._is_update_op(op): continue - - cur_device_spec = op.attr(self._op_device_key) + cur_device = op.attr(self._op_device_key) + if cur_device == "gpu:all": continue for var_name in op.input_arg_names: # i.e., lod_tensor_blocking_queue created by DataLoader, # which only exists in startup program. - if not var_name in block.vars: continue + # if not var_name in block.vars: continue var = block.var(var_name) # skip data, because we will process it later if var.is_data: continue prev_op = self._find_real_prev_op(block.ops, op, var_name) - if prev_op is None: - continue - prev_device_spec = prev_op.attr(self._op_device_key) + prev_device = prev_op.attr(self._op_device_key) \ + if prev_op else None + if not prev_device or prev_device == 'gpu:all': continue - if prev_device_spec != cur_device_spec: - if var_name not in var_devspec: - var_devspec[var_name] = [] - if cur_device_spec in var_devspec[var_name]: continue - var_devspec[var_name].append(cur_device_spec) + if prev_device != cur_device: + if var_name not in var_dev_map: var_dev_map[var_name] = [] + if cur_device in var_dev_map[var_name]: continue + var_dev_map[var_name].append(cur_device) op_role = op.all_attrs()[self._op_role_key] var = block.vars[var_name] - prev_device_index = int(prev_device_spec.split(':')[1]) - cur_device_index = int(cur_device_spec.split(':')[1]) + prev_device_index = int(prev_device.split(':')[1]) + cur_device_index = int(cur_device.split(':')[1]) block._insert_op( index=index + extra_index, type='send_v2', inputs={'X': var}, attrs={ - self._op_device_key: prev_device_spec, + self._op_device_key: prev_device, self._op_role_key: op_role, 'use_calc_stream': True, 'peer': cur_device_index, + 'ring_id': self.ring_id, }) extra_index += 1 block._insert_op( @@ -4234,23 +4428,28 @@ class PipelineOptimizer(object): attrs={ 'out_shape': var.shape, 'dtype': var.dtype, - self._op_device_key: cur_device_spec, + self._op_device_key: cur_device, self._op_role_key: op_role, 'use_calc_stream': True, 'peer': prev_device_index, + 'ring_id': self.ring_id, }) extra_index += 1 - def _clear_gradients(self, main_block, dev_spec): + def _clear_gradients(self, main_block, param_names): """ Clear gradients at the begining of each run of a minibatch. """ - for param_name in self._param_device_map: - device = self._param_device_map[param_name] - if device != dev_spec: continue + # for param_name in self._param_device_map: + print("param_names:", param_names) + for param_name in param_names: + # device = self._param_device_map[param_name] + # if device != dev_spec: continue grad_name = self._append_grad_suffix(param_name) - if not main_block.has_var(grad_name): continue - grad_var = main_block.vars[grad_name] + # if not main_block.has_var(grad_name): continue + assert main_block.has_var(grad_name) + grad_var = main_block.var(grad_name) + grad_var.persistable = True main_block._insert_op( index=0, type='fill_constant', @@ -4260,21 +4459,20 @@ class PipelineOptimizer(object): 'shape': grad_var.shape, 'dtype': grad_var.dtype, 'value': float(0), - self._op_device_key: device, + # self._op_device_key: device, # a trick to run this op once per mini-batch self._op_role_key: self._op_role.Optimize.LRSched, }) - def _accumulate_gradients(self, block): + def _insert_loss_scale(self, block): """ - Accumulate the gradients generated in microbatch to the one in mini-batch. We also scale the loss corresponding to number of micro-batches as well. """ + if self._num_microbatches == 1: return for index, op in reversed(tuple(enumerate(list(block.ops)))): offset = index - device = op.attr(self._op_device_key) + #device = op.attr(self._op_device_key) - # Backward pass if self._is_loss_grad_op(op): loss_grad_var = block.vars[op.output_arg_names[0]] scale_factor = self._num_microbatches @@ -4285,36 +4483,151 @@ class PipelineOptimizer(object): outputs={'Out': loss_grad_var}, attrs={ 'scale': 1.0 / scale_factor, - self._op_device_key: device, + #self._op_device_key: device, self._op_role_key: self._op_role.Backward }) break + + def _rename_gradient_var_name(self, block): + for index, op in enumerate(block.ops): if self._is_backward_op(op) and ( self._op_role_var_key in op.attr_names): - op_role_var = op.all_attrs()[self._op_role_var_key] + op_role_var = op.attr(self._op_role_var_key) if len(op_role_var) == 0: continue - assert len(op_role_var) % 2 == 0 - offset = index for i in range(0, len(op_role_var), 2): grad_name = op_role_var[i + 1] grad_var = block.vars[grad_name] new_grad_var_name = unique_name.generate(grad_name) new_var = self._create_var(block, grad_var, new_grad_var_name) + new_var.persistable = False self._rename_arg(op, grad_name, new_grad_var_name) + + def _accumulate_gradients(self, block): + """ + Accumulate the gradients generated in microbatch to the one in mini-batch. + """ + first_optimize_op_index = None + for index, op in reversed(tuple(enumerate(list(block.ops)))): + # device = op.attr(self._op_device_key) + if not self._is_optimize_op(op) and not first_optimize_op_index: + first_optimize_op_index = index + 1 + if block.ops[ + first_optimize_op_index].type == 'c_sync_comm_stream': + block.ops[first_optimize_op_index]._set_attr( + self._op_role_key, self._op_role.Backward) + first_optimize_op_index += 1 + + if self._is_backward_op(op) and ( + self._op_role_var_key in op.attr_names): + op_role_var = op.attr(self._op_role_var_key) + + if len(op_role_var) == 0: + continue + assert len(op_role_var) % 2 == 0 + for i in range(0, len(op_role_var), 2): + offset = 0 + param_name = op_role_var[i] + if not block.has_var(param_name): continue + # clear gradient + param_grad_name = self._append_grad_suffix(param_name) + # if not main_block.has_var(grad_name): continue + if not block.has_var(param_grad_name): + self._create_var(block, block.vars[param_name], + param_grad_name) + assert block.has_var(param_grad_name) + param_grad_var = block.var(param_grad_name) + param_grad_var.persistable = True block._insert_op( - index=offset + 1, - type='sum', - inputs={'X': [grad_var, new_var]}, - outputs={'Out': grad_var}, + index=first_optimize_op_index + offset, + type='fill_constant', + inputs={}, + outputs={'Out': [param_grad_var]}, attrs={ - self._op_device_key: device, - self._op_role_key: self._op_role.Backward, - self._op_role_var_key: op_role_var + 'shape': param_grad_var.shape, + 'dtype': param_grad_var.dtype, + 'value': float(0), + # self._op_device_key: device, + # a trick to run this op once per mini-batch + self._op_role_key: self._op_role.Optimize.LRSched, }) offset += 1 + grad_name = op_role_var[i + 1] # with _0 suffix + grad_var = block.vars[grad_name] # without _0 suffix + real_grad_name = grad_name[0:grad_name.find( + '@GRAD')] + '@GRAD' + real_grad_var = block.vars[ + real_grad_name] # without _0 suffix + # new_grad_var_name = unique_name.generate(grad_name) + # new_var = self._create_var(block, grad_var, + # new_grad_var_name) + # new_var.persistable = False + # self._rename_arg(op, grad_name, new_grad_var_name) + if not 'cast_fp16' in grad_name: + block._insert_op( + index=first_optimize_op_index + offset, + type='sum', + inputs={'X': [grad_var, real_grad_var]}, + outputs={'Out': real_grad_var}, + attrs={ + #self._op_device_key: device, + self._op_role_key: self._op_role.Backward, + #self._op_role_var_key: op_role_var + }) + offset += 1 + else: + grad_name = op_role_var[i + 1] # with _0 suffix + grad_var = block.vars[grad_name] # without _0 suffix + fp32_grad_var_name = param_name + core.grad_var_suffix() + fp32_grad_var = block.vars[fp32_grad_var_name] + fp32_grad_var.persistable = True + cast_grad_var_name = unique_name.generate( + fp32_grad_var_name) + cast_var = self._create_var(block, grad_var, + cast_grad_var_name) + cast_var.persistable = False + real_grad_name = grad_name[0:grad_name.find( + '@GRAD')] + '@GRAD' + real_grad_var = block.vars[ + real_grad_name] # without _0 suffix + block._insert_op( + index=first_optimize_op_index + offset, + type='cast', + inputs={'X': fp32_grad_var}, + outputs={'Out': cast_var}, + attrs={ + 'in_dtype': fp32_grad_var.dtype, + 'out_dtype': cast_var.dtype, + # self._op_device_key: device, + self._op_role_key: self._op_role.Backward, + # self._op_role_var_key: op_role_var + }) + offset += 1 + block._insert_op( + index=first_optimize_op_index + offset, + type='sum', + inputs={'X': [grad_var, cast_var]}, + outputs={'Out': real_grad_var}, + attrs={ + # self._op_device_key: device, + self._op_role_key: self._op_role.Backward, + # self._op_role_var_key: op_role_var + }) + offset += 1 + block._insert_op( + index=first_optimize_op_index + offset, + type='cast', + inputs={'X': real_grad_var}, + outputs={'Out': fp32_grad_var}, + attrs={ + 'in_dtype': real_grad_var.dtype, + 'out_dtype': fp32_grad_var.dtype, + # self._op_device_key: device, + self._op_role_key: self._op_role.Backward, + # self._op_role_var_key: op_role_var + }) def _add_sub_blocks(self, main_block, program_list): main_program = main_block.program @@ -4372,7 +4685,7 @@ class PipelineOptimizer(object): block = prog.block(0) for op in block.ops: if op.type == "recv_v2" or op.type == "create_py_reader" or \ - op.type == "read": + op.type == "read" or op.type == "update_loss_scaling": continue # We have processed lr related vars if op.attr(self._op_role_key) == int( @@ -4412,6 +4725,7 @@ class PipelineOptimizer(object): # microbatch self._op_role_key: self._op_role.LRSched, 'peer': read_dev_index, + 'ring_id': self.ring_id, }) read_block._insert_op( index=0, @@ -4425,9 +4739,18 @@ class PipelineOptimizer(object): # A trick to make the role LRSched to avoid copy every # microbatch self._op_role_key: self._op_role.LRSched, - 'peer': write_dev_index + 'peer': write_dev_index, + 'ring_id': self.ring_id, }) + def _is_gradient_clip_op(self, op): + return op.desc.has_attr("op_namescope") \ + and op.desc.attr("op_namescope").startswith("/gradient_clip") + + def _is_regularization_op(self, op): + return op.desc.has_attr("op_namescope") \ + and op.desc.attr("op_namescope").startswith("/regularization") + def minimize(self, loss, startup_program=None, @@ -4438,17 +4761,29 @@ class PipelineOptimizer(object): startup_program = default_startup_program() optimize_ops, params_grads = self._optimizer.minimize( loss, startup_program, parameter_list, no_grad_set) - self._param_device_map = self._optimizer._param_device_map - - # Step1: add default op_device attribute for regulization and clip ops - self._add_opdevice_attr_for_regularization_clip(main_block) - - # Step2: add default op_device attribute for ops whose op_device - # attribute have not been set yet. Then check all ops have the - # op_device attribute. - self._add_default_opdevice_attr(main_block) - - device_specs = self._check_validation(main_block) + self._param_device_map = self._origin_optimizer._param_device_map + assert main_block.program._pipeline_opt \ + and 'local_rank' in main_block.program._pipeline_opt, \ + 'Please use pipeline with fleet.' + local_rank = main_block.program._pipeline_opt['local_rank'] + + self.use_sharding = False + if 'use_sharding' in main_block.program._pipeline_opt: + self.use_sharding = main_block.program._pipeline_opt['use_sharding'] + + self.ring_id = 0 + if 'ring_id' in main_block.program._pipeline_opt: + self.ring_id = main_block.program._pipeline_opt['ring_id'] + + if main_block.program._pipeline_opt['global_rank'] == 0: + with open("startup_raw", 'w') as f: + f.writelines(str(startup_program)) + with open("main_raw", 'w') as f: + f.writelines(str(main_block.program)) + + # Step1: add default op_device attribute for ops. + self._add_op_device_attr(main_block) + device_list = self._check_validation(main_block) def device_cmp(device1, device2): dev1_id = int(device1.split(':')[1]) @@ -4460,66 +4795,62 @@ class PipelineOptimizer(object): else: return 0 - sorted_device_spec = sorted(device_specs, key=cmp_to_key(device_cmp)) - assert sorted_device_spec == device_specs, ( - "With pipeline " - "parallelism, you must use gpu devices one after another " - "in the order of their ids.") + sorted_device_list = sorted(device_list, key=cmp_to_key(device_cmp)) + assert sorted_device_list == device_list, ( + "With pipeline parallelism, you must use gpu devices one after " + "another in the order of their ids.") - # Step3: add send and recv ops between section boundaries + # Step2: add send and recv ops between section boundaries self._insert_sendrecv_ops_for_boundaries(main_block) - # Step4: split program into sections and add pairs of + # Step3: split program into sections and add pairs of # send and recv ops for data var. main_program = main_block.program - program_list = self._split_program(main_program, device_specs) + program_list = self._split_program(main_program, device_list) for p in program_list: - self._create_vars(p["program"].block(0), - main_program.global_block()) + self._create_vars(p["program"].block(0), main_block) self._insert_sendrecv_for_data_var(main_block, program_list, - startup_program, device_specs) + startup_program, device_list) - # Step5: Special Case: process persistable vars that exist in + # Step4: Special Case: process persistable vars that exist in # multiple sections self._process_persistable_vars_in_multi_sections( main_program, startup_program, program_list) - # Step6: Add sub blocks for section programs + # Step5: Add sub blocks for section programs self._add_sub_blocks(main_block, program_list) - assert (main_program._pipeline_opt and - isinstance(main_program._pipeline_opt, dict) and - 'local_rank' in main_program._pipeline_opt), \ - "You must use pipeline with fleet" - local_rank = main_program._pipeline_opt['local_rank'] % len( - device_specs) + local_rank = main_program._pipeline_opt['local_rank'] % len(device_list) place_list = [] - for dev_spec in device_specs: - dev_index = dev_spec.split(":")[1] - place_list.append(core.CUDAPlace(local_rank)) + for dev in device_list: + dev_index = int(dev.split(":")[1]) + place_list.append(core.CUDAPlace(dev_index)) - # Step7: Split startup program + # Step6: Split startup program new_startup_program = self._split_startup_program(startup_program, local_rank) - - # Step8: clear gradients before each mini-batch and - # accumulate gradients during backward - self._clear_gradients( - program_list[local_rank]['program'].global_block(), - dev_spec=device_specs[local_rank]) - self._accumulate_gradients(program_list[local_rank]['program'] - .global_block()) - startup_program._pipeline_opt = { "startup_program": new_startup_program, } + real_block = program_list[local_rank]['program'].global_block() + self._insert_loss_scale(real_block) + if not self.use_sharding: + # Step7: clear gradients before each mini-batch and + # accumulate gradients during backward + param_list = [] + for param, grad in params_grads: + if real_block.has_var(param): param_list.append(param) + #self._clear_gradients(real_block, param_list) + self._rename_gradient_var_name(real_block) + self._accumulate_gradients(real_block) + place_id = int(os.getenv("FLAGS_selected_gpus", "0")) main_program._pipeline_opt = { "trainer": "PipelineTrainer", "device_worker": "Section", - "inner_parallelism": len(device_specs), + "inner_parallelism": len(device_list), "section_program": program_list[local_rank], "place": place_list[local_rank], "place_id": place_id, @@ -5487,7 +5818,7 @@ class GradientMergeOptimizer(object): def _is_the_backward_op(self, op): op_maker = core.op_proto_and_checker_maker - backward = core.op_proto_and_checker_maker.OpRole.Backward + backward = core.op_proto_and_checker_maker.OpRole.Bcackward if op_maker.kOpRoleVarAttrName() in op.attr_names and \ int(op.all_attrs()[op_maker.kOpRoleAttrName()]) == int(backward): return True