提交 25a01932 编写于 作者: W wanghaox

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into iou_sim

......@@ -18,6 +18,11 @@ dynamic_lstm
.. autofunction:: paddle.v2.fluid.layers.dynamic_lstm
:noindex:
dynamic_gru
-----------
.. autofunction:: paddle.v2.fluid.layers.dynamic_gru
:noindex:
data
----
.. autofunction:: paddle.v2.fluid.layers.data
......
......@@ -25,14 +25,14 @@
.. code-block:: bash
docker pull docker.paddlepaddle.org/paddle
docker pull docker.paddlepaddlehub.com/paddle
下载GPU版本(cuda8.0_cudnn5_avx_mkl)的Docker镜像:
.. code-block:: bash
docker pull paddlepaddle/paddle:latest-gpu
docker pull docker.paddlepaddle.org/paddle:latest-gpu
docker pull docker.paddlepaddlehub.com/paddle:latest-gpu
选择下载使用不同的BLAS库的Docker镜像:
......@@ -49,7 +49,7 @@
docker pull paddlepaddle/paddle:[tag]
# 比如:
docker pull docker.paddlepaddle.org/paddle:0.10.0-gpu
docker pull docker.paddlepaddlehub.com/paddle:0.11.0-gpu
.. _docker_run:
......
......@@ -26,14 +26,14 @@ For users in China, we provide a faster mirror:
.. code-block:: bash
docker pull docker.paddlepaddle.org/paddle
docker pull docker.paddlepaddlehub.com/paddle
Download GPU version (cuda8.0_cudnn5_avx_mkl) images:
.. code-block:: bash
docker pull paddlepaddle/paddle:latest-gpu
docker pull docker.paddlepaddle.org/paddle:latest-gpu
docker pull docker.paddlepaddlehub.com/paddle:latest-gpu
Choose between different BLAS version:
......@@ -53,7 +53,7 @@ and run:
docker pull paddlepaddle/paddle:[tag]
# i.e.
docker pull docker.paddlepaddle.org/paddle:0.10.0-gpu
docker pull docker.paddlepaddlehub.com/paddle:0.11.0-gpu
.. _docker_run:
......
......@@ -12,19 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
/*
Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <memory>
#include <string>
......
......@@ -21,8 +21,6 @@ namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
constexpr char kEPS = 1e-6;
class BipartiteMatchOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -46,6 +44,7 @@ class BipartiteMatchKernel : public framework::OpKernel<T> {
// The match_dist must be initialized to 0 at first.
void BipartiteMatch(const Tensor& dist, int* match_indices,
T* match_dist) const {
constexpr T kEPS = static_cast<T>(1e-6);
PADDLE_ENFORCE_EQ(dist.dims().size(), 2, "The rank of dist must be 2.");
int64_t row = dist.dims()[0];
int64_t col = dist.dims()[1];
......
......@@ -124,7 +124,8 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker {
"This attribute only be used in unitest. Classes "
"in this list wiil be used as negative classes "
"for every samples. Under normal conditions, "
"user should avoid setting this attribute.");
"user should avoid setting this attribute.")
.SetDefault({});
AddComment(R"DOC(
Compute and return the noise-contrastive estimation training loss.
See [Noise-contrastive estimation: A new estimation principle for unnormalized statistical models](http://www.jmlr.org/proceedings/papers/v9/gutmann10a/gutmann10a.pdf).
......
......@@ -197,7 +197,8 @@ class NCEGradKernel : public framework::OpKernel<T> {
// get d_x
auto d_x = context.Output<Tensor>(framework::GradVarName("Input"));
if (d_x != nullptr) {
d_x->mutable_data<T>(context.GetPlace());
auto* d_x_data = d_x->mutable_data<T>(context.GetPlace());
std::fill(d_x_data, d_x_data + d_x->numel(), 0.0);
auto d_x_matrix = EigenMatrix<T>::From(*d_x);
auto w_matrix = EigenMatrix<T>::From(*(context.Input<Tensor>("Weight")));
for (int64_t i = 0; i < sample_labels->numel(); ++i) {
......
......@@ -305,9 +305,9 @@ def get_dict(lang, dict_size, reverse=False):
dict_path = os.path.join(paddle.v2.dataset.common.DATA_HOME,
"wmt16/%s_%d.dict" % (lang, dict_size))
assert (os.path.exists(dict_path), "Word dictionary does not exist. "
"Please invoke paddle.dataset.wmt16.train/test/validation "
"first to build the dictionary.")
assert os.path.exists(dict_path), "Word dictionary does not exist. "
"Please invoke paddle.dataset.wmt16.train/test/validation first "
"to build the dictionary."
tar_file = os.path.join(paddle.v2.dataset.common.DATA_HOME, "wmt16.tar.gz")
return __load_dict(tar_file, dict_size, lang, reverse)
......
......@@ -19,12 +19,14 @@ from ..layer_helper import LayerHelper
from ..initializer import Normal, Constant
from ..framework import Variable
from ..param_attr import ParamAttr
from layer_function_generator import autodoc
from tensor import concat
__all__ = [
'fc',
'embedding',
'dynamic_lstm',
'dynamic_gru',
'gru_unit',
'linear_chain_crf',
'crf_decoding',
......@@ -57,6 +59,7 @@ __all__ = [
'warpctc',
'sequence_reshape',
'transpose',
'nce',
]
......@@ -366,6 +369,113 @@ def dynamic_lstm(input,
return hidden, cell
def dynamic_gru(input,
size,
param_attr=None,
bias_attr=None,
is_reverse=False,
gate_activation='sigmoid',
candidate_activation='tanh',
h_0=None):
"""
**Dynamic GRU Layer**
Refer to `Empirical Evaluation of Gated Recurrent Neural Networks on
Sequence Modeling <https://arxiv.org/abs/1412.3555>`_
The formula is as follows:
.. math::
u_t & = act_g(W_{ux}x_{t} + W_{uh}h_{t-1} + b_u)
r_t & = act_g(W_{rx}x_{t} + W_{rh}h_{t-1} + b_r)
\\tilde{h_t} & = act_c(W_{cx}x_{t} + W_{ch}(r_t \odot h_{t-1}) + b_c)
h_t & = (1-u_t) \odot h_{t-1} + u_t \odot \\tilde{h_t}
The :math:`\odot` is the element-wise product of the vectors. :math:`act_g`
is the update gate and reset gate activation function and :math:`sigmoid`
is usually used for it. :math:`act_c` is the activation function for
candidate hidden state and :math:`tanh` is usually used for it.
Note that these :math:`W_{ux}x_{t}, W_{rx}x_{t}, W_{cx}x_{t}` operations on
the input :math:`x_{t}` are NOT included in this operator. Users can choose
to use fully-connect layer before GRU layer.
Args:
input(Variable): The input of dynamic_gru layer, which supports
variable-time length input sequence. The underlying tensor in this
Variable is a matrix with shape :math:`(T \\times 3D)`, where
:math:`T` is the total time steps in this mini-batch, :math:`D`
is the hidden size.
size(int): The dimension of the gru cell.
param_attr(ParamAttr|None): The parameter attribute for the learnable
hidden-hidden weight matrix. Note:
- The shape of the weight matrix is :math:`(T \\times 3D)`, where
:math:`D` is the hidden size.
- All elements in the weight matrix can be divided into two parts.
The first part are weights of the update gate and reset gate with
shape :math:`(D \\times 2D)`, and the second part are weights for
candidate hidden state with shape :math:`(D \\times D)`.
bias_attr(ParamAttr): The parameter attribute for learnable the
hidden-hidden bias.
is_reverse(bool): Whether to compute reversed GRU, default
:attr:`False`.
gate_activation(str): The activation for update gate and reset gate.
Choices = ["sigmoid", "tanh", "relu", "identity"], default "sigmoid".
activation(str): The activation for candidate hidden state.
Choices = ["sigmoid", "tanh", "relu", "identity"], default "tanh".
Returns:
Variable: The hidden state of GRU. The shape is (T \\times D), and lod \
is the same with the input.
Examples:
.. code-block:: python
hidden_dim = 512
x = fluid.layers.fc(input=data, size=hidden_dim * 3)
hidden = fluid.layers.dynamic_gru(input=x, dim=hidden_dim)
"""
helper = LayerHelper('gru', **locals())
dtype = helper.input_dtype()
weight = helper.create_parameter(
attr=helper.param_attr, shape=[size, 3 * size], dtype=dtype)
bias = helper.create_parameter(
attr=helper.bias_attr, shape=[1, 3 * size], dtype=dtype, is_bias=True)
inputs = {'Input': input, 'Weight': weight, 'Bias': bias}
if h_0 != None:
assert h_0.shape == (
size, size), 'The shape of h0 should be(%d, %d)' % (size, size)
inputs['h0'] = h_0
hidden = helper.create_tmp_variable(dtype)
batch_gate = helper.create_tmp_variable(dtype)
batch_reset_hidden_prev = helper.create_tmp_variable(dtype)
batch_hidden = helper.create_tmp_variable(dtype)
helper.append_op(
type='gru',
inputs=inputs,
outputs={
'Hidden': hidden,
'BatchGate': batch_gate,
'BatchResetHiddenPrev': batch_reset_hidden_prev,
'BatchHidden': batch_hidden
},
attrs={
'is_reverse': is_reverse,
'gate_activation': gate_activation,
'activation': candidate_activation
})
return hidden
def gru_unit(input,
hidden,
size,
......@@ -2190,6 +2300,61 @@ def sequence_reshape(input, new_dim):
return out
@autodoc()
def nce(input,
label,
num_total_classes,
sample_weight=None,
param_attr=None,
bias_attr=None,
num_neg_samples=None):
helper = LayerHelper('nce', **locals())
assert isinstance(input, Variable)
dim = input.shape[1]
assert isinstance(label, Variable)
num_true_class = label.shape[1]
w = helper.create_parameter(
attr=helper.param_attr,
shape=[num_total_classes, dim],
is_bias=False,
dtype=input.dtype)
b = helper.create_parameter(
attr=helper.bias_attr,
shape=[num_total_classes, 1],
is_bias=True,
dtype=input.dtype)
cost = helper.create_tmp_variable(dtype=input.dtype)
sample_logits = helper.create_tmp_variable(dtype=input.dtype)
sample_labels = helper.create_tmp_variable(dtype=label.dtype)
if num_neg_samples is None:
num_neg_samples = 10
else:
num_neg_samples = int(num_neg_samples)
attrs = {
'num_total_classes': int(num_total_classes),
'num_neg_samples': num_neg_samples
}
helper.append_op(
type='nce',
inputs={
'Input': input,
'Label': label,
'Weight': w,
'Bias': b,
'SampleWeight': sample_weight if sample_weight is not None else []
},
outputs={
'Cost': cost,
'SampleLogits': sample_logits,
'SampleLabels': sample_labels
},
attrs=attrs)
return cost / (num_neg_samples + 1)
def transpose(x, perm, name=None):
"""
**transpose Layer**
......
......@@ -16,13 +16,13 @@ import numpy as np
from op_test import OpTest
def bipartite_match(distance, match_indices, match_dis):
def bipartite_match(distance, match_indices, match_dist):
"""Bipartite Matching algorithm.
Arg:
distance (numpy.array) : The distance of two entries with shape [M, N].
match_indices (numpy.array): the matched indices from column to row
with shape [1, N], it must be initialized to -1.
match_dis (numpy.array): The matched distance from column to row
match_dist (numpy.array): The matched distance from column to row
with shape [1, N], it must be initialized to 0.
"""
match_pair = []
......@@ -36,13 +36,13 @@ def bipartite_match(distance, match_indices, match_dis):
row_indices = -1 * np.ones((row, ), dtype=np.int)
idx = 0
for i, j, dis in match_sorted:
for i, j, dist in match_sorted:
if idx >= row:
break
if match_indices[j] == -1 and row_indices[i] == -1 and dis > 0:
if match_indices[j] == -1 and row_indices[i] == -1 and dist > 0:
match_indices[j] = i
row_indices[i] = j
match_dis[j] = dis
match_dist[j] = dist
idx += 1
......@@ -55,24 +55,24 @@ def batch_bipartite_match(distance, lod):
n = len(lod) - 1
m = distance.shape[1]
match_indices = -1 * np.ones((n, m), dtype=np.int)
match_dis = np.zeros((n, m), dtype=np.float32)
match_dist = np.zeros((n, m), dtype=np.float32)
for i in range(len(lod) - 1):
bipartite_match(distance[lod[i]:lod[i + 1], :], match_indices[i, :],
match_dis[i, :])
return match_indices, match_dis
match_dist[i, :])
return match_indices, match_dist
class TestBipartiteMatchOpForWithLoD(OpTest):
def setUp(self):
self.op_type = 'bipartite_match'
lod = [[0, 5, 11, 23]]
dis = np.random.random((23, 217)).astype('float32')
match_indices, match_dis = batch_bipartite_match(dis, lod[0])
dist = np.random.random((23, 217)).astype('float32')
match_indices, match_dist = batch_bipartite_match(dist, lod[0])
self.inputs = {'DistMat': (dis, lod)}
self.inputs = {'DistMat': (dist, lod)}
self.outputs = {
'ColToRowMatchIndices': (match_indices),
'ColToRowMatchDis': (match_dis),
'ColToRowMatchDis': (match_dist),
}
def test_check_output(self):
......@@ -83,13 +83,13 @@ class TestBipartiteMatchOpWithoutLoD(OpTest):
def setUp(self):
self.op_type = 'bipartite_match'
lod = [[0, 8]]
dis = np.random.random((8, 17)).astype('float32')
match_indices, match_dis = batch_bipartite_match(dis, lod[0])
dist = np.random.random((8, 17)).astype('float32')
match_indices, match_dist = batch_bipartite_match(dist, lod[0])
self.inputs = {'DistMat': dis}
self.inputs = {'DistMat': dist}
self.outputs = {
'ColToRowMatchIndices': (match_indices),
'ColToRowMatchDis': (match_dis),
'ColToRowMatchIndices': match_indices,
'ColToRowMatchDis': match_dist,
}
def test_check_output(self):
......
......@@ -17,8 +17,9 @@ import unittest
import paddle.v2.fluid.layers as layers
import paddle.v2.fluid.nets as nets
from paddle.v2.fluid.framework import Program, program_guard
from paddle.v2.fluid.framework import Program, program_guard, default_main_program
from paddle.v2.fluid.param_attr import ParamAttr
import decorators
class TestBook(unittest.TestCase):
......@@ -225,6 +226,41 @@ class TestBook(unittest.TestCase):
self.assertIsNotNone(out)
print(str(program))
@decorators.prog_scope()
def test_nce(self):
window_size = 5
words = []
for i in xrange(window_size):
words.append(
layers.data(
name='word_{0}'.format(i), shape=[1], dtype='int64'))
dict_size = 10000
label_word = int(window_size / 2) + 1
embs = []
for i in xrange(window_size):
if i == label_word:
continue
emb = layers.embedding(
input=words[i],
size=[dict_size, 32],
param_attr='emb.w',
is_sparse=True)
embs.append(emb)
embs = layers.concat(input=embs, axis=1)
loss = layers.nce(input=embs,
label=words[label_word],
num_total_classes=dict_size,
param_attr='nce.w',
bias_attr='nce.b')
avg_loss = layers.mean(x=loss)
self.assertIsNotNone(avg_loss)
print(str(default_main_program()))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册