diff --git a/paddle/framework/op_desc.cc b/paddle/framework/op_desc.cc index 781bbb4c19f1c610df485c3061ca8b510e727019..3e58e6442edfe006c8aed238f67b9524783601ee 100644 --- a/paddle/framework/op_desc.cc +++ b/paddle/framework/op_desc.cc @@ -260,7 +260,13 @@ struct SetAttrDescVisitor : public boost::static_visitor { void operator()(int v) const { attr_->set_i(v); } void operator()(float v) const { attr_->set_f(v); } void operator()(const std::string &v) const { attr_->set_s(v); } - void operator()(bool b) const { attr_->set_b(b); } + + // Please refer to https://github.com/PaddlePaddle/Paddle/issues/7162 + template ::value>::type> + void operator()(T b) const { + attr_->set_b(b); + } void operator()(const std::vector &v) const { VectorToRepeated(v, attr_->mutable_ints()); @@ -274,9 +280,7 @@ struct SetAttrDescVisitor : public boost::static_visitor { void operator()(const std::vector &v) const { VectorToRepeated(v, attr_->mutable_bools()); } - void operator()(proto::BlockDesc *desc) const { - attr_->set_block_idx(desc->idx()); - } + void operator()(BlockDesc *desc) const { attr_->set_block_idx(desc->ID()); } void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); } }; diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index bdaa25918155caca4b64b0ed60aa3f6be03eb12f..d75c0233e8e0134ddf4edc50c07490a234b65cd0 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -37,8 +37,8 @@ class Registrar { public: // In our design, various kinds of classes, e.g., operators and kernels, // have their corresponding registry and registrar. The action of - // registration is in the constructor of a global registrar variable, which, - // however, are not used in the code that calls package framework, and would + // registration is in the constructor of a global registrar variable, which + // are not used in the code that calls package framework, and would // be removed from the generated binary file by the linker. To avoid such // removal, we add Touch to all registrar classes and make USE_OP macros to // call this method. So, as long as the callee code calls USE_OP, the global diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index a2f07937b8834e3f3fa7a6bf2ae10f29a8d84f29..ba83667ebc9a89c37f77a7f71e6df90b54723cc0 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -1472,7 +1472,8 @@ TEST(Layer, RecurrentLayer) { for (auto reversed : {false, true}) { config.layerConfig.set_reversed(reversed); config.testState = !reversed; - testLayerGrad(config, "recurrent", 50, /* trans= */ false, useGpu); + testLayerGrad( + config, "recurrent", 50, /* trans= */ false, useGpu, false, 1.0); } } } @@ -1494,7 +1495,8 @@ TEST(Layer, LstmLayer) { for (auto reversed : {false, true}) { config.layerConfig.set_reversed(reversed); config.testState = !reversed; - testLayerGrad(config, "lstmemory", 100, /* trans= */ false, useGpu); + testLayerGrad( + config, "lstmemory", 100, /* trans= */ false, useGpu, false, 0.02); } } for (auto useGpu : {true}) { diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 9f603474de2f845822651c707174a8804ecf5aad..77b52eb1760c7b79aa47efb447a4c99b6ab5e027 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -61,106 +61,28 @@ function(op_library TARGET) ${op_common_deps}) endif() - # net_op doesn't need pybind - if ("${TARGET}" STREQUAL "net_op") - set(pybind_flag 1) - endif() - - if ("${TARGET}" STREQUAL "compare_op") - set(pybind_flag 1) - file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(equal);\n") - endif() - - # conv_op contains several operators - if ("${TARGET}" STREQUAL "conv_op") - set(pybind_flag 1) - # It's enough to just adding one operator to pybind - file(APPEND ${pybind_file} "USE_OP(conv2d);\n") - endif() - - # conv_cudnn_op contains several operators - if ("${TARGET}" STREQUAL "conv_cudnn_op") - set(pybind_flag 1) - # It's enough to just adding one operator to pybind - file(APPEND ${pybind_file} "USE_OP(conv2d_cudnn);\n") - endif() - - # pool_op contains several operators - if ("${TARGET}" STREQUAL "pool_op") - set(pybind_flag 1) - # It's enough to just adding one operator to pybind - file(APPEND ${pybind_file} "USE_OP(pool2d);\n") - endif() - - # pool_cudnn_op contains several operators - if ("${TARGET}" STREQUAL "pool_cudnn_op") - set(pybind_flag 1) - # It's enough to just adding one operator to pybind - file(APPEND ${pybind_file} "USE_OP(pool2d_cudnn);\n") - endif() - - if ("${TARGET}" STREQUAL "logical_op") - set(pybind_flag 1) - file(APPEND ${pybind_file} "USE_OP(logical_and);\n") - endif() - - # pool_with_index_op contains several operators - if ("${TARGET}" STREQUAL "pool_with_index_op") - set(pybind_flag 1) - # It's enough to just adding one operator to pybind - file(APPEND ${pybind_file} "USE_OP(max_pool2d_with_index);\n") - endif() - - # conv_transpose_op contains several operators - if ("${TARGET}" STREQUAL "conv_transpose_op") - set(pybind_flag 1) - # It's enough to just adding one operator to pybind - file(APPEND ${pybind_file} "USE_OP(conv2d_transpose);\n") - endif() - - # conv_transpose_cudnn_op contains two operators - if ("${TARGET}" STREQUAL "conv_transpose_cudnn_op") - set(pybind_flag 1) - # It's enough to just adding one operator to pybind - file(APPEND ${pybind_file} "USE_OP(conv2d_transpose_cudnn);\n") - endif() - - # save_restore_op contains several operators - if ("${TARGET}" STREQUAL "save_restore_op") - set(pybind_flag 1) - # It's enough to just adding one operator to pybind - file(APPEND ${pybind_file} "USE_NO_KERNEL_OP(save);\n") - endif() - - # activation_op contains several operators - if ("${TARGET}" STREQUAL "activation_op") - set(pybind_flag 1) - # It's enough to just adding one operator to pybind - file(APPEND ${pybind_file} "USE_OP(sigmoid);\n") - endif() - - # nccl_op contains several operators - if ("${TARGET}" STREQUAL "nccl_op") - set(pybind_flag 1) - # It's enough to just adding one operator to pybind - file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(ncclAllReduce);\n") - endif() - - # reduce_op contains several operators - if ("${TARGET}" STREQUAL "reduce_op") - set(pybind_flag 1) - # It's enough to just adding one operator to pybind - file(APPEND ${pybind_file} "USE_OP(reduce_sum);\n") - endif() + # Define operators that don't need pybind here. + foreach(manual_pybind_op "net_op" "compare_op" "logical_op" "nccl_op" "tensor_array_read_write_op") + if ("${TARGET}" STREQUAL "${manual_pybind_op}") + set(pybind_flag 1) + endif() + endforeach() - if ("${TARGET}" STREQUAL "tensor_array_read_write_op") - set(pybind_flag 1) - file(APPEND ${pybind_file} "USE_NO_KERNEL_OP(read_from_array);\nUSE_NO_KERNEL_OP(write_to_array);\n") + # The registration of USE_OP, please refer to paddle/framework/op_registry.h. + # Note that it's enough to just adding one operator to pybind in a *_op.cc file. + # And for detail pybind information, please see generated paddle/pybind/pybind.h. + file(READ ${TARGET}.cc TARGET_CONTENT) + string(REGEX MATCH "REGISTER_OP\\(.*REGISTER_OP\\(" multi_register "${TARGET_CONTENT}") + string(REGEX MATCH "REGISTER_OP\\([a-z0-9_]*," one_register "${multi_register}") + if (one_register STREQUAL "") + string(REPLACE "_op" "" TARGET "${TARGET}") + else () + string(REPLACE "REGISTER_OP(" "" TARGET "${one_register}") + string(REPLACE "," "" TARGET "${TARGET}") endif() # pybind USE_NO_KERNEL_OP # HACK: if REGISTER_OP_CPU_KERNEL presents the operator must have kernel - file(READ ${TARGET}.cc TARGET_CONTENT) string(REGEX MATCH "REGISTER_OP_CPU_KERNEL" regex_result "${TARGET_CONTENT}") string(REPLACE "_op" "" TARGET "${TARGET}") if (${pybind_flag} EQUAL 0 AND regex_result STREQUAL "") @@ -171,7 +93,6 @@ function(op_library TARGET) # pybind USE_CPU_ONLY_OP list(LENGTH cu_srcs cu_srcs_len) list(LENGTH cu_cc_srcs cu_cc_srcs_len) - if (${pybind_flag} EQUAL 0 AND ${cu_srcs_len} EQUAL 0 AND ${cu_cc_srcs_len} EQUAL 0) file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(${TARGET});\n") set(pybind_flag 1) @@ -188,6 +109,7 @@ add_subdirectory(nccl) if(WITH_GPU) op_library(nccl_op DEPS nccl_common) + file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(ncclAllReduce);\n") else() set(DEPS_OPS ${DEPS_OPS} nccl_op) endif() @@ -238,6 +160,8 @@ list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS}) foreach(src ${GENERAL_OPS}) op_library(${src}) endforeach() +file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\n") + set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library") diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py index 55d8bf8a8a60a832000f7119a8bc039127ab1f3a..1c1c09dd28b7563cfb340370a495d7e7c66988cc 100644 --- a/python/paddle/v2/fluid/layers/nn.py +++ b/python/paddle/v2/fluid/layers/nn.py @@ -151,7 +151,7 @@ def embedding(input, size, is_sparse=False, param_attr=None, dtype='float32'): Args: input(Variable): Input to the function - size(tuple|list|None): Shape of the look up table parameter + size(tuple|list|None): Shape of the look up table parameter is_sparse(bool): Boolean flag that specifying whether the input is sparse param_attr(ParamAttr): Parameters for this layer dtype(np.dtype|core.DataType|str): The type of data : float32, float_16, int etc @@ -366,9 +366,9 @@ def cross_entropy(input, label, **kwargs): 1) One-hot cross-entropy: `soft_label = False`, `Label[i, 0]` indicates the class index for sample i: - + .. math:: - + Y[i] = -\log(X[i, Label[i]]) 2) Soft-label cross-entropy: @@ -386,15 +386,15 @@ def cross_entropy(input, label, **kwargs): As a special case of 2), when each row of 'label' has only one non-zero element which is equal to 1, soft-label cross-entropy degenerates to a one-hot cross-entropy with one-hot label representation. - + Args: - input (Variable|list): a 2-D tensor with shape [N x D], where N is the - batch size and D is the number of classes. This input is a probability + input (Variable|list): a 2-D tensor with shape [N x D], where N is the + batch size and D is the number of classes. This input is a probability computed by the previous operator, which is almost always the result of a softmax operator. - label (Variable|list): the ground truth which is a 2-D tensor. When - `soft_label` is set to `False`, `label` is a tensor with shape - [N x 1]. When `soft_label` is set to `True`, `label` is a + label (Variable|list): the ground truth which is a 2-D tensor. When + `soft_label` is set to `False`, `label` is a tensor with shape + [N x 1]. When `soft_label` is set to `True`, `label` is a tensor with shape [N x D]. soft_label (bool, via `**kwargs`): a flag indicating whether to interpretate the given labels as soft labels, default `False`. @@ -403,7 +403,7 @@ def cross_entropy(input, label, **kwargs): A 2-D tensor with shape [N x 1], the cross entropy loss. Raises: - `ValueError`: 1) the 1st dimension of `input` and `label` are not equal; 2) when \ + `ValueError`: 1) the 1st dimension of `input` and `label` are not equal; 2) when \ `soft_label == True`, and the 2nd dimension of `input` and `label` are not \ equal; 3) when `soft_label == False`, and the 2nd dimension of `label` is not 1. @@ -727,9 +727,9 @@ def conv2d(input, def sequence_pool(input, pool_type, **kwargs): """ - This function add the operator for sequence pooling. - It pools features of all time-steps of each instance, and is applied - on top of the input using pool_type mentioned in the parameters. + This function add the operator for sequence pooling. + It pools features of all time-steps of each instance, and is applied + on top of the input using pool_type mentioned in the parameters. It supports four pool_type: @@ -758,7 +758,7 @@ def sequence_pool(input, pool_type, **kwargs): Args: input(variable): The input variable which is a LoDTensor. - pool_type (string): The pooling type of sequence_pool. + pool_type (string): The pooling type of sequence_pool. It supports average, sum, sqrt and max. Returns: @@ -768,7 +768,7 @@ def sequence_pool(input, pool_type, **kwargs): .. code-block:: python - x = fluid.layers.data(name='x', shape=[7, 1], + x = fluid.layers.data(name='x', shape=[7, 1], dtype='float32', lod_level=1) avg_x = fluid.layers.sequence_pool(input=x, pool_type='average') sum_x = fluid.layers.sequence_pool(input=x, pool_type='sum') @@ -816,7 +816,7 @@ def sequence_first_step(input, **kwargs): .. code-block:: python - x = fluid.layers.data(name='x', shape=[7, 1], + x = fluid.layers.data(name='x', shape=[7, 1], dtype='float32', lod_level=1) x_first_step = fluid.layers.sequence_first_step(input=x) """ @@ -849,7 +849,7 @@ def sequence_last_step(input, **kwargs): .. code-block:: python - x = fluid.layers.data(name='x', shape=[7, 1], + x = fluid.layers.data(name='x', shape=[7, 1], dtype='float32', lod_level=1) x_last_step = fluid.layers.sequence_last_step(input=x) """ @@ -1168,25 +1168,26 @@ def lstm_unit(x_t, .. math:: - i_t & = \sigma(W_{x_i}x_{t} + W_{h_i}h_{t-1} + W_{c_i}c_{t-1} + b_i) + i_t & = \sigma(W_{x_i}x_{t} + W_{h_i}h_{t-1} + b_i) - f_t & = \sigma(W_{x_f}x_{t} + W_{h_f}h_{t-1} + W_{c_f}c_{t-1} + b_f) + f_t & = \sigma(W_{x_f}x_{t} + W_{h_f}h_{t-1} + b_f) - c_t & = f_tc_{t-1} + i_t tanh (W_{x_c}x_t+W_{h_c}h_{t-1} + b_c) + c_t & = f_tc_{t-1} + i_t tanh (W_{x_c}x_t + W_{h_c}h_{t-1} + b_c) - o_t & = \sigma(W_{x_o}x_{t} + W_{h_o}h_{t-1} + W_{c_o}c_t + b_o) + o_t & = \sigma(W_{x_o}x_{t} + W_{h_o}h_{t-1} + b_o) h_t & = o_t tanh(c_t) - The inputs of lstm unit includes :math:`x_t`, :math:`h_{t-1}` and - :math:`c_{t-1}`. The implementation separates the linear transformation - and non-linear transformation apart. Here, we take :math:`i_t` as an - example. The linear transformation is applied by calling a `fc` layer and - the equation is: + The inputs of lstm unit include :math:`x_t`, :math:`h_{t-1}` and + :math:`c_{t-1}`. The 2nd dimensions of :math:`h_{t-1}` and :math:`c_{t-1}` + should be same. The implementation separates the linear transformation and + non-linear transformation apart. Here, we take :math:`i_t` as an example. + The linear transformation is applied by calling a `fc` layer and the + equation is: .. math:: - L_{i_t} = W_{x_i}x_{t} + W_{h_i}h_{t-1} + W_{c_i}c_{t-1} + b_i + L_{i_t} = W_{x_i}x_{t} + W_{h_i}h_{t-1} + b_i The non-linear transformation is applied by calling `lstm_unit_op` and the equation is: @@ -1198,9 +1199,12 @@ def lstm_unit(x_t, This layer has two outputs including :math:`h_t` and :math:`o_t`. Args: - x_t (Variable): The input value of current step. - hidden_t_prev (Variable): The hidden value of lstm unit. - cell_t_prev (Variable): The cell value of lstm unit. + x_t (Variable): The input value of current step, a 2-D tensor with shape + M x N, M for batch size and N for input size. + hidden_t_prev (Variable): The hidden value of lstm unit, a 2-D tensor + with shape M x S, M for batch size and S for size of lstm unit. + cell_t_prev (Variable): The cell value of lstm unit, a 2-D tensor with + shape M x S, M for batch size and S for size of lstm unit. forget_bias (float): The forget bias of lstm unit. param_attr (ParamAttr): The attributes of parameter weights, used to set initializer, name etc. @@ -1213,14 +1217,15 @@ def lstm_unit(x_t, Raises: ValueError: The ranks of **x_t**, **hidden_t_prev** and **cell_t_prev**\ not be 2 or the 1st dimensions of **x_t**, **hidden_t_prev** \ - and **cell_t_prev** not be the same. + and **cell_t_prev** not be the same or the 2nd dimensions of \ + **hidden_t_prev** and **cell_t_prev** not be the same. Examples: .. code-block:: python x_t = fluid.layers.fc(input=x_t_data, size=10) - prev_hidden = fluid.layers.fc(input=prev_hidden_data, size=20) + prev_hidden = fluid.layers.fc(input=prev_hidden_data, size=30) prev_cell = fluid.layers.fc(input=prev_cell_data, size=30) hidden_value, cell_value = fluid.layers.lstm_unit(x_t=x_t, hidden_t_prev=prev_hidden, @@ -1239,7 +1244,11 @@ def lstm_unit(x_t, if x_t.shape[0] != hidden_t_prev.shape[0] or x_t.shape[ 0] != cell_t_prev.shape[0]: - raise ValueError("The 1s dimension of x_t, hidden_t_prev and " + raise ValueError("The 1st dimensions of x_t, hidden_t_prev and " + "cell_t_prev must be the same.") + + if hidden_t_prev.shape[1] != cell_t_prev.shape[1]: + raise ValueError("The 2nd dimensions of hidden_t_prev and " "cell_t_prev must be the same.") if bias_attr is None: @@ -1268,17 +1277,17 @@ def lstm_unit(x_t, def reduce_sum(input, dim=None, keep_dim=False): """ - Computes the sum of tensor elements over the given dimension. + Computes the sum of tensor elements over the given dimension. Args: input (Variable): The input variable which is a Tensor or LoDTensor. - dim (int|None): The dimension along which the sum is performed. If - :attr:`None`, sum all elements of :attr:`input` and return a - Tensor variable with a single element, otherwise must be in the - range :math:`[-rank(input), rank(input))`. If :math:`dim < 0`, + dim (int|None): The dimension along which the sum is performed. If + :attr:`None`, sum all elements of :attr:`input` and return a + Tensor variable with a single element, otherwise must be in the + range :math:`[-rank(input), rank(input))`. If :math:`dim < 0`, the dimension to reduce is :math:`rank + dim`. - keep_dim (bool): Whether to reserve the reduced dimension in the - output Tensor. The result tensor will have one fewer dimension + keep_dim (bool): Whether to reserve the reduced dimension in the + output Tensor. The result tensor will have one fewer dimension than the :attr:`input` unless :attr:`keep_dim` is true. Returns: @@ -1312,17 +1321,17 @@ def reduce_sum(input, dim=None, keep_dim=False): def reduce_mean(input, dim=None, keep_dim=False): """ - Computes the mean of tensor elements over the given dimension. + Computes the mean of tensor elements over the given dimension. Args: input (Variable): The input variable which is a Tensor or LoDTensor. - dim (int|None): The dimension along which the mean is computed. If - :attr:`None`, compute the mean over all elements of :attr:`input` - and return a Tensor variable with a single element, otherwise - must be in the range :math:`[-rank(input), rank(input))`. If + dim (int|None): The dimension along which the mean is computed. If + :attr:`None`, compute the mean over all elements of :attr:`input` + and return a Tensor variable with a single element, otherwise + must be in the range :math:`[-rank(input), rank(input))`. If :math:`dim < 0`, the dimension to reduce is :math:`rank + dim`. - keep_dim (bool): Whether to reserve the reduced dimension in the - output Tensor. The result tensor will have one fewer dimension + keep_dim (bool): Whether to reserve the reduced dimension in the + output Tensor. The result tensor will have one fewer dimension than the :attr:`input` unless :attr:`keep_dim` is true. Returns: @@ -1356,22 +1365,22 @@ def reduce_mean(input, dim=None, keep_dim=False): def reduce_max(input, dim=None, keep_dim=False): """ - Computes the maximum of tensor elements over the given dimension. + Computes the maximum of tensor elements over the given dimension. Args: input (Variable): The input variable which is a Tensor or LoDTensor. - dim (int|None): The dimension along which the maximum is computed. - If :attr:`None`, compute the maximum over all elements of - :attr:`input` and return a Tensor variable with a single element, - otherwise must be in the range :math:`[-rank(input), rank(input))`. + dim (int|None): The dimension along which the maximum is computed. + If :attr:`None`, compute the maximum over all elements of + :attr:`input` and return a Tensor variable with a single element, + otherwise must be in the range :math:`[-rank(input), rank(input))`. If :math:`dim < 0`, the dimension to reduce is :math:`rank + dim`. - keep_dim (bool): Whether to reserve the reduced dimension in the - output Tensor. The result tensor will have one fewer dimension + keep_dim (bool): Whether to reserve the reduced dimension in the + output Tensor. The result tensor will have one fewer dimension than the :attr:`input` unless :attr:`keep_dim` is true. Returns: Variable: The reduced Tensor variable. - + Examples: .. code-block:: python @@ -1400,22 +1409,22 @@ def reduce_max(input, dim=None, keep_dim=False): def reduce_min(input, dim=None, keep_dim=False): """ - Computes the minimum of tensor elements over the given dimension. + Computes the minimum of tensor elements over the given dimension. Args: input (Variable): The input variable which is a Tensor or LoDTensor. - dim (int|None): The dimension along which the minimum is computed. - If :attr:`None`, compute the minimum over all elements of - :attr:`input` and return a Tensor variable with a single element, - otherwise must be in the range :math:`[-rank(input), rank(input))`. + dim (int|None): The dimension along which the minimum is computed. + If :attr:`None`, compute the minimum over all elements of + :attr:`input` and return a Tensor variable with a single element, + otherwise must be in the range :math:`[-rank(input), rank(input))`. If :math:`dim < 0`, the dimension to reduce is :math:`rank + dim`. - keep_dim (bool): Whether to reserve the reduced dimension in the - output Tensor. The result tensor will have one fewer dimension + keep_dim (bool): Whether to reserve the reduced dimension in the + output Tensor. The result tensor will have one fewer dimension than the :attr:`input` unless :attr:`keep_dim` is true. Returns: Variable: The reduced Tensor variable. - + Examples: .. code-block:: python diff --git a/python/paddle/v2/fluid/tests/test_layers.py b/python/paddle/v2/fluid/tests/test_layers.py index 9d2dcca56dd1361b9e2448be9f1d5403f8ee17e3..77f0f11f1bcd5fa88700a33eec5a2abc2666ed02 100644 --- a/python/paddle/v2/fluid/tests/test_layers.py +++ b/python/paddle/v2/fluid/tests/test_layers.py @@ -177,8 +177,8 @@ class TestBook(unittest.TestCase): name='x_t_data', shape=[10, 10], dtype='float32') x_t = layers.fc(input=x_t_data, size=10) prev_hidden_data = layers.data( - name='prev_hidden_data', shape=[10, 20], dtype='float32') - prev_hidden = layers.fc(input=prev_hidden_data, size=20) + name='prev_hidden_data', shape=[10, 30], dtype='float32') + prev_hidden = layers.fc(input=prev_hidden_data, size=30) prev_cell_data = layers.data( name='prev_cell', shape=[10, 30], dtype='float32') prev_cell = layers.fc(input=prev_cell_data, size=30)