diff --git a/doc/design/kernel_hint_design.md b/doc/design/kernel_hint_design.md new file mode 100644 index 0000000000000000000000000000000000000000..a54b7da045e1a362626ef066f9ebb56af2c3181a --- /dev/null +++ b/doc/design/kernel_hint_design.md @@ -0,0 +1,57 @@ +## Problem +In PaddlePaddle's [Design](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/switch_kernel.md), one Operator may have multiple kernels. Users may have some personal preference to choose a certain type of kernel for an operator, such as `force_cpu` to choose a CPU kernel, `use_cudnn` to choose a CUDNN kernel, we need to provide a way for users to do this. + +In the current design, we use KernelType to describe one kernel. + +```cpp +struct KernelType { + Place place_; + DataType data_type_; + LayoutType layout_; +}; +``` + `place_` `data_type_` and `layout_` can be got from the input tensors of the operator, `GetActualKernelType(inputs)` use inputs to infer the proper kernel key that fit the incoming data, but users can not directly configure it. + +The [design](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/switch_kernel.md) also provides a virtual method `GetExpectedKernelType` that user can overload and use to choose the KernelType they want to use. + +So we should send the information user defined in proto to `GetExpectedKernelType` for choosing a kernel. + +The problem is, how should we define and send the information for `GetExpectedKernelType` to use? + +## Solution + +### Potential choice +1. Do nothing, let the user add the information they want to operator‘s attribute and get them inside `GetExpectedKernelType`, this can work properly. But there is a little problem that users may define many kinds of hints for the same purpose, such as `force_cpu`, `use_cpu`, `cpu_kernel` to choose CPU kernel, and `use_cudnn`, `force_cudnn`, `cudnn_kernel` to choose CUDNN kernel. + +2. Pre-define all the needed option and use a single attr key such as `kernel_hint` for the user, this is not so flexible if the user wants to define some more kind of hint. + +### Final choice +To provide enough flexibility while avoiding confusion definition, we can define some global constants for these attribute names, such as `force_cpu`, `use_cudnn`, `use_mkldnn` for a user to choose. + +In C++ + +```cpp +const std::string kForceCPU = "force_cpu"; +const std::string kUseCUDNN = "use_cudnn"; +const std::string kUseMKLDNN = "use_mkldnn"; + +KernelType GetExpectedKernelType() { + if (Attr(kForceCPU)) { + return KernelType(CPUPlace, ...) + } else { + ... + } +} +``` + +In Python code + +```python +FORCE_CPU = core.kForceCPU() + +def xx_layer(..., force_cpu=false): + layer_helper = LayerHelper(...) + layer_helper.append_op( + type="xx", + attr={FORCE_CPU: force_cpu}) +``` diff --git a/paddle/operators/accuracy_op.cu b/paddle/operators/accuracy_op.cu index 539a93530206c93a37791a9ccb2fb104af17f940..dd51aad105fecf4e3118f03e2f1868abb5523bc8 100644 --- a/paddle/operators/accuracy_op.cu +++ b/paddle/operators/accuracy_op.cu @@ -26,7 +26,7 @@ template __global__ void AccuracyCudaKernel(const int N, const int D, const int64_t* Xdata, const int64_t* labeldata, int* correct_data, - float* accuracy) { + float* accuracy, int* total_data) { int count = 0; __shared__ int total[BlockSize]; @@ -47,6 +47,7 @@ __global__ void AccuracyCudaKernel(const int N, const int D, if (threadIdx.x == 0) { *correct_data = result; *accuracy = static_cast(result) / static_cast(N); + *total_data = N; } } @@ -80,22 +81,11 @@ class AccuracyOpCUDAKernel : public framework::OpKernel { if (num_samples == 0) { return; } - platform::GpuMemcpyAsync(total_data, &num_samples, sizeof(int), - cudaMemcpyHostToDevice, stream); AccuracyCudaKernel< PADDLE_CUDA_NUM_THREADS><<<1, PADDLE_CUDA_NUM_THREADS, 0, stream>>>( num_samples, infer_width, indices_data, label_data, correct_data, - accuracy_data); - - int d_num_samples, d_num_correct; - float d_accuracy; - platform::GpuMemcpyAsync(&d_num_correct, correct_data, sizeof(int), - cudaMemcpyDeviceToHost, stream); - platform::GpuMemcpyAsync(&d_num_samples, total_data, sizeof(int), - cudaMemcpyDeviceToHost, stream); - platform::GpuMemcpyAsync(&d_accuracy, accuracy_data, sizeof(float), - cudaMemcpyDeviceToHost, stream); + accuracy_data, total_data); } }; diff --git a/paddle/operators/dropout_op.cu b/paddle/operators/dropout_op.cu index 10c670751d026ef92e01aad7da31a8f59b8514c0..c31d2195e95b116451b0f620f6582f65c0dae706 100644 --- a/paddle/operators/dropout_op.cu +++ b/paddle/operators/dropout_op.cu @@ -71,7 +71,7 @@ class GPUDropoutKernel : public framework::OpKernel { auto M = EigenMatrix::Reshape(*mask, 1); Y.device(place) = X * M; } else { - Y.device(place) = X * dropout_prob; + Y.device(place) = X * (1.0f - dropout_prob); } } }; diff --git a/paddle/operators/dropout_op.h b/paddle/operators/dropout_op.h index 84ad39f0bb639975365d427aa205411ef79ecd46..9f6c4212d4f834bbb2d1c65c836f3d3d0f3e0c96 100644 --- a/paddle/operators/dropout_op.h +++ b/paddle/operators/dropout_op.h @@ -57,7 +57,7 @@ class CPUDropoutKernel : public framework::OpKernel { auto Y = EigenMatrix::Reshape(*y, 1); auto& place = *context.template device_context().eigen_device(); - Y.device(place) = X * dropout_prob; + Y.device(place) = X * (1.0f - dropout_prob); } } }; diff --git a/paddle/operators/elementwise_op_function.h b/paddle/operators/elementwise_op_function.h index 7ebfc7df8c117edd7bcf14cc5ae6ba3dc1302c03..9edfacd6dfb506e12a0d82772d8de301bb8506e2 100644 --- a/paddle/operators/elementwise_op_function.h +++ b/paddle/operators/elementwise_op_function.h @@ -103,10 +103,12 @@ class MidWiseTransformIterator { MidWiseTransformIterator& operator++() { ++j_; - i_ = j_ / post_; - if (UNLIKELY(i_ == n_)) { + if (UNLIKELY(j_ == post_)) { + ++i_; j_ = 0; - i_ = 0; + if (UNLIKELY(i_ == n_)) { + i_ = 0; + } } return *this; } @@ -125,10 +127,10 @@ class MidWiseTransformIterator { private: const T* ptr_; - int i_; + int64_t i_; int64_t j_; int64_t n_; - int post_; + int64_t post_; }; #ifdef __NVCC__ diff --git a/paddle/operators/math/im2col.cc b/paddle/operators/math/im2col.cc index 707ebf05962fb65892c2adbbf41a0a3449763d31..c2633b2e16434558d16f699a701e7b8cf1de8342 100644 --- a/paddle/operators/math/im2col.cc +++ b/paddle/operators/math/im2col.cc @@ -61,14 +61,13 @@ class Im2ColFunctor(); T* col_data = col->data(); - for (int c = 0; c < channels_col; ++c) { int w_offset = c % filter_width; int h_offset = (c / filter_width) % filter_height; - int c_im = c / filter_width / filter_height; + int c_im = c / (filter_width * filter_height); for (int h = 0; h < col_height; ++h) { + int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0]; for (int w = 0; w < col_width; ++w) { - int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0]; int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1]; int col_idx = (c * col_height + h) * col_width + w; int im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx; @@ -130,16 +129,14 @@ class Col2ImFunctor= 0 && (im_row_idx) < im_height && (im_col_idx) >= 0 && (im_col_idx) < im_width) { - im_row_idx += c_im * im_height; - im_data[im_row_idx * im_width + im_col_idx] += + im_data[(im_row_idx + c_im * im_height) * im_width + im_col_idx] += col_data[(c * col_height + h) * col_width + w]; } } @@ -199,12 +196,13 @@ class Im2ColFunctor= 0 && im_row_offset < im_height && im_col_offset >= 0 && im_col_offset < im_width) { int im_offset = diff --git a/paddle/operators/batch_norm_op.md b/paddle/operators/op_documentation/batch_norm_op.md similarity index 100% rename from paddle/operators/batch_norm_op.md rename to paddle/operators/op_documentation/batch_norm_op.md diff --git a/paddle/operators/name_convention.md b/paddle/operators/op_documentation/name_convention.md similarity index 100% rename from paddle/operators/name_convention.md rename to paddle/operators/op_documentation/name_convention.md diff --git a/paddle/operators/net_op_design.md b/paddle/operators/op_documentation/net_op_design.md similarity index 100% rename from paddle/operators/net_op_design.md rename to paddle/operators/op_documentation/net_op_design.md diff --git a/paddle/operators/op_documentation/op_markdown_format.md b/paddle/operators/op_documentation/op_markdown_format.md new file mode 100644 index 0000000000000000000000000000000000000000..0ee804d592252c727622cbe59b0644813db3c4fd --- /dev/null +++ b/paddle/operators/op_documentation/op_markdown_format.md @@ -0,0 +1,64 @@ +# Standard Markdown Format for Operators +The following should be the standard format for documentation for all the operators that will get rendered in the `html`: + +``` +Operator Name (In PaddlePaddle) + +Operator Name (Standard) + +Operator description. + +LaTeX equation of how the operator performs an update. + +The signature of the operator. +``` + +Each section mentioned above has been covered in further detail in the rest of the document. + +# PaddlePaddle Operator Name +This should be in all small letters, in case of multiple words, we separate them with an underscore. For example: +`array to lod tensor` should be written as `array_to_lod_tensor`. + +This naming convention should be standard across all PaddlePaddle operators. + +# Standard Operator Name +This is the standard name of the operator as used in the community. The general standard is usually: +- Standard abbreviations like `SGD` are written in all capital letters. +- Operator names that have multiple words inside a single word use `camelCase` (capitalize word boundaries inside of a word). +- Keep numbers inside a word as is, with no boundary delimiters. +- Follow the name of the operator with the keyword: `Activation Operator.` + +# Operator description +This section should contain the description of what the operator does, including the operation performed, the literature from where it comes and was introduced first, and other important details. The relevant paper/article including the hyperlink should be cited in this section. + +# LaTeX equation +This section should contain an overall equation of the update or operation that the operator performs. The variables used in the equation should follow the naming convention of operators as described [here](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/name_convention.md). Two words in the same word should be separated by an underscore (`_`). + +# The signature +This section describes the signature of the operator. A list of Inputs and Outputs, each of which have a small description of what the variable represents and the type of variable. The variable names follow the `CamelCase` naming convention. The proposed format for this is: +`Section : +VariableName : (VariableType) VariableDescription +... +... +` + + +The following example for an `sgd` operator covers the above mentioned sections as they would ideally look like in the `html`: + +``` +sgd + +SGD operator + +This operator implements one step of the stochastic gradient descent algorithm. + +param_out = param_learning_rate * grad + +Inputs: +Param : (Tensor) Input parameter +LearningRate : (Tensor) Learning rate of SGD +Grad : (Tensor) Input gradient + +Outputs: +ParamOut : (Tensor) Output parameter +``` diff --git a/paddle/operators/rnn_design.md b/paddle/operators/op_documentation/rnn_design.md similarity index 100% rename from paddle/operators/rnn_design.md rename to paddle/operators/op_documentation/rnn_design.md diff --git a/paddle/operators/sequence_softmax_op.cc b/paddle/operators/sequence_softmax_op.cc index fe1832a36fa7fd91cf1fccddfc09918e4698e09c..b74766f012e333cc2a317e6efe17c5b60238924a 100644 --- a/paddle/operators/sequence_softmax_op.cc +++ b/paddle/operators/sequence_softmax_op.cc @@ -50,10 +50,14 @@ input Tensor can be either [N, 1] or [N], where N is the sum of the length of all sequences. The algorithm works as follows: + for i-th sequence in a mini-batch: - $$Out(X[lod[i]:lod[i+1]], :) = - \frac{\exp(X[lod[i]:lod[i+1], :])} - {\sum(\exp(X[lod[i]:lod[i+1], :]))}$$ + +$$ +Out(X[lod[i]:lod[i+1]], :) = \ +\frac{\exp(X[lod[i]:lod[i+1], :])} \ +{\sum(\exp(X[lod[i]:lod[i+1], :]))} +$$ For example, for a mini-batch of 3 sequences with variable-length, each containing 2, 3, 2 time-steps, the lod of which is [0, 2, 5, 7], diff --git a/paddle/platform/device_context.cc b/paddle/platform/device_context.cc index 8cdc5f43403b0c54d3f1f01a3e97405fd5b2f434..dacee74fff369586c7ca2ff62cfe6aeebd8f39c7 100644 --- a/paddle/platform/device_context.cc +++ b/paddle/platform/device_context.cc @@ -19,7 +19,7 @@ CPUDeviceContext::CPUDeviceContext() { eigen_device_.reset(new Eigen::DefaultDevice()); } -CPUDeviceContext::CPUDeviceContext(CPUPlace place) { +CPUDeviceContext::CPUDeviceContext(CPUPlace place) : place_(place) { eigen_device_.reset(new Eigen::DefaultDevice()); } @@ -27,7 +27,7 @@ Eigen::DefaultDevice* CPUDeviceContext::eigen_device() const { return eigen_device_.get(); } -Place CPUDeviceContext::GetPlace() const { return CPUPlace(); } +Place CPUDeviceContext::GetPlace() const { return place_; } #ifdef PADDLE_WITH_CUDA diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index 56813a1d5b3c2a7f4ff7b4eddc6fa47ed861700c..6cc0508522a97f3097b30e3340e7413a7093714a 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -45,6 +45,7 @@ class CPUDeviceContext : public DeviceContext { Place GetPlace() const override; private: + CPUPlace place_; std::unique_ptr eigen_device_; }; diff --git a/python/paddle/v2/fluid/layers/control_flow.py b/python/paddle/v2/fluid/layers/control_flow.py index f22dfb4c85a8e5854564e974c9c9cfb225f355f8..d66c834654759d3a391eb7bdb5581b7a32a89af1 100644 --- a/python/paddle/v2/fluid/layers/control_flow.py +++ b/python/paddle/v2/fluid/layers/control_flow.py @@ -519,6 +519,24 @@ def create_array(dtype): def less_than(x, y, cond=None, **ignored): + """ + **Less than** + + This layer returns the truth value of :math:`x < y` elementwise. + + Args: + x(Variable): First operand of *less_than* + y(Variable): Second operand of *less_than* + cond(Variable|None): Optional output variable to store the result of *less_than* + + Returns: + Variable: The tensor variable storing the output of *less_than*. + + Examples: + .. code-block:: python + + less = fluid.layers.less_than(x=label, y=limit) + """ helper = LayerHelper("less_than", **locals()) if cond is None: cond = helper.create_tmp_variable(dtype='bool') diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py index 73f68466da78021778d2860f24c1b699081fb95b..8d819de603a81de6934dd4b03a6fc49f4319a2d3 100644 --- a/python/paddle/v2/fluid/layers/nn.py +++ b/python/paddle/v2/fluid/layers/nn.py @@ -25,32 +25,48 @@ def fc(input, act=None, name=None): """ - Fully Connected Layer. + **Fully Connected Layer** + + This layer accepts multiple inputs and applies a linear transformation to each input. + If activation type is provided, the corresponding activation function is applied to the + output of the linear transformation. For each input :math:`X`, the equation is: + + .. math:: + + Out = Act(WX + b) + + In the above equation: + + * :math:`X`: Input value, a tensor with rank at least 2. + * :math:`W`: Weight, a 2-D tensor with shape [M, N]. + * :math:`b`: Bias, a 2-D tensor with shape [M, 1]. + * :math:`Act`: Activation function. + * :math:`Out`: Output value, same shape with :math:`X`. + + All the input variables are passed in as local variables to the LayerHelper + constructor. Args: - input: The input tensor to the function - size: The size of the layer - num_flatten_dims: Number of columns in input - param_attr: The parameters/weights to the FC Layer - param_initializer: Initializer used for the weight/parameter. If None, XavierInitializer() is used - bias_attr: The bias parameter for the FC layer - bias_initializer: Initializer used for the bias. If None, then ConstantInitializer() is used - act: Activation to be applied to the output of FC layer - name: Name/alias of the function - main_program: Name of the main program that calls this - startup_program: Name of the startup program - - This function can take in multiple inputs and performs the Fully Connected - function (linear transformation) on top of each of them. - So for input x, the output will be : Wx + b. Where W is the parameter, - b the bias and x is the input. - - The function also applies an activation (non-linearity) on top of the - output, if activation is passed in the input. - - All the input variables of this function are passed in as local variables - to the LayerHelper constructor. + input(Variable|list): Input tensors. Each tensor has a rank of atleast 2 + size(int): Output size + num_flatten_dims(int): Number of columns in input + param_attr(ParamAttr|list): The parameters/weights to the FC Layer + bias_attr(ParamAttr|list): Bias parameter for the FC layer + act(str): Activation type + name(str): Name/alias of the function + + Returns: + Variable: The tensor variable storing the transformation and \ + non-linearity activation result. + + Raises: + ValueError: If rank of input tensor is less than 2. + Examples: + .. code-block:: python + + data = fluid.layers.data(name='data', shape=[32, 32], dtype='float32') + fc = fluid.layers.fc(input=data, size=1000, act="tanh") """ helper = LayerHelper('fc', **locals()) @@ -91,25 +107,30 @@ def fc(input, def embedding(input, size, is_sparse=False, param_attr=None, dtype='float32'): """ - Embedding Layer. + **Embedding Layer** + + This layer is used to lookup a vector of IDs, provided by *input*, in a lookup table. + The result of this lookup is the embedding of each ID in the *input*. + + All the input variables are passed in as local variables to the LayerHelper + constructor. Args: - param_initializer: - input: The input to the function - size: The size of the layer - is_sparse: A flag that decleares whether the input is sparse - param_attr: Parameters for this layer - dtype: The type of data : float32, float_16, int etc - main_program: Name of the main program that calls this - startup_program: Name of the startup program - - This function can take in the input (which is a vector of IDs) and - performs a lookup in the lookup_table using these IDs, to result into - the embedding of each ID in the input. - - All the input variables of this function are passed in as local variables - to the LayerHelper constructor. + input(Variable): Input to the function + size(int): Output size + is_sparse(bool): Boolean flag that specifying whether the input is sparse + param_attr(ParamAttr): Parameters for this layer + dtype(np.dtype|core.DataType|str): The type of data : float32, float_16, int etc + + Returns: + Variable: The tensor variable storing the embeddings of the \ + supplied inputs. + + Examples: + .. code-block:: python + data = fluid.layers.data(name='ids', shape=[32, 32], dtype='float32') + fc = fluid.layers.embedding(input=data, size=16) """ helper = LayerHelper('embedding', **locals()) diff --git a/python/paddle/v2/fluid/layers/tensor.py b/python/paddle/v2/fluid/layers/tensor.py index bda017b141dcba5ac268c34388742c433a533337..e984a6be19f43afca41482f3cf4c52df4409ab92 100644 --- a/python/paddle/v2/fluid/layers/tensor.py +++ b/python/paddle/v2/fluid/layers/tensor.py @@ -66,9 +66,26 @@ def assign(input, output): def fill_constant(shape, dtype, value, out=None): """ - This function creates a tensor , with shape as mentioned in the input and - specified dtype and fills this up with a constant value that - comes in the input. It also sets the stop_gradient to be True. + **fill_constant** + + This function creates a tensor of specified *shape* and + *dtype*, and initializes this with a constant supplied in *value*. + + It also sets *stop_gradient* to True. + + Args: + shape(tuple|list|None): Shape of output tensor + dtype(np.dtype|core.DataType|str): Data type of output tensor + value(float): Constant value to initialize the output tensor + out(Variable): Output Variable to initialize + + Returns: + Variable: The tensor variable storing the output + + Examples: + .. code-block:: python + + data = fluid.layers.fill_constant(shape=[1], value=0, dtype='int64') """ helper = LayerHelper("fill_constant", **locals()) if out is None: @@ -90,6 +107,31 @@ def fill_constant_batch_size_like(input, value, input_dim_idx=0, output_dim_idx=0): + """ + **fill_constant_batch_size_like** + + This function creates a tensor of specified *shape*, *dtype* and batch size, + and initializes this with a constant supplied in *value*. The batch size is + obtained from the `input` tensor. + + It also sets *stop_gradient* to True. + + Args: + input(Variable): Tensor whose dimensions will be used to get batch size + shape(tuple|list|None): Shape of output tensor + dtype(np.dtype|core.DataType|str): Data type of output tensor + value(float): Constant value to initialize the output tensor + input_dim_idx(int): Index of input's batch size dimension + output_dim_idx(int): Index of output's batch size dimension + + Returns: + Variable: The tensor variable storing the output + + Examples: + .. code-block:: python + + data = fluid.layers.fill_constant(shape=[1], value=0, dtype='int64') + """ helper = LayerHelper("fill_constant_batch_size_like", **locals()) out = helper.create_tmp_variable(dtype=dtype) helper.append_op( diff --git a/python/paddle/v2/fluid/tests/test_dropout_op.py b/python/paddle/v2/fluid/tests/test_dropout_op.py index 4f5ea836b44102e5599a2302efd669291ebe920b..2483200212686caf9c46f9c1129b5d8ffdcc9145 100644 --- a/python/paddle/v2/fluid/tests/test_dropout_op.py +++ b/python/paddle/v2/fluid/tests/test_dropout_op.py @@ -47,7 +47,9 @@ class TestDropoutOp4(OpTest): self.op_type = "dropout" self.inputs = {'X': np.random.random((32, 64)).astype("float32")} self.attrs = {'dropout_prob': 0.35, 'is_test': True} - self.outputs = {'Out': self.inputs['X'] * self.attrs['dropout_prob']} + self.outputs = { + 'Out': self.inputs['X'] * (1.0 - self.attrs['dropout_prob']) + } def test_check_output(self): self.check_output() @@ -58,7 +60,9 @@ class TestDropoutOp5(OpTest): self.op_type = "dropout" self.inputs = {'X': np.random.random((32, 64, 3)).astype("float32")} self.attrs = {'dropout_prob': 0.75, 'is_test': True} - self.outputs = {'Out': self.inputs['X'] * self.attrs['dropout_prob']} + self.outputs = { + 'Out': self.inputs['X'] * (1.0 - self.attrs['dropout_prob']) + } def test_check_output(self): self.check_output()