diff --git a/doc/getstarted/build_and_install/docker_install_cn.rst b/doc/getstarted/build_and_install/docker_install_cn.rst
index bae42593ddc6f7a7eb47d603752ad6efa9820b45..98fada7bdb46f4dd2927d6f93bcbcebbe7d18604 100644
--- a/doc/getstarted/build_and_install/docker_install_cn.rst
+++ b/doc/getstarted/build_and_install/docker_install_cn.rst
@@ -25,14 +25,14 @@
 
   .. code-block:: bash
 
-     docker pull docker.paddlepaddle.org/paddle
+     docker pull docker.paddlepaddlehub.com/paddle
 
 下载GPU版本(cuda8.0_cudnn5_avx_mkl)的Docker镜像:
 
   .. code-block:: bash
 
      docker pull paddlepaddle/paddle:latest-gpu
-     docker pull docker.paddlepaddle.org/paddle:latest-gpu
+     docker pull docker.paddlepaddlehub.com/paddle:latest-gpu
 
 选择下载使用不同的BLAS库的Docker镜像:
 
@@ -49,7 +49,7 @@
 
      docker pull paddlepaddle/paddle:[tag]
      # 比如:
-     docker pull docker.paddlepaddle.org/paddle:0.10.0-gpu
+     docker pull docker.paddlepaddlehub.com/paddle:0.11.0-gpu
 
 .. _docker_run:
 
diff --git a/doc/getstarted/build_and_install/docker_install_en.rst b/doc/getstarted/build_and_install/docker_install_en.rst
index 56a7c68e4d39c45249fa55a964dc48b7081596a6..b1d0890b4cdddb77114a80276130afd07c22d270 100644
--- a/doc/getstarted/build_and_install/docker_install_en.rst
+++ b/doc/getstarted/build_and_install/docker_install_en.rst
@@ -26,14 +26,14 @@ For users in China, we provide a faster mirror:
 
   .. code-block:: bash
 
-     docker pull docker.paddlepaddle.org/paddle
+     docker pull docker.paddlepaddlehub.com/paddle
 
 Download GPU version (cuda8.0_cudnn5_avx_mkl) images:
 
   .. code-block:: bash
 
      docker pull paddlepaddle/paddle:latest-gpu
-     docker pull docker.paddlepaddle.org/paddle:latest-gpu
+     docker pull docker.paddlepaddlehub.com/paddle:latest-gpu
 
 Choose between different BLAS version:
 
@@ -53,7 +53,7 @@ and run:
 
      docker pull paddlepaddle/paddle:[tag]
      # i.e.
-     docker pull docker.paddlepaddle.org/paddle:0.10.0-gpu
+     docker pull docker.paddlepaddlehub.com/paddle:0.11.0-gpu
 
 .. _docker_run:
 
diff --git a/paddle/framework/variable_test.cc b/paddle/framework/variable_test.cc
index e4732d9718e2b46a068963d44c4c1e04024f2330..e5585c8724d712e273d086001b6cbc3d59c46ebe 100644
--- a/paddle/framework/variable_test.cc
+++ b/paddle/framework/variable_test.cc
@@ -12,19 +12,6 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-/*
-  Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
-  Licensed under the Apache License, Version 2.0 (the "License");
-  you may not use this file except in compliance with the License.
-  You may obtain a copy of the License at
-  http://www.apache.org/licenses/LICENSE-2.0
-  Unless required by applicable law or agreed to in writing, software
-  distributed under the License is distributed on an "AS IS" BASIS,
-  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-  See the License for the specific language governing permissions and
-  limitations under the License.
-*/
-
 #include <memory>
 #include <string>
 
diff --git a/paddle/operators/bipartite_match_op.cc b/paddle/operators/bipartite_match_op.cc
index b0f7376d272a66e0b01d6b3f7e546372397772f7..83c8778fe4cec4d9d80de691e117a39fdd92f494 100644
--- a/paddle/operators/bipartite_match_op.cc
+++ b/paddle/operators/bipartite_match_op.cc
@@ -21,8 +21,6 @@ namespace operators {
 using Tensor = framework::Tensor;
 using LoDTensor = framework::LoDTensor;
 
-constexpr char kEPS = 1e-6;
-
 class BipartiteMatchOp : public framework::OperatorWithKernel {
  public:
   using framework::OperatorWithKernel::OperatorWithKernel;
@@ -46,6 +44,7 @@ class BipartiteMatchKernel : public framework::OpKernel<T> {
   // The match_dist must be initialized to 0 at first.
   void BipartiteMatch(const Tensor& dist, int* match_indices,
                       T* match_dist) const {
+    constexpr T kEPS = static_cast<T>(1e-6);
     PADDLE_ENFORCE_EQ(dist.dims().size(), 2, "The rank of dist must be 2.");
     int64_t row = dist.dims()[0];
     int64_t col = dist.dims()[1];
diff --git a/paddle/operators/conv_transpose_op.cc b/paddle/operators/conv_transpose_op.cc
index a2382a7e42eb9c5c6a8f13265b0e6173e6b05f76..089290a506db10f676c8d7eb92663d2cb56892af 100644
--- a/paddle/operators/conv_transpose_op.cc
+++ b/paddle/operators/conv_transpose_op.cc
@@ -160,8 +160,8 @@ Example:
        Output shape: $(N, C_{out}, H_{out}, W_{out})$
   Where
   $$
-       H_{out} = (H_{in} - 1) * strides[0] - 2 * paddings[0] + H_f \\
-       W_{out} = (W_{in} - 1) * strides[1] - 2 * paddings[1] + W_f
+       H_{out} = (H_{in} - 1) * strides[0] - 2 * paddings[0] + dilations[0] * (H_f - 1) + 1 \\
+       W_{out} = (W_{in} - 1) * strides[1] - 2 * paddings[1] + dilations[1] * (W_f - 1) + 1
   $$
 )DOC");
 }
@@ -249,9 +249,9 @@ Example:
        Output shape: $(N, C_{out}, D_{out}, H_{out}, W_{out})$
   Where
   $$
-       D_{out} = (D_{in} - 1) * strides[0] - 2 * paddings[0] + D_f \\
-       H_{out} = (H_{in} - 1) * strides[1] - 2 * paddings[1] + H_f \\
-       W_{out} = (W_{in} - 1) * strides[2] - 2 * paddings[2] + W_f
+       D_{out} = (D_{in} - 1) * strides[0] - 2 * paddings[0] + dilations[0] * (D_f - 1) + 1 \\
+       H_{out} = (H_{in} - 1) * strides[1] - 2 * paddings[1] + dilations[1] * (H_f - 1) + 1 \\
+       W_{out} = (W_{in} - 1) * strides[2] - 2 * paddings[2] + dilations[2] * (W_f - 1) + 1
   $$
 )DOC");
 }
diff --git a/paddle/operators/conv_transpose_op.h b/paddle/operators/conv_transpose_op.h
index a42ade41b165d1bfa00d2db0e45d40cf5d7b00bc..8c0d57afcd21d8622fb6316f7b988d79a45b57fe 100644
--- a/paddle/operators/conv_transpose_op.h
+++ b/paddle/operators/conv_transpose_op.h
@@ -141,9 +141,9 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
       if (data_dim == 2U) {
         // col2im: col_matrix -> dy
         // from (c * k_h * k_w, h * w) to (c, o_h, o_w)
-        col2im(dev_ctx, col, std::vector<int>{dilations[0], dilations[1]},
-               strides, std::vector<int>{paddings[0], paddings[1], paddings[0],
-                                         paddings[1]},
+        col2im(dev_ctx, col, dilations, strides,
+               std::vector<int>{paddings[0], paddings[1], paddings[0],
+                                paddings[1]},
                &output_batch);
       } else if (data_dim == 3U) {
         // col2vol: col_matrix -> dy
@@ -247,8 +247,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
         if (data_dim == 2U) {
           // im2col: dy -> col matrix
           // from (c, o_h, o_w) to (c * k_h * k_w, h * w)
-          im2col(dev_ctx, output_grad_batch,
-                 std::vector<int>{dilations[0], dilations[1]}, strides,
+          im2col(dev_ctx, output_grad_batch, dilations, strides,
                  std::vector<int>{paddings[0], paddings[1], paddings[0],
                                   paddings[1]},
                  &col);
diff --git a/paddle/operators/nce_op.cc b/paddle/operators/nce_op.cc
index 84ba3ead2b52547b989a4541f31ea31ffcce6c63..994ddf717e7a5b883d8071c6a47da0b4b4074f2e 100644
--- a/paddle/operators/nce_op.cc
+++ b/paddle/operators/nce_op.cc
@@ -124,7 +124,8 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker {
                               "This attribute only be used in unitest. Classes "
                               "in this list wiil be used as negative classes "
                               "for every samples. Under normal conditions, "
-                              "user should avoid setting this attribute.");
+                              "user should avoid setting this attribute.")
+        .SetDefault({});
     AddComment(R"DOC(
 Compute and return the noise-contrastive estimation training loss.
 See [Noise-contrastive estimation: A new estimation principle for unnormalized statistical models](http://www.jmlr.org/proceedings/papers/v9/gutmann10a/gutmann10a.pdf).
diff --git a/paddle/operators/nce_op.h b/paddle/operators/nce_op.h
index e6b496f7896dcb412be8ff096fdccb2f0b682369..86fa13a649ce7fdcaad64e2609ceea2fb4d7e072 100644
--- a/paddle/operators/nce_op.h
+++ b/paddle/operators/nce_op.h
@@ -197,7 +197,8 @@ class NCEGradKernel : public framework::OpKernel<T> {
     // get d_x
     auto d_x = context.Output<Tensor>(framework::GradVarName("Input"));
     if (d_x != nullptr) {
-      d_x->mutable_data<T>(context.GetPlace());
+      auto* d_x_data = d_x->mutable_data<T>(context.GetPlace());
+      std::fill(d_x_data, d_x_data + d_x->numel(), 0.0);
       auto d_x_matrix = EigenMatrix<T>::From(*d_x);
       auto w_matrix = EigenMatrix<T>::From(*(context.Input<Tensor>("Weight")));
       for (int64_t i = 0; i < sample_labels->numel(); ++i) {
diff --git a/python/paddle/v2/dataset/wmt16.py b/python/paddle/v2/dataset/wmt16.py
index e2f463be2f7bcd667855f64206d78f387e92ef33..c8818f715beadd9499ae588f2c19a57fbf26f372 100644
--- a/python/paddle/v2/dataset/wmt16.py
+++ b/python/paddle/v2/dataset/wmt16.py
@@ -305,9 +305,9 @@ def get_dict(lang, dict_size, reverse=False):
 
     dict_path = os.path.join(paddle.v2.dataset.common.DATA_HOME,
                              "wmt16/%s_%d.dict" % (lang, dict_size))
-    assert (os.path.exists(dict_path), "Word dictionary does not exist. "
-            "Please invoke paddle.dataset.wmt16.train/test/validation "
-            "first to build the dictionary.")
+    assert os.path.exists(dict_path), "Word dictionary does not exist. "
+    "Please invoke paddle.dataset.wmt16.train/test/validation first "
+    "to build the dictionary."
     tar_file = os.path.join(paddle.v2.dataset.common.DATA_HOME, "wmt16.tar.gz")
     return __load_dict(tar_file, dict_size, lang, reverse)
 
diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py
index bd0404c94d326bd1fff56fec132301c5eae0f10f..99ef6932c9f42140b71cfb63c369e8ea7eac529d 100644
--- a/python/paddle/v2/fluid/layers/nn.py
+++ b/python/paddle/v2/fluid/layers/nn.py
@@ -19,6 +19,7 @@ from ..layer_helper import LayerHelper
 from ..initializer import Normal, Constant
 from ..framework import Variable
 from ..param_attr import ParamAttr
+from layer_function_generator import autodoc
 from tensor import concat
 
 __all__ = [
@@ -58,6 +59,7 @@ __all__ = [
     'sequence_reshape',
     'transpose',
     'im2sequence',
+    'nce',
 ]
 
 
@@ -791,8 +793,8 @@ def conv2d(input,
     <http://ufldl.stanford.edu/tutorial/supervised/FeatureExtractionUsingConvolution/>`_ .
     If bias attribution and activation type are provided, bias is added to the output of the convolution,
     and the corresponding activation function is applied to the final result.
-    For each input :math:`X`, the equation is:
 
+    For each input :math:`X`, the equation is:
 
     .. math::
 
@@ -800,51 +802,54 @@ def conv2d(input,
 
     In the above equation:
 
-        * :math:`X`: Input value, a tensor with NCHW format.
-        * :math:`W`: Filter value, a tensor with MCHW format.
-        * :math:`\\ast`: Convolution operation.
-        * :math:`b`: Bias value, a 2-D tensor with shape [M, 1].
-        * :math:`\\sigma`: Activation function.
-        * :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different.
+    * :math:`X`: Input value, a tensor with NCHW format.
+    * :math:`W`: Filter value, a tensor with MCHW format.
+    * :math:`\\ast`: Convolution operation.
+    * :math:`b`: Bias value, a 2-D tensor with shape [M, 1].
+    * :math:`\\sigma`: Activation function.
+    * :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different.
 
     Example:
 
-        Input:
-            Input shape: $(N, C_{in}, H_{in}, W_{in})$
+        - Input:
+
+          Input shape: $(N, C_{in}, H_{in}, W_{in})$
+
+          Filter shape: $(C_{out}, C_{in}, H_f, W_f)$
 
-            Filter shape: $(C_{out}, C_{in}, H_f, W_f)$
+        - Output:
+          Output shape: $(N, C_{out}, H_{out}, W_{out})$
 
-        Output:
-            Output shape: $(N, C_{out}, H_{out}, W_{out})$
         Where
-    .. math::
+
+        .. math::
 
         H_{out}&= \\frac{(H_{in} + 2 * paddings[0] - (dilations[0] * (H_f - 1) + 1))}{strides[0]} + 1 \\\\
         W_{out}&= \\frac{(W_{in} + 2 * paddings[1] - (dilations[1] * (W_f - 1) + 1))}{strides[1]} + 1
 
     Args:
-        input(Variable): The input image with [N, C, H, W] format.
-        num_filters(int): The number of filter. It is as same as the output
-            image channel.
-        filter_size(int|tuple|None): The filter size. If filter_size is a tuple,
-            it must contain two integers, (filter_size_H, filter_size_W).
-            Otherwise, the filter will be a square.
-        stride(int|tuple): The stride size. If stride is a tuple, it must
-            contain two integers, (stride_H, stride_W). Otherwise, the
-            stride_H = stride_W = stride. Default: stride = 1.
-        padding(int|tuple): The padding size. If padding is a tuple, it must
-            contain two integers, (padding_H, padding_W). Otherwise, the
-            padding_H = padding_W = padding. Default: padding = 0.
-        groups(int): The groups number of the Conv2d Layer. According to grouped
-            convolution in Alex Krizhevsky's Deep CNN paper: when group=2,
-            the first half of the filters is only connected to the first half
-            of the input channels, while the second half of the filters is only
-            connected to the second half of the input channels. Default: groups=1
-        param_attr(ParamAttr): The parameters to the Conv2d Layer. Default: None
-        bias_attr(ParamAttr): Bias parameter for the Conv2d layer. Default: None
-        use_cudnn(bool): Use cudnn kernel or not, it is valid only when the cudnn
-            library is installed. Default: True
-        act(str): Activation type. Default: None
+       input(Variable): The input image with [N, C, H, W] format.
+       num_filters(int): The number of filter. It is as same as the output
+           image channel.
+       filter_size(int|tuple|None): The filter size. If filter_size is a tuple,
+           it must contain two integers, (filter_size_H, filter_size_W).
+           Otherwise, the filter will be a square.
+       stride(int|tuple): The stride size. If stride is a tuple, it must
+           contain two integers, (stride_H, stride_W). Otherwise, the
+           stride_H = stride_W = stride. Default: stride = 1.
+       padding(int|tuple): The padding size. If padding is a tuple, it must
+           contain two integers, (padding_H, padding_W). Otherwise, the
+           padding_H = padding_W = padding. Default: padding = 0.
+       groups(int): The groups number of the Conv2d Layer. According to grouped
+           convolution in Alex Krizhevsky's Deep CNN paper: when group=2,
+           the first half of the filters is only connected to the first half
+           of the input channels, while the second half of the filters is only
+           connected to the second half of the input channels. Default: groups=1
+       param_attr(ParamAttr): The parameters to the Conv2d Layer. Default: None
+       bias_attr(ParamAttr): Bias parameter for the Conv2d layer. Default: None
+       use_cudnn(bool): Use cudnn kernel or not, it is valid only when the cudnn
+           library is installed. Default: True
+       act(str): Activation type. Default: None
 
     Returns:
         Variable: The tensor variable storing the convolution and \
@@ -859,7 +864,6 @@ def conv2d(input,
           data = fluid.layers.data(name='data', shape=[3, 32, 32], dtype='float32')
           conv2d = fluid.layers.conv2d(input=data, num_filters=2, filter_size=3, act="relu")
     """
-
     if stride is None:
         stride = [1, 1]
     helper = LayerHelper('conv2d', **locals())
@@ -1213,38 +1217,85 @@ def conv2d_transpose(input,
                      use_cudnn=True,
                      name=None):
     """
-    The transpose of conv2d layer.
+    **Convlution2D transpose layer**
+
+    The convolution2D transpose layer calculates the output based on the input,
+    filter, and dilations, strides, paddings. Input(Input) and output(Output)
+    are in NCHW format. Where N is batch size, C is the number of channels,
+    H is the height of the feature, and W is the width of the feature.
+    Parameters(dilations, strides, paddings) are two elements. These two elements
+    represent height and width, respectively. The details of convolution transpose
+    layer, please refer to the following explanation and references `therein <http://www.matthewzeiler.com/wp-content/uploads/2017/07/cvpr2010.pdf>`_.
+
+    For each input :math:`X`, the equation is:
+
+    .. math::
+
+        Out = W \\ast X
+
+    In the above equation:
+
+    * :math:`X`: Input value, a tensor with NCHW format.
+    * :math:`W`: Filter value, a tensor with MCHW format.
+    * :math:`\\ast` : Convolution transpose operation.
+    * :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different.
+
+    Example:
+
+        - Input:
+
+          Input shape: $(N, C_{in}, H_{in}, W_{in})$
+
+          Filter shape: $(C_{in}, C_{out}, H_f, W_f)$
+
+        - Output:
+
+          Output shape: $(N, C_{out}, H_{out}, W_{out})$
 
-    This layer is also known as deconvolution layer.
+        Where
+
+        .. math::
+
+           H_{out} &= (H_{in} - 1) * strides[0] - 2 * paddings[0] + dilations[0] * (H_f - 1) + 1 \\\\
+           W_{out} &= (W_{in} - 1) * strides[1] - 2 * paddings[1] + dilations[1] * (W_f - 1) + 1
 
     Args:
-        input(Variable): The input image with [N, C, H, W] format.
-        num_filters(int): The number of filter. It is as same as the output
-            image channel.
-        output_size(int|tuple|None): The output image size. If output size is a
-            tuple, it must contain two integers, (image_H, image_W). This
-            parameter only works when filter_size is None.
-        filter_size(int|tuple|None): The filter size. If filter_size is a tuple,
-            it must contain two integers, (filter_size_H, filter_size_W).
-            Otherwise, the filter will be a square.  None if use output size to
-            calculate filter_size
-        padding(int|tuple): The padding size. If padding is a tuple, it must
-            contain two integers, (padding_H, padding_W). Otherwise, the
-            padding_H = padding_W = padding.
-        stride(int|tuple): The stride size. If stride is a tuple, it must
-            contain two integers, (stride_H, stride_W). Otherwise, the
-            stride_H = stride_W = stride.
-        dilation(int|tuple): The dilation size. If dilation is a tuple, it must
-            contain two integers, (dilation_H, dilation_W). Otherwise, the
-            dilation_H = dilation_W = dilation.
-        param_attr: Parameter Attribute.
-        use_cudnn(bool): Use cudnn kernel or not, it is valid only when the cudnn
-            library is installed. Default: True
-        name(str|None): A name for this layer(optional). If set None, the layer
-                       will be named automatically.
+       input(Variable): The input image with [N, C, H, W] format.
+       num_filters(int): The number of the filter. It is as same as the output
+           image channel.
+       output_size(int|tuple|None): The output image size. If output size is a
+           tuple, it must contain two integers, (image_H, image_W). This
+           parameter only works when filter_size is None.
+       filter_size(int|tuple|None): The filter size. If filter_size is a tuple,
+           it must contain two integers, (filter_size_H, filter_size_W).
+           Otherwise, the filter will be a square. None if use output size to
+           calculate filter_size.
+       padding(int|tuple): The padding size. If padding is a tuple, it must
+           contain two integers, (padding_H, padding_W). Otherwise, the
+           padding_H = padding_W = padding. Default: padding = 0.
+       stride(int|tuple): The stride size. If stride is a tuple, it must
+           contain two integers, (stride_H, stride_W). Otherwise, the
+           stride_H = stride_W = stride. Default: stride = 1.
+       dilation(int|tuple): The dilation size. If dilation is a tuple, it must
+           contain two integers, (dilation_H, dilation_W). Otherwise, the
+           dilation_H = dilation_W = dilation. Default: dilation = 1.
+       param_attr(ParamAttr): The parameters to the Conv2d_transpose Layer. Default: None
+       use_cudnn(bool): Use cudnn kernel or not, it is valid only when the cudnn
+           library is installed. Default: True
+       name(str|None): A name for this layer(optional). If set None, the layer
+           will be named automatically.
 
     Returns:
-        Variable: Output image.
+       Variable: The tensor variable storing the convolution transpose result.
+
+    Raises:
+       ValueError: If the shapes of input, filter_size, stride, padding and groups mismatch.
+
+    Examples:
+       .. code-block:: python
+
+          data = fluid.layers.data(name='data', shape=[3, 32, 32], dtype='float32')
+          conv2d_transpose = fluid.layers.conv2d_transpose(input=data, num_filters=2, filter_size=3)
     """
     helper = LayerHelper("conv2d_transpose", **locals())
     if not isinstance(input, Variable):
@@ -2142,6 +2193,61 @@ def sequence_reshape(input, new_dim):
     return out
 
 
+@autodoc()
+def nce(input,
+        label,
+        num_total_classes,
+        sample_weight=None,
+        param_attr=None,
+        bias_attr=None,
+        num_neg_samples=None):
+    helper = LayerHelper('nce', **locals())
+    assert isinstance(input, Variable)
+    dim = input.shape[1]
+    assert isinstance(label, Variable)
+    num_true_class = label.shape[1]
+    w = helper.create_parameter(
+        attr=helper.param_attr,
+        shape=[num_total_classes, dim],
+        is_bias=False,
+        dtype=input.dtype)
+    b = helper.create_parameter(
+        attr=helper.bias_attr,
+        shape=[num_total_classes, 1],
+        is_bias=True,
+        dtype=input.dtype)
+    cost = helper.create_tmp_variable(dtype=input.dtype)
+    sample_logits = helper.create_tmp_variable(dtype=input.dtype)
+    sample_labels = helper.create_tmp_variable(dtype=label.dtype)
+
+    if num_neg_samples is None:
+        num_neg_samples = 10
+    else:
+        num_neg_samples = int(num_neg_samples)
+
+    attrs = {
+        'num_total_classes': int(num_total_classes),
+        'num_neg_samples': num_neg_samples
+    }
+
+    helper.append_op(
+        type='nce',
+        inputs={
+            'Input': input,
+            'Label': label,
+            'Weight': w,
+            'Bias': b,
+            'SampleWeight': sample_weight if sample_weight is not None else []
+        },
+        outputs={
+            'Cost': cost,
+            'SampleLogits': sample_logits,
+            'SampleLabels': sample_labels
+        },
+        attrs=attrs)
+    return cost / (num_neg_samples + 1)
+
+
 def transpose(x, perm, name=None):
     """
     **transpose Layer**
diff --git a/python/paddle/v2/fluid/tests/test_bipartite_match_op.py b/python/paddle/v2/fluid/tests/test_bipartite_match_op.py
index 34101b1da46d46d0e7a995ba80d8644dc586065d..74138298978c7c18936f53761b313887f07aea81 100644
--- a/python/paddle/v2/fluid/tests/test_bipartite_match_op.py
+++ b/python/paddle/v2/fluid/tests/test_bipartite_match_op.py
@@ -16,13 +16,13 @@ import numpy as np
 from op_test import OpTest
 
 
-def bipartite_match(distance, match_indices, match_dis):
+def bipartite_match(distance, match_indices, match_dist):
     """Bipartite Matching algorithm.
     Arg:
         distance (numpy.array) : The distance of two entries with shape [M, N].
         match_indices (numpy.array): the matched indices from column to row
             with shape [1, N], it must be initialized to -1.
-        match_dis (numpy.array): The matched distance from column to row
+        match_dist (numpy.array): The matched distance from column to row
             with shape [1, N], it must be initialized to 0.
     """
     match_pair = []
@@ -36,13 +36,13 @@ def bipartite_match(distance, match_indices, match_dis):
     row_indices = -1 * np.ones((row, ), dtype=np.int)
 
     idx = 0
-    for i, j, dis in match_sorted:
+    for i, j, dist in match_sorted:
         if idx >= row:
             break
-        if match_indices[j] == -1 and row_indices[i] == -1 and dis > 0:
+        if match_indices[j] == -1 and row_indices[i] == -1 and dist > 0:
             match_indices[j] = i
             row_indices[i] = j
-            match_dis[j] = dis
+            match_dist[j] = dist
             idx += 1
 
 
@@ -55,24 +55,24 @@ def batch_bipartite_match(distance, lod):
     n = len(lod) - 1
     m = distance.shape[1]
     match_indices = -1 * np.ones((n, m), dtype=np.int)
-    match_dis = np.zeros((n, m), dtype=np.float32)
+    match_dist = np.zeros((n, m), dtype=np.float32)
     for i in range(len(lod) - 1):
         bipartite_match(distance[lod[i]:lod[i + 1], :], match_indices[i, :],
-                        match_dis[i, :])
-    return match_indices, match_dis
+                        match_dist[i, :])
+    return match_indices, match_dist
 
 
 class TestBipartiteMatchOpForWithLoD(OpTest):
     def setUp(self):
         self.op_type = 'bipartite_match'
         lod = [[0, 5, 11, 23]]
-        dis = np.random.random((23, 217)).astype('float32')
-        match_indices, match_dis = batch_bipartite_match(dis, lod[0])
+        dist = np.random.random((23, 217)).astype('float32')
+        match_indices, match_dist = batch_bipartite_match(dist, lod[0])
 
-        self.inputs = {'DistMat': (dis, lod)}
+        self.inputs = {'DistMat': (dist, lod)}
         self.outputs = {
             'ColToRowMatchIndices': (match_indices),
-            'ColToRowMatchDis': (match_dis),
+            'ColToRowMatchDis': (match_dist),
         }
 
     def test_check_output(self):
@@ -83,13 +83,13 @@ class TestBipartiteMatchOpWithoutLoD(OpTest):
     def setUp(self):
         self.op_type = 'bipartite_match'
         lod = [[0, 8]]
-        dis = np.random.random((8, 17)).astype('float32')
-        match_indices, match_dis = batch_bipartite_match(dis, lod[0])
+        dist = np.random.random((8, 17)).astype('float32')
+        match_indices, match_dist = batch_bipartite_match(dist, lod[0])
 
-        self.inputs = {'DistMat': dis}
+        self.inputs = {'DistMat': dist}
         self.outputs = {
-            'ColToRowMatchIndices': (match_indices),
-            'ColToRowMatchDis': (match_dis),
+            'ColToRowMatchIndices': match_indices,
+            'ColToRowMatchDis': match_dist,
         }
 
     def test_check_output(self):
diff --git a/python/paddle/v2/fluid/tests/test_layers.py b/python/paddle/v2/fluid/tests/test_layers.py
index 58544b2982519f9badbfad97cbd2cd6bf13136e6..8104599e42cc57a48db8be6d8fdb476b39ed39f8 100644
--- a/python/paddle/v2/fluid/tests/test_layers.py
+++ b/python/paddle/v2/fluid/tests/test_layers.py
@@ -17,8 +17,9 @@ import unittest
 
 import paddle.v2.fluid.layers as layers
 import paddle.v2.fluid.nets as nets
-from paddle.v2.fluid.framework import Program, program_guard
+from paddle.v2.fluid.framework import Program, program_guard, default_main_program
 from paddle.v2.fluid.param_attr import ParamAttr
+import decorators
 
 
 class TestBook(unittest.TestCase):
@@ -235,6 +236,41 @@ class TestBook(unittest.TestCase):
             self.assertIsNotNone(output)
         print(str(program))
 
+    @decorators.prog_scope()
+    def test_nce(self):
+        window_size = 5
+        words = []
+        for i in xrange(window_size):
+            words.append(
+                layers.data(
+                    name='word_{0}'.format(i), shape=[1], dtype='int64'))
+
+        dict_size = 10000
+        label_word = int(window_size / 2) + 1
+
+        embs = []
+        for i in xrange(window_size):
+            if i == label_word:
+                continue
+
+            emb = layers.embedding(
+                input=words[i],
+                size=[dict_size, 32],
+                param_attr='emb.w',
+                is_sparse=True)
+
+            embs.append(emb)
+
+        embs = layers.concat(input=embs, axis=1)
+        loss = layers.nce(input=embs,
+                          label=words[label_word],
+                          num_total_classes=dict_size,
+                          param_attr='nce.w',
+                          bias_attr='nce.b')
+        avg_loss = layers.mean(x=loss)
+        self.assertIsNotNone(avg_loss)
+        print(str(default_main_program()))
+
 
 if __name__ == '__main__':
     unittest.main()