提交 c5277630 编写于 作者: M mir-of

fix conv2d_layer

上级 49dc6b01
......@@ -17,7 +17,14 @@ def conv2d_layer(
weight_initializer=flow.random_uniform_initializer(),
bias_initializer=flow.constant_initializer(),
):
weight_shape = (filters, input.shape[1], kernel_size, kernel_size)
if isinstance(kernel_size, int):
kernel_size_1 = kernel_size
kernel_size_2 = kernel_size
if isinstance(kernel_size, list):
kernel_size_1 = kernel_size[0]
kernel_size_2 = kernel_size[1]
weight_shape = (filters, input.shape[1], kernel_size_1, kernel_size_2)
weight = flow.get_variable(
name + "-weight",
shape=weight_shape,
......@@ -43,3 +50,43 @@ def conv2d_layer(
raise NotImplementedError
return output
def conv2d_layer_with_bn(
name,
input,
filters,
kernel_size=3,
strides=1,
padding="SAME",
data_format="NCHW",
dilation_rate=1,
activation="Relu",
use_bias=True,
weight_initializer=flow.random_uniform_initializer(),
bias_initializer=flow.constant_initializer(),
use_bn=True,
):
output = conv2d_layer(name=name,
input=input,
filters=filters,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format=data_format,
dilation_rate=dilation_rate,
activation=activation,
use_bias=use_bias,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer)
if use_bn:
output = flow.layers.batch_normalization(inputs=output,
axis=1,
momentum=0.997,
epsilon=1.001e-5,
center=True,
scale=True,
trainable=True,
name=name + "_bn")
return output
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册