From 5aa473f4676dc7322dc6f29b18aa2ae8d0b120f4 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Tue, 19 Jun 2018 10:25:51 +0800 Subject: [PATCH] add Doc fluid.net --- python/paddle/fluid/nets.py | 227 +++++++++++++++++++++++++++++++----- 1 file changed, 198 insertions(+), 29 deletions(-) diff --git a/python/paddle/fluid/nets.py b/python/paddle/fluid/nets.py index bbedf6fde08..9b3f2aebee7 100644 --- a/python/paddle/fluid/nets.py +++ b/python/paddle/fluid/nets.py @@ -26,16 +26,87 @@ def simple_img_conv_pool(input, filter_size, pool_size, pool_stride, - act, - param_attr=None, + pool_padding=0, pool_type='max', + global_pooling=False, + conv_stride=1, + conv_padding=0, + conv_dilation=1, + conv_groups=1, + param_attr=None, + bias_attr=None, + act=None, use_cudnn=True, use_mkldnn=False): + """ + The simple_img_conv_pool is composed with one Convolution2d and one Pool2d. + + Args: + input (Variable): The input image with [N, C, H, W] format. + num_filters(int): The number of filter. It is as same as the output + feature channel. + filter_size (int|list|tuple): The filter size. If filter_size is a list or + tuple, it must contain two integers, (filter_size_H, filter_size_W). Otherwise, + the filter_size_H = filter_size_W = filter_size. + pool_size (int|list|tuple): The pooling size of Pool2d layer. If pool_size + is a list or tuple, it must contain two integers, (pool_size_H, pool_size_W). + Otherwise, the pool_size_H = pool_size_W = pool_size. + pool_stride (int|list|tuple): The pooling stride of Pool2d layer. If pool_stride + is a list or tuple, it must contain two integers, (pooling_stride_H, pooling_stride_W). + Otherwise, the pooling_stride_H = pooling_stride_W = pool_stride. + pool_padding (int|list|tuple): The padding of Pool2d layer. If pool_padding is a list or + tuple, it must contain two integers, (pool_padding_H, pool_padding_W). + Otherwise, the pool_padding_H = pool_padding_W = pool_padding. Default 0. + pool_type (str): Pooling type can be :math:`max` for max-pooling and :math:`avg` for + average-pooling. Default :math:`max`. + global_pooling (bool): Whether to use the global pooling. If global_pooling = true, + pool_size and pool_padding while be ignored. Default False + conv_stride (int|list|tuple): The stride size of the Conv2d Layer. If stride is a + list or tuple, it must contain two integers, (conv_stride_H, conv_stride_W). Otherwise, + the conv_stride_H = conv_stride_W = conv_stride. Default: conv_stride = 1. + conv_padding (int|list|tuple): The padding size of the Conv2d Layer. If padding is + a list or tuple, it must contain two integers, (conv_padding_H, conv_padding_W). + Otherwise, the conv_padding_H = conv_padding_W = conv_padding. Default: conv_padding = 0. + conv_dilation (int|list|tuple): The dilation size of the Conv2d Layer. If dilation is + a list or tuple, it must contain two integers, (conv_dilation_H, conv_dilation_W). + Otherwise, the conv_dilation_H = conv_dilation_W = conv_dilation. Default: conv_dilation = 1. + conv_groups (int): The groups number of the Conv2d Layer. According to grouped + convolution in Alex Krizhevsky's Deep CNN paper: when group=2, + the first half of the filters is only connected to the first half + of the input channels, while the second half of the filters is only + connected to the second half of the input channels. Default: groups=1 + param_attr (ParamAttr): The parameters to the Conv2d Layer. Default: None + bias_attr (ParamAttr): Bias parameter for the Conv2d layer. Default: None + act (str): Activation type for Conv2d. Default: None + use_cudnn (bool): Use cudnn kernel or not, it is valid only when the cudnn + library is installed. Default: True + use_mkldnn (bool): Use mkldnn kernels or not, it is valid only when compiled + with mkldnn library. Default: False + + Return: + Variable: The result of input after Convolution2d and Pool2d. + + Examples: + .. code-block:: python + + img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32') + conv_pool = fluid.nets.simple_img_conv_pool(input=img, + filter_size=5, + num_filters=20, + pool_size=2, + pool_stride=2, + act="relu") + """ conv_out = layers.conv2d( input=input, num_filters=num_filters, filter_size=filter_size, + stride=conv_stride, + padding=conv_padding, + dilation=conv_dilation, + groups=conv_groups, param_attr=param_attr, + bias_attr=bias_attr, act=act, use_cudnn=use_cudnn, use_mkldnn=use_mkldnn) @@ -45,6 +116,8 @@ def simple_img_conv_pool(input, pool_size=pool_size, pool_type=pool_type, pool_stride=pool_stride, + pool_padding=pool_padding, + global_pooling=global_pooling, use_cudnn=use_cudnn, use_mkldnn=use_mkldnn) return pool_out @@ -60,11 +133,65 @@ def img_conv_group(input, conv_with_batchnorm=False, conv_batchnorm_drop_rate=0.0, pool_stride=1, - pool_type=None, + pool_type="max", use_cudnn=True, use_mkldnn=False): """ - Image Convolution Group, Used for vgg net. + The Image Convolution Group is composed of Convolution2d, BatchNorm, DropOut, + and Pool2d. According to the input arguments, img_conv_group will do serials of + computation for Input using Convolution2d, BatchNorm, DropOut, and pass the last + result to Pool2d. + + Args: + input (Variable): The input image with [N, C, H, W] format. + conv_num_filter(list|tuple): Indicates the numbers of filter of this group. + pool_size (int|list|tuple): The pooling size of Pool2d Layer. If pool_size + is a list or tuple, it must contain two integers, (pool_size_H, pool_size_W). + Otherwise, the pool_size_H = pool_size_W = pool_size. + conv_padding (int|list|tuple): The padding size of the Conv2d Layer. If padding is + a list or tuple, its length must be equal to the length of conv_num_filter. + Otherwise the conv_padding of all Conv2d Layers are the same. Default 1. + conv_filter_size (int|list|tuple): The filter size. If filter_size is a list or + tuple, its length must be equal to the length of conv_num_filter. + Otherwise the conv_filter_size of all Conv2d Layers are the same. Default 3. + conv_act (str): Activation type for Conv2d Layer that is not followed by BatchNorm. + Default: None. + param_attr (ParamAttr): The parameters to the Conv2d Layer. Default: None + conv_with_batchnorm (bool|list): Indicates whether to use BatchNorm after Conv2d Layer. + If conv_with_batchnorm is a list, its length must be equal to the length of + conv_num_filter. Otherwise, conv_with_batchnorm indicates whether all the + Conv2d Layer follows a BatchNorm. Default False. + conv_batchnorm_drop_rate (float|list): Indicates the drop_rate of Dropout Layer + after BatchNorm. If conv_batchnorm_drop_rate is a list, its length must be + equal to the length of conv_num_filter. Otherwise, drop_rate of all Dropout + Layers is conv_batchnorm_drop_rate. Default 0.0. + pool_stride (int|list|tuple): The pooling stride of Pool2d layer. If pool_stride + is a list or tuple, it must contain two integers, (pooling_stride_H, + pooling_stride_W). Otherwise, the pooling_stride_H = pooling_stride_W = pool_stride. + Default 1. + pool_type (str): Pooling type can be :math:`max` for max-pooling and :math:`avg` for + average-pooling. Default :math:`max`. + use_cudnn (bool): Use cudnn kernel or not, it is valid only when the cudnn + library is installed. Default: True + use_mkldnn (bool): Use mkldnn kernels or not, it is valid only when compiled + with mkldnn library. Default: False + + Return: + Variable: The final result after serial computation using Convolution2d, + BatchNorm, DropOut, and Pool2d. + + Examples: + .. code-block:: python + + img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32') + conv_pool = fluid.nets.img_conv_group(input=img, + num_channels=3, + conv_padding=1, + conv_num_filter=[3, 3], + conv_filter_size=3, + conv_act="relu", + pool_size=2, + pool_stride=2) """ tmp = input assert isinstance(conv_num_filter, list) or \ @@ -74,6 +201,7 @@ def img_conv_group(input, if not hasattr(obj, '__len__'): return [obj] * len(conv_num_filter) else: + assert len(obj) == len(conv_num_filter) return obj conv_padding = __extend_list__(conv_padding) @@ -119,6 +247,39 @@ def sequence_conv_pool(input, param_attr=None, act="sigmoid", pool_type="max"): + """ + The sequence_conv_pool is composed with Sequence Convolution and Pooling. + + Args: + input (Variable): The input of sequence_conv, which supports variable-time + length input sequence. The underlying of input is a matrix with shape + (T, N), where T is the total time steps in this mini-batch and N is + the input_hidden_size + num_filters(int): The number of filter. + filter_size (int): The filter size. + param_attr (ParamAttr): The parameters to the Sequence_conv Layer. Default: None. + act (str): Activation type for Sequence_conv Layer. Default: "sigmoid". + pool_type (str): Pooling type can be :math:`max` for max-pooling, :math:`average` for + average-pooling, :math:`sum` for sum-pooling, :math:`sqrt` for sqrt-pooling. + Default :math:`max`. + + Return: + Variable: The final result after Sequence Convolution and Pooling. + + Examples: + .. code-block:: python + + input_dim = len(word_dict) + emb_dim = 128 + hid_dim = 512 + data = fluid.layers.data( ame="words", shape=[1], dtype="int64", lod_level=1) + emb = fluid.layers.embedding(input=data, size=[input_dim, emb_dim], is_sparse=True) + seq_conv = fluid.nets.sequence_conv_pool(input=emb, + num_filters=hid_dim, + filter_size=3, + act="tanh", + pool_type="sqrt") + """ conv_out = layers.sequence_conv( input=input, num_filters=num_filters, @@ -132,9 +293,9 @@ def sequence_conv_pool(input, def glu(input, dim=-1): """ - The gated linear unit composed by split, sigmoid activation and elementwise - multiplication. Specifically, Split the input into two equal sized parts - :math:`a` and :math:`b` along the given dimension and then compute as + The Gated Linear Units(GLU) composed by split, sigmoid activation and element-wise + multiplication. Specifically, Split the input into two equal sized parts, + :math:`a` and :math:`b`, along the given dimension and then compute as following: .. math:: @@ -147,16 +308,16 @@ def glu(input, dim=-1): Args: input (Variable): The input variable which is a Tensor or LoDTensor. dim (int): The dimension along which to split. If :math:`dim < 0`, the - dimension to split along is :math:`rank(input) + dim`. + dimension to split along is :math:`rank(input) + dim`. Default -1. Returns: - Variable: The Tensor variable with half the size of input. + Variable: Variable with half the size of input. Examples: .. code-block:: python - # x is a Tensor variable with shape [3, 6, 9] - fluid.nets.glu(input=x, dim=1) # shape of output: [3, 3, 9] + data = fluid.layers.data(name="words", shape=[3, 6, 9], dtype="float32") + output = fluid.nets.glu(input=data, dim=1) # shape of output: [3, 3, 9] """ a, b = layers.split(input, num_or_sections=2, dim=dim) @@ -189,40 +350,48 @@ def scaled_dot_product_attention(queries, `_. Args: - queries (Variable): The input variable which should be a 3-D Tensor. keys (Variable): The input variable which should be a 3-D Tensor. values (Variable): The input variable which should be a 3-D Tensor. num_heads (int): Head number to compute the scaled dot product - attention. Default value is 1. + attention. Default: 1. dropout_rate (float): The dropout rate to drop the attention weight. - Default value is 0. + Default: 0.0. Returns: - - Variable: A 3-D Tensor computed by multi-head scaled dot product \ - attention. + Variable: A 3-D Tensor computed by multi-head scaled dot product\ + attention. Raises: - ValueError: If input queries, keys, values are not 3-D Tensors. - NOTE: + NOTES: 1. When num_heads > 1, three linear projections are learned respectively - to map input queries, keys and values into queries', keys' and values'. - queries', keys' and values' have the same shapes with queries, keys - and values. - - 1. When num_heads == 1, scaled_dot_product_attention has no learnable - parameters. + to map input queries, keys and values into queries', keys' and values'. + queries', keys' and values' have the same shapes with queries, keys + and values. + 2. When num_heads == 1, scaled_dot_product_attention has no learnable + parameters. Examples: .. code-block:: python - # Suppose q, k, v are Tensors with the following shape: - # q: [3, 5, 9], k: [3, 6, 9], v: [3, 6, 10] - - contexts = fluid.nets.scaled_dot_product_attention(q, k, v) + queries = fluid.layers.data(name="queries", + shape=[3, 5, 9], + dtype="float32", + append_batch_size=False) + queries.stop_gradient = False + keys = fluid.layers.data(name="keys", + shape=[3, 6, 9], + dtype="float32", + append_batch_size=False) + keys.stop_gradient = False + values = fluid.layers.data(name="values", + shape=[3, 6, 10], + dtype="float32", + append_batch_size=False) + values.stop_gradient = False + contexts = fluid.nets.scaled_dot_product_attention(queries, keys, values) contexts.shape # [3, 5, 10] """ if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3): -- GitLab