提交 d251bf70 编写于 作者: Z zhiqiu

enable addto with amp

上级 a00c8afb
......@@ -3,8 +3,9 @@
export FLAGS_conv_workspace_size_limit=4000 #MB
export FLAGS_cudnn_exhaustive_search=1
export FLAGS_cudnn_batchnorm_spatial_persistent=1
export FLAGS_max_inplace_grad_add=8
DATA_DIR="Your image dataset path, e.g. /work/datasets/ILSVRC2012/"
DATA_DIR="./data/ILSVRC2012/"
DATA_FORMAT="NHWC"
USE_FP16=true #whether to use float16
......@@ -17,7 +18,7 @@ fi
python train.py \
--model=ResNet50 \
--data_dir=${DATA_DIR} \
--batch_size=256 \
--batch_size=128 \
--total_images=1281167 \
--image_shape 3 224 224 \
--class_dim=1000 \
......@@ -36,5 +37,7 @@ python train.py \
--reader_thread=10 \
--reader_buf_size=4000 \
--use_dali=${USE_DALI} \
--fuse_all_optimizer_ops=true \
--enable_addto=true \
--lr=0.1
......@@ -244,8 +244,16 @@ def train(args):
t1 = time.time()
for batch in train_iter:
#NOTE: this is for benchmark
if args.max_iter and total_batch_num == args.max_iter:
if total_batch_num == 200:
#if args.max_iter and total_batch_num == args.max_iter:
print("=" *20)
print("total_batch_num: ", total_batch_num, "records_num: ", len(train_batch_time_record))
avg_times = sum(train_batch_time_record[-150:]) / 150
avg_speed = args.batch_size / avg_times
print("average time: %.5f s/batch, average speed: %.5f imgs/s" % (avg_times, avg_speed))
return
#if args.max_iter and total_batch_num == args.max_iter:
# return
train_batch_metrics = exe.run(compiled_train_prog,
feed=batch,
fetch_list=train_fetch_list)
......
......@@ -162,6 +162,8 @@ def parse_args():
add_arg('profiler_path', str, './profilier_files', "the profiler output file path")
add_arg('max_iter', int, 0, "the max train batch num")
add_arg('same_feed', int, 0, "whether to feed same images")
add_arg('enable_addto', bool, False, "whether to enable the addto strategy for gradient accumulation")
add_arg('fuse_all_optimizer_ops', bool, False, "whether to fuse all optimizer operators")
# yapf: enable
......@@ -524,6 +526,8 @@ def best_strategy_compiled(args,
try:
fluid.require_version(min_version='1.7.0')
build_strategy.fuse_bn_act_ops = args.fuse_bn_act_ops
build_strategy.fuse_all_optimizer_ops = args.fuse_all_optimizer_ops
build_strategy.enable_addto = args.enable_addto
except Exception as e:
logger.info("PaddlePaddle version 1.7.0 or higher is "
"required when you want to fuse batch_norm and activation_op.")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册