diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py b/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py index 75f90cabfff434225e1b111746a28251d57b5b92..0c7e623d469f3a6d2382029e962fcedce9bd1550 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py @@ -29,12 +29,16 @@ class AutoMixedPrecisionLists(object): custom_black_list (set): Users' custom black list. """ - def __init__(self, custom_white_list=None, custom_black_list=None): + def __init__(self, + custom_white_list=None, + custom_black_list=None, + custom_black_varnames=None): self._custom_white_list = custom_white_list self._custom_black_list = custom_black_list self.white_list = copy.copy(white_list) self.black_list = copy.copy(black_list) self.gray_list = copy.copy(gray_list) + self.black_varnames = copy.copy(custom_black_varnames) self._update_list() def _update_list(self): diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py index 05dfe27303505903533d6404de0e6ffe51a661ad..1a4eae3e6100ef7084f3c5715da8706f9d156dcb 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py @@ -85,6 +85,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): core.VarDesc.VarType.LOD_TENSOR, core.VarDesc.VarType.SELECTED_ROWS, core.VarDesc.VarType.LOD_TENSOR_ARRAY ] + for in_name in op.input_names: if src_dtype == core.VarDesc.VarType.FP32 and op.type == 'batch_norm': if in_name != 'X': @@ -94,22 +95,25 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): if in_var.type not in valid_types: continue if in_var.dtype == src_dtype: - out_var = block.create_var( - name=in_var.name + \ - '.cast_' + _dtype_to_str(dest_dtype), - dtype=dest_dtype, - persistable=False, - stop_gradient=False) - block._insert_op( - idx, - type="cast", - inputs={"X": in_var}, - outputs={"Out": out_var}, - attrs={ - "in_dtype": in_var.dtype, - "out_dtype": out_var.dtype - }) - num_cast_ops += 1 + cast_name = in_var.name + '.cast_' + _dtype_to_str(dest_dtype) + out_var = block.vars.get(cast_name) + if out_var is None or out_var.dtype != dest_dtype: + out_var = block.create_var( + name=cast_name, + dtype=dest_dtype, + persistable=False, + stop_gradient=False) + + block._insert_op( + idx, + type="cast", + inputs={"X": in_var}, + outputs={"Out": out_var}, + attrs={ + "in_dtype": in_var.dtype, + "out_dtype": out_var.dtype + }) + num_cast_ops += 1 _rename_arg(op, in_var.name, out_var.name) else: if op.has_attr('in_dtype'): @@ -155,6 +159,18 @@ def find_true_prev_op(ops, cur_op, var_name): return None +def _is_in_black_varnames(op, amp_lists): + for in_name in op.input_arg_names: + if in_name in amp_lists.black_varnames: + return True + + for out_name in op.output_arg_names: + if out_name in amp_lists.black_varnames: + return True + + return False + + def rewrite_program(main_prog, amp_lists): """ Traverse all ops in current block and insert cast op according to @@ -180,6 +196,11 @@ def rewrite_program(main_prog, amp_lists): white_op_set = set() black_op_set = set() for op in ops: + if amp_lists.black_varnames is not None and _is_in_black_varnames( + op, amp_lists): + black_op_set.add(op) + continue + if op.type in amp_lists.black_list: black_op_set.add(op) elif op.type in amp_lists.white_list: diff --git a/python/paddle/fluid/contrib/tests/test_image_classification_fp16.py b/python/paddle/fluid/contrib/tests/test_image_classification_fp16.py index 2e3e364df3071cdcd4d5a54d0bf3619004bddd4a..918544e1c990781df891f3cfe0af1b5bc8c0c92c 100644 --- a/python/paddle/fluid/contrib/tests/test_image_classification_fp16.py +++ b/python/paddle/fluid/contrib/tests/test_image_classification_fp16.py @@ -134,8 +134,11 @@ def train(net_type, use_cuda, save_dirname, is_local): optimizer = fluid.optimizer.Lamb(learning_rate=0.001) + amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionLists( + custom_black_varnames={"loss", "conv2d_0.w_0"}) mp_optimizer = fluid.contrib.mixed_precision.decorate( optimizer=optimizer, + amp_lists=amp_lists, init_loss_scaling=8.0, use_dynamic_loss_scaling=True)