未验证 提交 7d0f694c 编写于 作者: L littletomatodonkey 提交者: GitHub

unify model and model_name (#386)

* unify model and model_name

* rm unsed archs fix bug of se_resnext

* fix distillation model for distributed.launch

* fix export model for googlenet
上级 7c0ef004
mode: 'train'
ARCHITECTURE:
name: 'CSPResNet50_leaky'
pretrained_model: ""
model_save_dir: "./output/"
classes_num: 1000
total_images: 1281167
save_interval: 1
validate: True
valid_interval: 1
epochs: 120
topk: 5
image_shape: [3, 256, 256]
use_mix: False
ls_epsilon: -1
LEARNING_RATE:
function: 'Piecewise'
params:
lr: 0.1
decay_epochs: [30, 60, 90]
gamma: 0.1
OPTIMIZER:
function: 'Momentum'
params:
momentum: 0.9
regularizer:
function: 'L2'
factor: 0.000100
TRAIN:
batch_size: 256
num_workers: 4
file_list: "./dataset/ILSVRC2012/train_list.txt"
data_dir: "./dataset/ILSVRC2012/"
shuffle_seed: 0
transforms:
- DecodeImage:
to_rgb: True
to_np: False
channel_first: False
- RandCropImage:
size: 256
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
VALID:
batch_size: 64
num_workers: 4
file_list: "./dataset/ILSVRC2012/val_list.txt"
data_dir: "./dataset/ILSVRC2012/"
shuffle_seed: 0
transforms:
- DecodeImage:
to_rgb: True
to_np: False
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 256
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
mode: 'train'
ARCHITECTURE:
name: "ResNet50_ACNet"
pretrained_model: ""
model_save_dir: "./output/"
classes_num: 1000
total_images: 1281167
save_interval: 1
validate: True
valid_interval: 1
epochs: 120
topk: 5
image_shape: [3, 224, 224]
LEARNING_RATE:
function: 'Piecewise'
params:
lr: 0.1
decay_epochs: [30, 60, 90]
gamma: 0.1
OPTIMIZER:
function: 'Momentum'
params:
momentum: 0.9
regularizer:
function: 'L2'
factor: 0.0001
TRAIN:
batch_size: 256
num_workers: 4
file_list: "./dataset/ILSVRC2012/train_list.txt"
data_dir: "./dataset/ILSVRC2012/"
shuffle_seed: 0
transforms:
- DecodeImage:
to_rgb: True
to_np: False
channel_first: False
- RandCropImage:
size: 224
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
VALID:
batch_size: 64
num_workers: 4
file_list: "./dataset/ILSVRC2012/val_list.txt"
data_dir: "./dataset/ILSVRC2012/"
shuffle_seed: 0
transforms:
- DecodeImage:
to_rgb: True
to_np: False
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
......@@ -20,6 +20,7 @@ import numpy as np
import paddle
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
from paddle.nn.initializer import Uniform
......
......@@ -80,15 +80,19 @@ def load_distillation_model(model, pretrained_model, load_static_weights):
load_static_weights
) == 2, "load_static_weights length should be 2 but got {}".format(
len(load_static_weights))
teacher = model.teacher if hasattr(model,
"teacher") else model._layers.teacher
student = model.student if hasattr(model,
"student") else model._layers.student
load_dygraph_pretrain(
model.teacher,
teacher,
path=pretrained_model[0],
load_static_weights=load_static_weights[0])
logger.info(
logger.coloring("Finish initing teacher model from {}".format(
pretrained_model), "HEADER"))
load_dygraph_pretrain(
model.student,
student,
path=pretrained_model[1],
load_static_weights=load_static_weights[1])
logger.info(
......
......@@ -43,10 +43,11 @@ def parse_args():
class Net(paddle.nn.Layer):
def __init__(self, net, to_static, class_dim):
def __init__(self, net, to_static, class_dim, model):
super(Net, self).__init__()
self.pre_net = net(class_dim=class_dim)
self.to_static = to_static
self.model = model
# Please modify the 'shape' according to actual needs
@to_static(input_spec=[
......@@ -55,6 +56,8 @@ class Net(paddle.nn.Layer):
])
def forward(self, inputs):
x = self.pre_net(inputs)
if self.model == "GoogLeNet":
x = x[0]
x = F.softmax(x)
return x
......@@ -64,7 +67,7 @@ def main():
net = architectures.__dict__[args.model]
model = Net(net, to_static, args.class_dim)
model = Net(net, to_static, args.class_dim, args.model)
load_dygraph_pretrain(
model.pre_net,
path=args.pretrained_model,
......
......@@ -53,7 +53,7 @@ def main(args):
assert args.use_fp16 is False
else:
assert args.use_gpu is True
assert args.model_name is not None
assert args.model is not None
# HALF precission predict only work when using tensorrt
if args.use_fp16 is True:
assert args.use_tensorrt is True
......@@ -105,8 +105,8 @@ def main(args):
fp_message = "FP16" if args.use_fp16 else "FP32"
trt_msg = "using tensorrt" if args.use_tensorrt else "not using tensorrt"
print("{0}\t{1}\t{2}\tbatch size: {3}\ttime(ms): {4}".format(
args.model_name, trt_msg, fp_message, args.batch_size, 1000 *
test_time / test_num))
args.model, trt_msg, fp_message, args.batch_size, 1000 * test_time
/ test_num))
if __name__ == "__main__":
......
......@@ -40,7 +40,6 @@ def parse_args():
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
parser.add_argument("--gpu_mem", type=int, default=8000)
parser.add_argument("--enable_benchmark", type=str2bool, default=False)
parser.add_argument("--model_name", type=str)
parser.add_argument("--top_k", type=int, default=1)
parser.add_argument("--hubserving", type=str2bool, default=False)
......
......@@ -52,7 +52,7 @@ def parse_args():
def main(args):
paddle.seed(123)
paddle.seed(12345)
config = get_config(args.config, overrides=args.override, show=True)
# assign the place
......@@ -68,7 +68,6 @@ def main(args):
strategy = paddle.distributed.init_parallel_env()
net = program.create_model(config.ARCHITECTURE, config.classes_num)
optimizer, lr_scheduler = program.create_optimizer(
config, parameter_list=net.parameters())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册