From d714851927ee2b831606624febb8374df181c980 Mon Sep 17 00:00:00 2001 From: wuyefeilin <30919197+wuyefeilin@users.noreply.github.com> Date: Thu, 21 May 2020 17:04:22 +0800 Subject: [PATCH] update *_batch_size_like ops (#258) --- contrib/LaneNet/loss.py | 59 ++++++++------- pdseg/models/model_builder.py | 1 + pdseg/models/modeling/fast_scnn.py | 113 ++++++++++++++++++++--------- 3 files changed, 107 insertions(+), 66 deletions(-) diff --git a/contrib/LaneNet/loss.py b/contrib/LaneNet/loss.py index e8883745..e7aa1a4d 100644 --- a/contrib/LaneNet/loss.py +++ b/contrib/LaneNet/loss.py @@ -19,8 +19,9 @@ from utils.config import cfg def unsorted_segment_sum(data, segment_ids, unique_labels, feature_dims): - zeros = fluid.layers.fill_constant_batch_size_like(unique_labels, shape=[1, feature_dims], - dtype='float32', value=0) + unique_labels_shape = fluid.layers.shape(unique_labels) + zeros = fluid.layers.fill_constant( + shape=[unique_labels_shape[0], feature_dims], dtype='float32', value=0) segment_ids = fluid.layers.unsqueeze(segment_ids, axes=[1]) segment_ids.stop_gradient = True segment_sum = fluid.layers.scatter_nd_add(zeros, segment_ids, data) @@ -30,29 +31,23 @@ def unsorted_segment_sum(data, segment_ids, unique_labels, feature_dims): def norm(x, axis=-1): - distance = fluid.layers.reduce_sum(fluid.layers.abs(x), dim=axis, keep_dim=True) + distance = fluid.layers.reduce_sum( + fluid.layers.abs(x), dim=axis, keep_dim=True) return distance -def discriminative_loss_single( - prediction, - correct_label, - feature_dim, - label_shape, - delta_v, - delta_d, - param_var, - param_dist, - param_reg): - - correct_label = fluid.layers.reshape( - correct_label, [ - label_shape[1] * label_shape[0]]) + +def discriminative_loss_single(prediction, correct_label, feature_dim, + label_shape, delta_v, delta_d, param_var, + param_dist, param_reg): + + correct_label = fluid.layers.reshape(correct_label, + [label_shape[1] * label_shape[0]]) prediction = fluid.layers.transpose(prediction, [1, 2, 0]) reshaped_pred = fluid.layers.reshape( - prediction, [ - label_shape[1] * label_shape[0], feature_dim]) + prediction, [label_shape[1] * label_shape[0], feature_dim]) - unique_labels, unique_id, counts = fluid.layers.unique_with_counts(correct_label) + unique_labels, unique_id, counts = fluid.layers.unique_with_counts( + correct_label) correct_label.stop_gradient = True counts = fluid.layers.cast(counts, 'float32') num_instances = fluid.layers.shape(unique_labels) @@ -69,24 +64,29 @@ def discriminative_loss_single( distance = norm(tmp) distance = distance - delta_v - distance_pos = fluid.layers.greater_equal(distance, fluid.layers.zeros_like(distance)) + distance_pos = fluid.layers.greater_equal(distance, + fluid.layers.zeros_like(distance)) distance_pos = fluid.layers.cast(distance_pos, 'float32') distance = distance * distance_pos distance = fluid.layers.square(distance) - l_var = unsorted_segment_sum(distance, unique_id, unique_labels, feature_dims=1) + l_var = unsorted_segment_sum( + distance, unique_id, unique_labels, feature_dims=1) l_var = fluid.layers.elementwise_div(l_var, counts_rsp) l_var = fluid.layers.reduce_sum(l_var) - l_var = l_var / fluid.layers.cast(num_instances * (num_instances - 1), 'float32') + l_var = l_var / fluid.layers.cast(num_instances * (num_instances - 1), + 'float32') mu_interleaved_rep = fluid.layers.expand(mu, [num_instances, 1]) mu_band_rep = fluid.layers.expand(mu, [1, num_instances]) - mu_band_rep = fluid.layers.reshape(mu_band_rep, (num_instances * num_instances, feature_dim)) + mu_band_rep = fluid.layers.reshape( + mu_band_rep, (num_instances * num_instances, feature_dim)) mu_diff = fluid.layers.elementwise_sub(mu_band_rep, mu_interleaved_rep) - intermediate_tensor = fluid.layers.reduce_sum(fluid.layers.abs(mu_diff), dim=1) + intermediate_tensor = fluid.layers.reduce_sum( + fluid.layers.abs(mu_diff), dim=1) intermediate_tensor.stop_gradient = True zero_vector = fluid.layers.zeros([1], 'float32') bool_mask = fluid.layers.not_equal(intermediate_tensor, zero_vector) @@ -95,7 +95,8 @@ def discriminative_loss_single( mu_norm = norm(mu_diff_bool) mu_norm = 2. * delta_d - mu_norm - mu_norm_pos = fluid.layers.greater_equal(mu_norm, fluid.layers.zeros_like(mu_norm)) + mu_norm_pos = fluid.layers.greater_equal(mu_norm, + fluid.layers.zeros_like(mu_norm)) mu_norm_pos = fluid.layers.cast(mu_norm_pos, 'float32') mu_norm = mu_norm * mu_norm_pos mu_norm_pos.stop_gradient = True @@ -122,8 +123,8 @@ def discriminative_loss(prediction, correct_label, feature_dim, image_shape, output_ta_reg = 0. for i in range(batch_size): disc_loss_single, l_var_single, l_dist_single, l_reg_single = discriminative_loss_single( - prediction[i], correct_label[i], feature_dim, image_shape, delta_v, delta_d, param_var, param_dist, - param_reg) + prediction[i], correct_label[i], feature_dim, image_shape, delta_v, + delta_d, param_var, param_dist, param_reg) output_ta_loss += disc_loss_single output_ta_var += l_var_single output_ta_dist += l_dist_single @@ -134,5 +135,3 @@ def discriminative_loss(prediction, correct_label, feature_dim, image_shape, l_dist = output_ta_dist / batch_size l_reg = output_ta_reg / batch_size return disc_loss, l_var, l_dist, l_reg - - diff --git a/pdseg/models/model_builder.py b/pdseg/models/model_builder.py index 86460224..77a444c7 100644 --- a/pdseg/models/model_builder.py +++ b/pdseg/models/model_builder.py @@ -223,6 +223,7 @@ def build_model(main_prog, start_prog, phase=ModelPhase.TRAIN): raise Exception( "softmax loss or lovasz softmax loss can not combine with bce loss or dice loss or lovasz hinge loss." ) + cfg.PHASE = phase logits = seg_model(image, class_num) # 根据选择的loss函数计算相应的损失函数 diff --git a/pdseg/models/modeling/fast_scnn.py b/pdseg/models/modeling/fast_scnn.py index b1ecdffe..3d8301cd 100644 --- a/pdseg/models/modeling/fast_scnn.py +++ b/pdseg/models/modeling/fast_scnn.py @@ -25,12 +25,15 @@ from models.libs.model_libs import separate_conv from utils.config import cfg -def learning_to_downsample(x, dw_channels1=32, dw_channels2=48, out_channels=64): +def learning_to_downsample(x, dw_channels1=32, dw_channels2=48, + out_channels=64): x = relu(bn(conv(x, dw_channels1, 3, 2))) with scope('dsconv1'): - x = separate_conv(x, dw_channels2, stride=2, filter=3, act=fluid.layers.relu) + x = separate_conv( + x, dw_channels2, stride=2, filter=3, act=fluid.layers.relu) with scope('dsconv2'): - x = separate_conv(x, out_channels, stride=2, filter=3, act=fluid.layers.relu) + x = separate_conv( + x, out_channels, stride=2, filter=3, act=fluid.layers.relu) return x @@ -43,7 +46,9 @@ def dropout2d(input, prob, is_train=False): return input channels = input.shape[1] keep_prob = 1.0 - prob - random_tensor = keep_prob + fluid.layers.uniform_random_batch_size_like(input, [-1, channels, 1, 1], min=0., max=1.) + shape = fluid.layers.shape(input) + random_tensor = keep_prob + fluid.layers.uniform_random( + [shape[0], channels, 1, 1], min=0., max=1.) binary_tensor = fluid.layers.floor(random_tensor) output = input / keep_prob * binary_tensor return output @@ -136,18 +141,23 @@ def psp_module(input, out_features): for size in sizes: psp_name = "psp" + str(size) with scope(psp_name): - pool = fluid.layers.adaptive_pool2d(input, - pool_size=[size, size], - pool_type='avg', - name=psp_name + '_adapool') - data = conv(pool, out_features, - filter_size=1, - bias_attr=False, - name=psp_name + '_conv') + pool = fluid.layers.adaptive_pool2d( + input, + pool_size=[size, size], + pool_type='avg', + name=psp_name + '_adapool') + data = conv( + pool, + out_features, + filter_size=1, + bias_attr=False, + name=psp_name + '_conv') data_bn = bn(data, act='relu') - interp = fluid.layers.resize_bilinear(data_bn, - out_shape=input.shape[2:], - name=psp_name + '_interp', align_mode=0) + interp = fluid.layers.resize_bilinear( + data_bn, + out_shape=input.shape[2:], + name=psp_name + '_interp', + align_mode=0) cat_layers.append(interp) cat_layers = [input] + cat_layers out = fluid.layers.concat(cat_layers, axis=1, name='psp_cat') @@ -158,7 +168,11 @@ def psp_module(input, out_features): class FeatureFusionModule: """Feature fusion module""" - def __init__(self, higher_in_channels, lower_in_channels, out_channels, scale_factor=4): + def __init__(self, + higher_in_channels, + lower_in_channels, + out_channels, + scale_factor=4): self.higher_in_channels = higher_in_channels self.lower_in_channels = lower_in_channels self.out_channels = out_channels @@ -166,14 +180,19 @@ class FeatureFusionModule: def net(self, higher_res_feature, lower_res_feature): h, w = higher_res_feature.shape[2:] - lower_res_feature = fluid.layers.resize_bilinear(lower_res_feature, [h, w], align_mode=0) + lower_res_feature = fluid.layers.resize_bilinear( + lower_res_feature, [h, w], align_mode=0) with scope('dwconv'): - lower_res_feature = relu(bn(conv(lower_res_feature, self.out_channels, 1)))#(lower_res_feature) + lower_res_feature = relu( + bn(conv(lower_res_feature, self.out_channels, + 1))) #(lower_res_feature) with scope('conv_lower_res'): - lower_res_feature = bn(conv(lower_res_feature, self.out_channels, 1, bias_attr=True)) + lower_res_feature = bn( + conv(lower_res_feature, self.out_channels, 1, bias_attr=True)) with scope('conv_higher_res'): - higher_res_feature = bn(conv(higher_res_feature, self.out_channels, 1, bias_attr=True)) + higher_res_feature = bn( + conv(higher_res_feature, self.out_channels, 1, bias_attr=True)) out = higher_res_feature + lower_res_feature return relu(out) @@ -182,8 +201,12 @@ class FeatureFusionModule: class GlobalFeatureExtractor(): """Global feature extractor module""" - def __init__(self, in_channels=64, block_channels=(64, 96, 128), out_channels=128, - t=6, num_blocks=(3, 3, 3)): + def __init__(self, + in_channels=64, + block_channels=(64, 96, 128), + out_channels=128, + t=6, + num_blocks=(3, 3, 3)): self.in_channels = in_channels self.block_channels = block_channels self.out_channels = out_channels @@ -191,12 +214,15 @@ class GlobalFeatureExtractor(): self.num_blocks = num_blocks def net(self, x): - x, _ = inverted_blocks(x, self.in_channels, self.t, self.block_channels[0], - self.num_blocks[0], 2, 'inverted_block_1') - x, _ = inverted_blocks(x, self.block_channels[0], self.t, self.block_channels[1], - self.num_blocks[1], 2, 'inverted_block_2') - x, _ = inverted_blocks(x, self.block_channels[1], self.t, self.block_channels[2], - self.num_blocks[2], 1, 'inverted_block_3') + x, _ = inverted_blocks(x, self.in_channels, self.t, + self.block_channels[0], self.num_blocks[0], 2, + 'inverted_block_1') + x, _ = inverted_blocks(x, self.block_channels[0], self.t, + self.block_channels[1], self.num_blocks[1], 2, + 'inverted_block_2') + x, _ = inverted_blocks(x, self.block_channels[1], self.t, + self.block_channels[2], self.num_blocks[2], 1, + 'inverted_block_3') x = psp_module(x, self.block_channels[2] // 4) with scope('out'): x = relu(bn(conv(x, self.out_channels, 1))) @@ -213,10 +239,21 @@ class Classifier: def net(self, x): with scope('dsconv1'): - x = separate_conv(x, self.dw_channels, stride=self.stride, filter=3, act=fluid.layers.relu) + x = separate_conv( + x, + self.dw_channels, + stride=self.stride, + filter=3, + act=fluid.layers.relu) with scope('dsconv2'): - x = separate_conv(x, self.dw_channels, stride=self.stride, filter=3, act=fluid.layers.relu) - x = dropout2d(x, 0.1, is_train=cfg.PHASE=='train') + x = separate_conv( + x, + self.dw_channels, + stride=self.stride, + filter=3, + act=fluid.layers.relu) + + x = dropout2d(x, 0.1, is_train=cfg.PHASE == 'train') x = conv(x, self.num_classes, 1, bias_attr=True) return x @@ -233,7 +270,8 @@ def fast_scnn(img, num_classes): size = img.shape[2:] classifier = Classifier(128, num_classes) - global_feature_extractor = GlobalFeatureExtractor(64, [64, 96, 128], 128, 6, [3, 3, 3]) + global_feature_extractor = GlobalFeatureExtractor(64, [64, 96, 128], 128, 6, + [3, 3, 3]) feature_fusion = FeatureFusionModule(64, 128, 128) with scope('learning_to_downsample'): @@ -249,15 +287,18 @@ def fast_scnn(img, num_classes): if len(cfg.MODEL.MULTI_LOSS_WEIGHT) == 3: with scope('aux_layer_higher'): higher_logit = aux_layer(higher_res_features, num_classes) - higher_logit = fluid.layers.resize_bilinear(higher_logit, size, align_mode=0) + higher_logit = fluid.layers.resize_bilinear( + higher_logit, size, align_mode=0) with scope('aux_layer_lower'): lower_logit = aux_layer(lower_res_feature, num_classes) - lower_logit = fluid.layers.resize_bilinear(lower_logit, size, align_mode=0) + lower_logit = fluid.layers.resize_bilinear( + lower_logit, size, align_mode=0) return logit, higher_logit, lower_logit elif len(cfg.MODEL.MULTI_LOSS_WEIGHT) == 2: with scope('aux_layer_higher'): higher_logit = aux_layer(higher_res_features, num_classes) - higher_logit = fluid.layers.resize_bilinear(higher_logit, size, align_mode=0) + higher_logit = fluid.layers.resize_bilinear( + higher_logit, size, align_mode=0) return logit, higher_logit - return logit \ No newline at end of file + return logit -- GitLab