未验证 提交 c1016c8b 编写于 作者: R ruri 提交者: GitHub

Merge pull request #1705 from shippingwang/dev_sc

refine image classification train script
......@@ -209,6 +209,7 @@ Models are trained by starting with learning rate ```0.1``` and decaying it by `
|[VGG16](https://paddle-imagenet-models-name.bj.bcebos.com/VGG16_pretrained.zip) | 72.08%/90.63% | 71.65%/90.57% |
|[VGG19](https://paddle-imagenet-models-name.bj.bcebos.com/VGG19_pretrained.zip) | 72.56%/90.83% | 72.32%/90.98% |
|[MobileNetV1](http://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV1_pretrained.zip) | 70.91%/89.54% | 70.51%/89.35% |
|[MobileNetV2](https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV2_pretrained.zip) | 71.90%/90.55% | 71.53%/90.41% |
|[ResNet50](http://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_pretrained.zip) | 76.35%/92.80% | 76.22%/92.92% |
|[ResNet101](http://paddle-imagenet-models-name.bj.bcebos.com/ResNet101_pretrained.zip) | 77.49%/93.57% | 77.56%/93.64% |
|[ResNet152](https://paddle-imagenet-models-name.bj.bcebos.com/ResNet152_pretrained.zip) | 78.12%/93.93% | 77.92%/93.87% |
......
......@@ -204,6 +204,7 @@ Models包括两种模型:带有参数名字的模型,和不带有参数名
|[VGG16](https://paddle-imagenet-models-name.bj.bcebos.com/VGG16_pretrained.zip) | 72.08%/90.63% | 71.65%/90.57% |
|[VGG19](https://paddle-imagenet-models-name.bj.bcebos.com/VGG19_pretrained.zip) | 72.56%/90.83% | 72.32%/90.98% |
|[MobileNetV1](http://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV1_pretrained.zip) | 70.91%/89.54% | 70.51%/89.35% |
|[MobileNetV2](https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV2_pretrained.zip) | 71.90%/90.55% | 71.53%/90.41% |
|[ResNet50](http://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_pretrained.zip) | 76.35%/92.80% | 76.22%/92.92% |
|[ResNet101](http://paddle-imagenet-models-name.bj.bcebos.com/ResNet101_pretrained.zip) | 77.49%/93.57% | 77.56%/93.64% |
|[ResNet152](https://paddle-imagenet-models-name.bj.bcebos.com/ResNet152_pretrained.zip) | 78.12%/93.93% | 77.92%/93.87% |
......
......@@ -6,7 +6,7 @@ python train.py \
--class_dim=1000 \
--image_shape=3,224,224 \
--model_save_dir=output/ \
--with_mem_opt=False \
--with_mem_opt=True \
--lr_strategy=piecewise_decay \
--lr=0.1
# >log_SE_ResNeXt50_32x4d.txt 2>&1 &
......@@ -19,7 +19,7 @@ python train.py \
# --class_dim=1000 \
# --image_shape=3,224,224 \
# --model_save_dir=output/ \
# --with_mem_opt=False \
# --with_mem_opt=True \
# --lr_strategy=piecewise_decay \
# --num_epochs=120 \
# --lr=0.01
......@@ -32,7 +32,7 @@ python train.py \
# --class_dim=1000 \
# --image_shape=3,224,224 \
# --model_save_dir=output/ \
# --with_mem_opt=False \
# --with_mem_opt=True \
# --lr_strategy=piecewise_decay \
# --num_epochs=120 \
# --lr=0.1
......@@ -46,12 +46,22 @@ python train.py \
# --class_dim=1000 \
# --image_shape=3,224,224 \
# --model_save_dir=output/ \
# --with_mem_opt=False \
# --with_mem_opt=True \
# --lr_strategy=piecewise_decay \
# --num_epochs=120 \
# --lr=0.1
#python train.py \
# --model=MobileNetV2 \
# --batch_size=500 \
# --total_images=1281167 \
# --class_dim=1000 \
# --image_shape=3,224,224 \
# --model_save_dir=output/ \
# --with_mem_opt=True \
# --lr_strategy=cosine_decay \
# --num_epochs=200 \
# --lr=0.1
#ResNet50:
#python train.py \
# --model=ResNet50 \
......@@ -60,7 +70,7 @@ python train.py \
# --class_dim=1000 \
# --image_shape=3,224,224 \
# --model_save_dir=output/ \
# --with_mem_opt=False \
# --with_mem_opt=True \
# --lr_strategy=piecewise_decay \
# --num_epochs=120 \
# --lr=0.1
......
......@@ -10,7 +10,6 @@ import math
import paddle
import paddle.fluid as fluid
import paddle.dataset.flowers as flowers
import models
import reader
import argparse
import functools
......@@ -19,8 +18,8 @@ import utils
from utils.learning_rate import cosine_decay
from utils.fp16_utils import create_master_params_grads, master_param_to_train_param
from utility import add_arguments, print_arguments
import models
import models_name
IMAGENET1000 = 1281167
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
......@@ -40,25 +39,32 @@ add_arg('lr_strategy', str, "piecewise_decay", "Set the learning rate
add_arg('model', str, "SE_ResNeXt50_32x4d", "Set the network to use.")
add_arg('enable_ce', bool, False, "If set True, enable continuous evaluation job.")
add_arg('data_dir', str, "./data/ILSVRC2012", "The ImageNet dataset root dir.")
add_arg('model_category', str, "models", "Whether to use models_name or not, valid value:'models','models_name'" )
add_arg('model_category', str, "models", "Whether to use models_name or not, valid value:'models','models_name'." )
add_arg('fp16', bool, False, "Enable half precision training with fp16." )
add_arg('scale_loss', float, 1.0, "Scale loss for fp16." )
add_arg('l2_decay', float, 1e-4, "L2_decay parameter.")
add_arg('momentum_rate', float, 0.9, "momentum_rate.")
# yapf: enable
def set_models(model):
def set_models(model_category):
global models
if model == "models":
models = models
assert model_category in ["models", "models_name"
], "{} is not in lists: {}".format(
model_category, ["models", "models_name"])
if model_category == "models_name":
import models_name as models
else:
models = models_name
import models as models
def optimizer_setting(params):
ls = params["learning_strategy"]
l2_decay = params["l2_decay"]
momentum_rate = params["momentum_rate"]
if ls["name"] == "piecewise_decay":
if "total_images" not in params:
total_images = 1281167
total_images = IMAGENET1000
else:
total_images = params["total_images"]
batch_size = ls["batch_size"]
......@@ -71,16 +77,17 @@ def optimizer_setting(params):
optimizer = fluid.optimizer.Momentum(
learning_rate=fluid.layers.piecewise_decay(
boundaries=bd, values=lr),
momentum=0.9,
regularization=fluid.regularizer.L2Decay(1e-4))
momentum=momentum_rate,
regularization=fluid.regularizer.L2Decay(l2_decay))
elif ls["name"] == "cosine_decay":
if "total_images" not in params:
total_images = 1281167
total_images = IMAGENET1000
else:
total_images = params["total_images"]
batch_size = ls["batch_size"]
l2_decay = params["l2_decay"]
momentum_rate = params["momentum_rate"]
step = int(total_images / batch_size + 1)
lr = params["lr"]
......@@ -89,43 +96,42 @@ def optimizer_setting(params):
optimizer = fluid.optimizer.Momentum(
learning_rate=cosine_decay(
learning_rate=lr, step_each_epoch=step, epochs=num_epochs),
momentum=0.9,
regularization=fluid.regularizer.L2Decay(4e-5))
elif ls["name"] == "exponential_decay":
momentum=momentum_rate,
regularization=fluid.regularizer.L2Decay(l2_decay))
elif ls["name"] == "linear_decay":
if "total_images" not in params:
total_images = 1281167
total_images = IMAGENET1000
else:
total_images = params["total_images"]
batch_size = ls["batch_size"]
step = int(total_images / batch_size +1)
lr = params["lr"]
num_epochs = params["num_epochs"]
learning_decay_rate_factor=ls["learning_decay_rate_factor"]
num_epochs_per_decay = ls["num_epochs_per_decay"]
NUM_GPUS = 1
start_lr = params["lr"]
l2_decay = params["l2_decay"]
momentum_rate = params["momentum_rate"]
end_lr = 0
total_step = int((total_images / batch_size) * num_epochs)
lr = fluid.layers.polynomial_decay(
start_lr, total_step, end_lr, power=1)
optimizer = fluid.optimizer.Momentum(
learning_rate=fluid.layers.exponential_decay(
learning_rate = lr * NUM_GPUS,
decay_steps = step * num_epochs_per_decay / NUM_GPUS,
decay_rate = learning_decay_rate_factor),
momentum=0.9,
regularization = fluid.regularizer.L2Decay(4e-5))
learning_rate=lr,
momentum=momentum_rate,
regularization=fluid.regularizer.L2Decay(l2_decay))
else:
lr = params["lr"]
l2_decay = params["l2_decay"]
momentum_rate = params["momentum_rate"]
optimizer = fluid.optimizer.Momentum(
learning_rate=lr,
momentum=0.9,
regularization=fluid.regularizer.L2Decay(1e-4))
momentum=momentum_rate,
regularization=fluid.regularizer.L2Decay(l2_decay))
return optimizer
def net_config(image, label, model, args):
model_list = [m for m in dir(models) if "__" not in m]
assert args.model in model_list,"{} is not lists: {}".format(
args.model, model_list)
assert args.model in model_list, "{} is not lists: {}".format(args.model,
model_list)
class_dim = args.class_dim
model_name = args.model
......@@ -148,8 +154,9 @@ def net_config(image, label, model, args):
acc_top1 = fluid.layers.accuracy(input=out0, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=out0, label=label, k=5)
else:
out = model.net(input=image, class_dim=class_dim)
cost, pred = fluid.layers.softmax_with_cross_entropy(out, label, return_softmax=True)
out = model.net(input=image, class_dim=class_dim)
cost, pred = fluid.layers.softmax_with_cross_entropy(
out, label, return_softmax=True)
if args.scale_loss > 1:
avg_cost = fluid.layers.mean(x=cost) * float(args.scale_loss)
else:
......@@ -190,19 +197,25 @@ def build_program(is_train, main_prog, startup_prog, args):
params["num_epochs"] = args.num_epochs
params["learning_strategy"]["batch_size"] = args.batch_size
params["learning_strategy"]["name"] = args.lr_strategy
params["l2_decay"] = args.l2_decay
params["momentum_rate"] = args.momentum_rate
optimizer = optimizer_setting(params)
if args.fp16:
params_grads = optimizer.backward(avg_cost)
master_params_grads = create_master_params_grads(
params_grads, main_prog, startup_prog, args.scale_loss)
optimizer.apply_gradients(master_params_grads)
master_param_to_train_param(master_params_grads, params_grads, main_prog)
master_param_to_train_param(master_params_grads,
params_grads, main_prog)
else:
optimizer.minimize(avg_cost)
global_lr = optimizer._global_learning_rate()
return py_reader, avg_cost, acc_top1, acc_top5
if is_train:
return py_reader, avg_cost, acc_top1, acc_top5, global_lr
else:
return py_reader, avg_cost, acc_top1, acc_top5
def train(args):
......@@ -220,7 +233,7 @@ def train(args):
startup_prog.random_seed = 1000
train_prog.random_seed = 1000
train_py_reader, train_cost, train_acc1, train_acc5 = build_program(
train_py_reader, train_cost, train_acc1, train_acc5, global_lr = build_program(
is_train=True,
main_prog=train_prog,
startup_prog=startup_prog,
......@@ -255,7 +268,8 @@ def train(args):
if visible_device:
device_num = len(visible_device.split(','))
else:
device_num = subprocess.check_output(['nvidia-smi', '-L']).decode().count('\n')
device_num = subprocess.check_output(
['nvidia-smi', '-L']).decode().count('\n')
train_batch_size = args.batch_size / device_num
test_batch_size = 16
......@@ -283,11 +297,12 @@ def train(args):
use_cuda=bool(args.use_gpu),
loss_name=train_cost.name)
train_fetch_list = [train_cost.name, train_acc1.name, train_acc5.name]
train_fetch_list = [
train_cost.name, train_acc1.name, train_acc5.name, global_lr.name
]
test_fetch_list = [test_cost.name, test_acc1.name, test_acc5.name]
params = models.__dict__[args.model]().params
for pass_id in range(params["num_epochs"]):
train_py_reader.start()
......@@ -299,7 +314,9 @@ def train(args):
try:
while True:
t1 = time.time()
loss, acc1, acc5 = train_exe.run(fetch_list=train_fetch_list)
loss, acc1, acc5, lr = train_exe.run(
fetch_list=train_fetch_list)
t2 = time.time()
period = t2 - t1
loss = np.mean(np.array(loss))
......@@ -308,21 +325,27 @@ def train(args):
train_info[0].append(loss)
train_info[1].append(acc1)
train_info[2].append(acc5)
lr = np.mean(np.array(lr))
train_time.append(period)
if batch_id % 10 == 0:
print("Pass {0}, trainbatch {1}, loss {2}, \
acc1 {3}, acc5 {4} time {5}"
.format(pass_id, batch_id, loss, acc1, acc5,
"%2.2f sec" % period))
acc1 {3}, acc5 {4}, lr{5}, time {6}"
.format(pass_id, batch_id, loss, acc1, acc5, "%.5f" %
lr, "%2.2f sec" % period))
sys.stdout.flush()
batch_id += 1
if batch_id == 31:
exit(0)
except fluid.core.EOFException:
train_py_reader.reset()
train_loss = np.array(train_info[0]).mean()
train_acc1 = np.array(train_info[1]).mean()
train_acc5 = np.array(train_info[2]).mean()
train_speed = np.array(train_time).mean() / (train_batch_size * device_num)
train_speed = np.array(train_time).mean() / (train_batch_size *
device_num)
test_py_reader.start()
......@@ -394,10 +417,7 @@ def train(args):
def main():
args = parser.parse_args()
models_now = args.model_category
assert models_now in ["models", "models_name"], "{} is not in lists: {}".format(
models_now, ["models", "models_name"])
set_models(models_now)
set_models(args.model_category)
print_arguments(args)
train(args)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册