提交 8823a12e 编写于 作者: Y Yang Yu

Merge branch 'develop' of github.com:baidu/Paddle into feature/add_reorder_lod_tensor

## Problem
In PaddlePaddle's [Design](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/switch_kernel.md), one Operator may have multiple kernels. Users may have some personal preference to choose a certain type of kernel for an operator, such as `force_cpu` to choose a CPU kernel, `use_cudnn` to choose a CUDNN kernel, we need to provide a way for users to do this.
In the current design, we use KernelType to describe one kernel.
```cpp
struct KernelType {
Place place_;
DataType data_type_;
LayoutType layout_;
};
```
`place_` `data_type_` and `layout_` can be got from the input tensors of the operator, `GetActualKernelType(inputs)` use inputs to infer the proper kernel key that fit the incoming data, but users can not directly configure it.
The [design](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/switch_kernel.md) also provides a virtual method `GetExpectedKernelType` that user can overload and use to choose the KernelType they want to use.
So we should send the information user defined in proto to `GetExpectedKernelType` for choosing a kernel.
The problem is, how should we define and send the information for `GetExpectedKernelType` to use?
## Solution
### Potential choice
1. Do nothing, let the user add the information they want to operator‘s attribute and get them inside `GetExpectedKernelType`, this can work properly. But there is a little problem that users may define many kinds of hints for the same purpose, such as `force_cpu`, `use_cpu`, `cpu_kernel` to choose CPU kernel, and `use_cudnn`, `force_cudnn`, `cudnn_kernel` to choose CUDNN kernel.
2. Pre-define all the needed option and use a single attr key such as `kernel_hint` for the user, this is not so flexible if the user wants to define some more kind of hint.
### Final choice
To provide enough flexibility while avoiding confusion definition, we can define some global constants for these attribute names, such as `force_cpu`, `use_cudnn`, `use_mkldnn` for a user to choose.
In C++
```cpp
const std::string kForceCPU = "force_cpu";
const std::string kUseCUDNN = "use_cudnn";
const std::string kUseMKLDNN = "use_mkldnn";
KernelType GetExpectedKernelType() {
if (Attr<bool>(kForceCPU)) {
return KernelType(CPUPlace, ...)
} else {
...
}
}
```
In Python code
```python
FORCE_CPU = core.kForceCPU()
def xx_layer(..., force_cpu=false):
layer_helper = LayerHelper(...)
layer_helper.append_op(
type="xx",
attr={FORCE_CPU: force_cpu})
```
......@@ -26,7 +26,7 @@ template <int BlockSize>
__global__ void AccuracyCudaKernel(const int N, const int D,
const int64_t* Xdata,
const int64_t* labeldata, int* correct_data,
float* accuracy) {
float* accuracy, int* total_data) {
int count = 0;
__shared__ int total[BlockSize];
......@@ -47,6 +47,7 @@ __global__ void AccuracyCudaKernel(const int N, const int D,
if (threadIdx.x == 0) {
*correct_data = result;
*accuracy = static_cast<float>(result) / static_cast<float>(N);
*total_data = N;
}
}
......@@ -80,22 +81,11 @@ class AccuracyOpCUDAKernel : public framework::OpKernel<T> {
if (num_samples == 0) {
return;
}
platform::GpuMemcpyAsync(total_data, &num_samples, sizeof(int),
cudaMemcpyHostToDevice, stream);
AccuracyCudaKernel<
PADDLE_CUDA_NUM_THREADS><<<1, PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
num_samples, infer_width, indices_data, label_data, correct_data,
accuracy_data);
int d_num_samples, d_num_correct;
float d_accuracy;
platform::GpuMemcpyAsync(&d_num_correct, correct_data, sizeof(int),
cudaMemcpyDeviceToHost, stream);
platform::GpuMemcpyAsync(&d_num_samples, total_data, sizeof(int),
cudaMemcpyDeviceToHost, stream);
platform::GpuMemcpyAsync(&d_accuracy, accuracy_data, sizeof(float),
cudaMemcpyDeviceToHost, stream);
accuracy_data, total_data);
}
};
......
......@@ -71,7 +71,7 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
auto M = EigenMatrix<T>::Reshape(*mask, 1);
Y.device(place) = X * M;
} else {
Y.device(place) = X * dropout_prob;
Y.device(place) = X * (1.0f - dropout_prob);
}
}
};
......
......@@ -57,7 +57,7 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
auto Y = EigenMatrix<T>::Reshape(*y, 1);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
Y.device(place) = X * dropout_prob;
Y.device(place) = X * (1.0f - dropout_prob);
}
}
};
......
......@@ -103,10 +103,12 @@ class MidWiseTransformIterator<T, platform::CPUDeviceContext> {
MidWiseTransformIterator<T, platform::CPUDeviceContext>& operator++() {
++j_;
i_ = j_ / post_;
if (UNLIKELY(i_ == n_)) {
if (UNLIKELY(j_ == post_)) {
++i_;
j_ = 0;
i_ = 0;
if (UNLIKELY(i_ == n_)) {
i_ = 0;
}
}
return *this;
}
......@@ -125,10 +127,10 @@ class MidWiseTransformIterator<T, platform::CPUDeviceContext> {
private:
const T* ptr_;
int i_;
int64_t i_;
int64_t j_;
int64_t n_;
int post_;
int64_t post_;
};
#ifdef __NVCC__
......
......@@ -61,14 +61,13 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
const T* im_data = im.data<T>();
T* col_data = col->data<T>();
for (int c = 0; c < channels_col; ++c) {
int w_offset = c % filter_width;
int h_offset = (c / filter_width) % filter_height;
int c_im = c / filter_width / filter_height;
int c_im = c / (filter_width * filter_height);
for (int h = 0; h < col_height; ++h) {
int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0];
for (int w = 0; w < col_width; ++w) {
int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0];
int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1];
int col_idx = (c * col_height + h) * col_width + w;
int im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx;
......@@ -130,16 +129,14 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
for (int c = 0; c < channels_col; ++c) {
int w_offset = c % filter_width;
int h_offset = (c / filter_width) % filter_height;
int c_im = c / filter_width / filter_height;
int c_im = c / (filter_width * filter_height);
for (int h = 0; h < col_height; ++h) {
int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0];
for (int w = 0; w < col_width; ++w) {
int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0];
int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1];
if ((im_row_idx) >= 0 && (im_row_idx) < im_height &&
(im_col_idx) >= 0 && (im_col_idx) < im_width) {
im_row_idx += c_im * im_height;
im_data[im_row_idx * im_width + im_col_idx] +=
im_data[(im_row_idx + c_im * im_height) * im_width + im_col_idx] +=
col_data[(c * col_height + h) * col_width + w];
}
}
......@@ -199,12 +196,13 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
for (int channel = 0; channel < im_channels; ++channel) {
for (int filter_row_idx = 0; filter_row_idx < filter_height;
++filter_row_idx) {
int im_row_offset =
col_row_idx * stride[0] + filter_row_idx - padding[0];
for (int filter_col_idx = 0; filter_col_idx < filter_width;
++filter_col_idx) {
int im_row_offset =
col_row_idx * stride[0] + filter_row_idx - padding[0];
int im_col_offset =
col_col_idx * stride[1] + filter_col_idx - padding[1];
int col_offset =
((((col_row_idx)*col_width + col_col_idx) * im_channels +
channel) *
......@@ -271,12 +269,13 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
for (int channel = 0; channel < im_channels; ++channel) {
for (int filter_row_idx = 0; filter_row_idx < filter_height;
++filter_row_idx) {
int im_row_offset =
col_row_idx * stride[0] + filter_row_idx - padding[0];
for (int filter_col_idx = 0; filter_col_idx < filter_width;
++filter_col_idx) {
int im_row_offset =
col_row_idx * stride[0] + filter_row_idx - padding[0];
int im_col_offset =
col_col_idx * stride[1] + filter_col_idx - padding[1];
int col_offset =
(((col_row_idx * col_width + col_col_idx) * im_channels +
channel) *
......@@ -284,6 +283,7 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
filter_row_idx) *
filter_width +
filter_col_idx;
if (im_row_offset >= 0 && im_row_offset < im_height &&
im_col_offset >= 0 && im_col_offset < im_width) {
int im_offset =
......
# Standard Markdown Format for Operators
The following should be the standard format for documentation for all the operators that will get rendered in the `html`:
```
Operator Name (In PaddlePaddle)
Operator Name (Standard)
Operator description.
LaTeX equation of how the operator performs an update.
The signature of the operator.
```
Each section mentioned above has been covered in further detail in the rest of the document.
# PaddlePaddle Operator Name
This should be in all small letters, in case of multiple words, we separate them with an underscore. For example:
`array to lod tensor` should be written as `array_to_lod_tensor`.
This naming convention should be standard across all PaddlePaddle operators.
# Standard Operator Name
This is the standard name of the operator as used in the community. The general standard is usually:
- Standard abbreviations like `SGD` are written in all capital letters.
- Operator names that have multiple words inside a single word use `camelCase` (capitalize word boundaries inside of a word).
- Keep numbers inside a word as is, with no boundary delimiters.
- Follow the name of the operator with the keyword: `Activation Operator.`
# Operator description
This section should contain the description of what the operator does, including the operation performed, the literature from where it comes and was introduced first, and other important details. The relevant paper/article including the hyperlink should be cited in this section.
# LaTeX equation
This section should contain an overall equation of the update or operation that the operator performs. The variables used in the equation should follow the naming convention of operators as described [here](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/name_convention.md). Two words in the same word should be separated by an underscore (`_`).
# The signature
This section describes the signature of the operator. A list of Inputs and Outputs, each of which have a small description of what the variable represents and the type of variable. The variable names follow the `CamelCase` naming convention. The proposed format for this is:
`Section :
VariableName : (VariableType) VariableDescription
...
...
`
The following example for an `sgd` operator covers the above mentioned sections as they would ideally look like in the `html`:
```
sgd
SGD operator
This operator implements one step of the stochastic gradient descent algorithm.
param_out = param_learning_rate * grad
Inputs:
Param : (Tensor) Input parameter
LearningRate : (Tensor) Learning rate of SGD
Grad : (Tensor) Input gradient
Outputs:
ParamOut : (Tensor) Output parameter
```
......@@ -50,10 +50,14 @@ input Tensor can be either [N, 1] or [N], where N is the sum of the length
of all sequences.
The algorithm works as follows:
for i-th sequence in a mini-batch:
$$Out(X[lod[i]:lod[i+1]], :) =
\frac{\exp(X[lod[i]:lod[i+1], :])}
{\sum(\exp(X[lod[i]:lod[i+1], :]))}$$
$$
Out(X[lod[i]:lod[i+1]], :) = \
\frac{\exp(X[lod[i]:lod[i+1], :])} \
{\sum(\exp(X[lod[i]:lod[i+1], :]))}
$$
For example, for a mini-batch of 3 sequences with variable-length,
each containing 2, 3, 2 time-steps, the lod of which is [0, 2, 5, 7],
......
......@@ -19,7 +19,7 @@ CPUDeviceContext::CPUDeviceContext() {
eigen_device_.reset(new Eigen::DefaultDevice());
}
CPUDeviceContext::CPUDeviceContext(CPUPlace place) {
CPUDeviceContext::CPUDeviceContext(CPUPlace place) : place_(place) {
eigen_device_.reset(new Eigen::DefaultDevice());
}
......@@ -27,7 +27,7 @@ Eigen::DefaultDevice* CPUDeviceContext::eigen_device() const {
return eigen_device_.get();
}
Place CPUDeviceContext::GetPlace() const { return CPUPlace(); }
Place CPUDeviceContext::GetPlace() const { return place_; }
#ifdef PADDLE_WITH_CUDA
......
......@@ -45,6 +45,7 @@ class CPUDeviceContext : public DeviceContext {
Place GetPlace() const override;
private:
CPUPlace place_;
std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
};
......
......@@ -519,6 +519,24 @@ def create_array(dtype):
def less_than(x, y, cond=None, **ignored):
"""
**Less than**
This layer returns the truth value of :math:`x < y` elementwise.
Args:
x(Variable): First operand of *less_than*
y(Variable): Second operand of *less_than*
cond(Variable|None): Optional output variable to store the result of *less_than*
Returns:
Variable: The tensor variable storing the output of *less_than*.
Examples:
.. code-block:: python
less = fluid.layers.less_than(x=label, y=limit)
"""
helper = LayerHelper("less_than", **locals())
if cond is None:
cond = helper.create_tmp_variable(dtype='bool')
......
......@@ -25,32 +25,48 @@ def fc(input,
act=None,
name=None):
"""
Fully Connected Layer.
**Fully Connected Layer**
This layer accepts multiple inputs and applies a linear transformation to each input.
If activation type is provided, the corresponding activation function is applied to the
output of the linear transformation. For each input :math:`X`, the equation is:
.. math::
Out = Act(WX + b)
In the above equation:
* :math:`X`: Input value, a tensor with rank at least 2.
* :math:`W`: Weight, a 2-D tensor with shape [M, N].
* :math:`b`: Bias, a 2-D tensor with shape [M, 1].
* :math:`Act`: Activation function.
* :math:`Out`: Output value, same shape with :math:`X`.
All the input variables are passed in as local variables to the LayerHelper
constructor.
Args:
input: The input tensor to the function
size: The size of the layer
num_flatten_dims: Number of columns in input
param_attr: The parameters/weights to the FC Layer
param_initializer: Initializer used for the weight/parameter. If None, XavierInitializer() is used
bias_attr: The bias parameter for the FC layer
bias_initializer: Initializer used for the bias. If None, then ConstantInitializer() is used
act: Activation to be applied to the output of FC layer
name: Name/alias of the function
main_program: Name of the main program that calls this
startup_program: Name of the startup program
This function can take in multiple inputs and performs the Fully Connected
function (linear transformation) on top of each of them.
So for input x, the output will be : Wx + b. Where W is the parameter,
b the bias and x is the input.
The function also applies an activation (non-linearity) on top of the
output, if activation is passed in the input.
All the input variables of this function are passed in as local variables
to the LayerHelper constructor.
input(Variable|list): Input tensors. Each tensor has a rank of atleast 2
size(int): Output size
num_flatten_dims(int): Number of columns in input
param_attr(ParamAttr|list): The parameters/weights to the FC Layer
bias_attr(ParamAttr|list): Bias parameter for the FC layer
act(str): Activation type
name(str): Name/alias of the function
Returns:
Variable: The tensor variable storing the transformation and \
non-linearity activation result.
Raises:
ValueError: If rank of input tensor is less than 2.
Examples:
.. code-block:: python
data = fluid.layers.data(name='data', shape=[32, 32], dtype='float32')
fc = fluid.layers.fc(input=data, size=1000, act="tanh")
"""
helper = LayerHelper('fc', **locals())
......@@ -91,25 +107,30 @@ def fc(input,
def embedding(input, size, is_sparse=False, param_attr=None, dtype='float32'):
"""
Embedding Layer.
**Embedding Layer**
This layer is used to lookup a vector of IDs, provided by *input*, in a lookup table.
The result of this lookup is the embedding of each ID in the *input*.
All the input variables are passed in as local variables to the LayerHelper
constructor.
Args:
param_initializer:
input: The input to the function
size: The size of the layer
is_sparse: A flag that decleares whether the input is sparse
param_attr: Parameters for this layer
dtype: The type of data : float32, float_16, int etc
main_program: Name of the main program that calls this
startup_program: Name of the startup program
This function can take in the input (which is a vector of IDs) and
performs a lookup in the lookup_table using these IDs, to result into
the embedding of each ID in the input.
All the input variables of this function are passed in as local variables
to the LayerHelper constructor.
input(Variable): Input to the function
size(int): Output size
is_sparse(bool): Boolean flag that specifying whether the input is sparse
param_attr(ParamAttr): Parameters for this layer
dtype(np.dtype|core.DataType|str): The type of data : float32, float_16, int etc
Returns:
Variable: The tensor variable storing the embeddings of the \
supplied inputs.
Examples:
.. code-block:: python
data = fluid.layers.data(name='ids', shape=[32, 32], dtype='float32')
fc = fluid.layers.embedding(input=data, size=16)
"""
helper = LayerHelper('embedding', **locals())
......
......@@ -66,9 +66,26 @@ def assign(input, output):
def fill_constant(shape, dtype, value, out=None):
"""
This function creates a tensor , with shape as mentioned in the input and
specified dtype and fills this up with a constant value that
comes in the input. It also sets the stop_gradient to be True.
**fill_constant**
This function creates a tensor of specified *shape* and
*dtype*, and initializes this with a constant supplied in *value*.
It also sets *stop_gradient* to True.
Args:
shape(tuple|list|None): Shape of output tensor
dtype(np.dtype|core.DataType|str): Data type of output tensor
value(float): Constant value to initialize the output tensor
out(Variable): Output Variable to initialize
Returns:
Variable: The tensor variable storing the output
Examples:
.. code-block:: python
data = fluid.layers.fill_constant(shape=[1], value=0, dtype='int64')
"""
helper = LayerHelper("fill_constant", **locals())
if out is None:
......@@ -90,6 +107,31 @@ def fill_constant_batch_size_like(input,
value,
input_dim_idx=0,
output_dim_idx=0):
"""
**fill_constant_batch_size_like**
This function creates a tensor of specified *shape*, *dtype* and batch size,
and initializes this with a constant supplied in *value*. The batch size is
obtained from the `input` tensor.
It also sets *stop_gradient* to True.
Args:
input(Variable): Tensor whose dimensions will be used to get batch size
shape(tuple|list|None): Shape of output tensor
dtype(np.dtype|core.DataType|str): Data type of output tensor
value(float): Constant value to initialize the output tensor
input_dim_idx(int): Index of input's batch size dimension
output_dim_idx(int): Index of output's batch size dimension
Returns:
Variable: The tensor variable storing the output
Examples:
.. code-block:: python
data = fluid.layers.fill_constant(shape=[1], value=0, dtype='int64')
"""
helper = LayerHelper("fill_constant_batch_size_like", **locals())
out = helper.create_tmp_variable(dtype=dtype)
helper.append_op(
......
......@@ -47,7 +47,9 @@ class TestDropoutOp4(OpTest):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64)).astype("float32")}
self.attrs = {'dropout_prob': 0.35, 'is_test': True}
self.outputs = {'Out': self.inputs['X'] * self.attrs['dropout_prob']}
self.outputs = {
'Out': self.inputs['X'] * (1.0 - self.attrs['dropout_prob'])
}
def test_check_output(self):
self.check_output()
......@@ -58,7 +60,9 @@ class TestDropoutOp5(OpTest):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64, 3)).astype("float32")}
self.attrs = {'dropout_prob': 0.75, 'is_test': True}
self.outputs = {'Out': self.inputs['X'] * self.attrs['dropout_prob']}
self.outputs = {
'Out': self.inputs['X'] * (1.0 - self.attrs['dropout_prob'])
}
def test_check_output(self):
self.check_output()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册