提交 832f1c74 编写于 作者: Z Zhen Wang

add pure fp16 training.

上级 08f3c0be
......@@ -14,7 +14,7 @@
import paddle
import paddle.fluid as fluid
import utils.utility as utility
from paddle.fluid.contrib.mixed_precision.fp16_utils import cast_net_to_fp16
def _calc_label_smoothing_loss(softmax_out, label, class_dim, epsilon):
"""Calculate label smoothing loss
......@@ -44,7 +44,12 @@ def _basic_model(data, model, args, is_train):
data_format=args.data_format)
else:
net_out = model.net(input=image, class_dim=args.class_dim)
softmax_out = fluid.layers.softmax(net_out, use_cudnn=False)
if args.use_pure_fp16:
cast_net_to_fp16(fluid.default_main_program())
net_out_fp32 = fluid.layers.cast(x=net_out, dtype="float32")
softmax_out = fluid.layers.softmax(net_out_fp32, use_cudnn=False)
else:
softmax_out = fluid.layers.softmax(net_out, use_cudnn=False)
if is_train and args.use_label_smoothing:
cost = _calc_label_smoothing_loss(softmax_out, label, args.class_dim,
......@@ -104,7 +109,12 @@ def _mixup_model(data, model, args, is_train):
data_format=args.data_format)
else:
net_out = model.net(input=image, class_dim=args.class_dim)
softmax_out = fluid.layers.softmax(net_out, use_cudnn=False)
if args.use_pure_fp16:
cast_net_to_fp16(fluid.default_main_program())
net_out_fp32 = fluid.layers.cast(x=net_out, dtype="float32")
softmax_out = fluid.layers.softmax(net_out_fp32, use_cudnn=False)
else:
softmax_out = fluid.layers.softmax(net_out, use_cudnn=False)
if not args.use_label_smoothing:
loss_a = fluid.layers.cross_entropy(input=softmax_out, label=y_a)
loss_b = fluid.layers.cross_entropy(input=softmax_out, label=y_b)
......
......@@ -43,7 +43,8 @@ class HybridTrainPipe(Pipeline):
num_shards=1,
random_shuffle=True,
num_threads=4,
seed=42):
seed=42,
output_dtype=types.FLOAT):
super(HybridTrainPipe, self).__init__(
batch_size, num_threads, device_id, seed=seed)
self.input = ops.FileReader(
......@@ -68,7 +69,7 @@ class HybridTrainPipe(Pipeline):
device='gpu', resize_x=crop, resize_y=crop, interp_type=interp)
self.cmnp = ops.CropMirrorNormalize(
device="gpu",
output_dtype=types.FLOAT,
output_dtype=output_dtype,
output_layout=types.NCHW,
crop=(crop, crop),
image_type=types.RGB,
......@@ -104,7 +105,8 @@ class HybridValPipe(Pipeline):
num_shards=1,
random_shuffle=False,
num_threads=4,
seed=42):
seed=42,
output_dtype=types.FLOAT):
super(HybridValPipe, self).__init__(
batch_size, num_threads, device_id, seed=seed)
self.input = ops.FileReader(
......@@ -118,7 +120,7 @@ class HybridValPipe(Pipeline):
device="gpu", resize_shorter=resize_shorter, interp_type=interp)
self.cmnp = ops.CropMirrorNormalize(
device="gpu",
output_dtype=types.FLOAT,
output_dtype=output_dtype,
output_layout=types.NCHW,
crop=(crop, crop),
image_type=types.RGB,
......@@ -159,6 +161,7 @@ def build(settings, mode='train'):
min_area = settings.lower_scale
lower = settings.lower_ratio
upper = settings.upper_ratio
output_dtype = types.FLOAT16 if settings.use_pure_fp16 else types.FLOAT
interp = settings.interpolation or 1 # default to linear
interp_map = {
......@@ -188,7 +191,8 @@ def build(settings, mode='train'):
interp,
mean,
std,
device_id=device_id)
device_id=device_id,
output_dtype=output_dtype)
pipe.build()
return DALIGenericIterator(
pipe, ['feed_image', 'feed_label'],
......@@ -221,7 +225,8 @@ def build(settings, mode='train'):
device_id,
shard_id,
num_shards,
seed=42 + shard_id)
seed=42 + shard_id,
output_dtype=output_dtype)
pipe.build()
pipelines = [pipe]
sample_per_shard = len(pipe) // num_shards
......@@ -248,7 +253,8 @@ def build(settings, mode='train'):
device_id,
idx,
num_shards,
seed=42 + idx)
seed=42 + idx,
output_dtype=output_dtype)
pipe.build()
pipelines.append(pipe)
sample_per_shard = len(pipelines[0])
......
......@@ -7,7 +7,8 @@ export FLAGS_cudnn_batchnorm_spatial_persistent=1
DATA_DIR="Your image dataset path, e.g. /work/datasets/ILSVRC2012/"
DATA_FORMAT="NHWC"
USE_FP16=true #whether to use float16
USE_AMP=true #whether to use amp
USE_PURE_FP16=false
USE_DALI=true
if ${USE_DALI}; then
......@@ -24,7 +25,8 @@ python train.py \
--print_step=10 \
--model_save_dir=output/ \
--lr_strategy=piecewise_decay \
--use_fp16=${USE_FP16} \
--use_amp=${USE_AMP} \
--use_pure_fp16=${USE_PURE_FP16} \
--scale_loss=128.0 \
--use_dynamic_loss_scaling=true \
--data_format=${DATA_FORMAT} \
......
......@@ -29,6 +29,7 @@ import reader
from utils import *
import models
from build_model import create_model
from paddle.fluid.contrib.mixed_precision.fp16_utils import cast_parameters_to_fp16
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
......@@ -72,7 +73,7 @@ def build_program(is_train, main_prog, startup_prog, args):
global_lr.persistable = True
loss_out.append(global_lr)
if args.use_fp16:
if args.use_amp:
optimizer = fluid.contrib.mixed_precision.decorate(
optimizer,
init_loss_scaling=args.scale_loss,
......@@ -192,6 +193,8 @@ def train(args):
place = fluid.CUDAPlace(gpu_id) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup_prog)
if args.use_pure_fp16:
cast_parameters_to_fp16(exe, train_prog)
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", 0))
......
......@@ -139,7 +139,8 @@ def parse_args():
# SWITCH
add_arg('validate', bool, True, "whether to validate when training.")
add_arg('use_fp16', bool, False, "Whether to enable half precision training with fp16." )
add_arg('use_amp', bool, False, "Whether to enable mixed precision training with fp16." )
add_arg('use_pure_fp16', bool, False, "Whether to enable all half precision training with fp16." )
add_arg('scale_loss', float, 1.0, "The value of scale_loss for fp16." )
add_arg('use_dynamic_loss_scaling', bool, True, "Whether to use dynamic loss scaling.")
add_arg('data_format', str, "NCHW", "Tensor data format when training.")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册