提交 371d8e21 编写于 作者: D danleifeng 提交者: lilong12

add fp16 training code (#4)

* add support for mixed precision training
上级 be5ba881
......@@ -99,6 +99,10 @@ class Entry(object):
self.fs_ugi = None
self.fs_dir = None
self.use_fp16 = False
self.init_loss_scaling = 1.0
self.fp16_user_dict = None
self.val_targets = self.config.val_targets
self.dataset_dir = self.config.dataset_dir
self.num_classes = self.config.num_classes
......@@ -145,6 +149,17 @@ class Entry(object):
self.global_train_batch_size = batch_size * self.num_trainers
logger.info("Set train batch size to {}.".format(batch_size))
def set_mixed_precision(self, use_fp16, loss_scaling):
"""
Whether to use mixed precision training.
"""
self.use_fp16 = use_fp16
self.init_loss_scaling = loss_scaling
self.fp16_user_dict = dict()
self.fp16_user_dict['init_loss_scaling'] = self.init_loss_scaling
logger.info("Use mixed precision training: {}.".format(use_fp16))
logger.info("Set init loss scaling to {}.".format(loss_scaling))
def set_test_batch_size(self, batch_size):
self.test_batch_size = batch_size
self.global_test_batch_size = batch_size * self.num_trainers
......@@ -293,8 +308,12 @@ class Entry(object):
if self.loss_type in ["dist_softmax", "dist_arcface"]:
self.optimizer = DistributedClassificationOptimizer(
self.optimizer, global_batch_size)
self.optimizer, global_batch_size, use_fp16=self.use_fp16,
loss_type=self.loss_type,
fp16_user_dict=self.fp16_user_dict)
elif self.use_fp16:
self.optimizer = fluid.contrib.mixed_precision.decorate(
optimizer=optimizer, init_loss_scaling=self.init_loss_scaling)
return self.optimizer
def build_program(self,
......@@ -358,7 +377,7 @@ class Entry(object):
dist_optimizer = self.fleet.distributed_optimizer(
optimizer, strategy=self.strategy)
dist_optimizer.minimize(loss)
if "dist" in self.loss_type:
if "dist" in self.loss_type or self.use_fp16:
optimizer = optimizer._optimizer
elif use_parallel_test:
emb = fluid.layers.collective._c_allgather(emb,
......
......@@ -15,6 +15,7 @@
from __future__ import print_function
import math
import logging
from six.moves import reduce
import paddle.fluid as fluid
from paddle.fluid.layer_helper import LayerHelper
......@@ -23,8 +24,13 @@ from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import Normal, Constant
import paddle.fluid.layers.nn as nn
import paddle.fluid.layers.ops as ops
import paddle.fluid.layers as layers
import paddle.fluid.layers.collective as collective
from paddle.fluid.optimizer import Optimizer
import paddle.fluid.unique_name as unique_name
from ..utils.fp16_utils import rewrite_program, update_role_var_grad
from ..utils.fp16_utils import update_loss_scaling, move_optimize_ops_back
from ..utils.fp16_lists import AutoMixedPrecisionLists
class DistributedClassificationOptimizer(Optimizer):
......@@ -32,53 +38,246 @@ class DistributedClassificationOptimizer(Optimizer):
A optimizer wrapper to generate backward network for distributed
classification training of model parallelism.
'''
def __init__(self, optimizer, batch_size, use_fp16=False):
def init_fp16_params(self, loss_type, fp16_user_dict):
# set default value for fp16_params_dict
fp16_params_dict = dict()
fp16_params_dict['amp_lists']= None
fp16_params_dict['init_loss_scaling'] = 1.0
fp16_params_dict['incr_every_n_steps'] = 1000
fp16_params_dict['decr_every_n_nan_or_inf'] = 2
fp16_params_dict['incr_ratio'] = 2.0
fp16_params_dict['decr_ratio'] = 0.5
fp16_params_dict['use_dynamic_loss_scaling'] = True
if fp16_user_dict is not None:
# update fp16_params_dict
for key in fp16_user_dict:
if fp16_params_dict.has_key(key):
fp16_params_dict[key] = fp16_user_dict[key]
else:
logging.warning("Can't find name '%s' in our fp16_params_dict. "
"Please check your dict key. You can set fp16 params only "
"in [amp_lists, init_loss_scaling, decr_every_n_nan_or_inf, "
"incr_ratio, decr_ratio, use_dynamic_loss_scaling]." % (key))
self._amp_lists = fp16_params_dict['amp_lists']
if self._amp_lists is None:
self._amp_lists = AutoMixedPrecisionLists()
self._loss_type = loss_type
self._loss_scaling = layers.create_global_var(
name=unique_name.generate("loss_scaling"),
shape=[1],
value=fp16_params_dict['init_loss_scaling'],
dtype='float32',
persistable=True)
self._use_dynamic_loss_scaling = fp16_params_dict['use_dynamic_loss_scaling']
if self._use_dynamic_loss_scaling:
self._incr_every_n_steps = layers.fill_constant(
shape=[1], dtype='int32', value=fp16_params_dict['incr_every_n_steps'])
self._decr_every_n_nan_or_inf = layers.fill_constant(
shape=[1], dtype='int32', value=fp16_params_dict['decr_every_n_nan_or_inf'])
self._incr_ratio = fp16_params_dict['incr_ratio']
self._decr_ratio = fp16_params_dict['decr_ratio']
self._num_good_steps = layers.create_global_var(
name=unique_name.generate("num_good_steps"),
shape=[1],
value=0,
dtype='int32',
persistable=True)
self._num_bad_steps = layers.create_global_var(
name=unique_name.generate("num_bad_steps"),
shape=[1],
value=0,
dtype='int32',
persistable=True)
# Ensure the data type of learning rate vars is float32 (same as the
# master parameter dtype)
if isinstance(self._optimizer._learning_rate, float):
self._optimizer._learning_rate_map[fluid.default_main_program()] = \
layers.create_global_var(
name=unique_name.generate("learning_rate"),
shape=[1],
value=float(self._optimizer._learning_rate),
dtype='float32',
persistable=True)
def __init__(self,
optimizer,
batch_size,
use_fp16=False,
loss_type='dist_arcface',
fp16_user_dict=None):
self._optimizer = optimizer
self._batch_size = batch_size
self._use_fp16 = use_fp16
def minimize(self,
loss,
startup_program=None,
parameter_list=None,
no_grad_set=None):
assert loss._get_info('shard_logit')
if self._use_fp16:
self.init_fp16_params(loss_type, fp16_user_dict)
def fp16_backward(self,
block,
scalar,
startup_program=None,
parameter_list=None,
no_grad_set=None,
callbacks=None):
rewrite_program(block.program, self._amp_lists)
self._params_grads = self._optimizer.backward(
scalar, startup_program, parameter_list,
no_grad_set, callbacks)
update_role_var_grad(block.program, self._params_grads)
move_optimize_ops_back(block.program.global_block())
scaled_params_grads = []
for p, g in self._params_grads:
with fluid.default_main_program()._optimized_guard([p, g]):
scaled_g = g / self._loss_scaling
scaled_params_grads.append([p, scaled_g])
return scaled_params_grads
def insert_dist_arcface_backward_op(self,
block,
index,
shard_logit,
shard_prob,
shard_label,
shard_dim,
op_role_key,
backward_role,
loss_backward_role):
'''
during mixed precision training(use_fp16=True), insert backward ops
when loss_type equals dist_arcface.
'''
shard_one_hot = block.create_var(
name=fluid.unique_name.generate('shard_one_hot'),
dtype=shard_logit.dtype)
shard_one_hot_fp32 = block.create_var(
name=fluid.unique_name.generate(shard_one_hot.name+'.cast_fp32'),
dtype=shard_logit.dtype)
# input var of elementwise_add_grad op after scale
shard_logit_grad_fp32 = block.var('tmp_3@GRAD')
shard_logit = loss._get_info('shard_logit')
shard_prob = loss._get_info('shard_prob')
shard_label = loss._get_info('shard_label')
shard_dim = loss._get_info('shard_dim')
block._insert_op(
index - 1,
type='one_hot',
inputs={'X': shard_label},
outputs={'Out': shard_one_hot},
attrs={
'depth': shard_dim,
'allow_out_of_range': True,
op_role_key: backward_role
})
block._insert_op(
index,
type="cast",
inputs={"X": shard_one_hot},
outputs={"Out": shard_one_hot_fp32},
attrs={
"in_dtype": fluid.core.VarDesc.VarType.FP16,
"out_dtype": fluid.core.VarDesc.VarType.FP32,
op_role_key: backward_role
})
block._insert_op(
index + 1,
type='elementwise_sub',
inputs={'X': shard_prob,
'Y': shard_one_hot_fp32},
outputs={'Out': shard_logit_grad_fp32},
attrs={op_role_key: backward_role})
block._insert_op(
index + 2,
type='elementwise_mul',
inputs={'X': shard_logit_grad_fp32,
'Y': self._loss_scaling},
outputs={'Out': shard_logit_grad_fp32},
attrs={op_role_key: backward_role})
block._insert_op(
index + 3,
type='scale',
inputs={'X': shard_logit_grad_fp32},
outputs={'Out': shard_logit_grad_fp32},
attrs={
'scale': 1.0 / self._batch_size,
op_role_key: loss_backward_role
})
op_maker = fluid.core.op_proto_and_checker_maker
op_role_key = op_maker.kOpRoleAttrName()
op_role_var_key = op_maker.kOpRoleVarAttrName()
backward_role = int(op_maker.OpRole.Backward)
loss_backward_role = int(op_maker.OpRole.Loss) | int(
op_maker.OpRole.Backward)
def insert_dist_softmax_backward_op(self,
block,
index,
shard_logit,
shard_prob,
shard_label,
shard_dim,
op_role_key,
backward_role,
loss_backward_role):
'''
during mixed precision training(use_fp16=True), insert backward ops
when loss_type equals dist_softmax.
'''
shard_one_hot = block.create_var(
name=fluid.unique_name.generate('shard_one_hot'),
dtype=fluid.core.VarDesc.VarType.FP32)
shard_one_hot_fp32 = block.create_var(
name=fluid.unique_name.generate(shard_one_hot.name+'.cast_fp32'),
dtype=fluid.core.VarDesc.VarType.FP32)
shard_logit_grad_fp32 = block.var(shard_logit.name + ".cast_fp32@GRAD")
# minimize a scalar of reduce_sum to generate the backward network
scalar = fluid.layers.reduce_sum(shard_logit)
ret = self._optimizer.minimize(scalar)
block._insert_op(
index - 1,
type='one_hot',
inputs={'X': shard_label},
outputs={'Out': shard_one_hot_fp32},
attrs={
'depth': shard_dim,
'allow_out_of_range': True,
op_role_key: backward_role
})
block._insert_op(
index,
type='elementwise_sub',
inputs={'X': shard_prob,
'Y': shard_one_hot_fp32},
outputs={'Out': shard_logit_grad_fp32},
attrs={op_role_key: backward_role})
block._insert_op(
index + 1,
type='elementwise_mul',
inputs={'X': shard_logit_grad_fp32,
'Y': self._loss_scaling},
outputs={'Out': shard_logit_grad_fp32},
attrs={op_role_key: backward_role})
block._insert_op(
index + 2,
type='scale',
inputs={'X': shard_logit_grad_fp32},
outputs={'Out': shard_logit_grad_fp32},
attrs={
'scale': 1.0 / self._batch_size,
op_role_key: loss_backward_role
})
block = loss.block
# remove the unnecessary ops
index = 0
for i, op in enumerate(block.ops):
if op.all_attrs()[op_role_key] == loss_backward_role:
index = i
break
assert block.ops[index - 1].type == 'reduce_sum'
assert block.ops[index].type == 'fill_constant'
assert block.ops[index + 1].type == 'reduce_sum_grad'
block._remove_op(index + 1)
block._remove_op(index)
block._remove_op(index - 1)
# insert the calculated gradient
def insert_commom_backward_op(self,
block,
index,
shard_logit,
shard_prob,
shard_label,
shard_dim,
op_role_key,
backward_role,
loss_backward_role):
'''
insert backward ops when not using mixed precision training.
common use in all lose type.
'''
# insert the calculated gradient
dtype = shard_logit.dtype
shard_one_hot = fluid.layers.create_tensor(dtype, name='shard_one_hot')
shard_one_hot = fluid.layers.create_tensor(
dtype, name='shard_one_hot')
block._insert_op(
index - 1,
type='one_hot',
......@@ -89,8 +288,8 @@ class DistributedClassificationOptimizer(Optimizer):
'allow_out_of_range': True,
op_role_key: backward_role
})
shard_logit_grad = fluid.layers.create_tensor(
dtype, name=fluid.backward._append_grad_suffix_(shard_logit.name))
shard_logit_grad = fluid.layers.create_tensor(dtype,
name=fluid.backward._append_grad_suffix_(shard_logit.name))
block._insert_op(
index,
type='elementwise_sub',
......@@ -107,7 +306,117 @@ class DistributedClassificationOptimizer(Optimizer):
'scale': 1.0 / self._batch_size,
op_role_key: loss_backward_role
})
return ret
def minimize(self,
loss,
startup_program=None,
parameter_list=None,
no_grad_set=None,
callbacks=None):
assert loss._get_info('shard_logit')
shard_logit = loss._get_info('shard_logit')
shard_prob = loss._get_info('shard_prob')
shard_label = loss._get_info('shard_label')
shard_dim = loss._get_info('shard_dim')
op_maker = fluid.core.op_proto_and_checker_maker
op_role_key = op_maker.kOpRoleAttrName()
op_role_var_key = op_maker.kOpRoleVarAttrName()
backward_role = int(op_maker.OpRole.Backward)
loss_backward_role = int(op_maker.OpRole.Loss) | int(
op_maker.OpRole.Backward)
# minimize a scalar of reduce_sum to generate the backward network
scalar = fluid.layers.reduce_sum(shard_logit)
block = loss.block
if not self._use_fp16:
ret = self._optimizer.minimize(scalar)
# remove the unnecessary ops
index = 0
for i, op in enumerate(block.ops):
if op.all_attrs()[op_role_key] == loss_backward_role:
index = i
break
assert block.ops[index - 1].type == 'reduce_sum'
assert block.ops[index].type == 'fill_constant'
assert block.ops[index + 1].type == 'reduce_sum_grad'
block._remove_op(index + 1)
block._remove_op(index)
block._remove_op(index - 1)
self.insert_commom_backward_op(block, index, shard_logit, shard_prob,
shard_label, shard_dim, op_role_key,
backward_role, loss_backward_role)
return ret
else:
scaled_params_grads = self.fp16_backward(block, scalar, startup_program,
parameter_list, no_grad_set, callbacks)
index = 0
for i, op in enumerate(block.ops):
if op.all_attrs()[op_role_key] == loss_backward_role:
index = i
break
if self._loss_type == 'dist_arcface':
assert block.ops[index - 2].type == 'fill_constant'
assert block.ops[index - 1].type == 'reduce_sum'
assert block.ops[index].type == 'fill_constant'
assert block.ops[index + 1].type == 'reduce_sum_grad'
assert block.ops[index + 2].type == 'scale'
assert block.ops[index + 3].type == 'elementwise_add_grad'
block._remove_op(index + 2)
block._remove_op(index + 1)
block._remove_op(index)
block._remove_op(index - 1)
self.insert_dist_arcface_backward_op(block, index, shard_logit, shard_prob,
shard_label, shard_dim, op_role_key,
backward_role, loss_backward_role)
elif self._loss_type == 'dist_softmax':
assert block.ops[index - 1].type == 'reduce_sum'
assert block.ops[index].type == 'fill_constant'
assert block.ops[index + 1].type == 'reduce_sum_grad'
assert block.ops[index + 2].type == 'cast'
assert block.ops[index + 3].type == 'elementwise_add_grad'
block._remove_op(index + 1)
block._remove_op(index)
block._remove_op(index - 1)
self.insert_dist_softmax_backward_op(block, index, shard_logit, shard_prob,
shard_label, shard_dim, op_role_key,
backward_role, loss_backward_role)
if self._use_dynamic_loss_scaling:
grads = [layers.reduce_sum(g) for [_, g] in scaled_params_grads]
all_grads = layers.concat(grads)
all_grads_sum = layers.reduce_sum(all_grads)
is_overall_finite = layers.isfinite(all_grads_sum)
update_loss_scaling(is_overall_finite, self._loss_scaling,
self._num_good_steps, self._num_bad_steps,
self._incr_every_n_steps,
self._decr_every_n_nan_or_inf,
self._incr_ratio,
self._decr_ratio)
with layers.Switch() as switch:
with switch.case(is_overall_finite):
pass
with switch.default():
for _, g in scaled_params_grads:
layers.assign(layers.zeros_like(g), g)
optimize_ops = self._optimizer.apply_gradients(scaled_params_grads)
ret = optimize_ops, scaled_params_grads
return ret
class DistributedClassifier(object):
......@@ -307,7 +616,7 @@ def _distributed_softmax_classify(x,
'''
Classification layer with FC, softmax and cross entropy calculation of
distibuted version in case of too large number of classes.
Args:
x (Variable): The feature representation of the input samples. This
feature will be flattened into 2-D tensor from dimension index
......@@ -330,13 +639,13 @@ def _distributed_softmax_classify(x,
import paddle.fluid as fluid
input = fluid.layers.data(name="input",
shape=[32, 1024],
dtype='float32',
append_batch_size=False)
shape=[32, 1024],
dtype='float32',
append_batch_size=False)
label = fluid.layers.data(name="label",
shape=[32, 1],
dtype='int64',
append_batch_size=False)
shape=[32, 1],
dtype='int64',
append_batch_size=False)
y = fluid.layers.collective.distributed_softmax_classify(x=input,
label=label,
class_num=1000,
......@@ -372,7 +681,7 @@ def _distributed_arcface_classify(x,
where the :math: `\theta_{y_i}` is the angle between the feature :math: `x` and
the representation of class :math: `i`. The details of ArcFace loss
could be referred to https://arxiv.org/abs/1801.07698.
Args:
x (Variable): The feature representation of the input samples. This
feature will be flattened into 2-D tensor from dimension index
......@@ -397,13 +706,13 @@ def _distributed_arcface_classify(x,
import paddle.fluid as fluid
input = fluid.layers.data(name="input",
shape=[32, 1024],
dtype='float32',
append_batch_size=False)
shape=[32, 1024],
dtype='float32',
append_batch_size=False)
label = fluid.layers.data(name="label",
shape=[32, 1],
dtype='int64',
append_batch_size=False)
shape=[32, 1],
dtype='int64',
append_batch_size=False)
y = fluid.layers.collective.distributed_arcface_classify(x=input,
label=label,
class_num=1000,
......@@ -420,4 +729,3 @@ def _distributed_arcface_classify(x,
margin=margin,
logit_scale=logit_scale,
param_attr=param_attr)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
__all__ = ["AutoMixedPrecisionLists"]
class AutoMixedPrecisionLists(object):
"""
AutoMixedPrecisionLists is a class for black/white list. It can update
pre-defined black list and white list according to users' custom black
white lists. The lists are used for an algorithm which determines op's
execution mode (fp32 or fp16).
Args:
custom_white_list (set): Users' custom white list.
custom_black_list (set): Users' custom black list.
"""
def __init__(self,
custom_white_list=None,
custom_black_list=None,
custom_black_varnames=None):
self._custom_white_list = custom_white_list
self._custom_black_list = custom_black_list
self.white_list = copy.copy(white_list)
self.black_list = copy.copy(black_list)
self.gray_list = copy.copy(gray_list)
self.black_varnames = copy.copy(custom_black_varnames)
self._update_list()
def _update_list(self):
"""
Update black and white list according to users' custom list.
"""
if self._custom_white_list and self._custom_black_list:
for op_name in self._custom_white_list:
if op_name in self._custom_black_list:
raise ValueError("Custom white list overlap "
"custom black list")
if self._custom_white_list:
for op_name in self._custom_white_list:
if op_name in self.black_list:
self.black_list.remove(op_name)
elif op_name in self.gray_list:
self.gray_list.remove(op_name)
self.white_list.add(op_name)
if self._custom_black_list:
for op_name in self._custom_black_list:
if op_name in self.white_list:
self.white_list.remove(op_name)
elif op_name in self.gray_list:
self.gray_list.remove(op_name)
self.black_list.add(op_name)
# The three sets listed below are changed dynamiclly. They don't contain all
# paddle ops currently.
# The set of ops that support fp16 calculation and are considered numerically-
# safe and performance-critical. These ops are always converted to fp16.
white_list = {
'conv2d',
'matmul',
'mul',
}
# The set of ops that support fp16 calculation and are considered numerically-
# dangerous and whose effects may also be observed in downstream ops.
black_list = {
'exp',
'square',
'log',
'mean',
'sum',
'cos_sim',
'softmax',
'softmax_with_cross_entropy',
'sigmoid_cross_entropy_with_logits',
'cross_entropy',
'cross_entropy2',
}
# This set contains two types of ops. All ops supported fp16 calculation. One
# of two types is considered numerically-safe, but may be made unsafe by an
# upstream blacklist op. Another type do not have numerically-significant
# effects, like stack, flatten2.
gray_list = {
'elementwise_add',
'elementwise_sub',
'elementwise_mul',
'elementwise_div',
'elementwise_max',
'elementwise_min',
'elementwise_pow',
'elementwise_mod',
'elementwise_floordiv',
'batch_norm',
'tanh',
'sigmoid',
'lookup_table',
'top_k',
'pool2d',
'pool3d',
'dropout',
'relu',
'relu6',
'leaky_relu',
'soft_relu',
'flatten2',
'stack',
'unstack',
'uniform_random_batch_size_like',
'gaussian_random',
'gaussian_random_batch_size_like',
'slice',
'rank',
'scale',
'transpose2',
'reshape2',
'gather',
'fill_constant',
'get_tensor_from_selected_rows',
'sign',
'cast',
}
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from paddle.fluid import core
from paddle.fluid import layers
def _rename_arg(op, old_name, new_name):
"""
If an op has old_name input and output, rename these input
args new_name.
Args:
op (Operator): Current operator.
old_name (str): The old name of input args.
new_name (str): The new name of input args.
"""
op_desc = op.desc
if isinstance(op_desc, tuple):
op_desc = op_desc[0]
op_desc._rename_input(old_name, new_name)
op_desc._rename_output(old_name, new_name)
def _dtype_to_str(dtype):
"""
Convert specific variable type to its corresponding string.
Args:
dtype (VarType): Variable type.
"""
if dtype == core.VarDesc.VarType.FP16:
return 'fp16'
else:
return 'fp32'
def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
"""
Insert cast op and rename args of input and output.
Args:
block (Program): The block in which the operator is.
op (Operator): The operator to insert cast op.
idx (int): The index of current operator.
src_dtype (VarType): The input variable dtype of cast op.
dest_dtype (VarType): The output variable dtype of cast op.
Returns:
num_cast_op (int): The number of cast ops that have been inserted.
"""
num_cast_ops = 0
valid_types = [
core.VarDesc.VarType.LOD_TENSOR, core.VarDesc.VarType.SELECTED_ROWS,
core.VarDesc.VarType.LOD_TENSOR_ARRAY
]
for in_name in op.input_names:
if src_dtype == core.VarDesc.VarType.FP32 and op.type == 'batch_norm':
if in_name != 'X':
continue
for in_var_name in op.input(in_name):
in_var = block.var(in_var_name)
if in_var.type not in valid_types:
continue
if in_var.dtype == src_dtype:
cast_name = in_var.name + '.cast_' + _dtype_to_str(dest_dtype)
out_var = block.vars.get(cast_name)
if out_var is None or out_var.dtype != dest_dtype:
out_var = block.create_var(
name=cast_name,
dtype=dest_dtype,
persistable=False,
stop_gradient=False)
block._insert_op(
idx,
type="cast",
inputs={"X": in_var},
outputs={"Out": out_var},
attrs={
"in_dtype": in_var.dtype,
"out_dtype": out_var.dtype
})
num_cast_ops += 1
_rename_arg(op, in_var.name, out_var.name)
else:
if op.has_attr('in_dtype'):
op._set_attr('in_dtype', dest_dtype)
if src_dtype == core.VarDesc.VarType.FP32:
for out_name in op.output_names:
if op.type == 'batch_norm' and out_name != 'Y':
continue
for out_var_name in op.output(out_name):
out_var = block.var(out_var_name)
if out_var.type not in valid_types:
continue
if out_var.dtype == core.VarDesc.VarType.FP32:
out_var.desc.set_dtype(core.VarDesc.VarType.FP16)
if op.has_attr('out_dtype'):
op._set_attr('out_dtype', core.VarDesc.VarType.FP16)
return num_cast_ops
def check_op_validation(ops, idx, cur_op):
"""
Check whether has op (backward_role) that inputs variable in
cur_op's outputs behind cur_op.
If yes, it means _move_optimize_ops_back will cause errors
in program order. Therefore raise valueerror.
If no, return True to continue moving in _move_optimize_ops_back.
Args:
ops (list): A list of ops.
idx (int): index of cur_op in ops.
cur_op (Operator): Current operator.
"""
OPTIMIZE = core.op_proto_and_checker_maker.OpRole.Optimize
for i in range(idx + 1, len(ops)):
op = ops[i]
if not op.attr('op_role') & int(OPTIMIZE):
for input_name in op.input_arg_names:
if input_name in cur_op.output_arg_names:
raise ValueError("There must be no next op that inputs {0} "
"variable after {1} op (optimize_role)".
format(var_name, cur_op.type))
return True
def move_optimize_ops_back(block):
"""
put optimize_role ops(cast) behind the backward_role ops
to speed up in Executor.
Args:
block: The global_block of main program. eg: block = main_prog.global_block()
"""
OPTIMIZE = core.op_proto_and_checker_maker.OpRole.Optimize
optimize_ops = []
for idx, op in reversed(list(enumerate(block.ops))):
if op.attr('op_role') & int(OPTIMIZE):
if check_op_validation(block.ops, idx, op):
optimize_ops.append([
op.type, dict(
zip(op.input_names,
[block.var(name) for name in op.input_arg_names])),
dict(
zip(op.output_names,
[block.var(name) for name in op.output_arg_names])),
op.all_attrs()
])
block._remove_op(idx)
for op in reversed(optimize_ops):
assert len(op) == 4
block.append_op(type=op[0], inputs=op[1], outputs=op[2], attrs=op[3])
def find_true_prev_op(ops, cur_op, var_name):
"""
Find the true prev op that outputs var_name variable.
Args:
ops (list): A list of ops.
cur_op (Operator): Current operator which has var_name variable.
var_name (string): Variable name.
"""
prev_op = []
for op in ops:
if op == cur_op:
break
for out_name in op.output_names:
for out_var_name in op.output(out_name):
if out_var_name == var_name:
prev_op.append(op)
if prev_op:
if not len(prev_op) == 1:
raise ValueError("There must be only one previous op "
"that outputs {0} variable".format(var_name))
else:
return prev_op[0]
return None
def _is_in_black_varnames(op, amp_lists):
for in_name in op.input_arg_names:
if in_name in amp_lists.black_varnames:
return True
for out_name in op.output_arg_names:
if out_name in amp_lists.black_varnames:
return True
return False
def rewrite_program(main_prog, amp_lists):
"""
Traverse all ops in current block and insert cast op according to
which set current op belongs to.
1. When an op belongs to the black list, add it to black set
2. When an op belongs to the white list, add it to white set
3. When an op belongs to the gray list. If one
of its inputs is the output of black set op or black list op,
add it to black set. If all of its previous ops are not black
op and one of its inputs is the output of white set op or
white list op, add it to white set.
4. When an op isn't in the lists, add it to black op set.
5. Add necessary cast ops to make sure that black set op will be
computed in fp32 mode, while white set op will be computed in
fp16 mode.
Args:
main_prog (Program): The main program for training.
"""
block = main_prog.global_block()
ops = block.ops
white_op_set = set()
black_op_set = set()
for op in ops:
if amp_lists.black_varnames is not None and _is_in_black_varnames(
op, amp_lists):
black_op_set.add(op)
continue
if op.type in amp_lists.black_list:
black_op_set.add(op)
elif op.type in amp_lists.white_list:
white_op_set.add(op)
elif op.type in amp_lists.gray_list:
is_black_op = False
is_white_op = False
for in_name in op.input_names:
# if this op has inputs
if in_name:
for in_var_name in op.input(in_name):
in_var = block.var(in_var_name)
# this in_var isn't the output of other op
if in_var.op is None:
continue
elif in_var.op is op:
prev_op = find_true_prev_op(ops, op, in_var_name)
if prev_op is None:
continue
else:
prev_op = in_var.op
# if it's one of inputs
if prev_op in black_op_set or \
prev_op.type in amp_lists.black_list:
is_black_op = True
elif prev_op in white_op_set or \
prev_op.type in amp_lists.white_list:
is_white_op = True
if is_black_op:
black_op_set.add(op)
elif is_white_op:
white_op_set.add(op)
else:
pass
else:
# For numerical safe, we apply fp32 computation on ops that
# are not determined which list they should stay.
black_op_set.add(op)
idx = 0
while idx < len(ops):
op = ops[idx]
num_cast_ops = 0
if op in black_op_set:
num_cast_ops = _insert_cast_op(block, op, idx,
core.VarDesc.VarType.FP16,
core.VarDesc.VarType.FP32)
elif op in white_op_set:
num_cast_ops = _insert_cast_op(block, op, idx,
core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16)
else:
pass
idx += num_cast_ops + 1
def update_role_var_grad(main_prog, params_grads):
"""
Update op_role_var attr for some ops to make sure the gradients
transferred across GPUs is FP16.
1. Check whether the op that outputs gradient is cast or not.
2. If op is cast and gradient is FP32, remove the op_role_var
and find the prev op which outputs FP16 gradient
3. Update the op_role_var of the prev op.
Args:
main_prog (Program): The main program for training.
params_grads (list): A list of params and grads.
"""
block = main_prog.global_block()
BACKWARD = core.op_proto_and_checker_maker.OpRole.Backward
OPTIMIZE = core.op_proto_and_checker_maker.OpRole.Optimize
for p, g in params_grads:
op = g.op
if g.dtype == core.VarDesc.VarType.FP32 and op.type == 'cast':
role = op.attr('op_role')
if role & int(BACKWARD) and op.has_attr('op_role_var'):
op.desc.remove_attr("op_role_var")
else:
raise ValueError("The cast op {0} must be in BACKWARD role "
"and have op_role_var attr.".format(op))
fp16_grad_name = op.input(op.input_names[0])[0]
op_for_fp16_grad = find_true_prev_op(block.ops, op, fp16_grad_name)
op_role_var_attr_name = \
core.op_proto_and_checker_maker.kOpRoleVarAttrName()
attr_val = [p.name, fp16_grad_name]
if op_for_fp16_grad.has_attr(op_role_var_attr_name):
attr_val.extend(op_for_fp16_grad.attr(op_role_var_attr_name))
op_for_fp16_grad._set_attr(op_role_var_attr_name, attr_val)
# Maximize the all_reduce overlap, and perform the cast
# operation after gradients transfer.
op._set_attr('op_role', OPTIMIZE)
move_optimize_ops_back(block)
def update_loss_scaling(is_overall_finite, prev_loss_scaling, num_good_steps,
num_bad_steps, incr_every_n_steps,
decr_every_n_nan_or_inf, incr_ratio, decr_ratio):
"""
Update loss scaling according to overall gradients. If all gradients is
finite after incr_every_n_steps, loss scaling will increase by incr_ratio.
Otherwise, loss scaling will decrease by decr_ratio after
decr_every_n_nan_or_inf steps and each step some gradients are infinite.
Args:
is_overall_finite (Variable): A boolean variable indicates whether
all gradients are finite.
prev_loss_scaling (Variable): Previous loss scaling.
num_good_steps (Variable): A variable accumulates good steps in which
all gradients are finite.
num_bad_steps (Variable): A variable accumulates bad steps in which
some gradients are infinite.
incr_every_n_steps (Variable): A variable represents increasing loss
scaling every n consecutive steps with
finite gradients.
decr_every_n_nan_or_inf (Variable): A variable represents decreasing
loss scaling every n accumulated
steps with nan or inf gradients.
incr_ratio(float): The multiplier to use when increasing the loss
scaling.
decr_ratio(float): The less-than-one-multiplier to use when decreasing
loss scaling.
"""
zero_steps = layers.fill_constant(shape=[1], dtype='int32', value=0)
with layers.Switch() as switch:
with switch.case(is_overall_finite):
should_incr_loss_scaling = layers.less_than(incr_every_n_steps,
num_good_steps + 1)
with layers.Switch() as switch1:
with switch1.case(should_incr_loss_scaling):
new_loss_scaling = prev_loss_scaling * incr_ratio
loss_scaling_is_finite = layers.isfinite(new_loss_scaling)
with layers.Switch() as switch2:
with switch2.case(loss_scaling_is_finite):
layers.assign(new_loss_scaling, prev_loss_scaling)
with switch2.default():
pass
layers.assign(zero_steps, num_good_steps)
layers.assign(zero_steps, num_bad_steps)
with switch1.default():
layers.increment(num_good_steps)
layers.assign(zero_steps, num_bad_steps)
with switch.default():
should_decr_loss_scaling = layers.less_than(decr_every_n_nan_or_inf,
num_bad_steps + 1)
with layers.Switch() as switch3:
with switch3.case(should_decr_loss_scaling):
new_loss_scaling = prev_loss_scaling * decr_ratio
static_loss_scaling = \
layers.fill_constant(shape=[1],
dtype='float32',
value=1.0)
less_than_one = layers.less_than(new_loss_scaling,
static_loss_scaling)
with layers.Switch() as switch4:
with switch4.case(less_than_one):
layers.assign(static_loss_scaling,
prev_loss_scaling)
with switch4.default():
layers.assign(new_loss_scaling, prev_loss_scaling)
layers.assign(zero_steps, num_good_steps)
layers.assign(zero_steps, num_bad_steps)
with switch3.default():
layers.assign(zero_steps, num_good_steps)
layers.increment(num_bad_steps)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册