未验证 提交 045282fa 编写于 作者: H huangxu96 提交者: GitHub

new usage of amp training. (#564)

* new usage of amp training.

* change the usage of amp and pure fp16 training.

* modified code as reviews
上级 1c1c9bb8
...@@ -11,21 +11,23 @@ validate: True ...@@ -11,21 +11,23 @@ validate: True
valid_interval: 1 valid_interval: 1
epochs: 120 epochs: 120
topk: 5 topk: 5
image_shape: [3, 224, 224] is_distributed: False
is_distributed: True
# mixed precision training use_dali: True
use_amp: True use_gpu: True
use_pure_fp16: False data_format: "NHWC"
multi_precision: False image_channel: &image_channel 4
scale_loss: 128.0 image_shape: [*image_channel, 224, 224]
use_dynamic_loss_scaling: True
data_format: "NCHW"
image_shape: [3, 224, 224]
use_mix: False use_mix: False
ls_epsilon: -1 ls_epsilon: -1
# mixed precision training
AMP:
scale_loss: 128.0
use_dynamic_loss_scaling: True
use_pure_fp16: &use_pure_fp16 True
LEARNING_RATE: LEARNING_RATE:
function: 'Piecewise' function: 'Piecewise'
params: params:
...@@ -37,6 +39,7 @@ OPTIMIZER: ...@@ -37,6 +39,7 @@ OPTIMIZER:
function: 'Momentum' function: 'Momentum'
params: params:
momentum: 0.9 momentum: 0.9
multi_precision: *use_pure_fp16
regularizer: regularizer:
function: 'L2' function: 'L2'
factor: 0.000100 factor: 0.000100
...@@ -61,6 +64,8 @@ TRAIN: ...@@ -61,6 +64,8 @@ TRAIN:
mean: [0.485, 0.456, 0.406] mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225] std: [0.229, 0.224, 0.225]
order: '' order: ''
output_fp16: *use_pure_fp16
channel_num: *image_channel
- ToCHWImage: - ToCHWImage:
VALID: VALID:
......
...@@ -195,14 +195,18 @@ class NormalizeImage(object): ...@@ -195,14 +195,18 @@ class NormalizeImage(object):
""" normalize image such as substract mean, divide std """ normalize image such as substract mean, divide std
""" """
def __init__(self, scale=None, mean=None, std=None, order='chw'): def __init__(self, scale=None, mean=None, std=None, order='chw', output_fp16=False, channel_num=3):
if isinstance(scale, str): if isinstance(scale, str):
scale = eval(scale) scale = eval(scale)
assert channel_num in [3, 4], "channel number of input image should be set to 3 or 4."
self.channel_num = channel_num
self.output_dtype = 'float16' if output_fp16 else 'float32'
self.scale = np.float32(scale if scale is not None else 1.0 / 255.0) self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
self.order = order
mean = mean if mean is not None else [0.485, 0.456, 0.406] mean = mean if mean is not None else [0.485, 0.456, 0.406]
std = std if std is not None else [0.229, 0.224, 0.225] std = std if std is not None else [0.229, 0.224, 0.225]
shape = (3, 1, 1) if order == 'chw' else (1, 1, 3) shape = (3, 1, 1) if self.order == 'chw' else (1, 1, 3)
self.mean = np.array(mean).reshape(shape).astype('float32') self.mean = np.array(mean).reshape(shape).astype('float32')
self.std = np.array(std).reshape(shape).astype('float32') self.std = np.array(std).reshape(shape).astype('float32')
...@@ -213,7 +217,16 @@ class NormalizeImage(object): ...@@ -213,7 +217,16 @@ class NormalizeImage(object):
assert isinstance(img, assert isinstance(img,
np.ndarray), "invalid input 'img' in NormalizeImage" np.ndarray), "invalid input 'img' in NormalizeImage"
return (img.astype('float32') * self.scale - self.mean) / self.std
img = (img.astype('float32') * self.scale - self.mean) / self.std
if self.channel_num == 4:
img_h = img.shape[1] if self.order == 'chw' else img.shape[0]
img_w = img.shape[2] if self.order == 'chw' else img.shape[1]
pad_zeros = np.zeros((1, img_h, img_w)) if self.order == 'chw' else np.zeros((img_h, img_w, 1))
img = (np.concatenate((img, pad_zeros), axis=0) if self.order == 'chw'
else np.concatenate((img, pad_zeros), axis=2))
return img.astype(self.output_dtype)
class ToCHWImage(object): class ToCHWImage(object):
......
...@@ -277,6 +277,10 @@ class ResNet(nn.Layer): ...@@ -277,6 +277,10 @@ class ResNet(nn.Layer):
bias_attr=ParamAttr(name="fc_0.b_0")) bias_attr=ParamAttr(name="fc_0.b_0"))
def forward(self, inputs): def forward(self, inputs):
with paddle.static.amp.fp16_guard():
if self.data_format == "NHWC":
inputs = paddle.tensor.transpose(inputs, [0, 2, 3, 1])
inputs.stop_gradient = True
y = self.conv(inputs) y = self.conv(inputs)
y = self.pool2d_max(y) y = self.pool2d_max(y)
for block in self.block_list: for block in self.block_list:
......
...@@ -42,16 +42,13 @@ class Loss(object): ...@@ -42,16 +42,13 @@ class Loss(object):
soft_target = paddle.reshape(soft_target, shape=[-1, self._class_dim]) soft_target = paddle.reshape(soft_target, shape=[-1, self._class_dim])
return soft_target return soft_target
def _crossentropy(self, input, target, use_pure_fp16=False): def _crossentropy(self, input, target):
if self._label_smoothing: if self._label_smoothing:
target = self._labelsmoothing(target) target = self._labelsmoothing(target)
input = -F.log_softmax(input, axis=-1) input = -F.log_softmax(input, axis=-1)
cost = paddle.sum(target * input, axis=-1) cost = paddle.sum(target * input, axis=-1)
else: else:
cost = F.cross_entropy(input=input, label=target) cost = F.cross_entropy(input=input, label=target)
if use_pure_fp16:
avg_cost = paddle.sum(cost)
else:
avg_cost = paddle.mean(cost) avg_cost = paddle.mean(cost)
return avg_cost return avg_cost
...@@ -81,8 +78,8 @@ class CELoss(Loss): ...@@ -81,8 +78,8 @@ class CELoss(Loss):
def __init__(self, class_dim=1000, epsilon=None): def __init__(self, class_dim=1000, epsilon=None):
super(CELoss, self).__init__(class_dim, epsilon) super(CELoss, self).__init__(class_dim, epsilon)
def __call__(self, input, target, use_pure_fp16=False): def __call__(self, input, target):
cost = self._crossentropy(input, target, use_pure_fp16) cost = self._crossentropy(input, target)
return cost return cost
...@@ -94,13 +91,10 @@ class MixCELoss(Loss): ...@@ -94,13 +91,10 @@ class MixCELoss(Loss):
def __init__(self, class_dim=1000, epsilon=None): def __init__(self, class_dim=1000, epsilon=None):
super(MixCELoss, self).__init__(class_dim, epsilon) super(MixCELoss, self).__init__(class_dim, epsilon)
def __call__(self, input, target0, target1, lam, use_pure_fp16=False): def __call__(self, input, target0, target1, lam):
cost0 = self._crossentropy(input, target0, use_pure_fp16) cost0 = self._crossentropy(input, target0)
cost1 = self._crossentropy(input, target1, use_pure_fp16) cost1 = self._crossentropy(input, target1)
cost = lam * cost0 + (1.0 - lam) * cost1 cost = lam * cost0 + (1.0 - lam) * cost1
if use_pure_fp16:
avg_cost = paddle.sum(cost)
else:
avg_cost = paddle.mean(cost) avg_cost = paddle.mean(cost)
return avg_cost return avg_cost
......
...@@ -74,19 +74,22 @@ class Momentum(object): ...@@ -74,19 +74,22 @@ class Momentum(object):
momentum, momentum,
parameter_list=None, parameter_list=None,
regularization=None, regularization=None,
multi_precision=False,
**args): **args):
super(Momentum, self).__init__() super(Momentum, self).__init__()
self.learning_rate = learning_rate self.learning_rate = learning_rate
self.momentum = momentum self.momentum = momentum
self.parameter_list = parameter_list self.parameter_list = parameter_list
self.regularization = regularization self.regularization = regularization
self.multi_precision = multi_precision
def __call__(self): def __call__(self):
opt = paddle.optimizer.Momentum( opt = paddle.optimizer.Momentum(
learning_rate=self.learning_rate, learning_rate=self.learning_rate,
momentum=self.momentum, momentum=self.momentum,
parameters=self.parameter_list, parameters=self.parameter_list,
weight_decay=self.regularization) weight_decay=self.regularization,
multi_precision=self.multi_precision)
return opt return opt
......
...@@ -176,7 +176,11 @@ def build(config, mode='train'): ...@@ -176,7 +176,11 @@ def build(config, mode='train'):
2: types.INTERP_CUBIC, # cv2.INTER_CUBIC 2: types.INTERP_CUBIC, # cv2.INTER_CUBIC
4: types.INTERP_LANCZOS3, # XXX use LANCZOS3 for cv2.INTER_LANCZOS4 4: types.INTERP_LANCZOS3, # XXX use LANCZOS3 for cv2.INTER_LANCZOS4
} }
output_dtype = types.FLOAT16 if config.get("use_pure_fp16", False) else types.FLOAT
output_dtype = (types.FLOAT16 if 'AMP' in config and
config.AMP.get("use_pure_fp16", False)
else types.FLOAT)
assert interp in interp_map, "interpolation method not supported by DALI" assert interp in interp_map, "interpolation method not supported by DALI"
interp = interp_map[interp] interp = interp_map[interp]
pad_output = False pad_output = False
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import paddle
import paddle.fluid as fluid
import paddle.regularizer as regularizer
__all__ = ['OptimizerBuilder']
class L1Decay(object):
"""
L1 Weight Decay Regularization, which encourages the weights to be sparse.
Args:
factor(float): regularization coeff. Default:0.0.
"""
def __init__(self, factor=0.0):
super(L1Decay, self).__init__()
self.factor = factor
def __call__(self):
reg = regularizer.L1Decay(self.factor)
return reg
class L2Decay(object):
"""
L2 Weight Decay Regularization, which encourages the weights to be sparse.
Args:
factor(float): regularization coeff. Default:0.0.
"""
def __init__(self, factor=0.0):
super(L2Decay, self).__init__()
self.factor = factor
def __call__(self):
reg = regularizer.L2Decay(self.factor)
return reg
class Momentum(object):
"""
Simple Momentum optimizer with velocity state.
Args:
learning_rate (float|Variable) - The learning rate used to update parameters.
Can be a float value or a Variable with one float value as data element.
momentum (float) - Momentum factor.
regularization (WeightDecayRegularizer, optional) - The strategy of regularization.
"""
def __init__(self,
learning_rate,
momentum,
parameter_list=None,
regularization=None,
config=None,
**args):
super(Momentum, self).__init__()
self.learning_rate = learning_rate
self.momentum = momentum
self.parameter_list = parameter_list
self.regularization = regularization
self.multi_precision = config.get('multi_precision', False)
self.rescale_grad = (1.0 / (config['TRAIN']['batch_size'] / len(fluid.cuda_places()))
if config.get('use_pure_fp16', False) else 1.0)
def __call__(self):
opt = fluid.contrib.optimizer.Momentum(
learning_rate=self.learning_rate,
momentum=self.momentum,
regularization=self.regularization,
multi_precision=self.multi_precision,
rescale_grad=self.rescale_grad)
return opt
class RMSProp(object):
"""
Root Mean Squared Propagation (RMSProp) is an unpublished, adaptive learning rate method.
Args:
learning_rate (float|Variable) - The learning rate used to update parameters.
Can be a float value or a Variable with one float value as data element.
momentum (float) - Momentum factor.
rho (float) - rho value in equation.
epsilon (float) - avoid division by zero, default is 1e-6.
regularization (WeightDecayRegularizer, optional) - The strategy of regularization.
"""
def __init__(self,
learning_rate,
momentum,
rho=0.95,
epsilon=1e-6,
parameter_list=None,
regularization=None,
**args):
super(RMSProp, self).__init__()
self.learning_rate = learning_rate
self.momentum = momentum
self.rho = rho
self.epsilon = epsilon
self.parameter_list = parameter_list
self.regularization = regularization
def __call__(self):
opt = paddle.optimizer.RMSProp(
learning_rate=self.learning_rate,
momentum=self.momentum,
rho=self.rho,
epsilon=self.epsilon,
parameters=self.parameter_list,
weight_decay=self.regularization)
return opt
class OptimizerBuilder(object):
"""
Build optimizer
Args:
function(str): optimizer name of learning rate
params(dict): parameters used for init the class
regularizer (dict): parameters used for create regularization
"""
def __init__(self,
config=None,
function='Momentum',
params={'momentum': 0.9},
regularizer=None):
self.function = function
self.params = params
self.config = config
# create regularizer
if regularizer is not None:
mod = sys.modules[__name__]
reg_func = regularizer['function'] + 'Decay'
del regularizer['function']
reg = getattr(mod, reg_func)(**regularizer)()
self.params['regularization'] = reg
def __call__(self, learning_rate, parameter_list=None):
mod = sys.modules[__name__]
opt = getattr(mod, self.function)
return opt(learning_rate=learning_rate,
parameter_list=parameter_list,
config=self.config,
**self.params)()
...@@ -21,12 +21,10 @@ import time ...@@ -21,12 +21,10 @@ import time
import numpy as np import numpy as np
from collections import OrderedDict from collections import OrderedDict
from optimizer import OptimizerBuilder from ppcls.optimizer import OptimizerBuilder
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle import fluid
from paddle.fluid.contrib.mixed_precision.fp16_utils import cast_model_to_fp16
from ppcls.optimizer.learning_rate import LearningRateBuilder from ppcls.optimizer.learning_rate import LearningRateBuilder
from ppcls.modeling import architectures from ppcls.modeling import architectures
...@@ -83,11 +81,9 @@ def create_model(architecture, image, classes_num, config, is_train): ...@@ -83,11 +81,9 @@ def create_model(architecture, image, classes_num, config, is_train):
Returns: Returns:
out(variable): model output variable out(variable): model output variable
""" """
use_pure_fp16 = config.get("use_pure_fp16", False)
name = architecture["name"] name = architecture["name"]
params = architecture.get("params", {}) params = architecture.get("params", {})
data_format = "NCHW"
if "data_format" in config: if "data_format" in config:
params["data_format"] = config["data_format"] params["data_format"] = config["data_format"]
data_format = config["data_format"] data_format = config["data_format"]
...@@ -101,15 +97,7 @@ def create_model(architecture, image, classes_num, config, is_train): ...@@ -101,15 +97,7 @@ def create_model(architecture, image, classes_num, config, is_train):
params['is_test'] = not is_train params['is_test'] = not is_train
model = architectures.__dict__[name](class_dim=classes_num, **params) model = architectures.__dict__[name](class_dim=classes_num, **params)
if use_pure_fp16 and not config.get("use_dali", False):
image = image.astype('float16')
if data_format == "NHWC":
image = paddle.tensor.transpose(image, [0, 2, 3, 1])
image.stop_gradient = True
out = model(image) out = model(image)
if config.get("use_pure_fp16", False):
cast_model_to_fp16(paddle.static.default_main_program())
out = out.astype('float32')
return out return out
...@@ -119,8 +107,7 @@ def create_loss(out, ...@@ -119,8 +107,7 @@ def create_loss(out,
classes_num=1000, classes_num=1000,
epsilon=None, epsilon=None,
use_mix=False, use_mix=False,
use_distillation=False, use_distillation=False):
use_pure_fp16=False):
""" """
Create a loss for optimization, such as: Create a loss for optimization, such as:
1. CrossEnotry loss 1. CrossEnotry loss
...@@ -137,7 +124,6 @@ def create_loss(out, ...@@ -137,7 +124,6 @@ def create_loss(out,
classes_num(int): num of classes classes_num(int): num of classes
epsilon(float): parameter for label smoothing, 0.0 <= epsilon <= 1.0 epsilon(float): parameter for label smoothing, 0.0 <= epsilon <= 1.0
use_mix(bool): whether to use mix(include mixup, cutmix, fmix) use_mix(bool): whether to use mix(include mixup, cutmix, fmix)
use_pure_fp16(bool): whether to use pure fp16 data as training parameter
Returns: Returns:
loss(variable): loss variable loss(variable): loss variable
...@@ -162,10 +148,10 @@ def create_loss(out, ...@@ -162,10 +148,10 @@ def create_loss(out,
if use_mix: if use_mix:
loss = MixCELoss(class_dim=classes_num, epsilon=epsilon) loss = MixCELoss(class_dim=classes_num, epsilon=epsilon)
return loss(out, feed_y_a, feed_y_b, feed_lam, use_pure_fp16) return loss(out, feed_y_a, feed_y_b, feed_lam)
else: else:
loss = CELoss(class_dim=classes_num, epsilon=epsilon) loss = CELoss(class_dim=classes_num, epsilon=epsilon)
return loss(out, target, use_pure_fp16) return loss(out, target)
def create_metric(out, def create_metric(out,
...@@ -239,9 +225,8 @@ def create_fetchs(out, ...@@ -239,9 +225,8 @@ def create_fetchs(out,
fetchs(dict): dict of model outputs(included loss and measures) fetchs(dict): dict of model outputs(included loss and measures)
""" """
fetchs = OrderedDict() fetchs = OrderedDict()
use_pure_fp16 = config.get("use_pure_fp16", False)
loss = create_loss(out, feeds, architecture, classes_num, epsilon, use_mix, loss = create_loss(out, feeds, architecture, classes_num, epsilon, use_mix,
use_distillation, use_pure_fp16) use_distillation)
fetchs['loss'] = (loss, AverageMeter('loss', '7.4f', need_avg=True)) fetchs['loss'] = (loss, AverageMeter('loss', '7.4f', need_avg=True))
if not use_mix: if not use_mix:
metric = create_metric(out, feeds, architecture, topk, classes_num, metric = create_metric(out, feeds, architecture, topk, classes_num,
...@@ -285,7 +270,7 @@ def create_optimizer(config): ...@@ -285,7 +270,7 @@ def create_optimizer(config):
# create optimizer instance # create optimizer instance
opt_config = config['OPTIMIZER'] opt_config = config['OPTIMIZER']
opt = OptimizerBuilder(config, **opt_config) opt = OptimizerBuilder(**opt_config)
return opt(lr), lr return opt(lr), lr
...@@ -304,11 +289,11 @@ def create_strategy(config): ...@@ -304,11 +289,11 @@ def create_strategy(config):
exec_strategy = paddle.static.ExecutionStrategy() exec_strategy = paddle.static.ExecutionStrategy()
exec_strategy.num_threads = 1 exec_strategy.num_threads = 1
exec_strategy.num_iteration_per_drop_scope = 10000 if config.get( exec_strategy.num_iteration_per_drop_scope = (10000 if 'AMP' in config and
'use_pure_fp16', False) else 10 config.AMP.get("use_pure_fp16", False) else 10)
fuse_op = True if 'AMP' in config else False
fuse_op = config.get('use_amp', False) or config.get('use_pure_fp16',
False)
fuse_bn_act_ops = config.get('fuse_bn_act_ops', fuse_op) fuse_bn_act_ops = config.get('fuse_bn_act_ops', fuse_op)
fuse_elewise_add_act_ops = config.get('fuse_elewise_add_act_ops', fuse_op) fuse_elewise_add_act_ops = config.get('fuse_elewise_add_act_ops', fuse_op)
fuse_bn_add_act_ops = config.get('fuse_bn_add_act_ops', fuse_op) fuse_bn_add_act_ops = config.get('fuse_bn_add_act_ops', fuse_op)
...@@ -369,14 +354,17 @@ def dist_optimizer(config, optimizer): ...@@ -369,14 +354,17 @@ def dist_optimizer(config, optimizer):
def mixed_precision_optimizer(config, optimizer): def mixed_precision_optimizer(config, optimizer):
use_amp = config.get('use_amp', False) if 'AMP' in config:
scale_loss = config.get('scale_loss', 1.0) amp_cfg = config.AMP if config.AMP else dict()
use_dynamic_loss_scaling = config.get('use_dynamic_loss_scaling', False) scale_loss = amp_cfg.get('scale_loss', 1.0)
if use_amp: use_dynamic_loss_scaling = amp_cfg.get('use_dynamic_loss_scaling', False)
optimizer = fluid.contrib.mixed_precision.decorate( use_pure_fp16 = amp_cfg.get('use_pure_fp16', False)
optimizer = paddle.static.amp.decorate(
optimizer, optimizer,
init_loss_scaling=scale_loss, init_loss_scaling=scale_loss,
use_dynamic_loss_scaling=use_dynamic_loss_scaling) use_dynamic_loss_scaling=use_dynamic_loss_scaling,
use_pure_fp16=use_pure_fp16,
use_fp16_guard=True)
return optimizer return optimizer
...@@ -407,15 +395,11 @@ def build(config, main_prog, startup_prog, is_train=True, is_distributed=True): ...@@ -407,15 +395,11 @@ def build(config, main_prog, startup_prog, is_train=True, is_distributed=True):
use_dali = config.get('use_dali', False) use_dali = config.get('use_dali', False)
use_distillation = config.get('use_distillation') use_distillation = config.get('use_distillation')
image_dtype = "float32"
if config["ARCHITECTURE"]["name"] == "ResNet50" and config.get("use_pure_fp16", False) \
and config.get("use_dali", False):
image_dtype = "float16"
feeds = create_feeds( feeds = create_feeds(
config.image_shape, config.image_shape,
use_mix=use_mix, use_mix=use_mix,
use_dali=use_dali, use_dali=use_dali,
dtype=image_dtype) dtype="float32")
if use_dali and use_mix: if use_dali and use_mix:
import dali import dali
feeds = dali.mix(feeds, config, is_train) feeds = dali.mix(feeds, config, is_train)
...@@ -432,13 +416,14 @@ def build(config, main_prog, startup_prog, is_train=True, is_distributed=True): ...@@ -432,13 +416,14 @@ def build(config, main_prog, startup_prog, is_train=True, is_distributed=True):
config=config, config=config,
use_distillation=use_distillation) use_distillation=use_distillation)
lr_scheduler = None lr_scheduler = None
optimizer = None
if is_train: if is_train:
optimizer, lr_scheduler = create_optimizer(config) optimizer, lr_scheduler = create_optimizer(config)
optimizer = mixed_precision_optimizer(config, optimizer) optimizer = mixed_precision_optimizer(config, optimizer)
if is_distributed: if is_distributed:
optimizer = dist_optimizer(config, optimizer) optimizer = dist_optimizer(config, optimizer)
optimizer.minimize(fetchs['loss'][0]) optimizer.minimize(fetchs['loss'][0])
return fetchs, lr_scheduler, feeds return fetchs, lr_scheduler, feeds, optimizer
def compile(config, program, loss_name=None, share_prog=None): def compile(config, program, loss_name=None, share_prog=None):
......
...@@ -26,8 +26,6 @@ sys.path.append(os.path.abspath(os.path.join(__dir__, '../../'))) ...@@ -26,8 +26,6 @@ sys.path.append(os.path.abspath(os.path.join(__dir__, '../../')))
from sys import version_info from sys import version_info
import paddle import paddle
import paddle.fluid as fluid
from paddle.fluid.contrib.mixed_precision.fp16_utils import cast_parameters_to_fp16
from paddle.distributed import fleet from paddle.distributed import fleet
from ppcls.data import Reader from ppcls.data import Reader
...@@ -67,9 +65,7 @@ def main(args): ...@@ -67,9 +65,7 @@ def main(args):
# assign the place # assign the place
use_gpu = config.get("use_gpu", True) use_gpu = config.get("use_gpu", True)
# amp related config # amp related config
use_amp = config.get('use_amp', False) if 'AMP' in config:
use_pure_fp16 = config.get('use_pure_fp16', False)
if use_amp or use_pure_fp16:
AMP_RELATED_FLAGS_SETTING = { AMP_RELATED_FLAGS_SETTING = {
'FLAGS_cudnn_exhaustive_search': 1, 'FLAGS_cudnn_exhaustive_search': 1,
'FLAGS_conv_workspace_size_limit': 4000, 'FLAGS_conv_workspace_size_limit': 4000,
...@@ -97,7 +93,7 @@ def main(args): ...@@ -97,7 +93,7 @@ def main(args):
best_top1_acc = 0.0 # best top1 acc record best_top1_acc = 0.0 # best top1 acc record
train_fetchs, lr_scheduler, train_feeds = program.build( train_fetchs, lr_scheduler, train_feeds, optimizer = program.build(
config, config,
train_prog, train_prog,
startup_prog, startup_prog,
...@@ -106,7 +102,7 @@ def main(args): ...@@ -106,7 +102,7 @@ def main(args):
if config.validate: if config.validate:
valid_prog = paddle.static.Program() valid_prog = paddle.static.Program()
valid_fetchs, _, valid_feeds = program.build( valid_fetchs, _, valid_feeds, _ = program.build(
config, config,
valid_prog, valid_prog,
startup_prog, startup_prog,
...@@ -119,11 +115,14 @@ def main(args): ...@@ -119,11 +115,14 @@ def main(args):
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
# Parameter initialization # Parameter initialization
exe.run(startup_prog) exe.run(startup_prog)
if config.get("use_pure_fp16", False):
cast_parameters_to_fp16(place, train_prog, fluid.global_scope())
# load pretrained models or checkpoints # load pretrained models or checkpoints
init_model(config, train_prog, exe) init_model(config, train_prog, exe)
if 'AMP' in config and config.AMP.get("use_pure_fp16", False):
optimizer.amp_init(place,
scope=paddle.static.global_scope(),
test_program=valid_prog if config.validate else None)
if not config.get("is_distributed", True) and not use_xpu: if not config.get("is_distributed", True) and not use_xpu:
compiled_train_prog = program.compile( compiled_train_prog = program.compile(
config, train_prog, loss_name=train_fetchs["loss"][0].name) config, train_prog, loss_name=train_fetchs["loss"][0].name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册