From 3255fe69bba8c89f158564388577ed90a45d3d69 Mon Sep 17 00:00:00 2001 From: gongweibao Date: Wed, 30 Oct 2019 20:47:22 +0800 Subject: [PATCH] Add custom black variable name set in amp interface. (#20875) * add custom black varname test=develop * fix dtype test=develop * fix num test=develop * fix ut test=develop * fix coverage test=develop * fix blackvar names test=develop --- .../contrib/mixed_precision/fp16_lists.py | 6 ++- .../contrib/mixed_precision/fp16_utils.py | 53 +++++++++++++------ .../tests/test_image_classification_fp16.py | 3 ++ 3 files changed, 45 insertions(+), 17 deletions(-) diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py b/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py index 75f90cabfff..0c7e623d469 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py @@ -29,12 +29,16 @@ class AutoMixedPrecisionLists(object): custom_black_list (set): Users' custom black list. """ - def __init__(self, custom_white_list=None, custom_black_list=None): + def __init__(self, + custom_white_list=None, + custom_black_list=None, + custom_black_varnames=None): self._custom_white_list = custom_white_list self._custom_black_list = custom_black_list self.white_list = copy.copy(white_list) self.black_list = copy.copy(black_list) self.gray_list = copy.copy(gray_list) + self.black_varnames = copy.copy(custom_black_varnames) self._update_list() def _update_list(self): diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py index 05dfe273035..1a4eae3e610 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py @@ -85,6 +85,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): core.VarDesc.VarType.LOD_TENSOR, core.VarDesc.VarType.SELECTED_ROWS, core.VarDesc.VarType.LOD_TENSOR_ARRAY ] + for in_name in op.input_names: if src_dtype == core.VarDesc.VarType.FP32 and op.type == 'batch_norm': if in_name != 'X': @@ -94,22 +95,25 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): if in_var.type not in valid_types: continue if in_var.dtype == src_dtype: - out_var = block.create_var( - name=in_var.name + \ - '.cast_' + _dtype_to_str(dest_dtype), - dtype=dest_dtype, - persistable=False, - stop_gradient=False) - block._insert_op( - idx, - type="cast", - inputs={"X": in_var}, - outputs={"Out": out_var}, - attrs={ - "in_dtype": in_var.dtype, - "out_dtype": out_var.dtype - }) - num_cast_ops += 1 + cast_name = in_var.name + '.cast_' + _dtype_to_str(dest_dtype) + out_var = block.vars.get(cast_name) + if out_var is None or out_var.dtype != dest_dtype: + out_var = block.create_var( + name=cast_name, + dtype=dest_dtype, + persistable=False, + stop_gradient=False) + + block._insert_op( + idx, + type="cast", + inputs={"X": in_var}, + outputs={"Out": out_var}, + attrs={ + "in_dtype": in_var.dtype, + "out_dtype": out_var.dtype + }) + num_cast_ops += 1 _rename_arg(op, in_var.name, out_var.name) else: if op.has_attr('in_dtype'): @@ -155,6 +159,18 @@ def find_true_prev_op(ops, cur_op, var_name): return None +def _is_in_black_varnames(op, amp_lists): + for in_name in op.input_arg_names: + if in_name in amp_lists.black_varnames: + return True + + for out_name in op.output_arg_names: + if out_name in amp_lists.black_varnames: + return True + + return False + + def rewrite_program(main_prog, amp_lists): """ Traverse all ops in current block and insert cast op according to @@ -180,6 +196,11 @@ def rewrite_program(main_prog, amp_lists): white_op_set = set() black_op_set = set() for op in ops: + if amp_lists.black_varnames is not None and _is_in_black_varnames( + op, amp_lists): + black_op_set.add(op) + continue + if op.type in amp_lists.black_list: black_op_set.add(op) elif op.type in amp_lists.white_list: diff --git a/python/paddle/fluid/contrib/tests/test_image_classification_fp16.py b/python/paddle/fluid/contrib/tests/test_image_classification_fp16.py index 2e3e364df30..918544e1c99 100644 --- a/python/paddle/fluid/contrib/tests/test_image_classification_fp16.py +++ b/python/paddle/fluid/contrib/tests/test_image_classification_fp16.py @@ -134,8 +134,11 @@ def train(net_type, use_cuda, save_dirname, is_local): optimizer = fluid.optimizer.Lamb(learning_rate=0.001) + amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionLists( + custom_black_varnames={"loss", "conv2d_0.w_0"}) mp_optimizer = fluid.contrib.mixed_precision.decorate( optimizer=optimizer, + amp_lists=amp_lists, init_loss_scaling=8.0, use_dynamic_loss_scaling=True) -- GitLab