未验证 提交 8f7d0b18 编写于 作者: F fengjiayi 提交者: GitHub

add param_attr for nets (#6509)

上级 c8d4efb2
......@@ -1732,8 +1732,10 @@ def conv2d_transpose(input,
h_in = input.shape[2]
w_in = input.shape[3]
filter_size_h = output_size[0] - (h_in - 1) * stride[0] + 2 * padding[0]
filter_size_w = output_size[1] - (w_in - 1) * stride[1] + 2 * padding[1]
filter_size_h = output_size[0] - \
(h_in - 1) * stride[0] + 2 * padding[0]
filter_size_w = output_size[1] - \
(w_in - 1) * stride[1] + 2 * padding[1]
filter_size = [filter_size_h, filter_size_w]
elif isinstance(filter_size, int):
filter_size = [filter_size, filter_size]
......
......@@ -9,6 +9,7 @@ def simple_img_conv_pool(input,
pool_size,
pool_stride,
act,
param_attr=None,
pool_type='max',
main_program=None,
startup_program=None):
......@@ -16,6 +17,7 @@ def simple_img_conv_pool(input,
input=input,
num_filters=num_filters,
filter_size=filter_size,
param_attr=param_attr,
act=act,
main_program=main_program,
startup_program=startup_program)
......@@ -36,6 +38,7 @@ def img_conv_group(input,
conv_padding=1,
conv_filter_size=3,
conv_act=None,
param_attr=None,
conv_with_batchnorm=False,
conv_batchnorm_drop_rate=None,
pool_stride=1,
......@@ -57,6 +60,7 @@ def img_conv_group(input,
conv_padding = __extend_list__(conv_padding)
conv_filter_size = __extend_list__(conv_filter_size)
param_attr = __extend_list__(param_attr)
conv_with_batchnorm = __extend_list__(conv_with_batchnorm)
conv_batchnorm_drop_rate = __extend_list__(conv_batchnorm_drop_rate)
......@@ -70,6 +74,7 @@ def img_conv_group(input,
num_filters=conv_num_filter[i],
filter_size=conv_filter_size[i],
padding=conv_padding[i],
param_attr=param_attr[i],
act=local_conv_act,
main_program=main_program,
startup_program=startup_program)
......@@ -101,6 +106,7 @@ def img_conv_group(input,
def sequence_conv_pool(input,
num_filters,
filter_size,
param_attr=None,
act="sigmoid",
pool_type="max",
main_program=None,
......@@ -109,6 +115,7 @@ def sequence_conv_pool(input,
input=input,
num_filters=num_filters,
filter_size=filter_size,
param_attr=param_attr,
act=act,
main_program=main_program,
startup_program=startup_program)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册