未验证 提交 d7148519 编写于 作者: W wuyefeilin 提交者: GitHub

update *_batch_size_like ops (#258)

上级 7126bb37
...@@ -19,8 +19,9 @@ from utils.config import cfg ...@@ -19,8 +19,9 @@ from utils.config import cfg
def unsorted_segment_sum(data, segment_ids, unique_labels, feature_dims): def unsorted_segment_sum(data, segment_ids, unique_labels, feature_dims):
zeros = fluid.layers.fill_constant_batch_size_like(unique_labels, shape=[1, feature_dims], unique_labels_shape = fluid.layers.shape(unique_labels)
dtype='float32', value=0) zeros = fluid.layers.fill_constant(
shape=[unique_labels_shape[0], feature_dims], dtype='float32', value=0)
segment_ids = fluid.layers.unsqueeze(segment_ids, axes=[1]) segment_ids = fluid.layers.unsqueeze(segment_ids, axes=[1])
segment_ids.stop_gradient = True segment_ids.stop_gradient = True
segment_sum = fluid.layers.scatter_nd_add(zeros, segment_ids, data) segment_sum = fluid.layers.scatter_nd_add(zeros, segment_ids, data)
...@@ -30,29 +31,23 @@ def unsorted_segment_sum(data, segment_ids, unique_labels, feature_dims): ...@@ -30,29 +31,23 @@ def unsorted_segment_sum(data, segment_ids, unique_labels, feature_dims):
def norm(x, axis=-1): def norm(x, axis=-1):
distance = fluid.layers.reduce_sum(fluid.layers.abs(x), dim=axis, keep_dim=True) distance = fluid.layers.reduce_sum(
fluid.layers.abs(x), dim=axis, keep_dim=True)
return distance return distance
def discriminative_loss_single(
prediction, def discriminative_loss_single(prediction, correct_label, feature_dim,
correct_label, label_shape, delta_v, delta_d, param_var,
feature_dim, param_dist, param_reg):
label_shape,
delta_v, correct_label = fluid.layers.reshape(correct_label,
delta_d, [label_shape[1] * label_shape[0]])
param_var,
param_dist,
param_reg):
correct_label = fluid.layers.reshape(
correct_label, [
label_shape[1] * label_shape[0]])
prediction = fluid.layers.transpose(prediction, [1, 2, 0]) prediction = fluid.layers.transpose(prediction, [1, 2, 0])
reshaped_pred = fluid.layers.reshape( reshaped_pred = fluid.layers.reshape(
prediction, [ prediction, [label_shape[1] * label_shape[0], feature_dim])
label_shape[1] * label_shape[0], feature_dim])
unique_labels, unique_id, counts = fluid.layers.unique_with_counts(correct_label) unique_labels, unique_id, counts = fluid.layers.unique_with_counts(
correct_label)
correct_label.stop_gradient = True correct_label.stop_gradient = True
counts = fluid.layers.cast(counts, 'float32') counts = fluid.layers.cast(counts, 'float32')
num_instances = fluid.layers.shape(unique_labels) num_instances = fluid.layers.shape(unique_labels)
...@@ -69,24 +64,29 @@ def discriminative_loss_single( ...@@ -69,24 +64,29 @@ def discriminative_loss_single(
distance = norm(tmp) distance = norm(tmp)
distance = distance - delta_v distance = distance - delta_v
distance_pos = fluid.layers.greater_equal(distance, fluid.layers.zeros_like(distance)) distance_pos = fluid.layers.greater_equal(distance,
fluid.layers.zeros_like(distance))
distance_pos = fluid.layers.cast(distance_pos, 'float32') distance_pos = fluid.layers.cast(distance_pos, 'float32')
distance = distance * distance_pos distance = distance * distance_pos
distance = fluid.layers.square(distance) distance = fluid.layers.square(distance)
l_var = unsorted_segment_sum(distance, unique_id, unique_labels, feature_dims=1) l_var = unsorted_segment_sum(
distance, unique_id, unique_labels, feature_dims=1)
l_var = fluid.layers.elementwise_div(l_var, counts_rsp) l_var = fluid.layers.elementwise_div(l_var, counts_rsp)
l_var = fluid.layers.reduce_sum(l_var) l_var = fluid.layers.reduce_sum(l_var)
l_var = l_var / fluid.layers.cast(num_instances * (num_instances - 1), 'float32') l_var = l_var / fluid.layers.cast(num_instances * (num_instances - 1),
'float32')
mu_interleaved_rep = fluid.layers.expand(mu, [num_instances, 1]) mu_interleaved_rep = fluid.layers.expand(mu, [num_instances, 1])
mu_band_rep = fluid.layers.expand(mu, [1, num_instances]) mu_band_rep = fluid.layers.expand(mu, [1, num_instances])
mu_band_rep = fluid.layers.reshape(mu_band_rep, (num_instances * num_instances, feature_dim)) mu_band_rep = fluid.layers.reshape(
mu_band_rep, (num_instances * num_instances, feature_dim))
mu_diff = fluid.layers.elementwise_sub(mu_band_rep, mu_interleaved_rep) mu_diff = fluid.layers.elementwise_sub(mu_band_rep, mu_interleaved_rep)
intermediate_tensor = fluid.layers.reduce_sum(fluid.layers.abs(mu_diff), dim=1) intermediate_tensor = fluid.layers.reduce_sum(
fluid.layers.abs(mu_diff), dim=1)
intermediate_tensor.stop_gradient = True intermediate_tensor.stop_gradient = True
zero_vector = fluid.layers.zeros([1], 'float32') zero_vector = fluid.layers.zeros([1], 'float32')
bool_mask = fluid.layers.not_equal(intermediate_tensor, zero_vector) bool_mask = fluid.layers.not_equal(intermediate_tensor, zero_vector)
...@@ -95,7 +95,8 @@ def discriminative_loss_single( ...@@ -95,7 +95,8 @@ def discriminative_loss_single(
mu_norm = norm(mu_diff_bool) mu_norm = norm(mu_diff_bool)
mu_norm = 2. * delta_d - mu_norm mu_norm = 2. * delta_d - mu_norm
mu_norm_pos = fluid.layers.greater_equal(mu_norm, fluid.layers.zeros_like(mu_norm)) mu_norm_pos = fluid.layers.greater_equal(mu_norm,
fluid.layers.zeros_like(mu_norm))
mu_norm_pos = fluid.layers.cast(mu_norm_pos, 'float32') mu_norm_pos = fluid.layers.cast(mu_norm_pos, 'float32')
mu_norm = mu_norm * mu_norm_pos mu_norm = mu_norm * mu_norm_pos
mu_norm_pos.stop_gradient = True mu_norm_pos.stop_gradient = True
...@@ -122,8 +123,8 @@ def discriminative_loss(prediction, correct_label, feature_dim, image_shape, ...@@ -122,8 +123,8 @@ def discriminative_loss(prediction, correct_label, feature_dim, image_shape,
output_ta_reg = 0. output_ta_reg = 0.
for i in range(batch_size): for i in range(batch_size):
disc_loss_single, l_var_single, l_dist_single, l_reg_single = discriminative_loss_single( disc_loss_single, l_var_single, l_dist_single, l_reg_single = discriminative_loss_single(
prediction[i], correct_label[i], feature_dim, image_shape, delta_v, delta_d, param_var, param_dist, prediction[i], correct_label[i], feature_dim, image_shape, delta_v,
param_reg) delta_d, param_var, param_dist, param_reg)
output_ta_loss += disc_loss_single output_ta_loss += disc_loss_single
output_ta_var += l_var_single output_ta_var += l_var_single
output_ta_dist += l_dist_single output_ta_dist += l_dist_single
...@@ -134,5 +135,3 @@ def discriminative_loss(prediction, correct_label, feature_dim, image_shape, ...@@ -134,5 +135,3 @@ def discriminative_loss(prediction, correct_label, feature_dim, image_shape,
l_dist = output_ta_dist / batch_size l_dist = output_ta_dist / batch_size
l_reg = output_ta_reg / batch_size l_reg = output_ta_reg / batch_size
return disc_loss, l_var, l_dist, l_reg return disc_loss, l_var, l_dist, l_reg
...@@ -223,6 +223,7 @@ def build_model(main_prog, start_prog, phase=ModelPhase.TRAIN): ...@@ -223,6 +223,7 @@ def build_model(main_prog, start_prog, phase=ModelPhase.TRAIN):
raise Exception( raise Exception(
"softmax loss or lovasz softmax loss can not combine with bce loss or dice loss or lovasz hinge loss." "softmax loss or lovasz softmax loss can not combine with bce loss or dice loss or lovasz hinge loss."
) )
cfg.PHASE = phase
logits = seg_model(image, class_num) logits = seg_model(image, class_num)
# 根据选择的loss函数计算相应的损失函数 # 根据选择的loss函数计算相应的损失函数
......
...@@ -25,12 +25,15 @@ from models.libs.model_libs import separate_conv ...@@ -25,12 +25,15 @@ from models.libs.model_libs import separate_conv
from utils.config import cfg from utils.config import cfg
def learning_to_downsample(x, dw_channels1=32, dw_channels2=48, out_channels=64): def learning_to_downsample(x, dw_channels1=32, dw_channels2=48,
out_channels=64):
x = relu(bn(conv(x, dw_channels1, 3, 2))) x = relu(bn(conv(x, dw_channels1, 3, 2)))
with scope('dsconv1'): with scope('dsconv1'):
x = separate_conv(x, dw_channels2, stride=2, filter=3, act=fluid.layers.relu) x = separate_conv(
x, dw_channels2, stride=2, filter=3, act=fluid.layers.relu)
with scope('dsconv2'): with scope('dsconv2'):
x = separate_conv(x, out_channels, stride=2, filter=3, act=fluid.layers.relu) x = separate_conv(
x, out_channels, stride=2, filter=3, act=fluid.layers.relu)
return x return x
...@@ -43,7 +46,9 @@ def dropout2d(input, prob, is_train=False): ...@@ -43,7 +46,9 @@ def dropout2d(input, prob, is_train=False):
return input return input
channels = input.shape[1] channels = input.shape[1]
keep_prob = 1.0 - prob keep_prob = 1.0 - prob
random_tensor = keep_prob + fluid.layers.uniform_random_batch_size_like(input, [-1, channels, 1, 1], min=0., max=1.) shape = fluid.layers.shape(input)
random_tensor = keep_prob + fluid.layers.uniform_random(
[shape[0], channels, 1, 1], min=0., max=1.)
binary_tensor = fluid.layers.floor(random_tensor) binary_tensor = fluid.layers.floor(random_tensor)
output = input / keep_prob * binary_tensor output = input / keep_prob * binary_tensor
return output return output
...@@ -136,18 +141,23 @@ def psp_module(input, out_features): ...@@ -136,18 +141,23 @@ def psp_module(input, out_features):
for size in sizes: for size in sizes:
psp_name = "psp" + str(size) psp_name = "psp" + str(size)
with scope(psp_name): with scope(psp_name):
pool = fluid.layers.adaptive_pool2d(input, pool = fluid.layers.adaptive_pool2d(
input,
pool_size=[size, size], pool_size=[size, size],
pool_type='avg', pool_type='avg',
name=psp_name + '_adapool') name=psp_name + '_adapool')
data = conv(pool, out_features, data = conv(
pool,
out_features,
filter_size=1, filter_size=1,
bias_attr=False, bias_attr=False,
name=psp_name + '_conv') name=psp_name + '_conv')
data_bn = bn(data, act='relu') data_bn = bn(data, act='relu')
interp = fluid.layers.resize_bilinear(data_bn, interp = fluid.layers.resize_bilinear(
data_bn,
out_shape=input.shape[2:], out_shape=input.shape[2:],
name=psp_name + '_interp', align_mode=0) name=psp_name + '_interp',
align_mode=0)
cat_layers.append(interp) cat_layers.append(interp)
cat_layers = [input] + cat_layers cat_layers = [input] + cat_layers
out = fluid.layers.concat(cat_layers, axis=1, name='psp_cat') out = fluid.layers.concat(cat_layers, axis=1, name='psp_cat')
...@@ -158,7 +168,11 @@ def psp_module(input, out_features): ...@@ -158,7 +168,11 @@ def psp_module(input, out_features):
class FeatureFusionModule: class FeatureFusionModule:
"""Feature fusion module""" """Feature fusion module"""
def __init__(self, higher_in_channels, lower_in_channels, out_channels, scale_factor=4): def __init__(self,
higher_in_channels,
lower_in_channels,
out_channels,
scale_factor=4):
self.higher_in_channels = higher_in_channels self.higher_in_channels = higher_in_channels
self.lower_in_channels = lower_in_channels self.lower_in_channels = lower_in_channels
self.out_channels = out_channels self.out_channels = out_channels
...@@ -166,14 +180,19 @@ class FeatureFusionModule: ...@@ -166,14 +180,19 @@ class FeatureFusionModule:
def net(self, higher_res_feature, lower_res_feature): def net(self, higher_res_feature, lower_res_feature):
h, w = higher_res_feature.shape[2:] h, w = higher_res_feature.shape[2:]
lower_res_feature = fluid.layers.resize_bilinear(lower_res_feature, [h, w], align_mode=0) lower_res_feature = fluid.layers.resize_bilinear(
lower_res_feature, [h, w], align_mode=0)
with scope('dwconv'): with scope('dwconv'):
lower_res_feature = relu(bn(conv(lower_res_feature, self.out_channels, 1)))#(lower_res_feature) lower_res_feature = relu(
bn(conv(lower_res_feature, self.out_channels,
1))) #(lower_res_feature)
with scope('conv_lower_res'): with scope('conv_lower_res'):
lower_res_feature = bn(conv(lower_res_feature, self.out_channels, 1, bias_attr=True)) lower_res_feature = bn(
conv(lower_res_feature, self.out_channels, 1, bias_attr=True))
with scope('conv_higher_res'): with scope('conv_higher_res'):
higher_res_feature = bn(conv(higher_res_feature, self.out_channels, 1, bias_attr=True)) higher_res_feature = bn(
conv(higher_res_feature, self.out_channels, 1, bias_attr=True))
out = higher_res_feature + lower_res_feature out = higher_res_feature + lower_res_feature
return relu(out) return relu(out)
...@@ -182,8 +201,12 @@ class FeatureFusionModule: ...@@ -182,8 +201,12 @@ class FeatureFusionModule:
class GlobalFeatureExtractor(): class GlobalFeatureExtractor():
"""Global feature extractor module""" """Global feature extractor module"""
def __init__(self, in_channels=64, block_channels=(64, 96, 128), out_channels=128, def __init__(self,
t=6, num_blocks=(3, 3, 3)): in_channels=64,
block_channels=(64, 96, 128),
out_channels=128,
t=6,
num_blocks=(3, 3, 3)):
self.in_channels = in_channels self.in_channels = in_channels
self.block_channels = block_channels self.block_channels = block_channels
self.out_channels = out_channels self.out_channels = out_channels
...@@ -191,12 +214,15 @@ class GlobalFeatureExtractor(): ...@@ -191,12 +214,15 @@ class GlobalFeatureExtractor():
self.num_blocks = num_blocks self.num_blocks = num_blocks
def net(self, x): def net(self, x):
x, _ = inverted_blocks(x, self.in_channels, self.t, self.block_channels[0], x, _ = inverted_blocks(x, self.in_channels, self.t,
self.num_blocks[0], 2, 'inverted_block_1') self.block_channels[0], self.num_blocks[0], 2,
x, _ = inverted_blocks(x, self.block_channels[0], self.t, self.block_channels[1], 'inverted_block_1')
self.num_blocks[1], 2, 'inverted_block_2') x, _ = inverted_blocks(x, self.block_channels[0], self.t,
x, _ = inverted_blocks(x, self.block_channels[1], self.t, self.block_channels[2], self.block_channels[1], self.num_blocks[1], 2,
self.num_blocks[2], 1, 'inverted_block_3') 'inverted_block_2')
x, _ = inverted_blocks(x, self.block_channels[1], self.t,
self.block_channels[2], self.num_blocks[2], 1,
'inverted_block_3')
x = psp_module(x, self.block_channels[2] // 4) x = psp_module(x, self.block_channels[2] // 4)
with scope('out'): with scope('out'):
x = relu(bn(conv(x, self.out_channels, 1))) x = relu(bn(conv(x, self.out_channels, 1)))
...@@ -213,10 +239,21 @@ class Classifier: ...@@ -213,10 +239,21 @@ class Classifier:
def net(self, x): def net(self, x):
with scope('dsconv1'): with scope('dsconv1'):
x = separate_conv(x, self.dw_channels, stride=self.stride, filter=3, act=fluid.layers.relu) x = separate_conv(
x,
self.dw_channels,
stride=self.stride,
filter=3,
act=fluid.layers.relu)
with scope('dsconv2'): with scope('dsconv2'):
x = separate_conv(x, self.dw_channels, stride=self.stride, filter=3, act=fluid.layers.relu) x = separate_conv(
x = dropout2d(x, 0.1, is_train=cfg.PHASE=='train') x,
self.dw_channels,
stride=self.stride,
filter=3,
act=fluid.layers.relu)
x = dropout2d(x, 0.1, is_train=cfg.PHASE == 'train')
x = conv(x, self.num_classes, 1, bias_attr=True) x = conv(x, self.num_classes, 1, bias_attr=True)
return x return x
...@@ -233,7 +270,8 @@ def fast_scnn(img, num_classes): ...@@ -233,7 +270,8 @@ def fast_scnn(img, num_classes):
size = img.shape[2:] size = img.shape[2:]
classifier = Classifier(128, num_classes) classifier = Classifier(128, num_classes)
global_feature_extractor = GlobalFeatureExtractor(64, [64, 96, 128], 128, 6, [3, 3, 3]) global_feature_extractor = GlobalFeatureExtractor(64, [64, 96, 128], 128, 6,
[3, 3, 3])
feature_fusion = FeatureFusionModule(64, 128, 128) feature_fusion = FeatureFusionModule(64, 128, 128)
with scope('learning_to_downsample'): with scope('learning_to_downsample'):
...@@ -249,15 +287,18 @@ def fast_scnn(img, num_classes): ...@@ -249,15 +287,18 @@ def fast_scnn(img, num_classes):
if len(cfg.MODEL.MULTI_LOSS_WEIGHT) == 3: if len(cfg.MODEL.MULTI_LOSS_WEIGHT) == 3:
with scope('aux_layer_higher'): with scope('aux_layer_higher'):
higher_logit = aux_layer(higher_res_features, num_classes) higher_logit = aux_layer(higher_res_features, num_classes)
higher_logit = fluid.layers.resize_bilinear(higher_logit, size, align_mode=0) higher_logit = fluid.layers.resize_bilinear(
higher_logit, size, align_mode=0)
with scope('aux_layer_lower'): with scope('aux_layer_lower'):
lower_logit = aux_layer(lower_res_feature, num_classes) lower_logit = aux_layer(lower_res_feature, num_classes)
lower_logit = fluid.layers.resize_bilinear(lower_logit, size, align_mode=0) lower_logit = fluid.layers.resize_bilinear(
lower_logit, size, align_mode=0)
return logit, higher_logit, lower_logit return logit, higher_logit, lower_logit
elif len(cfg.MODEL.MULTI_LOSS_WEIGHT) == 2: elif len(cfg.MODEL.MULTI_LOSS_WEIGHT) == 2:
with scope('aux_layer_higher'): with scope('aux_layer_higher'):
higher_logit = aux_layer(higher_res_features, num_classes) higher_logit = aux_layer(higher_res_features, num_classes)
higher_logit = fluid.layers.resize_bilinear(higher_logit, size, align_mode=0) higher_logit = fluid.layers.resize_bilinear(
higher_logit, size, align_mode=0)
return logit, higher_logit return logit, higher_logit
return logit return logit
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册