提交 c5e91094 编写于 作者: D Dun 提交者: qingqing01

add num_classes config to deeplabv3+ (#1562)

*add num classes config
上级 cbe656e0
......@@ -26,6 +26,7 @@ def add_arguments():
add_argument('dataset_path', str, None, "Cityscape dataset path.")
add_argument('verbose', bool, False, "Print mIoU for each step if verbose.")
add_argument('use_gpu', bool, True, "Whether use GPU or CPU.")
add_argument('num_classes', int, 19, "Number of classes.")
def mean_iou(pred, label):
......@@ -69,7 +70,7 @@ tp = fluid.Program()
batch_size = 1
reader.default_config['crop_size'] = -1
reader.default_config['shuffle'] = False
num_classes = 19
num_classes = args.num_classes
with fluid.program_guard(tp, sp):
img = fluid.layers.data(name='img', shape=[3, 0, 0], dtype='float32')
......@@ -84,7 +85,7 @@ tp = tp.clone(True)
fluid.memory_optimize(
tp,
print_log=False,
skip_opt_set=[pred.name, miou, out_wrong, out_correct],
skip_opt_set=set([pred.name, miou, out_wrong, out_correct]),
level=1)
place = fluid.CPUPlace()
......
......@@ -20,6 +20,11 @@ op_results = {}
default_epsilon = 1e-3
default_norm_type = 'bn'
default_group_number = 32
depthwise_use_cudnn = False
bn_regularizer = fluid.regularizer.L2DecayRegularizer(regularization_coeff=0.0)
depthwise_regularizer = fluid.regularizer.L2DecayRegularizer(
regularization_coeff=0.0)
@contextlib.contextmanager
......@@ -52,20 +57,39 @@ def append_op_result(result, name):
def conv(*args, **kargs):
kargs['param_attr'] = name_scope + 'weights'
if "xception" in name_scope:
init_std = 0.09
elif "logit" in name_scope:
init_std = 0.01
elif name_scope.endswith('depthwise/'):
init_std = 0.33
else:
init_std = 0.06
if name_scope.endswith('depthwise/'):
regularizer = depthwise_regularizer
else:
regularizer = None
kargs['param_attr'] = fluid.ParamAttr(
name=name_scope + 'weights',
regularizer=regularizer,
initializer=fluid.initializer.TruncatedNormal(
loc=0.0, scale=init_std))
if 'bias_attr' in kargs and kargs['bias_attr']:
kargs['bias_attr'] = name_scope + 'biases'
kargs['bias_attr'] = fluid.ParamAttr(
name=name_scope + 'biases',
regularizer=regularizer,
initializer=fluid.initializer.ConstantInitializer(value=0.0))
else:
kargs['bias_attr'] = False
kargs['name'] = name_scope + 'conv'
return append_op_result(fluid.layers.conv2d(*args, **kargs), 'conv')
def group_norm(input, G, eps=1e-5, param_attr=None, bias_attr=None):
helper = fluid.layer_helper.LayerHelper('group_norm', **locals())
N, C, H, W = input.shape
if C % G != 0:
print("group can not divide channle:", C, G)
# print "group can not divide channle:", C, G
for d in range(10):
for t in [d, -d]:
if G + t <= 0: continue
......@@ -73,29 +97,16 @@ def group_norm(input, G, eps=1e-5, param_attr=None, bias_attr=None):
G = G + t
break
if C % G == 0:
print("use group size:", G)
# print "use group size:", G
break
assert C % G == 0
param_shape = (G, )
x = input
x = fluid.layers.reshape(x, [N, G, C // G * H * W])
mean = fluid.layers.reduce_mean(x, dim=2, keep_dim=True)
x = x - mean
var = fluid.layers.reduce_mean(fluid.layers.square(x), dim=2, keep_dim=True)
x = x / fluid.layers.sqrt(var + eps)
scale = helper.create_parameter(
attr=helper.param_attr,
shape=param_shape,
dtype='float32',
default_initializer=fluid.initializer.Constant(1.0))
bias = helper.create_parameter(
attr=helper.bias_attr, shape=param_shape, dtype='float32', is_bias=True)
x = fluid.layers.elementwise_add(
fluid.layers.elementwise_mul(
x, scale, axis=1), bias, axis=1)
return fluid.layers.reshape(x, input.shape)
x = fluid.layers.group_norm(
input,
groups=G,
param_attr=param_attr,
bias_attr=bias_attr,
name=name_scope + 'group_norm')
return x
def bn(*args, **kargs):
......@@ -106,8 +117,10 @@ def bn(*args, **kargs):
*args,
epsilon=default_epsilon,
momentum=bn_momentum,
param_attr=name_scope + 'gamma',
bias_attr=name_scope + 'beta',
param_attr=fluid.ParamAttr(
name=name_scope + 'gamma', regularizer=bn_regularizer),
bias_attr=fluid.ParamAttr(
name=name_scope + 'beta', regularizer=bn_regularizer),
moving_mean_name=name_scope + 'moving_mean',
moving_variance_name=name_scope + 'moving_variance',
**kargs),
......@@ -119,8 +132,10 @@ def bn(*args, **kargs):
args[0],
default_group_number,
eps=default_epsilon,
param_attr=name_scope + 'gamma',
bias_attr=name_scope + 'beta'),
param_attr=fluid.ParamAttr(
name=name_scope + 'gamma', regularizer=bn_regularizer),
bias_attr=fluid.ParamAttr(
name=name_scope + 'beta', regularizer=bn_regularizer)),
'gn')
else:
raise "Unsupport norm type:" + default_norm_type
......@@ -143,7 +158,8 @@ def seq_conv(input, channel, stride, filter, dilation=1, act=None):
stride,
groups=input.shape[1],
padding=(filter // 2) * dilation,
dilation=dilation)
dilation=dilation,
use_cudnn=depthwise_use_cudnn)
input = bn(input)
if act: input = act(input)
with scope('pointwise'):
......
......@@ -13,6 +13,7 @@ import reader
import models
import time
def add_argument(name, type, default, help):
parser.add_argument('--' + name, default=default, type=type, help=help)
......@@ -32,15 +33,28 @@ def add_arguments():
add_argument('dataset_path', str, None, "Cityscape dataset path.")
add_argument('parallel', bool, False, "using ParallelExecutor.")
add_argument('use_gpu', bool, True, "Whether use GPU or CPU.")
add_argument('num_classes', int, 19, "Number of classes.")
def load_model():
myvars = [
x for x in tp.list_vars()
if isinstance(x, fluid.framework.Parameter) and x.name.find('logit') ==
-1
]
if args.init_weights_path.endswith('/'):
fluid.io.load_params(
exe, dirname=args.init_weights_path, main_program=tp)
if args.num_classes == 19:
fluid.io.load_params(
exe, dirname=args.init_weights_path, main_program=tp)
else:
fluid.io.load_vars(exe, dirname=args.init_weights_path, vars=myvars)
else:
fluid.io.load_params(
exe, dirname="", filename=args.init_weights_path, main_program=tp)
if args.num_classes == 19:
fluid.io.load_params(
exe, dirname=args.init_weights_path, main_program=tp)
else:
fluid.io.load_vars(
exe, dirname="", filename=args.init_weights_path, vars=myvars)
def save_model():
......@@ -80,6 +94,7 @@ args = parser.parse_args()
models.clean()
models.bn_momentum = 0.9997
models.dropout_keep_prop = 0.9
models.label_number = args.num_classes
deeplabv3p = models.deeplabv3p
sp = fluid.Program()
......@@ -89,7 +104,7 @@ batch_size = args.batch_size
image_shape = [crop_size, crop_size]
reader.default_config['crop_size'] = crop_size
reader.default_config['shuffle'] = True
num_classes = 19
num_classes = args.num_classes
weight_decay = 0.00004
base_lr = args.base_lr
......@@ -120,7 +135,7 @@ with fluid.program_guard(tp, sp):
retv = opt.minimize(loss_mean, startup_program=sp, no_grad_set=no_grad_set)
fluid.memory_optimize(
tp, print_log=False, skip_opt_set=[pred.name, loss_mean.name], level=1)
tp, print_log=False, skip_opt_set=set([pred.name, loss_mean.name]), level=1)
place = fluid.CPUPlace()
if args.use_gpu:
......@@ -155,8 +170,8 @@ for i, imgs, labels, names in batches:
if i % 100 == 0:
print("Model is saved to", args.save_weights_path)
save_model()
print("step {:d}, loss: {:.6f}, step_time_cost: {:.3f}" .format(i,
np.mean(retv[1]), end_time - prev_start_time))
print("step {:d}, loss: {:.6f}, step_time_cost: {:.3f}".format(
i, np.mean(retv[1]), end_time - prev_start_time))
print("Training done. Model is saved to", args.save_weights_path)
save_model()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册