提交 9e57a5b5 编写于 作者: littletomatodonkey's avatar littletomatodonkey

add distillation params

上级 be49cec6
......@@ -29,9 +29,20 @@ __all__ = [
class ResNet():
def __init__(self, layers=50, is_3x3=False):
def __init__(self,
layers=50,
is_3x3=False,
postfix_name="",
lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0]):
self.layers = layers
self.is_3x3 = is_3x3
self.postfix_name = "" if postfix_name is None else postfix_name
self.lr_mult_list = lr_mult_list
assert len(
self.lr_mult_list
) == 5, "lr_mult_list length in ResNet must be 5 but got {}!!".format(
len(self.lr_mult_list))
self.curr_stage = 0
def net(self, input, class_dim=1000):
is_3x3 = self.is_3x3
......@@ -90,6 +101,7 @@ class ResNet():
if layers >= 50:
for block in range(len(depth)):
self.curr_stage += 1
for i in range(depth[block]):
if layers in [101, 152, 200] and block == 2:
if i == 0:
......@@ -106,6 +118,7 @@ class ResNet():
name=conv_name)
else:
for block in range(len(depth)):
self.curr_stage += 1
for i in range(depth[block]):
conv_name = "res" + str(block + 2) + chr(97 + i)
conv = self.basic_block(
......@@ -123,9 +136,9 @@ class ResNet():
input=pool,
size=class_dim,
param_attr=fluid.param_attr.ParamAttr(
name="fc_0.w_0",
name="fc_0.w_0" + self.postfix_name,
initializer=fluid.initializer.Uniform(-stdv, stdv)),
bias_attr=ParamAttr(name="fc_0.b_0"))
bias_attr=ParamAttr(name="fc_0.b_0" + self.postfix_name))
return out
......@@ -137,6 +150,7 @@ class ResNet():
groups=1,
act=None,
name=None):
lr_mult = self.lr_mult_list[self.curr_stage]
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
......@@ -145,7 +159,7 @@ class ResNet():
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + "_weights"),
param_attr=ParamAttr(name=name + "_weights" + self.postfix_name),
bias_attr=False)
if name == "conv1":
bn_name = "bn_" + name
......@@ -154,10 +168,10 @@ class ResNet():
return fluid.layers.batch_norm(
input=conv,
act=act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
param_attr=ParamAttr(name=bn_name + '_scale' + self.postfix_name),
bias_attr=ParamAttr(bn_name + '_offset' + self.postfix_name),
moving_mean_name=bn_name + '_mean' + self.postfix_name,
moving_variance_name=bn_name + '_variance' + self.postfix_name)
def conv_bn_layer_new(self,
input,
......@@ -167,6 +181,7 @@ class ResNet():
groups=1,
act=None,
name=None):
lr_mult = self.lr_mult_list[self.curr_stage]
pool = fluid.layers.pool2d(
input=input,
pool_size=2,
......@@ -183,7 +198,9 @@ class ResNet():
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + "_weights"),
param_attr=ParamAttr(
name=name + "_weights" + self.postfix_name,
learning_rate=lr_mult),
bias_attr=False)
if name == "conv1":
bn_name = "bn_" + name
......@@ -192,10 +209,14 @@ class ResNet():
return fluid.layers.batch_norm(
input=conv,
act=act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
param_attr=ParamAttr(
name=bn_name + '_scale' + self.postfix_name,
learning_rate=lr_mult),
bias_attr=ParamAttr(
bn_name + '_offset' + self.postfix_name,
learning_rate=lr_mult),
moving_mean_name=bn_name + '_mean' + self.postfix_name,
moving_variance_name=bn_name + '_variance' + self.postfix_name)
def shortcut(self, input, ch_out, stride, name, if_first=False):
ch_in = input.shape[1]
......@@ -273,8 +294,8 @@ def ResNet34_vd():
return model
def ResNet50_vd():
model = ResNet(layers=50, is_3x3=True)
def ResNet50_vd(**args):
model = ResNet(layers=50, is_3x3=True, **args)
return model
......
......@@ -59,15 +59,18 @@ def check_architecture(architecture):
"""
check architecture and recommend similar architectures
"""
assert isinstance(architecture, str), \
("the type of architecture({}) should be str". format(architecture))
similar_names = similar_architectures(architecture, get_architectures())
assert isinstance(architecture, dict), \
("the type of architecture({}) should be dict". format(architecture))
assert "name" in architecture, \
("name must be in the architecture keys, just contains: {}". format(architecture.keys()))
similar_names = similar_architectures(architecture["name"],
get_architectures())
model_list = ', '.join(similar_names)
err = "{} is not exist! Maybe you want: [{}]" \
"".format(architecture, model_list)
"".format(architecture["name"], model_list)
try:
assert architecture in similar_names
assert architecture["name"] in similar_names
except AssertionError:
logger.error(err)
sys.exit(1)
......@@ -80,7 +83,7 @@ def check_mix(architecture, use_mix=False):
err = "Cannot use mix processing in GoogLeNet, " \
"please set use_mix = False."
try:
if architecture == "GoogLeNet": assert use_mix == False
if architecture["name"] == "GoogLeNet": assert use_mix == False
except AssertionError:
logger.error(err)
sys.exit(1)
......
......@@ -19,7 +19,7 @@ from ppcls.utils import logger
__all__ = ['get_config']
CONFIG_SECS = ['TRAIN', 'VALID', 'OPTIMIZER', 'LEARNING_RATE']
CONFIG_SECS = ['ARCHITECTURE', 'TRAIN', 'VALID', 'OPTIMIZER', 'LEARNING_RATE']
class AttrDict(dict):
......@@ -110,7 +110,7 @@ def check_config(config):
mode = config.get('mode', 'train')
check.check_gpu()
architecture = config.get('architecture')
architecture = config.get('ARCHITECTURE')
check.check_architecture(architecture)
use_mix = config.get('use_mix')
......
......@@ -88,19 +88,21 @@ def create_dataloader(feeds):
return dataloader
def create_model(name, image, classes_num):
def create_model(architecture, image, classes_num):
"""
Create a model
Args:
name(str): model name, such as ResNet50
architecture(dict): architecture information, name(such as ResNet50) is needed
image(variable): model input variable
classes_num(int): num of classes
Returns:
out(variable): model output variable
"""
model = architectures.__dict__[name]()
name = architecture["name"]
params = architecture["params"] if "params" in architecture else {}
model = architectures.__dict__[name](**params)
out = model.net(input=image, class_dim=classes_num)
return out
......@@ -122,7 +124,7 @@ def create_loss(out,
Args:
out(variable): model output variable
feeds(dict): dict of model input variables
architecture(str): model name, such as ResNet50
architecture(dict): architecture information, name(such as ResNet50) is needed
classes_num(int): num of classes
epsilon(float): parameter for label smoothing, 0.0 <= epsilon <= 1.0
mix(bool): whether to use mix(include mixup, cutmix, fmix)
......@@ -130,7 +132,7 @@ def create_loss(out,
Returns:
loss(variable): loss variable
"""
if architecture == "GoogLeNet":
if architecture["name"] == "GoogLeNet":
assert len(out) == 3, "GoogLeNet should have 3 outputs"
loss = GoogLeNetLoss(class_dim=classes_num, epsilon=epsilon)
target = feeds['label']
......@@ -188,7 +190,7 @@ def create_fetchs(out,
Args:
out(variable): model output variable
feeds(dict): dict of model input variables(included label)
architecture(str): model name, such as ResNet50
architecture(dict): architecture information, name(such as ResNet50) is needed
topk(int): usually top5
classes_num(int): num of classes
epsilon(float): parameter for label smoothing, 0.0 <= epsilon <= 1.0
......@@ -293,12 +295,12 @@ def build(config, main_prog, startup_prog, is_train=True):
use_mix = config.get('use_mix') and is_train
feeds = create_feeds(config.image_shape, mix=use_mix)
dataloader = create_dataloader(feeds.values())
out = create_model(config.architecture, feeds['image'],
out = create_model(config.ARCHITECTURE, feeds['image'],
config.classes_num)
fetchs = create_fetchs(
out,
feeds,
config.architecture,
config.ARCHITECTURE,
config.topk,
config.classes_num,
epsilon=config.get('ls_epsilon'),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册