diff --git a/cnn_e2e/model_util.py b/cnn_e2e/model_util.py index 8cc4c917d0efa351ca7e605783743f936abbcd8c..28e05ee9a4fffc03be08604b8571ffd7c903d2d0 100644 --- a/cnn_e2e/model_util.py +++ b/cnn_e2e/model_util.py @@ -17,7 +17,14 @@ def conv2d_layer( weight_initializer=flow.random_uniform_initializer(), bias_initializer=flow.constant_initializer(), ): - weight_shape = (filters, input.shape[1], kernel_size, kernel_size) + if isinstance(kernel_size, int): + kernel_size_1 = kernel_size + kernel_size_2 = kernel_size + if isinstance(kernel_size, list): + kernel_size_1 = kernel_size[0] + kernel_size_2 = kernel_size[1] + + weight_shape = (filters, input.shape[1], kernel_size_1, kernel_size_2) weight = flow.get_variable( name + "-weight", shape=weight_shape, @@ -43,3 +50,43 @@ def conv2d_layer( raise NotImplementedError return output + + +def conv2d_layer_with_bn( + name, + input, + filters, + kernel_size=3, + strides=1, + padding="SAME", + data_format="NCHW", + dilation_rate=1, + activation="Relu", + use_bias=True, + weight_initializer=flow.random_uniform_initializer(), + bias_initializer=flow.constant_initializer(), + use_bn=True, +): + output = conv2d_layer(name=name, + input=input, + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + use_bias=use_bias, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer) + + if use_bn: + output = flow.layers.batch_normalization(inputs=output, + axis=1, + momentum=0.997, + epsilon=1.001e-5, + center=True, + scale=True, + trainable=True, + name=name + "_bn") + return output