提交 3255fe69 编写于 作者: G gongweibao 提交者: Yi Liu

Add custom black variable name set in amp interface. (#20875)

* add custom black varname test=develop

* fix dtype test=develop

* fix num test=develop

* fix ut test=develop

* fix coverage test=develop

* fix blackvar names test=develop
上级 aadd81b6
......@@ -29,12 +29,16 @@ class AutoMixedPrecisionLists(object):
custom_black_list (set): Users' custom black list.
"""
def __init__(self, custom_white_list=None, custom_black_list=None):
def __init__(self,
custom_white_list=None,
custom_black_list=None,
custom_black_varnames=None):
self._custom_white_list = custom_white_list
self._custom_black_list = custom_black_list
self.white_list = copy.copy(white_list)
self.black_list = copy.copy(black_list)
self.gray_list = copy.copy(gray_list)
self.black_varnames = copy.copy(custom_black_varnames)
self._update_list()
def _update_list(self):
......
......@@ -85,6 +85,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
core.VarDesc.VarType.LOD_TENSOR, core.VarDesc.VarType.SELECTED_ROWS,
core.VarDesc.VarType.LOD_TENSOR_ARRAY
]
for in_name in op.input_names:
if src_dtype == core.VarDesc.VarType.FP32 and op.type == 'batch_norm':
if in_name != 'X':
......@@ -94,22 +95,25 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
if in_var.type not in valid_types:
continue
if in_var.dtype == src_dtype:
out_var = block.create_var(
name=in_var.name + \
'.cast_' + _dtype_to_str(dest_dtype),
dtype=dest_dtype,
persistable=False,
stop_gradient=False)
block._insert_op(
idx,
type="cast",
inputs={"X": in_var},
outputs={"Out": out_var},
attrs={
"in_dtype": in_var.dtype,
"out_dtype": out_var.dtype
})
num_cast_ops += 1
cast_name = in_var.name + '.cast_' + _dtype_to_str(dest_dtype)
out_var = block.vars.get(cast_name)
if out_var is None or out_var.dtype != dest_dtype:
out_var = block.create_var(
name=cast_name,
dtype=dest_dtype,
persistable=False,
stop_gradient=False)
block._insert_op(
idx,
type="cast",
inputs={"X": in_var},
outputs={"Out": out_var},
attrs={
"in_dtype": in_var.dtype,
"out_dtype": out_var.dtype
})
num_cast_ops += 1
_rename_arg(op, in_var.name, out_var.name)
else:
if op.has_attr('in_dtype'):
......@@ -155,6 +159,18 @@ def find_true_prev_op(ops, cur_op, var_name):
return None
def _is_in_black_varnames(op, amp_lists):
for in_name in op.input_arg_names:
if in_name in amp_lists.black_varnames:
return True
for out_name in op.output_arg_names:
if out_name in amp_lists.black_varnames:
return True
return False
def rewrite_program(main_prog, amp_lists):
"""
Traverse all ops in current block and insert cast op according to
......@@ -180,6 +196,11 @@ def rewrite_program(main_prog, amp_lists):
white_op_set = set()
black_op_set = set()
for op in ops:
if amp_lists.black_varnames is not None and _is_in_black_varnames(
op, amp_lists):
black_op_set.add(op)
continue
if op.type in amp_lists.black_list:
black_op_set.add(op)
elif op.type in amp_lists.white_list:
......
......@@ -134,8 +134,11 @@ def train(net_type, use_cuda, save_dirname, is_local):
optimizer = fluid.optimizer.Lamb(learning_rate=0.001)
amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionLists(
custom_black_varnames={"loss", "conv2d_0.w_0"})
mp_optimizer = fluid.contrib.mixed_precision.decorate(
optimizer=optimizer,
amp_lists=amp_lists,
init_loss_scaling=8.0,
use_dynamic_loss_scaling=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册