diff --git a/doc/api/v2/fluid/layers.rst b/doc/api/v2/fluid/layers.rst index 550b0e5b82609750ccd318eee889313cb2d7925a..f873c93d9a3424497c089fa7ee44122856090610 100644 --- a/doc/api/v2/fluid/layers.rst +++ b/doc/api/v2/fluid/layers.rst @@ -18,6 +18,11 @@ dynamic_lstm .. autofunction:: paddle.v2.fluid.layers.dynamic_lstm :noindex: +dynamic_gru +----------- +.. autofunction:: paddle.v2.fluid.layers.dynamic_gru + :noindex: + data ---- .. autofunction:: paddle.v2.fluid.layers.data diff --git a/doc/getstarted/build_and_install/docker_install_cn.rst b/doc/getstarted/build_and_install/docker_install_cn.rst index bae42593ddc6f7a7eb47d603752ad6efa9820b45..98fada7bdb46f4dd2927d6f93bcbcebbe7d18604 100644 --- a/doc/getstarted/build_and_install/docker_install_cn.rst +++ b/doc/getstarted/build_and_install/docker_install_cn.rst @@ -25,14 +25,14 @@ .. code-block:: bash - docker pull docker.paddlepaddle.org/paddle + docker pull docker.paddlepaddlehub.com/paddle 下载GPU版本(cuda8.0_cudnn5_avx_mkl)的Docker镜像: .. code-block:: bash docker pull paddlepaddle/paddle:latest-gpu - docker pull docker.paddlepaddle.org/paddle:latest-gpu + docker pull docker.paddlepaddlehub.com/paddle:latest-gpu 选择下载使用不同的BLAS库的Docker镜像: @@ -49,7 +49,7 @@ docker pull paddlepaddle/paddle:[tag] # 比如: - docker pull docker.paddlepaddle.org/paddle:0.10.0-gpu + docker pull docker.paddlepaddlehub.com/paddle:0.11.0-gpu .. _docker_run: diff --git a/doc/getstarted/build_and_install/docker_install_en.rst b/doc/getstarted/build_and_install/docker_install_en.rst index 56a7c68e4d39c45249fa55a964dc48b7081596a6..b1d0890b4cdddb77114a80276130afd07c22d270 100644 --- a/doc/getstarted/build_and_install/docker_install_en.rst +++ b/doc/getstarted/build_and_install/docker_install_en.rst @@ -26,14 +26,14 @@ For users in China, we provide a faster mirror: .. code-block:: bash - docker pull docker.paddlepaddle.org/paddle + docker pull docker.paddlepaddlehub.com/paddle Download GPU version (cuda8.0_cudnn5_avx_mkl) images: .. code-block:: bash docker pull paddlepaddle/paddle:latest-gpu - docker pull docker.paddlepaddle.org/paddle:latest-gpu + docker pull docker.paddlepaddlehub.com/paddle:latest-gpu Choose between different BLAS version: @@ -53,7 +53,7 @@ and run: docker pull paddlepaddle/paddle:[tag] # i.e. - docker pull docker.paddlepaddle.org/paddle:0.10.0-gpu + docker pull docker.paddlepaddlehub.com/paddle:0.11.0-gpu .. _docker_run: diff --git a/paddle/framework/variable_test.cc b/paddle/framework/variable_test.cc index e4732d9718e2b46a068963d44c4c1e04024f2330..e5585c8724d712e273d086001b6cbc3d59c46ebe 100644 --- a/paddle/framework/variable_test.cc +++ b/paddle/framework/variable_test.cc @@ -12,19 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -/* - Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - #include #include diff --git a/paddle/operators/bipartite_match_op.cc b/paddle/operators/bipartite_match_op.cc index b0f7376d272a66e0b01d6b3f7e546372397772f7..83c8778fe4cec4d9d80de691e117a39fdd92f494 100644 --- a/paddle/operators/bipartite_match_op.cc +++ b/paddle/operators/bipartite_match_op.cc @@ -21,8 +21,6 @@ namespace operators { using Tensor = framework::Tensor; using LoDTensor = framework::LoDTensor; -constexpr char kEPS = 1e-6; - class BipartiteMatchOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -46,6 +44,7 @@ class BipartiteMatchKernel : public framework::OpKernel { // The match_dist must be initialized to 0 at first. void BipartiteMatch(const Tensor& dist, int* match_indices, T* match_dist) const { + constexpr T kEPS = static_cast(1e-6); PADDLE_ENFORCE_EQ(dist.dims().size(), 2, "The rank of dist must be 2."); int64_t row = dist.dims()[0]; int64_t col = dist.dims()[1]; diff --git a/paddle/operators/nce_op.cc b/paddle/operators/nce_op.cc index 84ba3ead2b52547b989a4541f31ea31ffcce6c63..994ddf717e7a5b883d8071c6a47da0b4b4074f2e 100644 --- a/paddle/operators/nce_op.cc +++ b/paddle/operators/nce_op.cc @@ -124,7 +124,8 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker { "This attribute only be used in unitest. Classes " "in this list wiil be used as negative classes " "for every samples. Under normal conditions, " - "user should avoid setting this attribute."); + "user should avoid setting this attribute.") + .SetDefault({}); AddComment(R"DOC( Compute and return the noise-contrastive estimation training loss. See [Noise-contrastive estimation: A new estimation principle for unnormalized statistical models](http://www.jmlr.org/proceedings/papers/v9/gutmann10a/gutmann10a.pdf). diff --git a/paddle/operators/nce_op.h b/paddle/operators/nce_op.h index e6b496f7896dcb412be8ff096fdccb2f0b682369..86fa13a649ce7fdcaad64e2609ceea2fb4d7e072 100644 --- a/paddle/operators/nce_op.h +++ b/paddle/operators/nce_op.h @@ -197,7 +197,8 @@ class NCEGradKernel : public framework::OpKernel { // get d_x auto d_x = context.Output(framework::GradVarName("Input")); if (d_x != nullptr) { - d_x->mutable_data(context.GetPlace()); + auto* d_x_data = d_x->mutable_data(context.GetPlace()); + std::fill(d_x_data, d_x_data + d_x->numel(), 0.0); auto d_x_matrix = EigenMatrix::From(*d_x); auto w_matrix = EigenMatrix::From(*(context.Input("Weight"))); for (int64_t i = 0; i < sample_labels->numel(); ++i) { diff --git a/python/paddle/v2/dataset/wmt16.py b/python/paddle/v2/dataset/wmt16.py index e2f463be2f7bcd667855f64206d78f387e92ef33..c8818f715beadd9499ae588f2c19a57fbf26f372 100644 --- a/python/paddle/v2/dataset/wmt16.py +++ b/python/paddle/v2/dataset/wmt16.py @@ -305,9 +305,9 @@ def get_dict(lang, dict_size, reverse=False): dict_path = os.path.join(paddle.v2.dataset.common.DATA_HOME, "wmt16/%s_%d.dict" % (lang, dict_size)) - assert (os.path.exists(dict_path), "Word dictionary does not exist. " - "Please invoke paddle.dataset.wmt16.train/test/validation " - "first to build the dictionary.") + assert os.path.exists(dict_path), "Word dictionary does not exist. " + "Please invoke paddle.dataset.wmt16.train/test/validation first " + "to build the dictionary." tar_file = os.path.join(paddle.v2.dataset.common.DATA_HOME, "wmt16.tar.gz") return __load_dict(tar_file, dict_size, lang, reverse) diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py index 072119881644c650c3430c70bdab42f8d17df7ba..930cd742bbdfdf193e88af713647778efe8c4de5 100644 --- a/python/paddle/v2/fluid/layers/nn.py +++ b/python/paddle/v2/fluid/layers/nn.py @@ -19,12 +19,14 @@ from ..layer_helper import LayerHelper from ..initializer import Normal, Constant from ..framework import Variable from ..param_attr import ParamAttr +from layer_function_generator import autodoc from tensor import concat __all__ = [ 'fc', 'embedding', 'dynamic_lstm', + 'dynamic_gru', 'gru_unit', 'linear_chain_crf', 'crf_decoding', @@ -57,6 +59,7 @@ __all__ = [ 'warpctc', 'sequence_reshape', 'transpose', + 'nce', ] @@ -366,6 +369,113 @@ def dynamic_lstm(input, return hidden, cell +def dynamic_gru(input, + size, + param_attr=None, + bias_attr=None, + is_reverse=False, + gate_activation='sigmoid', + candidate_activation='tanh', + h_0=None): + """ + **Dynamic GRU Layer** + + Refer to `Empirical Evaluation of Gated Recurrent Neural Networks on + Sequence Modeling `_ + + The formula is as follows: + + .. math:: + + u_t & = act_g(W_{ux}x_{t} + W_{uh}h_{t-1} + b_u) + + r_t & = act_g(W_{rx}x_{t} + W_{rh}h_{t-1} + b_r) + + \\tilde{h_t} & = act_c(W_{cx}x_{t} + W_{ch}(r_t \odot h_{t-1}) + b_c) + + h_t & = (1-u_t) \odot h_{t-1} + u_t \odot \\tilde{h_t} + + The :math:`\odot` is the element-wise product of the vectors. :math:`act_g` + is the update gate and reset gate activation function and :math:`sigmoid` + is usually used for it. :math:`act_c` is the activation function for + candidate hidden state and :math:`tanh` is usually used for it. + + Note that these :math:`W_{ux}x_{t}, W_{rx}x_{t}, W_{cx}x_{t}` operations on + the input :math:`x_{t}` are NOT included in this operator. Users can choose + to use fully-connect layer before GRU layer. + + Args: + input(Variable): The input of dynamic_gru layer, which supports + variable-time length input sequence. The underlying tensor in this + Variable is a matrix with shape :math:`(T \\times 3D)`, where + :math:`T` is the total time steps in this mini-batch, :math:`D` + is the hidden size. + size(int): The dimension of the gru cell. + param_attr(ParamAttr|None): The parameter attribute for the learnable + hidden-hidden weight matrix. Note: + + - The shape of the weight matrix is :math:`(T \\times 3D)`, where + :math:`D` is the hidden size. + - All elements in the weight matrix can be divided into two parts. + The first part are weights of the update gate and reset gate with + shape :math:`(D \\times 2D)`, and the second part are weights for + candidate hidden state with shape :math:`(D \\times D)`. + bias_attr(ParamAttr): The parameter attribute for learnable the + hidden-hidden bias. + is_reverse(bool): Whether to compute reversed GRU, default + :attr:`False`. + gate_activation(str): The activation for update gate and reset gate. + Choices = ["sigmoid", "tanh", "relu", "identity"], default "sigmoid". + activation(str): The activation for candidate hidden state. + Choices = ["sigmoid", "tanh", "relu", "identity"], default "tanh". + + Returns: + Variable: The hidden state of GRU. The shape is (T \\times D), and lod \ + is the same with the input. + + Examples: + .. code-block:: python + + hidden_dim = 512 + x = fluid.layers.fc(input=data, size=hidden_dim * 3) + hidden = fluid.layers.dynamic_gru(input=x, dim=hidden_dim) + """ + + helper = LayerHelper('gru', **locals()) + dtype = helper.input_dtype() + + weight = helper.create_parameter( + attr=helper.param_attr, shape=[size, 3 * size], dtype=dtype) + bias = helper.create_parameter( + attr=helper.bias_attr, shape=[1, 3 * size], dtype=dtype, is_bias=True) + inputs = {'Input': input, 'Weight': weight, 'Bias': bias} + if h_0 != None: + assert h_0.shape == ( + size, size), 'The shape of h0 should be(%d, %d)' % (size, size) + inputs['h0'] = h_0 + + hidden = helper.create_tmp_variable(dtype) + batch_gate = helper.create_tmp_variable(dtype) + batch_reset_hidden_prev = helper.create_tmp_variable(dtype) + batch_hidden = helper.create_tmp_variable(dtype) + + helper.append_op( + type='gru', + inputs=inputs, + outputs={ + 'Hidden': hidden, + 'BatchGate': batch_gate, + 'BatchResetHiddenPrev': batch_reset_hidden_prev, + 'BatchHidden': batch_hidden + }, + attrs={ + 'is_reverse': is_reverse, + 'gate_activation': gate_activation, + 'activation': candidate_activation + }) + return hidden + + def gru_unit(input, hidden, size, @@ -2190,6 +2300,61 @@ def sequence_reshape(input, new_dim): return out +@autodoc() +def nce(input, + label, + num_total_classes, + sample_weight=None, + param_attr=None, + bias_attr=None, + num_neg_samples=None): + helper = LayerHelper('nce', **locals()) + assert isinstance(input, Variable) + dim = input.shape[1] + assert isinstance(label, Variable) + num_true_class = label.shape[1] + w = helper.create_parameter( + attr=helper.param_attr, + shape=[num_total_classes, dim], + is_bias=False, + dtype=input.dtype) + b = helper.create_parameter( + attr=helper.bias_attr, + shape=[num_total_classes, 1], + is_bias=True, + dtype=input.dtype) + cost = helper.create_tmp_variable(dtype=input.dtype) + sample_logits = helper.create_tmp_variable(dtype=input.dtype) + sample_labels = helper.create_tmp_variable(dtype=label.dtype) + + if num_neg_samples is None: + num_neg_samples = 10 + else: + num_neg_samples = int(num_neg_samples) + + attrs = { + 'num_total_classes': int(num_total_classes), + 'num_neg_samples': num_neg_samples + } + + helper.append_op( + type='nce', + inputs={ + 'Input': input, + 'Label': label, + 'Weight': w, + 'Bias': b, + 'SampleWeight': sample_weight if sample_weight is not None else [] + }, + outputs={ + 'Cost': cost, + 'SampleLogits': sample_logits, + 'SampleLabels': sample_labels + }, + attrs=attrs) + return cost / (num_neg_samples + 1) + + def transpose(x, perm, name=None): """ **transpose Layer** diff --git a/python/paddle/v2/fluid/tests/test_bipartite_match_op.py b/python/paddle/v2/fluid/tests/test_bipartite_match_op.py index 34101b1da46d46d0e7a995ba80d8644dc586065d..74138298978c7c18936f53761b313887f07aea81 100644 --- a/python/paddle/v2/fluid/tests/test_bipartite_match_op.py +++ b/python/paddle/v2/fluid/tests/test_bipartite_match_op.py @@ -16,13 +16,13 @@ import numpy as np from op_test import OpTest -def bipartite_match(distance, match_indices, match_dis): +def bipartite_match(distance, match_indices, match_dist): """Bipartite Matching algorithm. Arg: distance (numpy.array) : The distance of two entries with shape [M, N]. match_indices (numpy.array): the matched indices from column to row with shape [1, N], it must be initialized to -1. - match_dis (numpy.array): The matched distance from column to row + match_dist (numpy.array): The matched distance from column to row with shape [1, N], it must be initialized to 0. """ match_pair = [] @@ -36,13 +36,13 @@ def bipartite_match(distance, match_indices, match_dis): row_indices = -1 * np.ones((row, ), dtype=np.int) idx = 0 - for i, j, dis in match_sorted: + for i, j, dist in match_sorted: if idx >= row: break - if match_indices[j] == -1 and row_indices[i] == -1 and dis > 0: + if match_indices[j] == -1 and row_indices[i] == -1 and dist > 0: match_indices[j] = i row_indices[i] = j - match_dis[j] = dis + match_dist[j] = dist idx += 1 @@ -55,24 +55,24 @@ def batch_bipartite_match(distance, lod): n = len(lod) - 1 m = distance.shape[1] match_indices = -1 * np.ones((n, m), dtype=np.int) - match_dis = np.zeros((n, m), dtype=np.float32) + match_dist = np.zeros((n, m), dtype=np.float32) for i in range(len(lod) - 1): bipartite_match(distance[lod[i]:lod[i + 1], :], match_indices[i, :], - match_dis[i, :]) - return match_indices, match_dis + match_dist[i, :]) + return match_indices, match_dist class TestBipartiteMatchOpForWithLoD(OpTest): def setUp(self): self.op_type = 'bipartite_match' lod = [[0, 5, 11, 23]] - dis = np.random.random((23, 217)).astype('float32') - match_indices, match_dis = batch_bipartite_match(dis, lod[0]) + dist = np.random.random((23, 217)).astype('float32') + match_indices, match_dist = batch_bipartite_match(dist, lod[0]) - self.inputs = {'DistMat': (dis, lod)} + self.inputs = {'DistMat': (dist, lod)} self.outputs = { 'ColToRowMatchIndices': (match_indices), - 'ColToRowMatchDis': (match_dis), + 'ColToRowMatchDis': (match_dist), } def test_check_output(self): @@ -83,13 +83,13 @@ class TestBipartiteMatchOpWithoutLoD(OpTest): def setUp(self): self.op_type = 'bipartite_match' lod = [[0, 8]] - dis = np.random.random((8, 17)).astype('float32') - match_indices, match_dis = batch_bipartite_match(dis, lod[0]) + dist = np.random.random((8, 17)).astype('float32') + match_indices, match_dist = batch_bipartite_match(dist, lod[0]) - self.inputs = {'DistMat': dis} + self.inputs = {'DistMat': dist} self.outputs = { - 'ColToRowMatchIndices': (match_indices), - 'ColToRowMatchDis': (match_dis), + 'ColToRowMatchIndices': match_indices, + 'ColToRowMatchDis': match_dist, } def test_check_output(self): diff --git a/python/paddle/v2/fluid/tests/test_layers.py b/python/paddle/v2/fluid/tests/test_layers.py index 709abd6c6a4e0c2aa1b38a135d7424cd6886c966..b14198b231372c6e75434162e3a84be4c9890ece 100644 --- a/python/paddle/v2/fluid/tests/test_layers.py +++ b/python/paddle/v2/fluid/tests/test_layers.py @@ -17,8 +17,9 @@ import unittest import paddle.v2.fluid.layers as layers import paddle.v2.fluid.nets as nets -from paddle.v2.fluid.framework import Program, program_guard +from paddle.v2.fluid.framework import Program, program_guard, default_main_program from paddle.v2.fluid.param_attr import ParamAttr +import decorators class TestBook(unittest.TestCase): @@ -225,6 +226,41 @@ class TestBook(unittest.TestCase): self.assertIsNotNone(out) print(str(program)) + @decorators.prog_scope() + def test_nce(self): + window_size = 5 + words = [] + for i in xrange(window_size): + words.append( + layers.data( + name='word_{0}'.format(i), shape=[1], dtype='int64')) + + dict_size = 10000 + label_word = int(window_size / 2) + 1 + + embs = [] + for i in xrange(window_size): + if i == label_word: + continue + + emb = layers.embedding( + input=words[i], + size=[dict_size, 32], + param_attr='emb.w', + is_sparse=True) + + embs.append(emb) + + embs = layers.concat(input=embs, axis=1) + loss = layers.nce(input=embs, + label=words[label_word], + num_total_classes=dict_size, + param_attr='nce.w', + bias_attr='nce.b') + avg_loss = layers.mean(x=loss) + self.assertIsNotNone(avg_loss) + print(str(default_main_program())) + if __name__ == '__main__': unittest.main()