From 8f7d0b18145c8310fb36ad017aa0f1d4bce98c65 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Tue, 12 Dec 2017 15:54:38 +0800 Subject: [PATCH] add param_attr for nets (#6509) --- python/paddle/v2/fluid/layers.py | 6 ++++-- python/paddle/v2/fluid/nets.py | 7 +++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/python/paddle/v2/fluid/layers.py b/python/paddle/v2/fluid/layers.py index fd8a2ed18c9..1f45487902a 100644 --- a/python/paddle/v2/fluid/layers.py +++ b/python/paddle/v2/fluid/layers.py @@ -1732,8 +1732,10 @@ def conv2d_transpose(input, h_in = input.shape[2] w_in = input.shape[3] - filter_size_h = output_size[0] - (h_in - 1) * stride[0] + 2 * padding[0] - filter_size_w = output_size[1] - (w_in - 1) * stride[1] + 2 * padding[1] + filter_size_h = output_size[0] - \ + (h_in - 1) * stride[0] + 2 * padding[0] + filter_size_w = output_size[1] - \ + (w_in - 1) * stride[1] + 2 * padding[1] filter_size = [filter_size_h, filter_size_w] elif isinstance(filter_size, int): filter_size = [filter_size, filter_size] diff --git a/python/paddle/v2/fluid/nets.py b/python/paddle/v2/fluid/nets.py index 05728ad75a5..7ef524318e6 100644 --- a/python/paddle/v2/fluid/nets.py +++ b/python/paddle/v2/fluid/nets.py @@ -9,6 +9,7 @@ def simple_img_conv_pool(input, pool_size, pool_stride, act, + param_attr=None, pool_type='max', main_program=None, startup_program=None): @@ -16,6 +17,7 @@ def simple_img_conv_pool(input, input=input, num_filters=num_filters, filter_size=filter_size, + param_attr=param_attr, act=act, main_program=main_program, startup_program=startup_program) @@ -36,6 +38,7 @@ def img_conv_group(input, conv_padding=1, conv_filter_size=3, conv_act=None, + param_attr=None, conv_with_batchnorm=False, conv_batchnorm_drop_rate=None, pool_stride=1, @@ -57,6 +60,7 @@ def img_conv_group(input, conv_padding = __extend_list__(conv_padding) conv_filter_size = __extend_list__(conv_filter_size) + param_attr = __extend_list__(param_attr) conv_with_batchnorm = __extend_list__(conv_with_batchnorm) conv_batchnorm_drop_rate = __extend_list__(conv_batchnorm_drop_rate) @@ -70,6 +74,7 @@ def img_conv_group(input, num_filters=conv_num_filter[i], filter_size=conv_filter_size[i], padding=conv_padding[i], + param_attr=param_attr[i], act=local_conv_act, main_program=main_program, startup_program=startup_program) @@ -101,6 +106,7 @@ def img_conv_group(input, def sequence_conv_pool(input, num_filters, filter_size, + param_attr=None, act="sigmoid", pool_type="max", main_program=None, @@ -109,6 +115,7 @@ def sequence_conv_pool(input, input=input, num_filters=num_filters, filter_size=filter_size, + param_attr=param_attr, act=act, main_program=main_program, startup_program=startup_program) -- GitLab