未验证 提交 8f7d0b18 编写于 作者: F fengjiayi 提交者: GitHub

add param_attr for nets (#6509)

上级 c8d4efb2
...@@ -1732,8 +1732,10 @@ def conv2d_transpose(input, ...@@ -1732,8 +1732,10 @@ def conv2d_transpose(input,
h_in = input.shape[2] h_in = input.shape[2]
w_in = input.shape[3] w_in = input.shape[3]
filter_size_h = output_size[0] - (h_in - 1) * stride[0] + 2 * padding[0] filter_size_h = output_size[0] - \
filter_size_w = output_size[1] - (w_in - 1) * stride[1] + 2 * padding[1] (h_in - 1) * stride[0] + 2 * padding[0]
filter_size_w = output_size[1] - \
(w_in - 1) * stride[1] + 2 * padding[1]
filter_size = [filter_size_h, filter_size_w] filter_size = [filter_size_h, filter_size_w]
elif isinstance(filter_size, int): elif isinstance(filter_size, int):
filter_size = [filter_size, filter_size] filter_size = [filter_size, filter_size]
......
...@@ -9,6 +9,7 @@ def simple_img_conv_pool(input, ...@@ -9,6 +9,7 @@ def simple_img_conv_pool(input,
pool_size, pool_size,
pool_stride, pool_stride,
act, act,
param_attr=None,
pool_type='max', pool_type='max',
main_program=None, main_program=None,
startup_program=None): startup_program=None):
...@@ -16,6 +17,7 @@ def simple_img_conv_pool(input, ...@@ -16,6 +17,7 @@ def simple_img_conv_pool(input,
input=input, input=input,
num_filters=num_filters, num_filters=num_filters,
filter_size=filter_size, filter_size=filter_size,
param_attr=param_attr,
act=act, act=act,
main_program=main_program, main_program=main_program,
startup_program=startup_program) startup_program=startup_program)
...@@ -36,6 +38,7 @@ def img_conv_group(input, ...@@ -36,6 +38,7 @@ def img_conv_group(input,
conv_padding=1, conv_padding=1,
conv_filter_size=3, conv_filter_size=3,
conv_act=None, conv_act=None,
param_attr=None,
conv_with_batchnorm=False, conv_with_batchnorm=False,
conv_batchnorm_drop_rate=None, conv_batchnorm_drop_rate=None,
pool_stride=1, pool_stride=1,
...@@ -57,6 +60,7 @@ def img_conv_group(input, ...@@ -57,6 +60,7 @@ def img_conv_group(input,
conv_padding = __extend_list__(conv_padding) conv_padding = __extend_list__(conv_padding)
conv_filter_size = __extend_list__(conv_filter_size) conv_filter_size = __extend_list__(conv_filter_size)
param_attr = __extend_list__(param_attr)
conv_with_batchnorm = __extend_list__(conv_with_batchnorm) conv_with_batchnorm = __extend_list__(conv_with_batchnorm)
conv_batchnorm_drop_rate = __extend_list__(conv_batchnorm_drop_rate) conv_batchnorm_drop_rate = __extend_list__(conv_batchnorm_drop_rate)
...@@ -70,6 +74,7 @@ def img_conv_group(input, ...@@ -70,6 +74,7 @@ def img_conv_group(input,
num_filters=conv_num_filter[i], num_filters=conv_num_filter[i],
filter_size=conv_filter_size[i], filter_size=conv_filter_size[i],
padding=conv_padding[i], padding=conv_padding[i],
param_attr=param_attr[i],
act=local_conv_act, act=local_conv_act,
main_program=main_program, main_program=main_program,
startup_program=startup_program) startup_program=startup_program)
...@@ -101,6 +106,7 @@ def img_conv_group(input, ...@@ -101,6 +106,7 @@ def img_conv_group(input,
def sequence_conv_pool(input, def sequence_conv_pool(input,
num_filters, num_filters,
filter_size, filter_size,
param_attr=None,
act="sigmoid", act="sigmoid",
pool_type="max", pool_type="max",
main_program=None, main_program=None,
...@@ -109,6 +115,7 @@ def sequence_conv_pool(input, ...@@ -109,6 +115,7 @@ def sequence_conv_pool(input,
input=input, input=input,
num_filters=num_filters, num_filters=num_filters,
filter_size=filter_size, filter_size=filter_size,
param_attr=param_attr,
act=act, act=act,
main_program=main_program, main_program=main_program,
startup_program=startup_program) startup_program=startup_program)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册