提交 3255fe69 编写于 作者: G gongweibao 提交者: Yi Liu

Add custom black variable name set in amp interface. (#20875)

* add custom black varname test=develop

* fix dtype test=develop

* fix num test=develop

* fix ut test=develop

* fix coverage test=develop

* fix blackvar names test=develop
上级 aadd81b6
...@@ -29,12 +29,16 @@ class AutoMixedPrecisionLists(object): ...@@ -29,12 +29,16 @@ class AutoMixedPrecisionLists(object):
custom_black_list (set): Users' custom black list. custom_black_list (set): Users' custom black list.
""" """
def __init__(self, custom_white_list=None, custom_black_list=None): def __init__(self,
custom_white_list=None,
custom_black_list=None,
custom_black_varnames=None):
self._custom_white_list = custom_white_list self._custom_white_list = custom_white_list
self._custom_black_list = custom_black_list self._custom_black_list = custom_black_list
self.white_list = copy.copy(white_list) self.white_list = copy.copy(white_list)
self.black_list = copy.copy(black_list) self.black_list = copy.copy(black_list)
self.gray_list = copy.copy(gray_list) self.gray_list = copy.copy(gray_list)
self.black_varnames = copy.copy(custom_black_varnames)
self._update_list() self._update_list()
def _update_list(self): def _update_list(self):
......
...@@ -85,6 +85,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): ...@@ -85,6 +85,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
core.VarDesc.VarType.LOD_TENSOR, core.VarDesc.VarType.SELECTED_ROWS, core.VarDesc.VarType.LOD_TENSOR, core.VarDesc.VarType.SELECTED_ROWS,
core.VarDesc.VarType.LOD_TENSOR_ARRAY core.VarDesc.VarType.LOD_TENSOR_ARRAY
] ]
for in_name in op.input_names: for in_name in op.input_names:
if src_dtype == core.VarDesc.VarType.FP32 and op.type == 'batch_norm': if src_dtype == core.VarDesc.VarType.FP32 and op.type == 'batch_norm':
if in_name != 'X': if in_name != 'X':
...@@ -94,22 +95,25 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): ...@@ -94,22 +95,25 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
if in_var.type not in valid_types: if in_var.type not in valid_types:
continue continue
if in_var.dtype == src_dtype: if in_var.dtype == src_dtype:
out_var = block.create_var( cast_name = in_var.name + '.cast_' + _dtype_to_str(dest_dtype)
name=in_var.name + \ out_var = block.vars.get(cast_name)
'.cast_' + _dtype_to_str(dest_dtype), if out_var is None or out_var.dtype != dest_dtype:
dtype=dest_dtype, out_var = block.create_var(
persistable=False, name=cast_name,
stop_gradient=False) dtype=dest_dtype,
block._insert_op( persistable=False,
idx, stop_gradient=False)
type="cast",
inputs={"X": in_var}, block._insert_op(
outputs={"Out": out_var}, idx,
attrs={ type="cast",
"in_dtype": in_var.dtype, inputs={"X": in_var},
"out_dtype": out_var.dtype outputs={"Out": out_var},
}) attrs={
num_cast_ops += 1 "in_dtype": in_var.dtype,
"out_dtype": out_var.dtype
})
num_cast_ops += 1
_rename_arg(op, in_var.name, out_var.name) _rename_arg(op, in_var.name, out_var.name)
else: else:
if op.has_attr('in_dtype'): if op.has_attr('in_dtype'):
...@@ -155,6 +159,18 @@ def find_true_prev_op(ops, cur_op, var_name): ...@@ -155,6 +159,18 @@ def find_true_prev_op(ops, cur_op, var_name):
return None return None
def _is_in_black_varnames(op, amp_lists):
for in_name in op.input_arg_names:
if in_name in amp_lists.black_varnames:
return True
for out_name in op.output_arg_names:
if out_name in amp_lists.black_varnames:
return True
return False
def rewrite_program(main_prog, amp_lists): def rewrite_program(main_prog, amp_lists):
""" """
Traverse all ops in current block and insert cast op according to Traverse all ops in current block and insert cast op according to
...@@ -180,6 +196,11 @@ def rewrite_program(main_prog, amp_lists): ...@@ -180,6 +196,11 @@ def rewrite_program(main_prog, amp_lists):
white_op_set = set() white_op_set = set()
black_op_set = set() black_op_set = set()
for op in ops: for op in ops:
if amp_lists.black_varnames is not None and _is_in_black_varnames(
op, amp_lists):
black_op_set.add(op)
continue
if op.type in amp_lists.black_list: if op.type in amp_lists.black_list:
black_op_set.add(op) black_op_set.add(op)
elif op.type in amp_lists.white_list: elif op.type in amp_lists.white_list:
......
...@@ -134,8 +134,11 @@ def train(net_type, use_cuda, save_dirname, is_local): ...@@ -134,8 +134,11 @@ def train(net_type, use_cuda, save_dirname, is_local):
optimizer = fluid.optimizer.Lamb(learning_rate=0.001) optimizer = fluid.optimizer.Lamb(learning_rate=0.001)
amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionLists(
custom_black_varnames={"loss", "conv2d_0.w_0"})
mp_optimizer = fluid.contrib.mixed_precision.decorate( mp_optimizer = fluid.contrib.mixed_precision.decorate(
optimizer=optimizer, optimizer=optimizer,
amp_lists=amp_lists,
init_loss_scaling=8.0, init_loss_scaling=8.0,
use_dynamic_loss_scaling=True) use_dynamic_loss_scaling=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册