提交 c6a598a2 编写于 作者: J Jie Fang 提交者: gongweibao

init new amp, optimize inserting cast op for batchnorm (#18596)

init new amp, optimize inserting cast op for batchnorm
上级 2f037c31
...@@ -17,7 +17,6 @@ from ... import default_startup_program ...@@ -17,7 +17,6 @@ from ... import default_startup_program
from ... import layers from ... import layers
from ... import unique_name from ... import unique_name
from . import fp16_utils from . import fp16_utils
from .fp16_utils import create_master_params_grads, master_param_to_train_param
from .fp16_utils import update_loss_scaling, rewrite_program from .fp16_utils import update_loss_scaling, rewrite_program
from .fp16_lists import AutoMixedPrecisionLists from .fp16_lists import AutoMixedPrecisionLists
...@@ -128,19 +127,20 @@ class OptimizerWithMixedPrecison(object): ...@@ -128,19 +127,20 @@ class OptimizerWithMixedPrecison(object):
self._param_grads = self._optimizer.backward( self._param_grads = self._optimizer.backward(
scaled_loss, startup_program, parameter_list, no_grad_set, scaled_loss, startup_program, parameter_list, no_grad_set,
callbacks) callbacks)
master_params_grads = create_master_params_grads( scaled_params_grad = []
self._param_grads, self._train_program, self._startup_prog, for p, g in self._param_grads:
self._loss_scaling) scaled_g = g / self._loss_scaling
scaled_params_grad.append([p, scaled_g])
return master_params_grads, scaled_loss return scaled_params_grad, scaled_loss
def apply_gradients(self, master_params_grads): def apply_gradients(self, scaled_params_grads):
""" """
Update master parameters by their gradients, and cast to parameters Check scaled gradients to determine whether to update loss scaling and update
in float16. parameters by their scaled gradients,
Args: Args:
master_params_grads (list): A list of master params and grads. scaled_params_grads (list): A list of params and scaled grads.
Returns: Returns:
A list of optimize operators. A list of optimize operators.
...@@ -148,7 +148,7 @@ class OptimizerWithMixedPrecison(object): ...@@ -148,7 +148,7 @@ class OptimizerWithMixedPrecison(object):
if self._use_dynamic_loss_scaling: if self._use_dynamic_loss_scaling:
grads = [layers.reduce_sum(g) for [_, g] in master_params_grads] grads = [layers.reduce_sum(g) for [_, g] in scaled_params_grads]
all_grads = layers.concat(grads) all_grads = layers.concat(grads)
all_grads_sum = layers.reduce_sum(all_grads) all_grads_sum = layers.reduce_sum(all_grads)
is_overall_finite = layers.isfinite(all_grads_sum) is_overall_finite = layers.isfinite(all_grads_sum)
...@@ -165,12 +165,10 @@ class OptimizerWithMixedPrecison(object): ...@@ -165,12 +165,10 @@ class OptimizerWithMixedPrecison(object):
with switch.case(is_overall_finite): with switch.case(is_overall_finite):
pass pass
with switch.default(): with switch.default():
for _, g in master_params_grads: for _, g in scaled_params_grads:
layers.assign(layers.zeros_like(g), g) layers.assign(layers.zeros_like(g), g)
optimize_ops = self._optimizer.apply_gradients(master_params_grads) optimize_ops = self._optimizer.apply_gradients(scaled_params_grads)
master_param_to_train_param(master_params_grads, self._param_grads,
self._train_program)
return optimize_ops return optimize_ops
...@@ -183,12 +181,12 @@ class OptimizerWithMixedPrecison(object): ...@@ -183,12 +181,12 @@ class OptimizerWithMixedPrecison(object):
Returns: Returns:
The scaled loss by scaling factor, the list of optimize ops, and a The scaled loss by scaling factor, the list of optimize ops, and a
list of master parameters and gradients. list of scaled parameters and gradients.
""" """
master_params_grads, scaled_loss = self.backward(loss) scaled_params_grads, scaled_loss = self.backward(loss)
optimize_ops = self.apply_gradients(master_params_grads) optimize_ops = self.apply_gradients(scaled_params_grads)
return scaled_loss, optimize_ops, master_params_grads return scaled_loss, optimize_ops, scaled_params_grads
def decorate(optimizer, def decorate(optimizer,
......
...@@ -94,6 +94,7 @@ gray_list = { ...@@ -94,6 +94,7 @@ gray_list = {
'elementwise_pow', 'elementwise_pow',
'elementwise_mod', 'elementwise_mod',
'elementwise_floordiv', 'elementwise_floordiv',
'batch_norm',
'tanh', 'tanh',
'sigmoid', 'sigmoid',
'lookup_table', 'lookup_table',
......
...@@ -36,92 +36,6 @@ def append_cast_op(i, o, prog): ...@@ -36,92 +36,6 @@ def append_cast_op(i, o, prog):
"out_dtype": o.dtype}) "out_dtype": o.dtype})
def copy_to_master_param(p, block):
"""
New a master parameter for the input parameter, and they two share the same
attributes except the data type.
Args:
p(Parameter): The input parameter in float16.
block(Program): The block in which the parameter is.
"""
v = block.vars.get(p.name, None)
if v is None:
raise ValueError("no param name %s found!" % p.name)
new_p = framework.Parameter(
block=block,
shape=v.shape,
dtype=core.VarDesc.VarType.FP32,
type=v.type,
lod_level=v.lod_level,
stop_gradient=p.stop_gradient,
trainable=p.trainable,
optimize_attr=p.optimize_attr,
regularizer=p.regularizer,
gradient_clip_attr=p.gradient_clip_attr,
error_clip=p.error_clip,
name=v.name + ".master")
return new_p
def create_master_params_grads(params_grads, main_prog, startup_prog,
loss_scaling):
"""
Create master parameters and gradients in float32 from params and grads
in float16.
Args:
params_grads (list): A list of tuple (parameter, gradient) in float32.
main_prog (Program): The main program for training.
startup_prog (Program): The startup program to initialize all parameters.
loss_scaling (float): The factor to scale loss and gradients.
Returns:
A list of master parameters and gradients.
"""
master_params_grads = []
for p, g in params_grads:
# create master parameters
with main_prog._optimized_guard([p, g]):
# create master parameters
master_param = copy_to_master_param(p, main_prog.global_block())
startup_master_param = startup_prog.global_block()._clone_variable(
master_param)
startup_p = startup_prog.global_block().var(p.name)
# fp16 -> fp32
append_cast_op(startup_p, startup_master_param, startup_prog)
# cast fp16 gradients to fp32 before apply gradients
if g.name.find("batch_norm") > -1:
scaled_g = g / loss_scaling
master_params_grads.append([p, scaled_g])
continue
master_grad = layers.cast(x=g, dtype="float32")
master_grad = master_grad / loss_scaling
master_params_grads.append([master_param, master_grad])
return master_params_grads
def master_param_to_train_param(master_params_grads, params_grads, main_prog):
"""
Convert master master parameters and gradients in float32 to parameters and
gradients in float16 for forward computation.
Args:
master_params_grads (list): A list of master parameters and gradients in
float32.
params_grads (list): A list of parameters and gradients in float16.
main_prog (list): The main program for execution.
"""
for idx, m_p_g in enumerate(master_params_grads):
train_p, _ = params_grads[idx]
if train_p.name.find("batch_norm") > -1:
continue
with main_prog._optimized_guard([m_p_g[0], m_p_g[1]]):
# fp32 -> fp16
append_cast_op(m_p_g[0], train_p, main_prog)
def _rename_arg(op, old_name, new_name): def _rename_arg(op, old_name, new_name):
""" """
If an op has old_name input and output, rename these input If an op has old_name input and output, rename these input
...@@ -172,6 +86,9 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): ...@@ -172,6 +86,9 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
core.VarDesc.VarType.LOD_TENSOR_ARRAY core.VarDesc.VarType.LOD_TENSOR_ARRAY
] ]
for in_name in op.input_names: for in_name in op.input_names:
if src_dtype == core.VarDesc.VarType.FP32 and op.type == 'batch_norm':
if in_name != 'X':
continue
for in_var_name in op.input(in_name): for in_var_name in op.input(in_name):
in_var = block.var(in_var_name) in_var = block.var(in_var_name)
if in_var.type not in valid_types: if in_var.type not in valid_types:
...@@ -197,16 +114,18 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): ...@@ -197,16 +114,18 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
else: else:
if op.has_attr('in_dtype'): if op.has_attr('in_dtype'):
op._set_attr('in_dtype', dest_dtype) op._set_attr('in_dtype', dest_dtype)
if src_dtype == core.VarDesc.VarType.FP16: if src_dtype == core.VarDesc.VarType.FP32:
for out_name in op.output_names: for out_name in op.output_names:
if op.type == 'batch_norm' and out_name != 'Y':
continue
for out_var_name in op.output(out_name): for out_var_name in op.output(out_name):
out_var = block.var(out_var_name) out_var = block.var(out_var_name)
if out_var.type not in valid_types: if out_var.type not in valid_types:
continue continue
if out_var.dtype == core.VarDesc.VarType.FP16: if out_var.dtype == core.VarDesc.VarType.FP32:
out_var.desc.set_dtype(core.VarDesc.VarType.FP32) out_var.desc.set_dtype(core.VarDesc.VarType.FP16)
if op.has_attr('out_dtype'): if op.has_attr('out_dtype'):
op._set_attr('out_dtype', core.VarDesc.VarType.FP32) op._set_attr('out_dtype', core.VarDesc.VarType.FP16)
return num_cast_ops return num_cast_ops
......
...@@ -113,13 +113,12 @@ def train(net_type, use_cuda, save_dirname, is_local): ...@@ -113,13 +113,12 @@ def train(net_type, use_cuda, save_dirname, is_local):
name='pixel', shape=data_shape, dtype='float32') name='pixel', shape=data_shape, dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64') label = fluid.layers.data(name='label', shape=[1], dtype='int64')
imgs = fluid.layers.cast(images, "float16")
if net_type == "vgg": if net_type == "vgg":
print("train vgg net") print("train vgg net")
net = vgg16_bn_drop(imgs) net = vgg16_bn_drop(images)
elif net_type == "resnet": elif net_type == "resnet":
print("train resnet") print("train resnet")
net = resnet_cifar10(imgs, 32) net = resnet_cifar10(images, 32)
else: else:
raise ValueError("%s network is not supported" % net_type) raise ValueError("%s network is not supported" % net_type)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册