提交 ebe4425f 编写于 作者: C caoying03

Merge branch 'develop' into refine_doc

......@@ -26,7 +26,7 @@ template <int BlockSize>
__global__ void AccuracyCudaKernel(const int N, const int D,
const int64_t* Xdata,
const int64_t* labeldata, int* correct_data,
float* accuracy) {
float* accuracy, int* total_data) {
int count = 0;
__shared__ int total[BlockSize];
......@@ -47,6 +47,7 @@ __global__ void AccuracyCudaKernel(const int N, const int D,
if (threadIdx.x == 0) {
*correct_data = result;
*accuracy = static_cast<float>(result) / static_cast<float>(N);
*total_data = N;
}
}
......@@ -80,22 +81,11 @@ class AccuracyOpCUDAKernel : public framework::OpKernel<T> {
if (num_samples == 0) {
return;
}
platform::GpuMemcpyAsync(total_data, &num_samples, sizeof(int),
cudaMemcpyHostToDevice, stream);
AccuracyCudaKernel<
PADDLE_CUDA_NUM_THREADS><<<1, PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
num_samples, infer_width, indices_data, label_data, correct_data,
accuracy_data);
int d_num_samples, d_num_correct;
float d_accuracy;
platform::GpuMemcpyAsync(&d_num_correct, correct_data, sizeof(int),
cudaMemcpyDeviceToHost, stream);
platform::GpuMemcpyAsync(&d_num_samples, total_data, sizeof(int),
cudaMemcpyDeviceToHost, stream);
platform::GpuMemcpyAsync(&d_accuracy, accuracy_data, sizeof(float),
cudaMemcpyDeviceToHost, stream);
accuracy_data, total_data);
}
};
......
# Standard Markdown Format for Operators
The following should be the standard format for documentation for all the operators that will get rendered in the `html`:
```
Operator Name (In PaddlePaddle)
Operator Name (Standard)
Operator description.
LaTeX equation of how the operator performs an update.
The signature of the operator.
```
Each section mentioned above has been covered in further detail in the rest of the document.
# PaddlePaddle Operator Name
This should be in all small letters, in case of multiple words, we separate them with an underscore. For example:
`array to lod tensor` should be written as `array_to_lod_tensor`.
This naming convention should be standard across all PaddlePaddle operators.
# Standard Operator Name
This is the standard name of the operator as used in the community. The general standard is usually:
- Standard abbreviations like `SGD` are written in all capital letters.
- Operator names that have multiple words inside a single word use `camelCase` (capitalize word boundaries inside of a word).
- Keep numbers inside a word as is, with no boundary delimiters.
- Follow the name of the operator with the keyword: `Activation Operator.`
# Operator description
This section should contain the description of what the operator does, including the operation performed, the literature from where it comes and was introduced first, and other important details. The relevant paper/article including the hyperlink should be cited in this section.
# LaTeX equation
This section should contain an overall equation of the update or operation that the operator performs. The variables used in the equation should follow the naming convention of operators as described [here](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/name_convention.md). Two words in the same word should be separated by an underscore (`_`).
# The signature
This section describes the signature of the operator. A list of Inputs and Outputs, each of which have a small description of what the variable represents and the type of variable. The variable names follow the `CamelCase` naming convention. The proposed format for this is:
`Section :
VariableName : (VariableType) VariableDescription
...
...
`
The following example for an `sgd` operator covers the above mentioned sections as they would ideally look like in the `html`:
```
sgd
SGD operator
This operator implements one step of the stochastic gradient descent algorithm.
param_out = param_learning_rate * grad
Inputs:
Param : (Tensor) Input parameter
LearningRate : (Tensor) Learning rate of SGD
Grad : (Tensor) Input gradient
Outputs:
ParamOut : (Tensor) Output parameter
```
......@@ -25,35 +25,7 @@ def fc(input,
act=None,
name=None):
"""
Fully Connected Layer.
Args:
input: The input tensor(s) to the fully connected layer.
size: The number of output units in the fully connected layer.
num_flatten_dims: The fc layer can accept an input tensor with more than
two dimensions. If this happens, the multidimensional
tensor will first be flattened into a 2-dimensional
matrix. The parameter `num_flatten_dims` determines
how the input tensor is flattened: the first
`num_flatten_dims` dimensions will be flatten to form
the first dimension of the final matrix (height of the
matrix), and the rest `rank(X) - num_col_dims`
dimensions are flattened to form the second dimension
of the final matrix (width of the matrix). For example,
suppose `X` is a 6-dimensional tensor with a shape
[2, 3, 4, 5, 6], and `x_num_col_dims` = 3. Then, the
flattened matrix will have a shape [2 x 3 x 4, 5 x 6]
= [24, 30]. By default, `x_num_col_dims` is set to 1.
param_attr: The parameter attribute for learnable parameters/weights of
the fully connected Layer.
param_initializer: The initializer used for the weight/parameter.
If set None, XavierInitializer() will be used.
bias_attr: The parameter attribute for the bias parameter for this layer.
If set None, no bias will be added to the output units.
bias_initializer: The initializer used for the bias. If set None,
then ConstantInitializer() will be used.
act: Activation to be applied to the output of the fully connected layer.
name: Name/alias of the fully connected layer.
**Fully Connected Layer**
The fully connected layer can take multiple tensors as its inputs. It
creates a variable (one for each input tensor) called weights for each input
......@@ -68,12 +40,64 @@ def fc(input,
This process can be formulated as follows:
.. math::
Y = \sigma({\sum_{i=0}^{N-1}W_iX_i + b})
Out = Act({\sum_{i=0}^{N-1}W_iX_i + b})
In the above equation:
* :math:`N`: Number of the input.
* :math:`X_i`: The input tensor.
* :math:`W`: The weights created by this layer.
* :math:`b`: The bias parameter created by this layer (if needed).
* :math`Act`: The activation funtion.
* :math`Out`: The output tensor.
Args:
input(Variable|list): The input tensor(s) to the fully connected layer.
size(int): The number of output units in the fully connected layer.
num_flatten_dims(int): The fc layer can accept an input tensor with more
than two dimensions. If this happens, the
multidimensional tensor will first be flattened
into a 2-dimensional matrix. The parameter
`num_flatten_dims` determines how the input tensor
is flattened: the first `num_flatten_dims`
dimensions will be flatten to form the first
dimension of the final matrix (height of the
matrix), and the rest `rank(X) - num_col_dims`
dimensions are flattened to form the second
dimension of the final matrix (width of the matrix).
For example, suppose `X` is a 6-dimensional tensor
with a shape [2, 3, 4, 5, 6], and
`x_num_col_dims` = 3. Then, the flattened matrix
will have a shape [2 x 3 x 4, 5 x 6] = [24, 30].
By default, `x_num_col_dims` is set to 1.
param_attr(ParamAttr|list): The parameter attribute for learnable
parameters/weights of the fully connected
layer.
param_initializer(ParamAttr|list): The initializer used for the
weight/parameter. If set None,
XavierInitializer() will be used.
bias_attr(ParamAttr|list): The parameter attribute for the bias parameter
for this layer. If set None, no bias will be
added to the output units.
bias_initializer(ParamAttr|list): The initializer used for the bias.
If set None, then ConstantInitializer()
will be used.
act(str): Activation to be applied to the output of the fully connected
layer.
name(str): Name/alias of the fully connected layer.
Returns:
Variable: The output tensor variable.
where, :math:`N` is the number of input, :math:`X_i` is the input tensor,
:math:`W` is the weights created by this layer, :math:`b` is the bias
created by this layer (if needed), :math:`\sigma` is the activation funtion.
Raises:
ValueError: If rank of the input tensor is less than 2.
Examples:
.. code-block:: python
data = fluid.layers.data(name="data", shape=[32, 32], dtype="float32")
fc = fluid.layers.fc(input=data, size=1000, act="tanh")
"""
helper = LayerHelper("fc", **locals())
......@@ -115,23 +139,30 @@ def fc(input,
def embedding(input, size, is_sparse=False, param_attr=None, dtype='float32'):
"""
Embedding Layer.
**Embedding Layer**
This layer is used to lookup a vector of IDs, provided by *input*, in a lookup table.
The result of this lookup is the embedding of each ID in the *input*.
All the input variables are passed in as local variables to the LayerHelper
constructor.
Args:
param_initializer:
input: The input to the function
size: The size of the layer
is_sparse: A flag that decleares whether the input is sparse
param_attr: Parameters for this layer
dtype: The type of data : float32, float_16, int etc
input(Variable): Input to the function
size(int): Output size
is_sparse(bool): Boolean flag that specifying whether the input is sparse
param_attr(ParamAttr): Parameters for this layer
dtype(np.dtype|core.DataType|str): The type of data : float32, float_16, int etc
This function can take in the input (which is a vector of IDs) and
performs a lookup in the lookup_table using these IDs, to result into
the embedding of each ID in the input.
Returns:
Variable: The tensor variable storing the embeddings of the \
supplied inputs.
All the input variables of this function are passed in as local variables
to the LayerHelper constructor.
Examples:
.. code-block:: python
data = fluid.layers.data(name='ids', shape=[32, 32], dtype='float32')
fc = fluid.layers.embedding(input=data, size=16)
"""
helper = LayerHelper('embedding', **locals())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册