提交 6a3e9145 编写于 作者: W wuzewu

add image classfication demo

上级 90252883
#!/bin/bash
set -o nounset
set -o errexit
script_path=$(cd `dirname $0`; pwd)
cd $script_path
model_name="ResNet50"
hub_module_save_dir="./hub_module"
while getopts "m:d:" options
do
case "$options" in
d)
hub_module_save_dir=$OPTARG;;
m)
model_name=$OPTARG;;
?)
echo "unknown options"
exit 1;;
esac
done
sh pretraind_models/download_model.sh ${model_name}
python train.py --create_module=True --pretrained_model=pretraind_models/${model_name} --model ${model_name} --use_gpu=False
#!/bin/bash
set -o nounset
set -o errexit
script_path=$(cd `dirname $0`; pwd)
cd $script_path
hub_module_path=hub_module_ResNet50
data_dir=dataset
batch_size=32
use_gpu=False
num_epochs=20
class_dim=2
learning_rate=0.001
model_save_dir=model_save/`date +%Y%m%d%H%M%S`
while getopts "b:c:d:gh:l:n:" options
do
case "$options" in
b)
batch_size=$OPTARG;;
c)
class_dim=$OPTARG;;
d)
data_dir=$OPTARG;;
g)
use_gpu=True;;
h)
hub_module_path=$OPTARG;;
l)
learning_rate=$OPTARG;;
n)
num_epochs=$OPTARG;;
s)
model_save_dir=$OPTARG;;
?)
echo "unknown options"
exit 1;;
esac
done
mkdir -p ${model_save_dir}
python retrain.py --batch_size=${batch_size} --class_dim=${class_dim} --data_dir=${data_dir} --use_gpu=${use_gpu} --hub_module_path ${hub_module_path} --lr ${learning_rate} --num_epochs=${num_epochs} --model_save_dir=${model_save_dir}
# nohup python retrain.py --batch_size=${batch_size} --class_dim=${class_dim} --data_dir=${data_dir} --use_gpu=${use_gpu} --hub_module_path ${hub_module_path} --lr ${learning_rate} --num_epochs=${num_epochs} --model_save_dir=${model_save_dir} > ${model_save_dir}/train.log 2>&1 &
#-*- coding:utf8 -*-
import paddle
import paddle.fluid as fluid
import paddle_hub as hub
import paddle_hub.module as module
import paddle_hub.logger as log
import sys
import numpy as np
import reader
import argparse
import functools
from visualdl import LogWriter
from utility import add_arguments, print_arguments
reader = paddle.batch(reader.test("dataset"), batch_size=1)
def infer():
model = module.Module(module_dir="hub_module_ResNet50")
feed_list, fetch_list, program = model(
sign_name="feature_map", trainable=True)
with fluid.program_guard(main_program=program):
img = feed_list[0]
feature_map = fetch_list[0]
fc = fluid.layers.fc(input=feature_map, size=2, act="softmax")
place = fluid.CPUPlace()
exe = fluid.Executor(place)
feeder = fluid.DataFeeder(feed_list=[img], place=place)
exe.run(fluid.default_startup_program())
for batch in reader():
print(batch[0][0].shape)
eval_val = exe.run(fetch_list=[fc.name], feed=feeder.feed(batch))
log.logger.info(eval_val)
input()
infer()
from .mobilenet_v2 import MobileNetV2
from .resnet import ResNet50, ResNet101, ResNet152
__all__ = ["MobileNetV2", "ResNet50", "ResNet101", "ResNet152"]
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
from paddle.fluid.initializer import MSRA
from paddle.fluid.param_attr import ParamAttr
__all__ = ["MobileNetV2"]
train_parameters = {
"input_size": [3, 224, 224],
"input_mean": [0.485, 0.456, 0.406],
"input_std": [0.229, 0.224, 0.225],
"learning_strategy": {
"name": "piecewise_decay",
"batch_size": 256,
"epochs": [30, 60, 90],
"steps": [0.1, 0.01, 0.001, 0.0001]
}
}
class MobileNetV2():
def __init__(self):
self.params = train_parameters
def net(self, input, class_dim=1000, scale=1.0):
bottleneck_params_list = [
(1, 16, 1, 1),
(6, 24, 2, 2),
(6, 32, 3, 2),
(6, 64, 4, 2),
(6, 96, 3, 1),
(6, 160, 3, 2),
(6, 320, 1, 1),
]
input = self.conv_bn_layer(
input,
num_filters=int(32 * scale),
filter_size=3,
stride=2,
padding=1,
if_act=True)
in_c = int(32 * scale)
for layer_setting in bottleneck_params_list:
t, c, n, s = layer_setting
input = self.invresi_blocks(
input=input,
in_c=in_c,
t=t,
c=int(c * scale),
n=n,
s=s,
)
in_c = int(c * scale)
input = self.conv_bn_layer(
input=input,
num_filters=int(1280 * scale) if scale > 1.0 else 1280,
filter_size=1,
stride=1,
padding=0,
if_act=True)
input = fluid.layers.pool2d(
input=input,
pool_size=7,
pool_stride=1,
pool_type='avg',
global_pooling=True)
output = fluid.layers.fc(
input=input,
size=class_dim,
param_attr=ParamAttr(initializer=MSRA()))
return output, input
def conv_bn_layer(self,
input,
filter_size,
num_filters,
stride,
padding,
channels=None,
num_groups=1,
use_cudnn=True,
if_act=True):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=padding,
groups=num_groups,
act=None,
use_cudnn=use_cudnn,
param_attr=ParamAttr(initializer=MSRA()),
bias_attr=False)
bn = fluid.layers.batch_norm(input=conv)
if if_act:
return fluid.layers.relu6(bn)
else:
return bn
def shortcut(self, input, data_residual):
return fluid.layers.elementwise_add(input, data_residual)
def inverted_residual_unit(self, input, num_in_filter, num_filters,
ifshortcut, stride, filter_size, padding,
expansion_factor):
num_expfilter = int(round(num_in_filter * expansion_factor))
channel_expand = self.conv_bn_layer(
input=input,
num_filters=num_expfilter,
filter_size=1,
stride=1,
padding=0,
num_groups=1,
if_act=True)
bottleneck_conv = self.conv_bn_layer(
input=channel_expand,
num_filters=num_expfilter,
filter_size=filter_size,
stride=stride,
padding=padding,
num_groups=num_expfilter,
if_act=True,
use_cudnn=False)
linear_out = self.conv_bn_layer(
input=bottleneck_conv,
num_filters=num_filters,
filter_size=1,
stride=1,
padding=0,
num_groups=1,
if_act=False)
if ifshortcut:
out = self.shortcut(input=input, data_residual=linear_out)
return out
else:
return linear_out
def invresi_blocks(self, input, in_c, t, c, n, s):
first_block = self.inverted_residual_unit(
input=input,
num_in_filter=in_c,
num_filters=c,
ifshortcut=False,
stride=s,
filter_size=3,
padding=1,
expansion_factor=t)
last_residual_block = first_block
last_c = c
for i in range(1, n):
last_residual_block = self.inverted_residual_unit(
input=last_residual_block,
num_in_filter=last_c,
num_filters=c,
ifshortcut=True,
stride=1,
filter_size=3,
padding=1,
expansion_factor=t)
return last_residual_block
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.fluid as fluid
import math
__all__ = ["ResNet", "ResNet50", "ResNet101", "ResNet152"]
train_parameters = {
"input_size": [3, 224, 224],
"input_mean": [0.485, 0.456, 0.406],
"input_std": [0.229, 0.224, 0.225],
"learning_strategy": {
"name": "piecewise_decay",
"batch_size": 256,
"epochs": [30, 60, 90],
"steps": [0.1, 0.01, 0.001, 0.0001]
}
}
class ResNet():
def __init__(self, layers=50):
self.params = train_parameters
self.layers = layers
def net(self, input, class_dim=1000):
layers = self.layers
supported_layers = [50, 101, 152]
assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(supported_layers, layers)
if layers == 50:
depth = [3, 4, 6, 3]
elif layers == 101:
depth = [3, 4, 23, 3]
elif layers == 152:
depth = [3, 8, 36, 3]
num_filters = [64, 128, 256, 512]
conv = self.conv_bn_layer(
input=input, num_filters=64, filter_size=7, stride=2, act='relu')
conv = fluid.layers.pool2d(
input=conv,
pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type='max')
for block in range(len(depth)):
for i in range(depth[block]):
conv = self.bottleneck_block(
input=conv,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1)
pool = fluid.layers.pool2d(
input=conv, pool_size=7, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
out = fluid.layers.fc(
input=pool,
size=class_dim,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv)))
return out, pool
def conv_bn_layer(self,
input,
num_filters,
filter_size,
stride=1,
groups=1,
act=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
bias_attr=False)
return fluid.layers.batch_norm(input=conv, act=act)
def shortcut(self, input, ch_out, stride):
ch_in = input.shape[1]
if ch_in != ch_out or stride != 1:
return self.conv_bn_layer(input, ch_out, 1, stride)
else:
return input
def bottleneck_block(self, input, num_filters, stride):
conv0 = self.conv_bn_layer(
input=input, num_filters=num_filters, filter_size=1, act='relu')
conv1 = self.conv_bn_layer(
input=conv0,
num_filters=num_filters,
filter_size=3,
stride=stride,
act='relu')
conv2 = self.conv_bn_layer(
input=conv1, num_filters=num_filters * 4, filter_size=1, act=None)
short = self.shortcut(input, num_filters * 4, stride)
return fluid.layers.elementwise_add(x=short, y=conv2, act='relu')
def ResNet50():
model = ResNet(layers=50)
return model
def ResNet101():
model = ResNet(layers=101)
return model
def ResNet152():
model = ResNet(layers=152)
return model
因为 它太大了无法显示 source diff 。你可以改为 查看blob
#!/bin/bash
set -o nounset
set -o errexit
script_path=$(cd `dirname $0`; pwd)
if [ $# -ne 1 ]
then
echo "usage: sh $0 {PRETRAINED_MODEL_NAME}"
exit 1
fi
if [ $1 != "ResNet50" -a $1 != "ResNet101" -a $1 != "ResNet152" -a $1 != "MobileNetV2" ]
then
echo "only suppory pretrained model in {ResNet50, ResNet101, ResNet152, MobileNetV2}"
exit 1
fi
model_name=${1}_pretrained
model=${model_name}.zip
cd ${script_path}
if [ -d ${model_name} ]
then
echo "model file ${model_name} is already existed"
exit 0
fi
if [ ! -f ${model} ]
then
wget http://paddle-imagenet-models-name.bj.bcebos.com/${model}
fi
unzip ${model}
# rm ${model}
rm -rf __MACOSX
import os
import math
import random
import functools
import numpy as np
import paddle
from PIL import Image, ImageEnhance
random.seed(0)
np.random.seed(0)
DATA_DIM = 224
THREAD = 8
BUF_SIZE = 102400
DATA_DIR = 'dataset'
img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
def resize_short(img, target_size):
percent = float(target_size) / min(img.size[0], img.size[1])
resized_width = int(round(img.size[0] * percent))
resized_height = int(round(img.size[1] * percent))
img = img.resize((resized_width, resized_height), Image.LANCZOS)
return img
def crop_image(img, target_size, center):
width, height = img.size
size = target_size
if center == True:
w_start = (width - size) / 2
h_start = (height - size) / 2
else:
w_start = np.random.randint(0, width - size + 1)
h_start = np.random.randint(0, height - size + 1)
w_end = w_start + size
h_end = h_start + size
img = img.crop((w_start, h_start, w_end, h_end))
return img
def random_crop(img, size, scale=[0.08, 1.0], ratio=[3. / 4., 4. / 3.]):
aspect_ratio = math.sqrt(np.random.uniform(*ratio))
w = 1. * aspect_ratio
h = 1. / aspect_ratio
bound = min((float(img.size[0]) / img.size[1]) / (w**2),
(float(img.size[1]) / img.size[0]) / (h**2))
scale_max = min(scale[1], bound)
scale_min = min(scale[0], bound)
target_area = img.size[0] * img.size[1] * np.random.uniform(
scale_min, scale_max)
target_size = math.sqrt(target_area)
w = int(target_size * w)
h = int(target_size * h)
i = np.random.randint(0, img.size[0] - w + 1)
j = np.random.randint(0, img.size[1] - h + 1)
img = img.crop((i, j, i + w, j + h))
img = img.resize((size, size), Image.LANCZOS)
return img
def rotate_image(img):
angle = np.random.randint(-10, 11)
img = img.rotate(angle)
return img
def distort_color(img):
def random_brightness(img, lower=0.5, upper=1.5):
e = np.random.uniform(lower, upper)
return ImageEnhance.Brightness(img).enhance(e)
def random_contrast(img, lower=0.5, upper=1.5):
e = np.random.uniform(lower, upper)
return ImageEnhance.Contrast(img).enhance(e)
def random_color(img, lower=0.5, upper=1.5):
e = np.random.uniform(lower, upper)
return ImageEnhance.Color(img).enhance(e)
ops = [random_brightness, random_contrast, random_color]
np.random.shuffle(ops)
img = ops[0](img)
img = ops[1](img)
img = ops[2](img)
return img
def process_image(sample, mode, color_jitter, rotate):
img_path = sample[0]
img = Image.open(img_path)
if mode == 'train':
if rotate: img = rotate_image(img)
img = random_crop(img, DATA_DIM)
else:
img = resize_short(img, target_size=256)
img = crop_image(img, target_size=DATA_DIM, center=True)
if mode == 'train':
if color_jitter:
img = distort_color(img)
if np.random.randint(0, 2) == 1:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
if img.mode != 'RGB':
img = img.convert('RGB')
img = np.array(img).astype('float32').transpose((2, 0, 1)) / 255
img -= img_mean
img /= img_std
if mode == 'train' or mode == 'val':
return img, sample[1]
elif mode == 'test':
return [img]
def _reader_creator(file_list,
mode,
shuffle=False,
color_jitter=False,
rotate=False,
data_dir=DATA_DIR):
def reader():
with open(file_list) as flist:
full_lines = [line.strip() for line in flist]
if shuffle:
np.random.shuffle(full_lines)
if mode == 'train' and os.getenv('PADDLE_TRAINING_ROLE'):
# distributed mode if the env var `PADDLE_TRAINING_ROLE` exits
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
trainer_count = int(os.getenv("PADDLE_TRAINERS", "1"))
per_node_lines = len(full_lines) // trainer_count
lines = full_lines[trainer_id * per_node_lines:
(trainer_id + 1) * per_node_lines]
print(
"read images from %d, length: %d, lines length: %d, total: %d"
% (trainer_id * per_node_lines, per_node_lines, len(lines),
len(full_lines)))
else:
lines = full_lines
for line in lines:
if mode == 'train' or mode == 'val':
img_path, label = line.split()
# img_path = img_path.replace("JPEG", "jpeg")
img_path = os.path.join(data_dir, img_path)
yield img_path, int(label)
elif mode == 'test':
img_path = os.path.join(data_dir, line)
yield [img_path]
mapper = functools.partial(
process_image, mode=mode, color_jitter=color_jitter, rotate=rotate)
return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE)
def train(data_dir=DATA_DIR):
file_list = os.path.join(data_dir, 'train_list.txt')
return _reader_creator(
file_list,
'train',
shuffle=True,
color_jitter=False,
rotate=False,
data_dir=data_dir + "/train")
def val(data_dir=DATA_DIR):
file_list = os.path.join(data_dir, 'val_list.txt')
return _reader_creator(
file_list, 'val', shuffle=False, data_dir=data_dir + "/val")
def test(data_dir=DATA_DIR):
file_list = os.path.join(data_dir, 'val_list.txt')
return _reader_creator(file_list, 'test', shuffle=False, data_dir=data_dir)
#-*- coding:utf8 -*-
import paddle
import paddle.fluid as fluid
import paddle_hub as hub
import paddle_hub.module as module
import sys
import os
import reader
import argparse
import functools
from visualdl import LogWriter
from utility import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('hub_module_path', str, "hub_module_ResNet50", "the hub module path" )
add_arg('batch_size', int, 32, "Minibatch size.")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('num_epochs', int, 20, "number of epochs.")
add_arg('class_dim', int, 2, "Class number.")
add_arg('image_shape', str, "3,224,224", "input image size")
add_arg('lr', float, 0.1, "set learning rate.")
add_arg('data_dir', str, "./dataset", "The ImageNet dataset root dir.")
add_arg('model_save_dir', str, "./model_save", "model save dir")
# yapf: enable
def retrain(modelpath):
module = hub.Module(module_dir=args.hub_module_path)
feed_list, fetch_list, program = module.context(
sign_name="feature_map", trainable=True)
# get the dog cat dataset
train_reader = paddle.batch(reader.train(args.data_dir), batch_size=32)
val_reader = paddle.batch(reader.val(args.data_dir), batch_size=32)
logger = LogWriter("vdl_log", sync_cycle=5)
with logger.mode("train") as logw:
train_acc_scalar = logw.scalar("acc")
train_cost_scalar = logw.scalar("cost")
with logger.mode("val") as logw:
val_acc_scalar = logw.scalar("acc")
val_cost_scalar = logw.scalar("cost")
with fluid.program_guard(main_program=program):
img = feed_list[0]
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
feature_map = fetch_list[0]
fc = fluid.layers.fc(input=feature_map, size=2, act="softmax")
cost = fluid.layers.cross_entropy(input=fc, label=label)
avg_cost = fluid.layers.mean(cost)
acc = fluid.layers.accuracy(input=fc, label=label)
inference_program = fluid.default_main_program().clone(for_test=True)
optimizer = fluid.optimizer.Adam(learning_rate=0.001)
optimizer.minimize(avg_cost)
# running on gpu
place = fluid.CUDAPlace(0)
feeder = fluid.DataFeeder(feed_list=[img, label], place=place)
exe = fluid.Executor(place)
train_exe = fluid.ParallelExecutor(
use_cuda=True,
loss_name=avg_cost.name,
main_program=fluid.default_main_program())
# init all param
exe.run(fluid.default_startup_program())
step = 0
sample_num = 0
epochs = 50
# start to train
for i in range(epochs):
train_size = 0
train_acc = 0
train_cost = 0
for batch in train_reader():
cost, accuracy = train_exe.run(
feed=feeder.feed(batch),
fetch_list=[avg_cost.name, acc.name])
step += 1
#####################
train_size += 1
train_acc += len(batch) * accuracy
train_cost += cost
#####################
print(
"epoch %d and step %d: train cost is %.2f, train acc is %.2f%%"
% (i, step, cost, accuracy * 100))
train_acc = 100 * train_acc / (train_size * 32)
print("epoch %d: train acc is %.2f%%" % (i, train_acc))
#####################
train_acc_scalar.add_record(i, train_acc)
train_cost_scalar.add_record(i, train_cost / train_size)
#####################
val_size = 0
val_acc = 0
val_cost = 0
with fluid.program_guard(inference_program):
for iter, batch in enumerate(val_reader()):
cost, accuracy = train_exe.run(
feed=feeder.feed(batch),
fetch_list=[avg_cost.name, acc.name])
val_size += 1
val_acc += len(batch) * accuracy
val_cost += cost
print("batch %d: val cost is %.2f, val acc is %.2f%%" %
(iter, cost, accuracy * 100))
val_acc = 100 * val_acc / (val_size * 32)
print("epoch %d: val acc is %.2f%%" % (i, val_acc))
val_acc_scalar.add_record(i, val_acc)
val_cost_scalar.add_record(i, val_cost / val_size)
fluid.io.save_inference_model(
dirname=os.path.join(args.model_save_dir, "iter%d" % i),
feeded_var_names=[img.name],
target_vars=[fc],
executor=exe)
if __name__ == "__main__":
args = parser.parse_args()
print_arguments(args)
retrain(sys.argv[1])
#-*- coding:utf8 -*-
import paddle
import paddle.fluid as fluid
import paddle_hub.module as module
import reader
import sys
def retrain(modelpath):
model = module.Module(module_dir=modelpath)
feed_list, fetch_list, program = model(
sign_name="feature_map", trainable=True)
# get the dog cat dataset
train_reader = paddle.batch(reader.train("./dataset"), batch_size=32)
val_reader = paddle.batch(reader.val("./dataset"), batch_size=32)
with fluid.program_guard(main_program=program):
img = feed_list[0]
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
feature_map = fetch_list[0]
fc = fluid.layers.fc(input=feature_map, size=2, act="softmax")
cost = fluid.layers.cross_entropy(input=fc, label=label)
avg_cost = fluid.layers.mean(cost)
acc = fluid.layers.accuracy(input=fc, label=label)
inference_program = fluid.default_main_program().clone(for_test=True)
optimizer = fluid.optimizer.Adam(learning_rate=0.001)
optimizer.minimize(avg_cost)
# running on gpu
place = fluid.CUDAPlace(0)
feeder = fluid.DataFeeder(feed_list=[img, label], place=place)
exe = fluid.Executor(place)
train_exe = fluid.ParallelExecutor(
use_cuda=True,
loss_name=avg_cost.name,
main_program=fluid.default_main_program())
# init all param
exe.run(fluid.default_startup_program())
step = 0
epochs = 50
# start to train
for i in range(epochs):
for batch in train_reader():
cost, accuracy = train_exe.run(
feed=feeder.feed(batch),
fetch_list=[avg_cost.name, acc.name])
step += 1
print(
"epoch %d and step %d: train cost is %.2f, train acc is %.2f%%"
% (i, step, cost, accuracy * 100))
if __name__ == "__main__":
retrain(sys.argv[1])
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import time
import sys
import functools
import math
import paddle
import paddle.fluid as fluid
import paddle.dataset.flowers as flowers
import reader
import argparse
import functools
import subprocess
import utils
import nets
import paddle_hub as hub
from utils.learning_rate import cosine_decay
from utils.fp16_utils import create_master_params_grads, master_param_to_train_param
from utility import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('create_module', bool, False, "create a hub module or not" )
add_arg('batch_size', int, 32, "Minibatch size.")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('total_images', int, 12000, "Training image number.")
add_arg('num_epochs', int, 120, "number of epochs.")
add_arg('class_dim', int, 2, "Class number.")
add_arg('image_shape', str, "3,224,224", "input image size")
add_arg('model_save_dir', str, "output", "model save directory")
add_arg('pretrained_model', str, None, "Whether to use pretrained model.")
add_arg('lr', float, 0.1, "set learning rate.")
add_arg('lr_strategy', str, "piecewise_decay", "Set the learning rate decay strategy.")
add_arg('model', str, "ResNet50", "Set the network to use.")
add_arg('data_dir', str, "./dataset", "The ImageNet dataset root dir.")
add_arg('fp16', bool, False, "Enable half precision training with fp16." )
add_arg('scale_loss', float, 1.0, "Scale loss for fp16." )
# yapf: enable
def optimizer_setting(params):
ls = params["learning_strategy"]
if ls["name"] == "piecewise_decay":
if "total_images" not in params:
total_images = 12000
else:
total_images = params["total_images"]
batch_size = ls["batch_size"]
step = int(total_images / batch_size + 1)
bd = [step * e for e in ls["epochs"]]
base_lr = params["lr"]
lr = []
lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)]
optimizer = fluid.optimizer.Momentum(
learning_rate=fluid.layers.piecewise_decay(
boundaries=bd, values=lr),
momentum=0.9,
regularization=fluid.regularizer.L2Decay(1e-4))
elif ls["name"] == "cosine_decay":
if "total_images" not in params:
total_images = 12000
else:
total_images = params["total_images"]
batch_size = ls["batch_size"]
step = int(total_images / batch_size + 1)
lr = params["lr"]
num_epochs = params["num_epochs"]
optimizer = fluid.optimizer.Momentum(
learning_rate=cosine_decay(
learning_rate=lr, step_each_epoch=step, epochs=num_epochs),
momentum=0.9,
regularization=fluid.regularizer.L2Decay(4e-5))
elif ls["name"] == "exponential_decay":
if "total_images" not in params:
total_images = 12000
else:
total_images = params["total_images"]
batch_size = ls["batch_size"]
step = int(total_images / batch_size + 1)
lr = params["lr"]
num_epochs = params["num_epochs"]
learning_decay_rate_factor = ls["learning_decay_rate_factor"]
num_epochs_per_decay = ls["num_epochs_per_decay"]
NUM_GPUS = 1
optimizer = fluid.optimizer.Momentum(
learning_rate=fluid.layers.exponential_decay(
learning_rate=lr * NUM_GPUS,
decay_steps=step * num_epochs_per_decay / NUM_GPUS,
decay_rate=learning_decay_rate_factor),
momentum=0.9,
regularization=fluid.regularizer.L2Decay(4e-5))
else:
lr = params["lr"]
optimizer = fluid.optimizer.Momentum(
learning_rate=lr,
momentum=0.9,
regularization=fluid.regularizer.L2Decay(1e-4))
return optimizer
def net_config(image, label, model, args):
class_dim = args.class_dim
model_name = args.model
out, feature_map = model.net(input=image, class_dim=class_dim)
cost, pred = fluid.layers.softmax_with_cross_entropy(
out, label, return_softmax=True)
if args.scale_loss > 1:
avg_cost = fluid.layers.mean(x=cost) * float(args.scale_loss)
else:
avg_cost = fluid.layers.mean(x=cost)
acc_top1 = fluid.layers.accuracy(input=pred, label=label, k=1)
return avg_cost, acc_top1, out, feature_map
def build_program(is_train, main_prog, startup_prog, args):
image_shape = [int(m) for m in args.image_shape.split(",")]
model_name = args.model
model = nets.__dict__[model_name]()
with fluid.program_guard(main_prog, startup_prog):
py_reader = fluid.layers.py_reader(
capacity=16,
shapes=[[-1] + image_shape, [-1, 1]],
lod_levels=[0, 0],
dtypes=["float32", "int64"],
use_double_buffer=True)
with fluid.unique_name.guard():
image, label = fluid.layers.read_file(py_reader)
if args.fp16:
image = fluid.layers.cast(image, "float16")
avg_cost, acc_top1, predition, feature_map = net_config(
image, label, model, args)
avg_cost.persistable = True
acc_top1.persistable = True
if is_train:
params = model.params
params["total_images"] = args.total_images
params["lr"] = args.lr
params["num_epochs"] = args.num_epochs
params["learning_strategy"]["batch_size"] = args.batch_size
params["learning_strategy"]["name"] = args.lr_strategy
optimizer = optimizer_setting(params)
if args.fp16:
params_grads = optimizer.backward(avg_cost)
master_params_grads = create_master_params_grads(
params_grads, main_prog, startup_prog, args.scale_loss)
optimizer.apply_gradients(master_params_grads)
master_param_to_train_param(master_params_grads,
params_grads, main_prog)
else:
optimizer.minimize(avg_cost)
return py_reader, avg_cost, acc_top1, image, predition, feature_map
def train(args):
# parameters from arguments
model_name = args.model
pretrained_model = args.pretrained_model
model_save_dir = args.model_save_dir
startup_prog = fluid.Program()
train_prog = fluid.Program()
test_prog = fluid.Program()
train_py_reader, train_cost, train_acc, image, predition, feature_map = build_program(
is_train=True,
main_prog=train_prog,
startup_prog=startup_prog,
args=args)
test_py_reader, test_cost, test_acc, image, predition, feature_map = build_program(
is_train=False,
main_prog=test_prog,
startup_prog=startup_prog,
args=args)
test_prog = test_prog.clone(for_test=True)
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup_prog)
if pretrained_model:
def if_exist(var):
return os.path.exists(os.path.join(pretrained_model, var.name))
fluid.io.load_vars(
exe, pretrained_model, main_program=train_prog, predicate=if_exist)
if args.create_module:
assert pretrained_model, "need a pretrained module to create a hub module"
sign1 = hub.create_signature(
"classification", inputs=[image], outputs=[predition])
sign2 = hub.create_signature(
"feature_map", inputs=[image], outputs=[feature_map])
sign3 = hub.create_signature(inputs=[image], outputs=[predition])
hub.create_module(
sign_arr=[sign1, sign2, sign3],
module_dir="hub_module_" + args.model)
exit()
visible_device = os.getenv('CUDA_VISIBLE_DEVICES')
if visible_device:
device_num = len(visible_device.split(','))
else:
device_num = subprocess.check_output(['nvidia-smi',
'-L']).decode().count('\n')
train_batch_size = args.batch_size / device_num
test_batch_size = 16
train_reader = paddle.batch(
reader.train(), batch_size=train_batch_size, drop_last=True)
test_reader = paddle.batch(reader.val(), batch_size=test_batch_size)
train_py_reader.decorate_paddle_reader(train_reader)
test_py_reader.decorate_paddle_reader(test_reader)
train_exe = fluid.ParallelExecutor(
main_program=train_prog,
use_cuda=bool(args.use_gpu),
loss_name=train_cost.name)
train_fetch_list = [train_cost.name, train_acc.name]
test_fetch_list = [test_cost.name, test_acc.name]
params = nets.__dict__[args.model]().params
for pass_id in range(params["num_epochs"]):
train_py_reader.start()
train_info = [[], [], []]
test_info = [[], [], []]
train_time = []
batch_id = 0
try:
while True:
t1 = time.time()
loss, acc = train_exe.run(fetch_list=train_fetch_list)
t2 = time.time()
period = t2 - t1
loss = np.mean(np.array(loss))
acc = np.mean(np.array(acc))
train_info[0].append(loss)
train_info[1].append(acc)
train_time.append(period)
if batch_id % 10 == 0:
print("Pass {0}, trainbatch {1}, loss {2}, \
acc {3}, time {4}".format(pass_id, batch_id, loss, acc,
"%2.2f sec" % period))
sys.stdout.flush()
batch_id += 1
except fluid.core.EOFException:
train_py_reader.reset()
train_loss = np.array(train_info[0]).mean()
train_acc = np.array(train_info[1]).mean()
train_speed = np.array(train_time).mean() / (
train_batch_size * device_num)
test_py_reader.start()
test_batch_id = 0
try:
while True:
t1 = time.time()
loss, acc = exe.run(
program=test_prog, fetch_list=test_fetch_list)
t2 = time.time()
period = t2 - t1
loss = np.mean(loss)
acc = np.mean(acc)
test_info[0].append(loss)
test_info[1].append(acc)
if test_batch_id % 10 == 0:
print("Pass {0},testbatch {1},loss {2}, \
acc {3},time {4}".format(pass_id, test_batch_id, loss,
acc, "%2.2f sec" % period))
sys.stdout.flush()
test_batch_id += 1
except fluid.core.EOFException:
test_py_reader.reset()
test_loss = np.array(test_info[0]).mean()
test_acc = np.array(test_info[1]).mean()
print("End pass {0}, train_loss {1}, train_acc {2}, "
"test_loss {3}, test_acc {4}".format(
pass_id, train_loss, train_acc, test_loss, test_acc))
sys.stdout.flush()
model_path = os.path.join(model_save_dir + '/' + model_name,
str(pass_id))
if not os.path.isdir(model_path):
os.makedirs(model_path)
fluid.io.save_persistables(exe, model_path, main_program=train_prog)
def main():
args = parser.parse_args()
assert args.model in nets.__all__, "model is not in list %s" % nets.__all__
print_arguments(args)
train(args)
if __name__ == '__main__':
main()
#!/bin/bash
set -o nounset
set -o errexit
script_path=$(cd `dirname $0`; pwd)
cd $script_path
model_name=ResNet50
batch_size=32
data_dir=./dataset
class_dim=2
use_gpu=False
while getopts "m:b:c:d:g" options
do
case "$options" in
b)
batch_size=$OPTARG;;
c)
class_dim=$OPTARG;;
d)
data_dir=$OPTARG;;
m)
model_name=$OPTARG;;
g)
use_gpu=True;;
?)
echo "unknown options"
exit 1;;
esac
done
python train.py --data_dir=${data_dir} --batch_size=${batch_size} --class_dim=${class_dim} --image_shape=3,224,224 --model_save_dir=output/ --lr_strategy=piecewise_decay --lr=0.1 --model=${model_name} --use_gpu=${use_gpu}
"""Contains common utility functions."""
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import distutils.util
import numpy as np
import six
from paddle.fluid import core
def print_arguments(args):
"""Print argparse's arguments.
Usage:
.. code-block:: python
parser = argparse.ArgumentParser()
parser.add_argument("name", default="Jonh", type=str, help="User name.")
args = parser.parse_args()
print_arguments(args)
:param args: Input argparse.Namespace for printing.
:type args: argparse.Namespace
"""
print("----------- Configuration Arguments -----------")
for arg, value in sorted(six.iteritems(vars(args))):
print("%s: %s" % (arg, value))
print("------------------------------------------------")
def add_arguments(argname, type, default, help, argparser, **kwargs):
"""Add argparse's argument.
Usage:
.. code-block:: python
parser = argparse.ArgumentParser()
add_argument("name", str, "Jonh", "User name.", parser)
args = parser.parse_args()
"""
type = distutils.util.strtobool if type == bool else type
argparser.add_argument(
"--" + argname,
default=default,
type=type,
help=help + ' Default: %(default)s.',
**kwargs)
from .learning_rate import cosine_decay, lr_warmup
from .fp16_utils import create_master_params_grads, master_param_to_train_param
from __future__ import print_function
import paddle
import paddle.fluid as fluid
def cast_fp16_to_fp32(i, o, prog):
prog.global_block().append_op(
type="cast",
inputs={"X": i},
outputs={"Out": o},
attrs={
"in_dtype": fluid.core.VarDesc.VarType.FP16,
"out_dtype": fluid.core.VarDesc.VarType.FP32
})
def cast_fp32_to_fp16(i, o, prog):
prog.global_block().append_op(
type="cast",
inputs={"X": i},
outputs={"Out": o},
attrs={
"in_dtype": fluid.core.VarDesc.VarType.FP32,
"out_dtype": fluid.core.VarDesc.VarType.FP16
})
def copy_to_master_param(p, block):
v = block.vars.get(p.name, None)
if v is None:
raise ValueError("no param name %s found!" % p.name)
new_p = fluid.framework.Parameter(
block=block,
shape=v.shape,
dtype=fluid.core.VarDesc.VarType.FP32,
type=v.type,
lod_level=v.lod_level,
stop_gradient=p.stop_gradient,
trainable=p.trainable,
optimize_attr=p.optimize_attr,
regularizer=p.regularizer,
gradient_clip_attr=p.gradient_clip_attr,
error_clip=p.error_clip,
name=v.name + ".master")
return new_p
def create_master_params_grads(params_grads, main_prog, startup_prog,
scale_loss):
master_params_grads = []
tmp_role = main_prog._current_role
OpRole = fluid.core.op_proto_and_checker_maker.OpRole
main_prog._current_role = OpRole.Backward
for p, g in params_grads:
# create master parameters
master_param = copy_to_master_param(p, main_prog.global_block())
startup_master_param = startup_prog.global_block()._clone_variable(
master_param)
startup_p = startup_prog.global_block().var(p.name)
cast_fp16_to_fp32(startup_p, startup_master_param, startup_prog)
# cast fp16 gradients to fp32 before apply gradients
if g.name.startswith("batch_norm"):
if scale_loss > 1:
scaled_g = g / float(scale_loss)
else:
scaled_g = g
master_params_grads.append([p, scaled_g])
continue
master_grad = fluid.layers.cast(g, "float32")
if scale_loss > 1:
master_grad = master_grad / float(scale_loss)
master_params_grads.append([master_param, master_grad])
main_prog._current_role = tmp_role
return master_params_grads
def master_param_to_train_param(master_params_grads, params_grads, main_prog):
for idx, m_p_g in enumerate(master_params_grads):
train_p, _ = params_grads[idx]
if train_p.name.startswith("batch_norm"):
continue
with main_prog._optimized_guard([m_p_g[0], m_p_g[1]]):
cast_fp32_to_fp16(m_p_g[0], train_p, main_prog)
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers.ops as ops
from paddle.fluid.initializer import init_on_cpu
from paddle.fluid.layers.learning_rate_scheduler import _decay_step_counter
import math
def cosine_decay(learning_rate, step_each_epoch, epochs=120):
"""Applies cosine decay to the learning rate.
lr = 0.05 * (math.cos(epoch * (math.pi / 120)) + 1)
"""
global_step = _decay_step_counter()
with init_on_cpu():
epoch = ops.floor(global_step / step_each_epoch)
decayed_lr = learning_rate * \
(ops.cos(epoch * (math.pi / epochs)) + 1)/2
return decayed_lr
def lr_warmup(learning_rate, warmup_steps, start_lr, end_lr):
""" Applies linear learning rate warmup for distributed training
Argument learning_rate can be float or a Variable
lr = lr + (warmup_rate * step / warmup_steps)
"""
assert (isinstance(end_lr, float))
assert (isinstance(start_lr, float))
linear_step = end_lr - start_lr
with fluid.default_main_program()._lr_schedule_guard():
lr = fluid.layers.tensor.create_global_var(
shape=[1],
value=0.0,
dtype='float32',
persistable=True,
name="learning_rate_warmup")
global_step = fluid.layers.learning_rate_scheduler._decay_step_counter()
with fluid.layers.control_flow.Switch() as switch:
with switch.case(global_step < warmup_steps):
decayed_lr = start_lr + linear_step * (
global_step / warmup_steps)
fluid.layers.tensor.assign(decayed_lr, lr)
with switch.default():
fluid.layers.tensor.assign(learning_rate, lr)
return lr
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册