diff --git a/paddle/operators/accuracy_op.cu b/paddle/operators/accuracy_op.cu index 539a93530206c93a37791a9ccb2fb104af17f940..dd51aad105fecf4e3118f03e2f1868abb5523bc8 100644 --- a/paddle/operators/accuracy_op.cu +++ b/paddle/operators/accuracy_op.cu @@ -26,7 +26,7 @@ template __global__ void AccuracyCudaKernel(const int N, const int D, const int64_t* Xdata, const int64_t* labeldata, int* correct_data, - float* accuracy) { + float* accuracy, int* total_data) { int count = 0; __shared__ int total[BlockSize]; @@ -47,6 +47,7 @@ __global__ void AccuracyCudaKernel(const int N, const int D, if (threadIdx.x == 0) { *correct_data = result; *accuracy = static_cast(result) / static_cast(N); + *total_data = N; } } @@ -80,22 +81,11 @@ class AccuracyOpCUDAKernel : public framework::OpKernel { if (num_samples == 0) { return; } - platform::GpuMemcpyAsync(total_data, &num_samples, sizeof(int), - cudaMemcpyHostToDevice, stream); AccuracyCudaKernel< PADDLE_CUDA_NUM_THREADS><<<1, PADDLE_CUDA_NUM_THREADS, 0, stream>>>( num_samples, infer_width, indices_data, label_data, correct_data, - accuracy_data); - - int d_num_samples, d_num_correct; - float d_accuracy; - platform::GpuMemcpyAsync(&d_num_correct, correct_data, sizeof(int), - cudaMemcpyDeviceToHost, stream); - platform::GpuMemcpyAsync(&d_num_samples, total_data, sizeof(int), - cudaMemcpyDeviceToHost, stream); - platform::GpuMemcpyAsync(&d_accuracy, accuracy_data, sizeof(float), - cudaMemcpyDeviceToHost, stream); + accuracy_data, total_data); } }; diff --git a/paddle/operators/batch_norm_op.md b/paddle/operators/op_documentation/batch_norm_op.md similarity index 100% rename from paddle/operators/batch_norm_op.md rename to paddle/operators/op_documentation/batch_norm_op.md diff --git a/paddle/operators/name_convention.md b/paddle/operators/op_documentation/name_convention.md similarity index 100% rename from paddle/operators/name_convention.md rename to paddle/operators/op_documentation/name_convention.md diff --git a/paddle/operators/net_op_design.md b/paddle/operators/op_documentation/net_op_design.md similarity index 100% rename from paddle/operators/net_op_design.md rename to paddle/operators/op_documentation/net_op_design.md diff --git a/paddle/operators/op_documentation/op_markdown_format.md b/paddle/operators/op_documentation/op_markdown_format.md new file mode 100644 index 0000000000000000000000000000000000000000..0ee804d592252c727622cbe59b0644813db3c4fd --- /dev/null +++ b/paddle/operators/op_documentation/op_markdown_format.md @@ -0,0 +1,64 @@ +# Standard Markdown Format for Operators +The following should be the standard format for documentation for all the operators that will get rendered in the `html`: + +``` +Operator Name (In PaddlePaddle) + +Operator Name (Standard) + +Operator description. + +LaTeX equation of how the operator performs an update. + +The signature of the operator. +``` + +Each section mentioned above has been covered in further detail in the rest of the document. + +# PaddlePaddle Operator Name +This should be in all small letters, in case of multiple words, we separate them with an underscore. For example: +`array to lod tensor` should be written as `array_to_lod_tensor`. + +This naming convention should be standard across all PaddlePaddle operators. + +# Standard Operator Name +This is the standard name of the operator as used in the community. The general standard is usually: +- Standard abbreviations like `SGD` are written in all capital letters. +- Operator names that have multiple words inside a single word use `camelCase` (capitalize word boundaries inside of a word). +- Keep numbers inside a word as is, with no boundary delimiters. +- Follow the name of the operator with the keyword: `Activation Operator.` + +# Operator description +This section should contain the description of what the operator does, including the operation performed, the literature from where it comes and was introduced first, and other important details. The relevant paper/article including the hyperlink should be cited in this section. + +# LaTeX equation +This section should contain an overall equation of the update or operation that the operator performs. The variables used in the equation should follow the naming convention of operators as described [here](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/name_convention.md). Two words in the same word should be separated by an underscore (`_`). + +# The signature +This section describes the signature of the operator. A list of Inputs and Outputs, each of which have a small description of what the variable represents and the type of variable. The variable names follow the `CamelCase` naming convention. The proposed format for this is: +`Section : +VariableName : (VariableType) VariableDescription +... +... +` + + +The following example for an `sgd` operator covers the above mentioned sections as they would ideally look like in the `html`: + +``` +sgd + +SGD operator + +This operator implements one step of the stochastic gradient descent algorithm. + +param_out = param_learning_rate * grad + +Inputs: +Param : (Tensor) Input parameter +LearningRate : (Tensor) Learning rate of SGD +Grad : (Tensor) Input gradient + +Outputs: +ParamOut : (Tensor) Output parameter +``` diff --git a/paddle/operators/rnn_design.md b/paddle/operators/op_documentation/rnn_design.md similarity index 100% rename from paddle/operators/rnn_design.md rename to paddle/operators/op_documentation/rnn_design.md diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py index 51da00f5658685811fa19740381eb15ae13af347..538a0e6f6ed5786cd08cc89ae9b0b6e10e8f4819 100644 --- a/python/paddle/v2/fluid/layers/nn.py +++ b/python/paddle/v2/fluid/layers/nn.py @@ -25,35 +25,7 @@ def fc(input, act=None, name=None): """ - Fully Connected Layer. - - Args: - input: The input tensor(s) to the fully connected layer. - size: The number of output units in the fully connected layer. - num_flatten_dims: The fc layer can accept an input tensor with more than - two dimensions. If this happens, the multidimensional - tensor will first be flattened into a 2-dimensional - matrix. The parameter `num_flatten_dims` determines - how the input tensor is flattened: the first - `num_flatten_dims` dimensions will be flatten to form - the first dimension of the final matrix (height of the - matrix), and the rest `rank(X) - num_col_dims` - dimensions are flattened to form the second dimension - of the final matrix (width of the matrix). For example, - suppose `X` is a 6-dimensional tensor with a shape - [2, 3, 4, 5, 6], and `x_num_col_dims` = 3. Then, the - flattened matrix will have a shape [2 x 3 x 4, 5 x 6] - = [24, 30]. By default, `x_num_col_dims` is set to 1. - param_attr: The parameter attribute for learnable parameters/weights of - the fully connected Layer. - param_initializer: The initializer used for the weight/parameter. - If set None, XavierInitializer() will be used. - bias_attr: The parameter attribute for the bias parameter for this layer. - If set None, no bias will be added to the output units. - bias_initializer: The initializer used for the bias. If set None, - then ConstantInitializer() will be used. - act: Activation to be applied to the output of the fully connected layer. - name: Name/alias of the fully connected layer. + **Fully Connected Layer** The fully connected layer can take multiple tensors as its inputs. It creates a variable (one for each input tensor) called weights for each input @@ -68,12 +40,64 @@ def fc(input, This process can be formulated as follows: .. math:: - Y = \sigma({\sum_{i=0}^{N-1}W_iX_i + b}) + Out = Act({\sum_{i=0}^{N-1}W_iX_i + b}) + + In the above equation: + + * :math:`N`: Number of the input. + * :math:`X_i`: The input tensor. + * :math:`W`: The weights created by this layer. + * :math:`b`: The bias parameter created by this layer (if needed). + * :math`Act`: The activation funtion. + * :math`Out`: The output tensor. + + Args: + input(Variable|list): The input tensor(s) to the fully connected layer. + size(int): The number of output units in the fully connected layer. + num_flatten_dims(int): The fc layer can accept an input tensor with more + than two dimensions. If this happens, the + multidimensional tensor will first be flattened + into a 2-dimensional matrix. The parameter + `num_flatten_dims` determines how the input tensor + is flattened: the first `num_flatten_dims` + dimensions will be flatten to form the first + dimension of the final matrix (height of the + matrix), and the rest `rank(X) - num_col_dims` + dimensions are flattened to form the second + dimension of the final matrix (width of the matrix). + For example, suppose `X` is a 6-dimensional tensor + with a shape [2, 3, 4, 5, 6], and + `x_num_col_dims` = 3. Then, the flattened matrix + will have a shape [2 x 3 x 4, 5 x 6] = [24, 30]. + By default, `x_num_col_dims` is set to 1. + param_attr(ParamAttr|list): The parameter attribute for learnable + parameters/weights of the fully connected + layer. + param_initializer(ParamAttr|list): The initializer used for the + weight/parameter. If set None, + XavierInitializer() will be used. + bias_attr(ParamAttr|list): The parameter attribute for the bias parameter + for this layer. If set None, no bias will be + added to the output units. + bias_initializer(ParamAttr|list): The initializer used for the bias. + If set None, then ConstantInitializer() + will be used. + act(str): Activation to be applied to the output of the fully connected + layer. + name(str): Name/alias of the fully connected layer. + + + Returns: + Variable: The output tensor variable. - where, :math:`N` is the number of input, :math:`X_i` is the input tensor, - :math:`W` is the weights created by this layer, :math:`b` is the bias - created by this layer (if needed), :math:`\sigma` is the activation funtion. + Raises: + ValueError: If rank of the input tensor is less than 2. + Examples: + .. code-block:: python + + data = fluid.layers.data(name="data", shape=[32, 32], dtype="float32") + fc = fluid.layers.fc(input=data, size=1000, act="tanh") """ helper = LayerHelper("fc", **locals()) @@ -115,23 +139,30 @@ def fc(input, def embedding(input, size, is_sparse=False, param_attr=None, dtype='float32'): """ - Embedding Layer. + **Embedding Layer** + + This layer is used to lookup a vector of IDs, provided by *input*, in a lookup table. + The result of this lookup is the embedding of each ID in the *input*. + + All the input variables are passed in as local variables to the LayerHelper + constructor. Args: - param_initializer: - input: The input to the function - size: The size of the layer - is_sparse: A flag that decleares whether the input is sparse - param_attr: Parameters for this layer - dtype: The type of data : float32, float_16, int etc + input(Variable): Input to the function + size(int): Output size + is_sparse(bool): Boolean flag that specifying whether the input is sparse + param_attr(ParamAttr): Parameters for this layer + dtype(np.dtype|core.DataType|str): The type of data : float32, float_16, int etc - This function can take in the input (which is a vector of IDs) and - performs a lookup in the lookup_table using these IDs, to result into - the embedding of each ID in the input. + Returns: + Variable: The tensor variable storing the embeddings of the \ + supplied inputs. - All the input variables of this function are passed in as local variables - to the LayerHelper constructor. + Examples: + .. code-block:: python + data = fluid.layers.data(name='ids', shape=[32, 32], dtype='float32') + fc = fluid.layers.embedding(input=data, size=16) """ helper = LayerHelper('embedding', **locals())