diff --git a/doc/design/ops/dist_train.md b/doc/design/ops/dist_train.md
new file mode 100644
index 0000000000000000000000000000000000000000..fa3c5d7990213cf2b0d236e66e592dd2699da876
--- /dev/null
+++ b/doc/design/ops/dist_train.md
@@ -0,0 +1,106 @@
+# Design Doc: Operation Graph Based Parameter Server
+
+## Abstract
+
+We propose an approach to implement the parameter server. In this
+approach, there is no fundamental difference between the trainer and
+the parameter server: they both run subgraphs, but subgraphs of
+different purposes.
+
+## Background
+
+The previous implementations of the parameter server does not run a
+subgraph. parameter initialization, optimizer computation, network
+communication and checkpointing are implemented twice on both the
+trainer and the parameter server.
+
+It would be great if we can write code once and use them on both the
+trainer and the parameter server: reduces code duplication and
+improves extensibility. Given that after the current refactor, we are
+representing everything as a computing graph on the
+trainer. Representing everything as a computing graph on the parameter
+server becomes a natural extension.
+
+## Design
+
+### Graph Converter
+
+The *graph converter* converts the user-defined operation (OP) graph
+into subgraphs to be scheduled on different nodes with the following
+steps:
+
+1. OP placement: the OPs will be placed on different nodes according
+   to heuristic that minimizes estimated total computation
+   time. Currently we will use a simple heuristic that puts parameter
+   varable on parameter server workers and everything else on trainer
+   workers.
+
+1. Add communication OPs to enable the communication between nodes.
+
+We will need these OPs: *Send*, *Recv*, *Enqueue*, *Dequeue*.
+
+Below is an example of converting the user defined graph to the
+subgraphs for the trainer and the parameter server:
+
+ +
+After converting:
+
+
+
+After converting:
+
+ +
+1. The parameter variable W and it's optimizer subgraph are placed on the parameter server.
+1. Operators are added to the subgraphs.
+   - *Send* sends data to the connected *Recv* operator.  The
+	 scheduler on the receive node will only schedule *Recv* operator
+	 to run when the *Send* operator has ran (the *Send* OP will mark
+	 the *Recv* OP runnable automatically).
+   - *Enueue* enqueues the input variable, it can block until space
+     become available in the queue.
+   - *Dequeue* outputs configurable numbers of tensors from the
+     queue. It will block until the queue have the required number of
+     tensors.
+
+
+### Benefits
+
+- Model parallelism become easier to implement: it's an extension to
+  the trainer - parameter server approach. we already have the
+  communication OPs, but need to extend the graph converter's
+  placement functionality.
+
+- User-defined optimizer is easier to add - user can now express it as
+  a subgraph.
+
+- No more duplication logic inside the trainer and the parameter
+  server mentioned in the background section.
+
+### Challenges
+
+- It might be hard for the graph converter to cut a general graph
+  (without any hint for which subgraph is the optimizer). We may need
+  to label which subgraph inside the OP graph is the optimizer.
+
+- It's important to balance the parameter shards of on multiple
+  parameter server. If a single parameter is very big (some
+  word-embedding, fully connected, softmax layer), we need to
+  automatically partition the single parameter onto different
+  parameter servers when possible (only element-wise optimizer depends
+  on the parameter variable).
+
+### Discussion
+
+- In the "Aync SGD" figure, the "W" variable on the parameter server
+  could be read and wrote concurrently, what is our locking strategy?
+  E.g., each variable have a lock cpp method to be invoked by every
+  OP, or, have a lock OP.
+
+- Can the Enqueue OP be implemented under our current tensor design
+  (puts the input tensor into the queue tensor)?
+
+- *Dequeue* OP will have variable numbers of output (depends on the
+  `min_count` attribute), does our current design support it? (similar
+  question for the *Add* OP)
+
+
+### References:
+[1] [TensorFlow: Large-Scale Machine Learning on Heterogeneous Distributed Systems](https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/45166.pdf)
diff --git a/doc/design/ops/src/dist-graph.graffle b/doc/design/ops/src/dist-graph.graffle
new file mode 100644
index 0000000000000000000000000000000000000000..941399c6ced8d5f65b6c595522b770c88259df4b
Binary files /dev/null and b/doc/design/ops/src/dist-graph.graffle differ
diff --git a/doc/design/ops/src/dist-graph.png b/doc/design/ops/src/dist-graph.png
new file mode 100644
index 0000000000000000000000000000000000000000..3546b09f1c2ee3e4f60f519d5e47f823f08051a7
Binary files /dev/null and b/doc/design/ops/src/dist-graph.png differ
diff --git a/doc/design/ops/src/local-graph.graffle b/doc/design/ops/src/local-graph.graffle
new file mode 100644
index 0000000000000000000000000000000000000000..19e509bd9af3c1e9a3f5e0f16ddd281457a339c5
Binary files /dev/null and b/doc/design/ops/src/local-graph.graffle differ
diff --git a/doc/design/ops/src/local-graph.png b/doc/design/ops/src/local-graph.png
new file mode 100644
index 0000000000000000000000000000000000000000..ada51200f793a9bb18911e7d63cfdb3244b967d7
Binary files /dev/null and b/doc/design/ops/src/local-graph.png differ
diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc
index 790cfc4746b1d34da413fa3c29a266f962c6dde6..e1e122091f7759b1a68f1f982bc2a35e8241f9f0 100644
--- a/paddle/framework/operator.cc
+++ b/paddle/framework/operator.cc
@@ -123,6 +123,15 @@ OperatorBase::OperatorBase(const std::string& type,
   CheckAllInputOutputSet();
 }
 
+std::vector OperatorBase::InputVars() const {
+  std::vector ret_val;
+  for (auto& o : outputs_) {
+    ret_val.reserve(ret_val.size() + o.second.size());
+    ret_val.insert(ret_val.end(), o.second.begin(), o.second.end());
+  }
+  return ret_val;
+}
+
 std::vector OperatorBase::OutputVars(bool has_intermediate) const {
   std::vector ret_val;
   if (has_intermediate) {
diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h
index 9a98d4d3be0d1cb875d614b263f1e4365ede4113..4600b06009bcef7d0774d25b816aac4733f30795 100644
--- a/paddle/framework/operator.h
+++ b/paddle/framework/operator.h
@@ -94,11 +94,14 @@ class OperatorBase {
 
   const VariableNameMap& Inputs() const { return inputs_; }
   const VariableNameMap& Outputs() const { return outputs_; }
+
   //! Get a input with argument's name described in `op_proto`
   std::string Input(const std::string& name) const;
   //! Get a input which has multiple variables.
   const std::vector& Inputs(const std::string& name) const;
 
+  std::vector InputVars() const;
+
   //! Get a output with argument's name described in `op_proto`
   std::string Output(const std::string& name) const;
   //! Get an output which has multiple variables.
@@ -311,9 +314,9 @@ class InferShapeContext {
   }
 
   template 
-  std::vector MultiOutput(const std::string& name) const {
+  std::vector MultiOutput(const std::string& name) const {
     auto names = op_.Outputs(name);
-    std::vector res;
+    std::vector res;
     res.reserve(names.size());
     std::transform(names.begin(), names.end(), std::back_inserter(res),
                    [&](const std::string& sub_name) {
diff --git a/paddle/gserver/layers/BatchNormBaseLayer.cpp b/paddle/gserver/layers/BatchNormBaseLayer.cpp
index 1ceaaaa206ee3cbc5421238574c7f310011ccaa5..f7a80e23e1bd49549bec57b360587adc6b423794 100644
--- a/paddle/gserver/layers/BatchNormBaseLayer.cpp
+++ b/paddle/gserver/layers/BatchNormBaseLayer.cpp
@@ -62,14 +62,18 @@ void BatchNormBaseLayer::calFeatureMapSize() {
   const ImageConfig& conf = config_.inputs(0).image_conf();
   imageH_ = inputLayers_[0]->getOutput().getFrameHeight();
   imageW_ = inputLayers_[0]->getOutput().getFrameWidth();
+  imageD_ = inputLayers_[0]->getOutput().getFrameDepth();
+
+  if (0 == imageD_) imageD_ = conf.img_size_z();
   if (imageH_ == 0 && imageW_ == 0) {
     imageH_ = conf.has_img_size_y() ? conf.img_size_y() : conf.img_size();
     imageW_ = conf.img_size();
   } else {
     getOutput().setFrameHeight(imageH_);
     getOutput().setFrameWidth(imageW_);
+    getOutput().setFrameDepth(imageD_);
   }
-  imgPixels_ = imageH_ * imageW_;
+  imgPixels_ = imageH_ * imageW_ * imageD_;
 }
 
 }  // namespace paddle
diff --git a/paddle/gserver/layers/BatchNormBaseLayer.h b/paddle/gserver/layers/BatchNormBaseLayer.h
index 230bafc31d96bbd49481a7ed135be6888688627e..e721d2d267a31cae46407673b8b1281e87055608 100644
--- a/paddle/gserver/layers/BatchNormBaseLayer.h
+++ b/paddle/gserver/layers/BatchNormBaseLayer.h
@@ -80,6 +80,7 @@ protected:
 
   /// Height or width of input image feature.
   /// Both of them are 1 if the input is fully-connected layer.
+  int imageD_;
   int imageH_;
   int imageW_;
   /// Height * Width.
diff --git a/paddle/gserver/layers/CudnnBatchNormLayer.cpp b/paddle/gserver/layers/CudnnBatchNormLayer.cpp
index 44ba2c4b7d1562d2ce839b5f4b4de1af35e6925f..49a9540c0b6e36b59ed786287ff5c4569b69a6a5 100644
--- a/paddle/gserver/layers/CudnnBatchNormLayer.cpp
+++ b/paddle/gserver/layers/CudnnBatchNormLayer.cpp
@@ -37,7 +37,7 @@ bool CudnnBatchNormLayer::init(const LayerMap& layerMap,
 }
 
 void CudnnBatchNormLayer::reshape(int batchSize) {
-  hl_tensor_reshape(ioDesc_, batchSize, channels_, imageH_, imageW_);
+  hl_tensor_reshape(ioDesc_, batchSize, channels_, imageH_ * imageD_, imageW_);
 }
 
 void CudnnBatchNormLayer::forward(PassType passType) {
@@ -104,7 +104,7 @@ void CudnnBatchNormLayer::forward(PassType passType) {
                                    EPS,
                                    batchSize,
                                    channels_,
-                                   imageH_,
+                                   imageH_ * imageD_,
                                    imageW_);
     }
   }
diff --git a/paddle/gserver/layers/SwitchOrderLayer.cpp b/paddle/gserver/layers/SwitchOrderLayer.cpp
index 92cd61cdd515d5c693df086c9575a5f197c00cee..d7eee6eaf078dab8d48adc4c7ee758a433672ac6 100644
--- a/paddle/gserver/layers/SwitchOrderLayer.cpp
+++ b/paddle/gserver/layers/SwitchOrderLayer.cpp
@@ -24,10 +24,12 @@ bool SwitchOrderLayer::init(const LayerMap& layerMap,
   /* Initialize the basic parent class */
   Layer::init(layerMap, parameterMap);
   auto& img_conf = config_.inputs(0).image_conf();
+  size_t inD = img_conf.img_size_z();
   size_t inH =
       img_conf.has_img_size_y() ? img_conf.img_size_y() : img_conf.img_size();
   size_t inW = img_conf.img_size();
   size_t inC = img_conf.channels();
+  inH = inH * inD;
   inDims_ = TensorShape({0, inC, inH, inW});
   outDims_ = TensorShape(4);
 
@@ -64,9 +66,10 @@ void SwitchOrderLayer::setInDims() {
   MatrixPtr input = inputLayers_[0]->getOutputValue();
   size_t batchSize = input->getHeight();
   inDims_.setDim(0, batchSize);
-
+  int d = inputLayers_[0]->getOutput().getFrameDepth();
+  d = (d == 0 ? 1 : d);
   int h = inputLayers_[0]->getOutput().getFrameHeight();
-  if (h != 0) inDims_.setDim(2, h);
+  if (h != 0) inDims_.setDim(2, h * d);
   int w = inputLayers_[0]->getOutput().getFrameWidth();
   if (w != 0) inDims_.setDim(3, w);
   int totalCount = input->getElementCnt();
diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp
index d1f3bc241fa621cb0070125980996e8627e40fd6..0e6be2df9ef5f0fae8ed2b0c65ac6c032fe45ab1 100644
--- a/paddle/gserver/tests/test_LayerGrad.cpp
+++ b/paddle/gserver/tests/test_LayerGrad.cpp
@@ -1703,6 +1703,55 @@ TEST(Layer, BatchNormalizationLayer) {
 #endif
 }
 
+void testBatchNorm3DLayer(const string& type, bool trans, bool useGpu) {
+  TestConfig config;
+  const int CHANNELS = 10;
+  const int IMG_SIZE = 16;
+  const int IMG_SIZE_Y = 8;
+  const int IMG_SIZE_Z = 8;
+  size_t size = CHANNELS * IMG_SIZE * IMG_SIZE_Y * IMG_SIZE_Z;
+  config.layerConfig.set_type(type);
+  config.layerConfig.set_size(size);
+  config.layerConfig.set_active_type("sigmoid");
+  config.biasSize = CHANNELS;
+  config.inputDefs.push_back({INPUT_DATA,
+                              "layer_0",
+                              /* dim= */ size,
+                              /* paraSize= */ CHANNELS});
+
+  config.inputDefs.push_back({INPUT_DATA, "layer_1_running_mean", 1, CHANNELS});
+  config.inputDefs.back().isStatic = true;
+  config.inputDefs.push_back({INPUT_DATA, "layer_2_running_var", 1, CHANNELS});
+  config.inputDefs.back().isStatic = true;
+
+  LayerInputConfig* input = config.layerConfig.add_inputs();
+  config.layerConfig.add_inputs();
+  config.layerConfig.add_inputs();
+
+  ImageConfig* img_conf = input->mutable_image_conf();
+  img_conf->set_channels(CHANNELS);
+  img_conf->set_img_size(IMG_SIZE);
+  img_conf->set_img_size_y(IMG_SIZE_Y);
+  img_conf->set_img_size_z(IMG_SIZE_Z);
+
+  testLayerGrad(config,
+                "batch_norm",
+                64,
+                /* trans= */ trans,
+                useGpu,
+                /* useWeight */ true);
+}
+
+TEST(Layer, testBatchNorm3DLayer) {
+  testBatchNorm3DLayer("batch_norm", false, false);
+#ifndef PADDLE_ONLY_CPU
+  testBatchNorm3DLayer("batch_norm", false, true);
+  if (hl_get_cudnn_lib_version() >= int(4000)) {
+    testBatchNorm3DLayer("cudnn_batch_norm", false, true);
+  }
+#endif
+}
+
 void testConvOperator(bool isDeconv) {
   TestConfig config;
   const int NUM_FILTERS = 16;
diff --git a/paddle/operators/sum_op.cc b/paddle/operators/sum_op.cc
new file mode 100644
index 0000000000000000000000000000000000000000..5805826ee8a555ca6dfc1ca81feaadffea9e1012
--- /dev/null
+++ b/paddle/operators/sum_op.cc
@@ -0,0 +1,73 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/operators/sum_op.h"
+#include 
+
+namespace paddle {
+namespace operators {
+using framework::Tensor;
+
+class SumOp : public framework::OperatorWithKernel {
+ public:
+  using framework::OperatorWithKernel::OperatorWithKernel;
+
+ protected:
+  void InferShape(const framework::InferShapeContext &ctx) const override {
+    auto ins = ctx.MultiInput("X");
+    auto *out = ctx.Output("Out");
+    int N = ins.size();
+
+    auto in_dim = ins[0]->dims();
+
+    PADDLE_ENFORCE_GT(N, 1, "Input tensors count should > 1.");
+    for (int i = 1; i < N; i++) {
+      auto dim = ins[i]->dims();
+      PADDLE_ENFORCE(in_dim == dim, "Input tensors must have same shape");
+    }
+    out->Resize(in_dim);
+  }
+};
+
+class SumOpMaker : public framework::OpProtoAndCheckerMaker {
+ public:
+  SumOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
+      : OpProtoAndCheckerMaker(proto, op_checker) {
+    AddInput("X", "the input tensors of sum operator.").AsDuplicable();
+    AddOutput("Out", "the output tensor of sum operator.");
+    AddComment(R"DOC(
+            Sum the input tensors.
+        )DOC");
+  }
+};
+
+class SumGradOp : public framework::OperatorWithKernel {
+ public:
+  using framework::OperatorWithKernel::OperatorWithKernel;
+
+ protected:
+  void InferShape(const framework::InferShapeContext &ctx) const override {
+    auto outputs = ctx.MultiOutput(framework::GradVarName("X"));
+    auto dims = ctx.Input(framework::GradVarName("Out"))->dims();
+    for (auto output : outputs) {
+      output->Resize(dims);
+    }
+  }
+};
+
+}  // namespace operators
+}  // namespace paddle
+
+namespace ops = paddle::operators;
+REGISTER_OP(sum, ops::SumOp, ops::SumOpMaker, sum_grad, ops::SumGradOp);
+REGISTER_OP_CPU_KERNEL(sum, ops::SumKernel);
+REGISTER_OP_CPU_KERNEL(sum_grad,
+                       ops::SumGradKernel);
diff --git a/paddle/operators/sum_op.cu b/paddle/operators/sum_op.cu
new file mode 100644
index 0000000000000000000000000000000000000000..a465cf3659ba7c51338abadfc62962fb6755a39d
--- /dev/null
+++ b/paddle/operators/sum_op.cu
@@ -0,0 +1,18 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#define EIGEN_USE_GPU
+#include "paddle/operators/sum_op.h"
+
+namespace ops = paddle::operators;
+REGISTER_OP_GPU_KERNEL(sum, ops::SumKernel);
+REGISTER_OP_GPU_KERNEL(sum_grad,
+                       ops::SumGradKernel);
diff --git a/paddle/operators/sum_op.h b/paddle/operators/sum_op.h
new file mode 100644
index 0000000000000000000000000000000000000000..0b1e9ebaa38d455fb5e3ce8c1a39cbbcdad9a940
--- /dev/null
+++ b/paddle/operators/sum_op.h
@@ -0,0 +1,65 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+#include "paddle/framework/eigen.h"
+#include "paddle/framework/op_registry.h"
+
+namespace paddle {
+namespace operators {
+
+using Tensor = framework::Tensor;
+template 
+using EigenVector = framework::EigenVector;
+
+template 
+class SumKernel : public framework::OpKernel {
+ public:
+  void Compute(const framework::ExecutionContext& context) const override {
+    auto ins = context.MultiInput("X");
+    auto* out = context.Output("Out");
+    out->mutable_data(context.GetPlace());
+
+    auto place = context.GetEigenDevice();
+    auto result = EigenVector::Flatten(*out);
+
+    int N = ins.size();
+    auto in = EigenVector::Flatten(*(ins[0]));
+    result.device(place) = in;
+    for (int i = 1; i < N; i++) {
+      auto in = EigenVector::Flatten(*(ins[i]));
+      result.device(place) = result + in;
+    }
+  }
+};
+
+template 
+class SumGradKernel : public framework::OpKernel {
+ public:
+  void Compute(const framework::ExecutionContext& context) const override {
+    auto* input = context.Input(framework::GradVarName("Out"));
+    auto outs = context.MultiOutput(framework::GradVarName("X"));
+    for (auto out : outs) {
+      out->mutable_data(context.GetPlace());
+    }
+
+    auto place = context.GetEigenDevice();
+    auto in = EigenVector::Flatten(*input);
+    for (auto out : outs) {
+      auto result = EigenVector::Flatten(*out);
+      result.device(place) = in;
+    }
+  }
+};
+
+}  // namespace operators
+}  // namespace paddle
diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc
index 109e62e8739f4b0cb0e2ba7d3f7cf2a2f5cbb9b7..be68c0930c849f969e58d6c786842acb99806eeb 100644
--- a/paddle/pybind/pybind.cc
+++ b/paddle/pybind/pybind.cc
@@ -52,6 +52,7 @@ USE_OP(pad);
 USE_CPU_ONLY_OP(scatter);
 USE_OP(top_k);
 USE_OP(squared_l2_distance);
+USE_OP(sum);
 
 namespace paddle {
 namespace framework {
@@ -217,7 +218,10 @@ All parameter, weight, gradient are variables in Paddle.
                -> std::map> {
                  return op.Outputs();
                })
+      .def("output_vars",
+           [](const OperatorBase &op) { return op.OutputVars(true); })
       .def("inputs", [](const OperatorBase &op) { return op.Inputs(); })
+      .def("input_vars", [](const OperatorBase &op) { return op.InputVars(); })
       .def("__str__", &OperatorBase::DebugString)
       .def("no_intermediate_outputs",
            [](const OperatorBase &op) { return op.OutputVars(false); })
diff --git a/proto/ModelConfig.proto b/proto/ModelConfig.proto
index 7d7fc23a4691646dfce4c162a445864c748501d9..ebf0911d6ea0b39d51447859ae2aef485b50b0e6 100644
--- a/proto/ModelConfig.proto
+++ b/proto/ModelConfig.proto
@@ -271,6 +271,7 @@ message ImageConfig {
   // The size of input feature map.
   required uint32 img_size = 8;
   optional uint32 img_size_y = 9;
+  optional uint32 img_size_z = 10 [ default = 1 ];
 }
 
 message PriorBoxConfig {
@@ -519,6 +520,7 @@ message LayerConfig {
   // for HuberRegressionLoss
   optional double delta = 57 [ default = 1.0 ];
 
+  // for 3D data
   optional uint64 depth = 58 [ default = 1 ];
 
   // for switch order layer
diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py
index 11dc84ae20679bb73735f9119739fca5ea7fa673..7e9112b43bf851575a3a798886d8b1b17e7c2017 100644
--- a/python/paddle/trainer/config_parser.py
+++ b/python/paddle/trainer/config_parser.py
@@ -1332,6 +1332,12 @@ def parse_image(image, input_layer_name, image_conf):
         get_img_size(input_layer_name, image_conf.channels)
 
 
+def parse_image3d(image, input_layer_name, image_conf):
+    image_conf.channels = image.channels
+    image_conf.img_size, image_conf.img_size_y, image_conf.img_size_z = \
+        get_img3d_size(input_layer_name, image_conf.channels)
+
+
 def parse_norm(norm, input_layer_name, norm_conf):
     norm_conf.norm_type = norm.norm_type
     config_assert(
@@ -2365,9 +2371,11 @@ class BatchNormLayer(LayerBase):
                  name,
                  inputs,
                  bias=True,
+                 img3D=False,
                  use_global_stats=True,
                  moving_average_fraction=0.9,
                  batch_norm_type=None,
+                 mean_var_names=None,
                  **xargs):
         if inputs is None:
             inputs = []
@@ -2409,24 +2417,69 @@ class BatchNormLayer(LayerBase):
 
         input_layer = self.get_input_layer(0)
         image_conf = self.config.inputs[0].image_conf
-        parse_image(self.inputs[0].image, input_layer.name, image_conf)
-
-        # Only pass the width and height of input to batch_norm layer
-        # when either of it is non-zero.
-        if input_layer.width != 0 or input_layer.height != 0:
-            self.set_cnn_layer(name, image_conf.img_size_y, image_conf.img_size,
-                               image_conf.channels, False)
+        if img3D:
+            parse_image3d(self.inputs[0].image, input_layer.name, image_conf)
+            # Only pass the width and height of input to batch_norm layer
+            # when either of it is non-zero.
+            if input_layer.width != 0 or input_layer.height != 0:
+                self.set_cnn_layer(
+                    input_layer_name=name,
+                    depth=image_conf.img_size_z,
+                    height=image_conf.img_size_y,
+                    width=image_conf.img_size,
+                    channels=image_conf.channels,
+                    is_print=True)
+            else:
+                self.set_layer_size(input_layer.size)
         else:
-            self.set_layer_size(input_layer.size)
+            parse_image(self.inputs[0].image, input_layer.name, image_conf)
+            # Only pass the width and height of input to batch_norm layer
+            # when either of it is non-zero.
+            if input_layer.width != 0 or input_layer.height != 0:
+                self.set_cnn_layer(
+                    input_layer_name=name,
+                    height=image_conf.img_size_y,
+                    width=image_conf.img_size,
+                    channels=image_conf.channels,
+                    is_print=True)
+            else:
+                self.set_layer_size(input_layer.size)
 
         psize = self.calc_parameter_size(image_conf)
         dims = [1, psize]
+        if mean_var_names is not None:
+            assert len(mean_var_names) == 2
+            self.inputs[1].parameter_name = mean_var_names[0]
+            self.inputs[2].parameter_name = mean_var_names[1]
+
         self.create_input_parameter(0, psize)
         self.create_input_parameter(1, psize, dims)
         self.create_input_parameter(2, psize, dims)
 
         self.create_bias_parameter(bias, psize)
 
+    def set_cnn_layer(self,
+                      input_layer_name,
+                      depth=None,
+                      height=None,
+                      width=None,
+                      channels=None,
+                      is_print=True):
+        depthIsNone = False
+        if depth is None:
+            depth = 1
+            depthIsNone = True
+        size = depth * height * width * channels
+        self.set_layer_size(size)
+        self.set_layer_height_width(height, width)
+        self.set_layer_depth(depth)
+        if is_print and depthIsNone:
+            print("output for %s: c = %d, h = %d, w = %d, size = %d" %
+                  (input_layer_name, channels, height, width, size))
+        elif is_print:
+            print("output for %s: c = %d, d = %d, h = %d, w = %d, size = %d" %
+                  (input_layer_name, channels, depth, height, width, size))
+
     def calc_parameter_size(self, image_conf):
         return image_conf.channels
 
@@ -2688,9 +2741,20 @@ class AddToLayer(LayerBase):
         super(AddToLayer, self).__init__(
             name, 'addto', 0, inputs=inputs, **xargs)
         config_assert(len(inputs) > 0, 'inputs cannot be empty for AddToLayer')
-        for input_index in xrange(len(self.inputs)):
-            input_layer = self.get_input_layer(input_index)
-            self.set_layer_size(input_layer.size)
+
+        if len(self.inputs) > 1:
+            for input_index in xrange(len(self.inputs)):
+                assert self.get_input_layer(0).height == self.get_input_layer(
+                    input_index).height
+                assert self.get_input_layer(0).width == self.get_input_layer(
+                    input_index).width
+                assert self.get_input_layer(0).depth == self.get_input_layer(
+                    input_index).depth
+
+        self.set_layer_size(self.get_input_layer(0).size)
+        self.set_layer_height_width(self.get_input_layer(0).height, \
+                                        self.get_input_layer(0).width)
+        self.set_layer_depth(self.get_input_layer(0).depth)
         self.create_bias_parameter(bias, self.config.size)
 
 
@@ -3370,11 +3434,20 @@ class ConcatenateLayer(LayerBase):
             name, 'concat', 0, inputs=inputs, **xargs)
         size = 0
         for input_index in xrange(len(self.inputs)):
+            assert self.get_input_layer(0).height == self.get_input_layer(
+                input_index).height
+            assert self.get_input_layer(0).width == self.get_input_layer(
+                input_index).width
+            assert self.get_input_layer(0).depth == self.get_input_layer(
+                input_index).depth
             input_layer = self.get_input_layer(input_index)
             input = self.inputs[input_index]
             if self.config.size == 0:
                 size += input_layer.size
 
+        self.set_layer_height_width(self.get_input_layer(0).height, \
+                                    self.get_input_layer(0).width)
+        self.set_layer_depth(self.get_input_layer(0).depth)
         self.set_layer_size(size)
 
 
diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py
index cba45bd3afa178ab4dd3a50f0947b144e7466e53..dc68c213da66ac680e6b14266cb5038a5ba73ec2 100644
--- a/python/paddle/trainer_config_helpers/layers.py
+++ b/python/paddle/trainer_config_helpers/layers.py
@@ -354,6 +354,10 @@ class LayerOutput(object):
     def height(self):
         return cp.g_layer_map[self.full_name].height
 
+    @property
+    def depth(self):
+        return cp.g_layer_map[self.full_name].depth
+
     def set_input(self, input):
         """
         Set the input for a memory layer. Can only be used for memory layer
@@ -943,7 +947,7 @@ def data_layer(name, size, depth=None, height=None, width=None,
     if height is not None and width is not None:
         num_filters = size / (width * height * depth)
         assert num_filters * width * height * depth == size, \
-                "size=%s width=%s height=%s depth=%s"  % (size, width, height, depth)
+                "size=%s width=%s height=%s depth=%s" % (size, width, height, depth)
 
     return LayerOutput(name, LayerType.DATA, size=size, num_filters=num_filters)
 
@@ -2953,13 +2957,15 @@ def img_cmrnorm_layer(input,
 def batch_norm_layer(input,
                      act=None,
                      name=None,
+                     img3D=False,
                      num_channels=None,
                      bias_attr=None,
                      param_attr=None,
                      layer_attr=None,
                      batch_norm_type=None,
                      moving_average_fraction=0.9,
-                     use_global_stats=None):
+                     use_global_stats=None,
+                     mean_var_names=None):
     """
     Batch Normalization Layer. The notation of this layer as follow.
 
@@ -3026,6 +3032,8 @@ def batch_norm_layer(input,
                                    :math:`runningMean = newMean*(1-factor)
                                    + runningMean*factor`
     :type moving_average_fraction: float.
+    :param mean_var_names: [mean name, variance name]
+    :type mean_var_names: string list
     :return: LayerOutput object.
     :rtype: LayerOutput
     """
@@ -3039,6 +3047,7 @@ def batch_norm_layer(input,
            (batch_norm_type == "cudnn_batch_norm")
     l = Layer(
         name=name,
+        img3D=img3D,
         inputs=Input(
             input.name, image=Image(channels=num_channels), **param_attr.attr),
         active_type=act.name,
@@ -3047,6 +3056,7 @@ def batch_norm_layer(input,
         bias=ParamAttr.to_bias(bias_attr),
         moving_average_fraction=moving_average_fraction,
         use_global_stats=use_global_stats,
+        mean_var_names=mean_var_names,
         **ExtraLayerAttribute.to_kwargs(layer_attr))
 
     return LayerOutput(
@@ -6410,7 +6420,7 @@ def gated_unit_layer(input,
 @wrap_name_default('switch_order')
 def switch_order_layer(input,
                        name=None,
-                       reshape=None,
+                       reshape_axis=None,
                        act=None,
                        layer_attr=None):
     """
@@ -6421,8 +6431,9 @@ def switch_order_layer(input,
     The example usage is:
 
     .. code-block:: python
+       reshape_axis = 3
+       switch = switch_order(input=layer, name='switch', reshape_axis=reshape_axis)
        reshape = {'height':[ 0, 1, 2], 'width':[3]}
-       switch = switch_order(input=layer, name='switch', reshape=reshape)
 
     :param input: The input layer.
     :type input: LayerOutput
@@ -6434,6 +6445,11 @@ def switch_order_layer(input,
     :rtype: LayerOutput
     """
     assert isinstance(input, LayerOutput)
+    assert reshape_axis != None and (reshape_axis > 0 and reshape_axis < 4)
+    height = [ele for ele in xrange(reshape_axis)]
+    width = [ele for ele in range(reshape_axis, 4)]
+    reshape = {'height': height, 'width': width}
+
     l = Layer(
         name=name,
         inputs=input.name,
diff --git a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh
index df872a90ff388f0d96cef44763dbd076bc768ab9..8a204a96f3ef57673cef65306d0bf8e8c3409751 100755
--- a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh
+++ b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh
@@ -10,6 +10,6 @@ test_prelu_layer test_row_conv test_detection_output_layer test_multibox_loss_la
 test_recursive_topology test_gated_unit_layer test_clip_layer test_row_l2_norm_layer
 test_kmax_seq_socre_layer test_sub_nested_seq_select_layer test_scale_shift_layer
 test_seq_slice_layer test_cross_entropy_over_beam test_pooling3D_layer
-test_conv3d_layer test_deconv3d_layer)
+test_conv3d_layer test_deconv3d_layer test_BatchNorm3D)
 
 export whole_configs=(test_split_datasource)
diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/img_layers.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/img_layers.protostr
index 1a577b8d9b1e1915236ba6afcfa97040d70c707a..5ddf6052df021b055390a42c25ce6c0d650e4aee 100644
--- a/python/paddle/trainer_config_helpers/tests/configs/protostr/img_layers.protostr
+++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/img_layers.protostr
@@ -62,6 +62,7 @@ layers {
   moving_average_fraction: 0.9
   height: 227
   width: 227
+  depth: 1
 }
 layers {
   name: "__crmnorm_0__"
diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/img_trans_layers.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/img_trans_layers.protostr
index 2818389b16cca75f5030b75fc4de8c89c06c5e02..c0252b945b4c7fd6b4dad8770e3e1dccb88df28a 100644
--- a/python/paddle/trainer_config_helpers/tests/configs/protostr/img_trans_layers.protostr
+++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/img_trans_layers.protostr
@@ -62,6 +62,7 @@ layers {
   moving_average_fraction: 0.9
   height: 256
   width: 256
+  depth: 1
 }
 layers {
   name: "__crmnorm_0__"
diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_BatchNorm3D.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_BatchNorm3D.protostr
new file mode 100644
index 0000000000000000000000000000000000000000..832ed24a31dd2bedba9a4fce77d7a088d1796fdb
--- /dev/null
+++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_BatchNorm3D.protostr
@@ -0,0 +1,92 @@
+type: "nn"
+layers {
+  name: "data3D"
+  type: "data"
+  size: 360
+  active_type: ""
+  height: 6
+  width: 20
+  depth: 3
+}
+layers {
+  name: "__batch_norm_0__"
+  type: "batch_norm"
+  size: 360
+  active_type: "relu"
+  inputs {
+    input_layer_name: "data3D"
+    input_parameter_name: "___batch_norm_0__.w0"
+    image_conf {
+      channels: 1
+      img_size: 20
+      img_size_y: 6
+      img_size_z: 3
+    }
+  }
+  inputs {
+    input_layer_name: "data3D"
+    input_parameter_name: "___batch_norm_0__.w1"
+  }
+  inputs {
+    input_layer_name: "data3D"
+    input_parameter_name: "___batch_norm_0__.w2"
+  }
+  bias_parameter_name: "___batch_norm_0__.wbias"
+  moving_average_fraction: 0.9
+  height: 6
+  width: 20
+  depth: 3
+}
+parameters {
+  name: "___batch_norm_0__.w0"
+  size: 1
+  initial_mean: 1.0
+  initial_std: 0.0
+  initial_strategy: 0
+  initial_smart: false
+}
+parameters {
+  name: "___batch_norm_0__.w1"
+  size: 1
+  initial_mean: 0.0
+  initial_std: 0.0
+  dims: 1
+  dims: 1
+  initial_strategy: 0
+  initial_smart: false
+  is_static: true
+  is_shared: true
+}
+parameters {
+  name: "___batch_norm_0__.w2"
+  size: 1
+  initial_mean: 0.0
+  initial_std: 0.0
+  dims: 1
+  dims: 1
+  initial_strategy: 0
+  initial_smart: false
+  is_static: true
+  is_shared: true
+}
+parameters {
+  name: "___batch_norm_0__.wbias"
+  size: 1
+  initial_mean: 0.0
+  initial_std: 0.0
+  dims: 1
+  dims: 1
+  initial_strategy: 0
+  initial_smart: false
+}
+input_layer_names: "data3D"
+output_layer_names: "__batch_norm_0__"
+sub_models {
+  name: "root"
+  layer_names: "data3D"
+  layer_names: "__batch_norm_0__"
+  input_layer_names: "data3D"
+  output_layer_names: "__batch_norm_0__"
+  is_recurrent_layer_group: false
+}
+
diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_bi_grumemory.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_bi_grumemory.protostr
index b110e91498ce7d112987714bd769868179141c54..8a1399efad0ff339e35f69400ac654a4787a6018 100644
--- a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_bi_grumemory.protostr
+++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_bi_grumemory.protostr
@@ -74,6 +74,9 @@ layers {
   inputs {
     input_layer_name: "__bidirectional_gru_0___bw"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 parameters {
   name: "___bidirectional_gru_0___fw_transform.w0"
diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_recursive_topology.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_recursive_topology.protostr
index 8133aa9c8d3e7c6843d1b27b70e87d394a1e0e47..046037936a6d85f54095c65f206e468aa69065d7 100644
--- a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_recursive_topology.protostr
+++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_recursive_topology.protostr
@@ -16,6 +16,9 @@ layers {
   inputs {
     input_layer_name: "data"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_1__"
@@ -28,6 +31,9 @@ layers {
   inputs {
     input_layer_name: "__addto_0__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_2__"
@@ -40,6 +46,9 @@ layers {
   inputs {
     input_layer_name: "__addto_1__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_3__"
@@ -52,6 +61,9 @@ layers {
   inputs {
     input_layer_name: "__addto_2__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_4__"
@@ -64,6 +76,9 @@ layers {
   inputs {
     input_layer_name: "__addto_3__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_5__"
@@ -76,6 +91,9 @@ layers {
   inputs {
     input_layer_name: "__addto_4__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_6__"
@@ -88,6 +106,9 @@ layers {
   inputs {
     input_layer_name: "__addto_5__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_7__"
@@ -100,6 +121,9 @@ layers {
   inputs {
     input_layer_name: "__addto_6__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_8__"
@@ -112,6 +136,9 @@ layers {
   inputs {
     input_layer_name: "__addto_7__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_9__"
@@ -124,6 +151,9 @@ layers {
   inputs {
     input_layer_name: "__addto_8__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_10__"
@@ -136,6 +166,9 @@ layers {
   inputs {
     input_layer_name: "__addto_9__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_11__"
@@ -148,6 +181,9 @@ layers {
   inputs {
     input_layer_name: "__addto_10__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_12__"
@@ -160,6 +196,9 @@ layers {
   inputs {
     input_layer_name: "__addto_11__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_13__"
@@ -172,6 +211,9 @@ layers {
   inputs {
     input_layer_name: "__addto_12__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_14__"
@@ -184,6 +226,9 @@ layers {
   inputs {
     input_layer_name: "__addto_13__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_15__"
@@ -196,6 +241,9 @@ layers {
   inputs {
     input_layer_name: "__addto_14__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_16__"
@@ -208,6 +256,9 @@ layers {
   inputs {
     input_layer_name: "__addto_15__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_17__"
@@ -220,6 +271,9 @@ layers {
   inputs {
     input_layer_name: "__addto_16__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_18__"
@@ -232,6 +286,9 @@ layers {
   inputs {
     input_layer_name: "__addto_17__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_19__"
@@ -244,6 +301,9 @@ layers {
   inputs {
     input_layer_name: "__addto_18__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_20__"
@@ -256,6 +316,9 @@ layers {
   inputs {
     input_layer_name: "__addto_19__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_21__"
@@ -268,6 +331,9 @@ layers {
   inputs {
     input_layer_name: "__addto_20__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_22__"
@@ -280,6 +346,9 @@ layers {
   inputs {
     input_layer_name: "__addto_21__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_23__"
@@ -292,6 +361,9 @@ layers {
   inputs {
     input_layer_name: "__addto_22__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_24__"
@@ -304,6 +376,9 @@ layers {
   inputs {
     input_layer_name: "__addto_23__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_25__"
@@ -316,6 +391,9 @@ layers {
   inputs {
     input_layer_name: "__addto_24__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_26__"
@@ -328,6 +406,9 @@ layers {
   inputs {
     input_layer_name: "__addto_25__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_27__"
@@ -340,6 +421,9 @@ layers {
   inputs {
     input_layer_name: "__addto_26__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_28__"
@@ -352,6 +436,9 @@ layers {
   inputs {
     input_layer_name: "__addto_27__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_29__"
@@ -364,6 +451,9 @@ layers {
   inputs {
     input_layer_name: "__addto_28__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_30__"
@@ -376,6 +466,9 @@ layers {
   inputs {
     input_layer_name: "__addto_29__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_31__"
@@ -388,6 +481,9 @@ layers {
   inputs {
     input_layer_name: "__addto_30__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__fc_layer_0__"
diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/util_layers.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/util_layers.protostr
index d0ad388165007b8f96f059e5b003c52f756383e5..7a2f3eab38808a031c27cf7ab9d6273952e389eb 100644
--- a/python/paddle/trainer_config_helpers/tests/configs/protostr/util_layers.protostr
+++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/util_layers.protostr
@@ -22,6 +22,9 @@ layers {
   inputs {
     input_layer_name: "b"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__concat_0__"
@@ -34,6 +37,9 @@ layers {
   inputs {
     input_layer_name: "b"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__concat_1__"
diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_BatchNorm3D.py b/python/paddle/trainer_config_helpers/tests/configs/test_BatchNorm3D.py
new file mode 100644
index 0000000000000000000000000000000000000000..a991b22252ba10eed895efd931108c2d8b0e52f1
--- /dev/null
+++ b/python/paddle/trainer_config_helpers/tests/configs/test_BatchNorm3D.py
@@ -0,0 +1,11 @@
+from paddle.trainer_config_helpers import *
+
+settings(batch_size=1000, learning_rate=1e-4)
+
+#data = data_layer(name='data', size=180, width=30, height=6)
+#batchNorm = batch_norm_layer(data, num_channels=1)
+#outputs(batchNorm)
+
+data3D = data_layer(name='data3D', size=120 * 3, width=20, height=6, depth=3)
+batchNorm3D = batch_norm_layer(data3D, num_channels=1, img3D=True)
+outputs(batchNorm3D)
diff --git a/python/paddle/v2/framework/op.py b/python/paddle/v2/framework/op.py
index 78c64e261b7eea07c902743ff86a552c4b7dc355..8109227828bafda85d3a556dda928acd7c1fc94c 100644
--- a/python/paddle/v2/framework/op.py
+++ b/python/paddle/v2/framework/op.py
@@ -142,8 +142,8 @@ def create_op_creation_method(op_proto):
     return OpInfo(
         method=__impl__,
         name=op_proto.type,
-        inputs=[var.name for var in op_proto.inputs],
-        outputs=[var.name for var in op_proto.outputs],
+        inputs=[(var.name, var.duplicable) for var in op_proto.inputs],
+        outputs=[(var.name, var.duplicable) for var in op_proto.outputs],
         attrs=[attr.name for attr in op_proto.attrs])
 
 
@@ -180,9 +180,15 @@ class OperatorFactory(object):
         return self.op_methods.get(t)
 
     def get_op_input_names(self, type):
+        return map(lambda x: x[0], self.get_op_info(type).inputs)
+
+    def get_op_inputs(self, type):
         return self.get_op_info(type).inputs
 
     def get_op_output_names(self, type):
+        return map(lambda x: x[0], self.get_op_info(type).outputs)
+
+    def get_op_outputs(self, type):
         return self.get_op_info(type).outputs
 
     def get_op_attr_names(self, type):
diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt
index ef910f939be0b9d3cb5e6d49e69a00daa191b1c6..2117fdf0d58520a008d2bd01d56d96dd248be025 100644
--- a/python/paddle/v2/framework/tests/CMakeLists.txt
+++ b/python/paddle/v2/framework/tests/CMakeLists.txt
@@ -33,5 +33,6 @@ py_test(test_sgd_op SRCS test_sgd_op.py)
 py_test(test_gradient_checker SRCS test_gradient_checker.py)
 py_test(test_lookup_table SRCS test_lookup_table.py)
 py_test(test_scale_and_identity_op SRCS test_scale_and_identity_op.py)
+py_test(test_sum_op SRCS test_sum_op.py)
 py_test(mnist SRCS mnist.py)
 py_test(test_squared_l2_distance_op SRCS test_squared_l2_distance_op.py)
diff --git a/python/paddle/v2/framework/tests/op_test.py b/python/paddle/v2/framework/tests/op_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a6a5dca4c4ddc1399d80e491e4072f24707c01e
--- /dev/null
+++ b/python/paddle/v2/framework/tests/op_test.py
@@ -0,0 +1,275 @@
+import unittest
+import numpy as np
+import itertools
+import paddle.v2.framework.core as core
+from paddle.v2.framework.op import Operator
+
+
+def grad_var_name(var_name):
+    return var_name + "@GRAD"
+
+
+def create_op(scope, op_type, inputs, outputs, attrs=None):
+    kwargs = dict()
+
+    for in_name, in_dup in Operator.get_op_inputs(op_type):
+        if in_name in inputs:
+            kwargs[in_name] = []
+            if in_dup:
+                sub_in = inputs[in_name]
+                for sub_in_name in sub_in:
+                    var = scope.new_var(sub_in_name)
+                    kwargs[in_name].append(sub_in_name)
+            else:
+                var = scope.new_var(in_name)
+                kwargs[in_name].append(in_name)
+
+    for out_name, out_dup in Operator.get_op_outputs(op_type):
+        if out_name in outputs:
+            kwargs[out_name] = []
+            if out_dup:
+                sub_in = outputs[out_name]
+                for sun_in_name in sub_in:
+                    var = scope.new_var(sun_in_name)
+                    kwargs[out_name].append(sun_in_name)
+            else:
+                var = scope.new_var(out_name)
+                kwargs[out_name].append(out_name)
+
+    for attr_name in Operator.get_op_attr_names(op_type):
+        kwargs[attr_name] = attrs[attr_name]
+    return Operator(op_type, **kwargs)
+
+
+def set_input(scope, op, inputs, place):
+    for in_name, in_dup in Operator.get_op_inputs(op.type()):
+        if in_name in inputs:
+            if in_dup:
+                sub_in = inputs[in_name]
+                for sub_in_name in sub_in:
+                    var = scope.find_var(sub_in_name)
+                    tensor = var.get_tensor()
+                    arr = sub_in[sub_in_name]
+                    tensor.set_dims(arr.shape)
+                    tensor.set(arr, place)
+            else:
+                var = scope.find_var(in_name)
+                tensor = var.get_tensor()
+                arr = inputs[in_name]
+                tensor.set_dims(arr.shape)
+                tensor.set(arr, place)
+
+
+def set_output_grad(scope, op, outputs, place):
+    for out_name, out_dup in Operator.get_op_outputs(op.type()):
+        if out_name in outputs:
+            if out_dup:
+                sub_out = outputs[out_name]
+                for sub_out_name in sub_out:
+                    out_tensor = scope.find_var(sub_out_name).get_tensor()
+                    grad_tensor = scope.new_var(grad_var_name(
+                        sub_out_name)).get_tensor()
+                    grad_tensor.set_dims(out_tensor.shape())
+                    data = np.ones(out_tensor.shape(), dtype=np.float32)
+                    grad_tensor.set(data, place)
+            else:
+                out_tensor = scope.find_var(out_name).get_tensor()
+                grad_tensor = scope.new_var(grad_var_name(out_name)).get_tensor(
+                )
+                grad_tensor.set_dims(out_tensor.shape())
+                data = np.ones(out_tensor.shape(), dtype=np.float32)
+                grad_tensor.set(data, place)
+
+
+def get_numeric_gradient(scope,
+                         op,
+                         inputs,
+                         input_to_check,
+                         output_name,
+                         delta=0.005,
+                         in_place=False):
+
+    set_input(scope, op, inputs, core.CPUPlace())
+    op.infer_shape(scope)
+
+    tensor_to_check = scope.find_var(input_to_check).get_tensor()
+
+    def product(dim):
+        return reduce(lambda a, b: a * b, dim, 1)
+
+    ctx = core.DeviceContext.create(core.CPUPlace())
+
+    def get_output():
+        op.run(scope, ctx)
+        return np.array(scope.find_var(output_name).get_tensor()).sum()
+
+    tensor_to_check = scope.find_var(input_to_check).get_tensor()
+    tensor_size = product(tensor_to_check.get_dims())
+    gradient_flat = np.zeros(shape=(tensor_size, ), dtype='float32')
+    # we only compute gradient of one element each time.
+    # we use a for loop to compute the gradient of every element.
+    for i in xrange(tensor_size):
+        if in_place:
+            set_input(op, inputs, core.CPUPlace())
+
+        # get one input element throw it's index i.
+        origin = tensor_to_check.get_float_element(i)
+        # add delta to it, run op and then get the sum of the result tensor.
+        x_pos = origin + delta
+        tensor_to_check.set_float_element(i, x_pos)
+        y_pos = get_output()
+
+        if in_place:
+            set_input(op, inputs, core.CPUPlace())
+
+        x_neg = origin - delta
+        tensor_to_check.set_float_element(i, x_neg)
+        y_neg = get_output()
+
+        tensor_to_check.set_float_element(i, origin)
+        gradient_flat[i] = (y_pos - y_neg) / delta / 2
+
+    return gradient_flat.reshape(tensor_to_check.get_dims())
+
+
+def get_backward_op(scope, op, no_grad_set):
+    backward_op = core.Operator.backward(op, no_grad_set)
+    for input in backward_op.input_vars():
+        var = scope.new_var(input)
+        var.get_tensor()
+    for output in backward_op.output_vars():
+        var = scope.new_var(output)
+        var.get_tensor()
+    return backward_op
+
+
+def get_gradient(scope, op, inputs, outputs, grad_name, place,
+                 no_grad_set=None):
+    ctx = core.DeviceContext.create(place)
+
+    set_input(scope, op, inputs, place)
+
+    op.infer_shape(scope)
+    op.run(scope, ctx)
+
+    if no_grad_set is None:
+        no_grad_set = set()
+
+    backward_op = get_backward_op(scope, op, no_grad_set)
+    set_output_grad(scope, op, outputs, place)
+
+    backward_op.infer_shape(scope)
+    backward_op.run(scope, ctx)
+
+    out = np.array(scope.find_var(grad_name).get_tensor())
+    return out
+
+
+class OpTest(unittest.TestCase):
+    def check_output_with_place(self, place):
+        self.scope = core.Scope()
+        self.op = create_op(self.scope, self.op_type, self.inputs, self.outputs)
+        if isinstance(place, core.GPUPlace) and not self.op.support_gpu():
+            return
+        set_input(self.scope, self.op, self.inputs, place)
+        self.op.infer_shape(self.scope)
+        ctx = core.DeviceContext.create(place)
+        self.op.run(self.scope, ctx)
+
+        for out_name, out_dup in Operator.get_op_outputs(self.op.type()):
+            if out_dup:
+                sub_out = self.outputs[out_name]
+                for sub_out_name in sub_out:
+                    actual = np.array(
+                        self.scope.find_var(sub_out_name).get_tensor())
+                    expect = sub_out[sub_out_name]
+                    self.assertTrue(
+                        np.allclose(
+                            actual, expect, atol=1e-05),
+                        "output name: " + out_name + "has diff")
+            else:
+                actual = np.array(self.scope.find_var(out_name).get_tensor())
+                expect = self.outputs[out_name]
+                self.assertTrue(
+                    np.allclose(
+                        actual, expect, atol=1e-05),
+                    "output name: " + out_name + "has diff")
+
+    def check_output(self):
+        places = [core.CPUPlace()]
+        if core.is_compile_gpu():
+            places.append(core.GPUPlace(0))
+        for place in places:
+            self.check_output_with_place(place)
+
+    def __assert_is_close(self, numeric_grads, analytic_grads, names,
+                          max_relative_error, msg_prefix):
+
+        for a, b, name in itertools.izip(numeric_grads, analytic_grads, names):
+            abs_a = np.abs(a)
+            abs_a[abs_a < 1e-3] = 1
+
+            diff_mat = np.abs(a - b) / abs_a
+            max_diff = np.max(diff_mat)
+
+            def err_msg():
+                offset = np.argmax(diff_mat > max_relative_error)
+                return "%s Variable %s max gradient diff %f over limit %f, the first " \
+                  "error element is %d" % (
+                   msg_prefix, name, max_diff, max_relative_error, offset)
+
+            self.assertLessEqual(max_diff, max_relative_error, err_msg())
+
+    def check_grad(self,
+                   inputs_to_check,
+                   output_name,
+                   no_grad_set=None,
+                   in_place=False,
+                   max_relative_error=0.005):
+        self.scope = core.Scope()
+        self.op = create_op(self.scope, self.op_type, self.inputs, self.outputs)
+        if no_grad_set is None:
+            no_grad_set = set()
+
+        numeric_grads = [
+            get_numeric_gradient(
+                self.scope,
+                self.op,
+                self.inputs,
+                input_to_check,
+                output_name,
+                in_place=in_place) for input_to_check in inputs_to_check
+        ]
+        grad_names = [
+            grad_var_name(input_to_check) for input_to_check in inputs_to_check
+        ]
+
+        cpu_place = core.CPUPlace()
+        cpu_analytic_grads = [
+            get_gradient(self.scope, self.op, self.inputs, self.outputs,
+                         grad_name, cpu_place, no_grad_set)
+            for grad_name in grad_names
+        ]
+
+        self.__assert_is_close(numeric_grads, cpu_analytic_grads, grad_names,
+                               max_relative_error,
+                               "Gradient Check On %s" % str(cpu_place))
+
+        if core.is_compile_gpu() and self.op.support_gpu():
+            gpu_place = core.GPUPlace(0)
+            gpu_analytic_grads = [
+                get_gradient(self.scope, self.op, self.inputs, self.outputs,
+                             grad_name, gpu_place, no_grad_set)
+                for grad_name in grad_names
+            ]
+
+            self.__assert_is_close(numeric_grads, gpu_analytic_grads,
+                                   grad_names, max_relative_error,
+                                   "Gradient Check On %s" % str(gpu_place))
+
+            for c_grad, g_grad, name in itertools.izip(
+                    cpu_analytic_grads, gpu_analytic_grads, grad_names):
+                self.assertTrue(
+                    np.allclose(
+                        c_grad, g_grad, atol=1e-4),
+                    "output name: " + name + " has diff")
diff --git a/python/paddle/v2/framework/tests/test_cross_entropy_op.py b/python/paddle/v2/framework/tests/test_cross_entropy_op.py
index d4277f2a42ce2e66e37405ccd3b2ee444d403d1a..fb6a440e23c26d1766bdf1fc5f24217afe1150f8 100644
--- a/python/paddle/v2/framework/tests/test_cross_entropy_op.py
+++ b/python/paddle/v2/framework/tests/test_cross_entropy_op.py
@@ -1,36 +1,27 @@
 import unittest
 import numpy
-from op_test_util import OpTestMeta
-from gradient_checker import GradientChecker, create_op
+from op_test import OpTest
 
 
-class TestCrossEntropy(unittest.TestCase):
-    __metaclass__ = OpTestMeta
-
+class TestCrossEntropy(OpTest):
     def setUp(self):
-        self.type = "onehot_cross_entropy"
+        self.op_type = "onehot_cross_entropy"
         batch_size = 30
         class_num = 10
-        X = numpy.random.random((batch_size, class_num)).astype("float32")
-        label = 5 * numpy.ones(batch_size).astype("int32")
+        X = numpy.random.uniform(0.1, 1.0,
+                                 [batch_size, class_num]).astype("float32")
+        label = (class_num / 2) * numpy.ones(batch_size).astype("int32")
         self.inputs = {'X': X, 'label': label}
         Y = []
         for i in range(0, batch_size):
             Y.append(-numpy.log(X[i][label[i]]))
         self.outputs = {'Y': numpy.array(Y).astype("float32")}
 
+    def test_check_output(self):
+        self.check_output()
 
-class CrossEntropyGradOpTest(GradientChecker):
     def test_check_grad(self):
-        op = create_op("onehot_cross_entropy")
-        batch_size = 30
-        class_num = 10
-        inputs = {
-            "X": numpy.random.uniform(
-                0.1, 1.0, [batch_size, class_num]).astype("float32"),
-            "label": (class_num / 2) * numpy.ones(batch_size).astype("int32")
-        }
-        self.check_grad(op, inputs, set("X"), "Y")
+        self.check_grad(["X"], "Y")
 
 
 if __name__ == "__main__":
diff --git a/python/paddle/v2/framework/tests/test_lookup_table.py b/python/paddle/v2/framework/tests/test_lookup_table.py
index 19eb464baa555fb67a994f3cfb4d3ed628367c73..4b7ce92c0f0492a73c158378299933a0b329948b 100644
--- a/python/paddle/v2/framework/tests/test_lookup_table.py
+++ b/python/paddle/v2/framework/tests/test_lookup_table.py
@@ -4,7 +4,7 @@ from op_test_util import OpTestMeta
 from gradient_checker import GradientChecker, create_op
 
 
-class TestSigmoidOp(unittest.TestCase):
+class TestLookupTableOp(unittest.TestCase):
     __metaclass__ = OpTestMeta
 
     def setUp(self):
@@ -15,7 +15,7 @@ class TestSigmoidOp(unittest.TestCase):
         self.outputs = {'Out': table[ids]}
 
 
-class TestSigmoidGradOp(GradientChecker):
+class TestLookupTableGradOp(GradientChecker):
     def test_grad(self):
         op = create_op('lookup_table')
         table = np.random.random((17, 31)).astype('float32')
diff --git a/python/paddle/v2/framework/tests/test_sigmoid_op.py b/python/paddle/v2/framework/tests/test_sigmoid_op.py
index 273c2e5ab1a84d12621fe9568c4cf22073b6aed4..2316e49eff7bb1cdb53acb3889a6ef05060b59f3 100644
--- a/python/paddle/v2/framework/tests/test_sigmoid_op.py
+++ b/python/paddle/v2/framework/tests/test_sigmoid_op.py
@@ -1,27 +1,21 @@
 import unittest
 import numpy as np
-from op_test_util import OpTestMeta
-from gradient_checker import GradientChecker, create_op
+from op_test import OpTest
 
 
-class TestSigmoidOp(unittest.TestCase):
-    __metaclass__ = OpTestMeta
-
+class TestSigmoid(OpTest):
     def setUp(self):
-        self.type = "sigmoid"
-        self.inputs = {'X': np.random.random((15, 31)).astype("float32")}
+        self.op_type = "sigmoid"
+        self.inputs = {
+            'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32")
+        }
         self.outputs = {'Y': 1 / (1 + np.exp(-self.inputs['X']))}
 
+    def test_check_output(self):
+        self.check_output()
 
-class TestSigmoidGradOp(GradientChecker):
-    def test_grad(self):
-        op = create_op("sigmoid")
-        inputs = {"X": np.random.uniform(0.1, 1, [11, 17]).astype("float32")}
-        # compare gpu and cpu results for backward op.
-        # this test will be skiped if only compiling CPU version.
-        self.compare_grad(op, inputs)
-        # check gradients 
-        self.check_grad(op, inputs, set("X"), "Y", max_relative_error=0.007)
+    def test_check_grad(self):
+        self.check_grad(["X"], "Y", max_relative_error=0.007)
 
 
 if __name__ == '__main__':
diff --git a/python/paddle/v2/framework/tests/test_sum_op.py b/python/paddle/v2/framework/tests/test_sum_op.py
new file mode 100644
index 0000000000000000000000000000000000000000..66417d70e81186465e6f59a17fb62255afeddea5
--- /dev/null
+++ b/python/paddle/v2/framework/tests/test_sum_op.py
@@ -0,0 +1,24 @@
+import unittest
+import numpy as np
+from op_test import OpTest
+
+
+class TestSumOp(OpTest):
+    def setUp(self):
+        self.op_type = "sum"
+        x0 = np.random.random((3, 4)).astype('float32')
+        x1 = np.random.random((3, 4)).astype('float32')
+        x2 = np.random.random((3, 4)).astype('float32')
+        self.inputs = {"X": {"x0": x0, "x1": x1, "x2": x2}}
+        y = x0 + x1 + x2
+        self.outputs = {'Out': y}
+
+    def test_check_output(self):
+        self.check_output()
+
+    def test_check_grad(self):
+        self.check_grad(["x0"], "Out")
+
+
+if __name__ == '__main__':
+    unittest.main()
+
+1. The parameter variable W and it's optimizer subgraph are placed on the parameter server.
+1. Operators are added to the subgraphs.
+   - *Send* sends data to the connected *Recv* operator.  The
+	 scheduler on the receive node will only schedule *Recv* operator
+	 to run when the *Send* operator has ran (the *Send* OP will mark
+	 the *Recv* OP runnable automatically).
+   - *Enueue* enqueues the input variable, it can block until space
+     become available in the queue.
+   - *Dequeue* outputs configurable numbers of tensors from the
+     queue. It will block until the queue have the required number of
+     tensors.
+
+
+### Benefits
+
+- Model parallelism become easier to implement: it's an extension to
+  the trainer - parameter server approach. we already have the
+  communication OPs, but need to extend the graph converter's
+  placement functionality.
+
+- User-defined optimizer is easier to add - user can now express it as
+  a subgraph.
+
+- No more duplication logic inside the trainer and the parameter
+  server mentioned in the background section.
+
+### Challenges
+
+- It might be hard for the graph converter to cut a general graph
+  (without any hint for which subgraph is the optimizer). We may need
+  to label which subgraph inside the OP graph is the optimizer.
+
+- It's important to balance the parameter shards of on multiple
+  parameter server. If a single parameter is very big (some
+  word-embedding, fully connected, softmax layer), we need to
+  automatically partition the single parameter onto different
+  parameter servers when possible (only element-wise optimizer depends
+  on the parameter variable).
+
+### Discussion
+
+- In the "Aync SGD" figure, the "W" variable on the parameter server
+  could be read and wrote concurrently, what is our locking strategy?
+  E.g., each variable have a lock cpp method to be invoked by every
+  OP, or, have a lock OP.
+
+- Can the Enqueue OP be implemented under our current tensor design
+  (puts the input tensor into the queue tensor)?
+
+- *Dequeue* OP will have variable numbers of output (depends on the
+  `min_count` attribute), does our current design support it? (similar
+  question for the *Add* OP)
+
+
+### References:
+[1] [TensorFlow: Large-Scale Machine Learning on Heterogeneous Distributed Systems](https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/45166.pdf)
diff --git a/doc/design/ops/src/dist-graph.graffle b/doc/design/ops/src/dist-graph.graffle
new file mode 100644
index 0000000000000000000000000000000000000000..941399c6ced8d5f65b6c595522b770c88259df4b
Binary files /dev/null and b/doc/design/ops/src/dist-graph.graffle differ
diff --git a/doc/design/ops/src/dist-graph.png b/doc/design/ops/src/dist-graph.png
new file mode 100644
index 0000000000000000000000000000000000000000..3546b09f1c2ee3e4f60f519d5e47f823f08051a7
Binary files /dev/null and b/doc/design/ops/src/dist-graph.png differ
diff --git a/doc/design/ops/src/local-graph.graffle b/doc/design/ops/src/local-graph.graffle
new file mode 100644
index 0000000000000000000000000000000000000000..19e509bd9af3c1e9a3f5e0f16ddd281457a339c5
Binary files /dev/null and b/doc/design/ops/src/local-graph.graffle differ
diff --git a/doc/design/ops/src/local-graph.png b/doc/design/ops/src/local-graph.png
new file mode 100644
index 0000000000000000000000000000000000000000..ada51200f793a9bb18911e7d63cfdb3244b967d7
Binary files /dev/null and b/doc/design/ops/src/local-graph.png differ
diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc
index 790cfc4746b1d34da413fa3c29a266f962c6dde6..e1e122091f7759b1a68f1f982bc2a35e8241f9f0 100644
--- a/paddle/framework/operator.cc
+++ b/paddle/framework/operator.cc
@@ -123,6 +123,15 @@ OperatorBase::OperatorBase(const std::string& type,
   CheckAllInputOutputSet();
 }
 
+std::vector OperatorBase::InputVars() const {
+  std::vector ret_val;
+  for (auto& o : outputs_) {
+    ret_val.reserve(ret_val.size() + o.second.size());
+    ret_val.insert(ret_val.end(), o.second.begin(), o.second.end());
+  }
+  return ret_val;
+}
+
 std::vector OperatorBase::OutputVars(bool has_intermediate) const {
   std::vector ret_val;
   if (has_intermediate) {
diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h
index 9a98d4d3be0d1cb875d614b263f1e4365ede4113..4600b06009bcef7d0774d25b816aac4733f30795 100644
--- a/paddle/framework/operator.h
+++ b/paddle/framework/operator.h
@@ -94,11 +94,14 @@ class OperatorBase {
 
   const VariableNameMap& Inputs() const { return inputs_; }
   const VariableNameMap& Outputs() const { return outputs_; }
+
   //! Get a input with argument's name described in `op_proto`
   std::string Input(const std::string& name) const;
   //! Get a input which has multiple variables.
   const std::vector& Inputs(const std::string& name) const;
 
+  std::vector InputVars() const;
+
   //! Get a output with argument's name described in `op_proto`
   std::string Output(const std::string& name) const;
   //! Get an output which has multiple variables.
@@ -311,9 +314,9 @@ class InferShapeContext {
   }
 
   template 
-  std::vector MultiOutput(const std::string& name) const {
+  std::vector MultiOutput(const std::string& name) const {
     auto names = op_.Outputs(name);
-    std::vector res;
+    std::vector res;
     res.reserve(names.size());
     std::transform(names.begin(), names.end(), std::back_inserter(res),
                    [&](const std::string& sub_name) {
diff --git a/paddle/gserver/layers/BatchNormBaseLayer.cpp b/paddle/gserver/layers/BatchNormBaseLayer.cpp
index 1ceaaaa206ee3cbc5421238574c7f310011ccaa5..f7a80e23e1bd49549bec57b360587adc6b423794 100644
--- a/paddle/gserver/layers/BatchNormBaseLayer.cpp
+++ b/paddle/gserver/layers/BatchNormBaseLayer.cpp
@@ -62,14 +62,18 @@ void BatchNormBaseLayer::calFeatureMapSize() {
   const ImageConfig& conf = config_.inputs(0).image_conf();
   imageH_ = inputLayers_[0]->getOutput().getFrameHeight();
   imageW_ = inputLayers_[0]->getOutput().getFrameWidth();
+  imageD_ = inputLayers_[0]->getOutput().getFrameDepth();
+
+  if (0 == imageD_) imageD_ = conf.img_size_z();
   if (imageH_ == 0 && imageW_ == 0) {
     imageH_ = conf.has_img_size_y() ? conf.img_size_y() : conf.img_size();
     imageW_ = conf.img_size();
   } else {
     getOutput().setFrameHeight(imageH_);
     getOutput().setFrameWidth(imageW_);
+    getOutput().setFrameDepth(imageD_);
   }
-  imgPixels_ = imageH_ * imageW_;
+  imgPixels_ = imageH_ * imageW_ * imageD_;
 }
 
 }  // namespace paddle
diff --git a/paddle/gserver/layers/BatchNormBaseLayer.h b/paddle/gserver/layers/BatchNormBaseLayer.h
index 230bafc31d96bbd49481a7ed135be6888688627e..e721d2d267a31cae46407673b8b1281e87055608 100644
--- a/paddle/gserver/layers/BatchNormBaseLayer.h
+++ b/paddle/gserver/layers/BatchNormBaseLayer.h
@@ -80,6 +80,7 @@ protected:
 
   /// Height or width of input image feature.
   /// Both of them are 1 if the input is fully-connected layer.
+  int imageD_;
   int imageH_;
   int imageW_;
   /// Height * Width.
diff --git a/paddle/gserver/layers/CudnnBatchNormLayer.cpp b/paddle/gserver/layers/CudnnBatchNormLayer.cpp
index 44ba2c4b7d1562d2ce839b5f4b4de1af35e6925f..49a9540c0b6e36b59ed786287ff5c4569b69a6a5 100644
--- a/paddle/gserver/layers/CudnnBatchNormLayer.cpp
+++ b/paddle/gserver/layers/CudnnBatchNormLayer.cpp
@@ -37,7 +37,7 @@ bool CudnnBatchNormLayer::init(const LayerMap& layerMap,
 }
 
 void CudnnBatchNormLayer::reshape(int batchSize) {
-  hl_tensor_reshape(ioDesc_, batchSize, channels_, imageH_, imageW_);
+  hl_tensor_reshape(ioDesc_, batchSize, channels_, imageH_ * imageD_, imageW_);
 }
 
 void CudnnBatchNormLayer::forward(PassType passType) {
@@ -104,7 +104,7 @@ void CudnnBatchNormLayer::forward(PassType passType) {
                                    EPS,
                                    batchSize,
                                    channels_,
-                                   imageH_,
+                                   imageH_ * imageD_,
                                    imageW_);
     }
   }
diff --git a/paddle/gserver/layers/SwitchOrderLayer.cpp b/paddle/gserver/layers/SwitchOrderLayer.cpp
index 92cd61cdd515d5c693df086c9575a5f197c00cee..d7eee6eaf078dab8d48adc4c7ee758a433672ac6 100644
--- a/paddle/gserver/layers/SwitchOrderLayer.cpp
+++ b/paddle/gserver/layers/SwitchOrderLayer.cpp
@@ -24,10 +24,12 @@ bool SwitchOrderLayer::init(const LayerMap& layerMap,
   /* Initialize the basic parent class */
   Layer::init(layerMap, parameterMap);
   auto& img_conf = config_.inputs(0).image_conf();
+  size_t inD = img_conf.img_size_z();
   size_t inH =
       img_conf.has_img_size_y() ? img_conf.img_size_y() : img_conf.img_size();
   size_t inW = img_conf.img_size();
   size_t inC = img_conf.channels();
+  inH = inH * inD;
   inDims_ = TensorShape({0, inC, inH, inW});
   outDims_ = TensorShape(4);
 
@@ -64,9 +66,10 @@ void SwitchOrderLayer::setInDims() {
   MatrixPtr input = inputLayers_[0]->getOutputValue();
   size_t batchSize = input->getHeight();
   inDims_.setDim(0, batchSize);
-
+  int d = inputLayers_[0]->getOutput().getFrameDepth();
+  d = (d == 0 ? 1 : d);
   int h = inputLayers_[0]->getOutput().getFrameHeight();
-  if (h != 0) inDims_.setDim(2, h);
+  if (h != 0) inDims_.setDim(2, h * d);
   int w = inputLayers_[0]->getOutput().getFrameWidth();
   if (w != 0) inDims_.setDim(3, w);
   int totalCount = input->getElementCnt();
diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp
index d1f3bc241fa621cb0070125980996e8627e40fd6..0e6be2df9ef5f0fae8ed2b0c65ac6c032fe45ab1 100644
--- a/paddle/gserver/tests/test_LayerGrad.cpp
+++ b/paddle/gserver/tests/test_LayerGrad.cpp
@@ -1703,6 +1703,55 @@ TEST(Layer, BatchNormalizationLayer) {
 #endif
 }
 
+void testBatchNorm3DLayer(const string& type, bool trans, bool useGpu) {
+  TestConfig config;
+  const int CHANNELS = 10;
+  const int IMG_SIZE = 16;
+  const int IMG_SIZE_Y = 8;
+  const int IMG_SIZE_Z = 8;
+  size_t size = CHANNELS * IMG_SIZE * IMG_SIZE_Y * IMG_SIZE_Z;
+  config.layerConfig.set_type(type);
+  config.layerConfig.set_size(size);
+  config.layerConfig.set_active_type("sigmoid");
+  config.biasSize = CHANNELS;
+  config.inputDefs.push_back({INPUT_DATA,
+                              "layer_0",
+                              /* dim= */ size,
+                              /* paraSize= */ CHANNELS});
+
+  config.inputDefs.push_back({INPUT_DATA, "layer_1_running_mean", 1, CHANNELS});
+  config.inputDefs.back().isStatic = true;
+  config.inputDefs.push_back({INPUT_DATA, "layer_2_running_var", 1, CHANNELS});
+  config.inputDefs.back().isStatic = true;
+
+  LayerInputConfig* input = config.layerConfig.add_inputs();
+  config.layerConfig.add_inputs();
+  config.layerConfig.add_inputs();
+
+  ImageConfig* img_conf = input->mutable_image_conf();
+  img_conf->set_channels(CHANNELS);
+  img_conf->set_img_size(IMG_SIZE);
+  img_conf->set_img_size_y(IMG_SIZE_Y);
+  img_conf->set_img_size_z(IMG_SIZE_Z);
+
+  testLayerGrad(config,
+                "batch_norm",
+                64,
+                /* trans= */ trans,
+                useGpu,
+                /* useWeight */ true);
+}
+
+TEST(Layer, testBatchNorm3DLayer) {
+  testBatchNorm3DLayer("batch_norm", false, false);
+#ifndef PADDLE_ONLY_CPU
+  testBatchNorm3DLayer("batch_norm", false, true);
+  if (hl_get_cudnn_lib_version() >= int(4000)) {
+    testBatchNorm3DLayer("cudnn_batch_norm", false, true);
+  }
+#endif
+}
+
 void testConvOperator(bool isDeconv) {
   TestConfig config;
   const int NUM_FILTERS = 16;
diff --git a/paddle/operators/sum_op.cc b/paddle/operators/sum_op.cc
new file mode 100644
index 0000000000000000000000000000000000000000..5805826ee8a555ca6dfc1ca81feaadffea9e1012
--- /dev/null
+++ b/paddle/operators/sum_op.cc
@@ -0,0 +1,73 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/operators/sum_op.h"
+#include 
+
+namespace paddle {
+namespace operators {
+using framework::Tensor;
+
+class SumOp : public framework::OperatorWithKernel {
+ public:
+  using framework::OperatorWithKernel::OperatorWithKernel;
+
+ protected:
+  void InferShape(const framework::InferShapeContext &ctx) const override {
+    auto ins = ctx.MultiInput("X");
+    auto *out = ctx.Output("Out");
+    int N = ins.size();
+
+    auto in_dim = ins[0]->dims();
+
+    PADDLE_ENFORCE_GT(N, 1, "Input tensors count should > 1.");
+    for (int i = 1; i < N; i++) {
+      auto dim = ins[i]->dims();
+      PADDLE_ENFORCE(in_dim == dim, "Input tensors must have same shape");
+    }
+    out->Resize(in_dim);
+  }
+};
+
+class SumOpMaker : public framework::OpProtoAndCheckerMaker {
+ public:
+  SumOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
+      : OpProtoAndCheckerMaker(proto, op_checker) {
+    AddInput("X", "the input tensors of sum operator.").AsDuplicable();
+    AddOutput("Out", "the output tensor of sum operator.");
+    AddComment(R"DOC(
+            Sum the input tensors.
+        )DOC");
+  }
+};
+
+class SumGradOp : public framework::OperatorWithKernel {
+ public:
+  using framework::OperatorWithKernel::OperatorWithKernel;
+
+ protected:
+  void InferShape(const framework::InferShapeContext &ctx) const override {
+    auto outputs = ctx.MultiOutput(framework::GradVarName("X"));
+    auto dims = ctx.Input(framework::GradVarName("Out"))->dims();
+    for (auto output : outputs) {
+      output->Resize(dims);
+    }
+  }
+};
+
+}  // namespace operators
+}  // namespace paddle
+
+namespace ops = paddle::operators;
+REGISTER_OP(sum, ops::SumOp, ops::SumOpMaker, sum_grad, ops::SumGradOp);
+REGISTER_OP_CPU_KERNEL(sum, ops::SumKernel);
+REGISTER_OP_CPU_KERNEL(sum_grad,
+                       ops::SumGradKernel);
diff --git a/paddle/operators/sum_op.cu b/paddle/operators/sum_op.cu
new file mode 100644
index 0000000000000000000000000000000000000000..a465cf3659ba7c51338abadfc62962fb6755a39d
--- /dev/null
+++ b/paddle/operators/sum_op.cu
@@ -0,0 +1,18 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#define EIGEN_USE_GPU
+#include "paddle/operators/sum_op.h"
+
+namespace ops = paddle::operators;
+REGISTER_OP_GPU_KERNEL(sum, ops::SumKernel);
+REGISTER_OP_GPU_KERNEL(sum_grad,
+                       ops::SumGradKernel);
diff --git a/paddle/operators/sum_op.h b/paddle/operators/sum_op.h
new file mode 100644
index 0000000000000000000000000000000000000000..0b1e9ebaa38d455fb5e3ce8c1a39cbbcdad9a940
--- /dev/null
+++ b/paddle/operators/sum_op.h
@@ -0,0 +1,65 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+#include "paddle/framework/eigen.h"
+#include "paddle/framework/op_registry.h"
+
+namespace paddle {
+namespace operators {
+
+using Tensor = framework::Tensor;
+template 
+using EigenVector = framework::EigenVector;
+
+template 
+class SumKernel : public framework::OpKernel {
+ public:
+  void Compute(const framework::ExecutionContext& context) const override {
+    auto ins = context.MultiInput("X");
+    auto* out = context.Output("Out");
+    out->mutable_data(context.GetPlace());
+
+    auto place = context.GetEigenDevice();
+    auto result = EigenVector::Flatten(*out);
+
+    int N = ins.size();
+    auto in = EigenVector::Flatten(*(ins[0]));
+    result.device(place) = in;
+    for (int i = 1; i < N; i++) {
+      auto in = EigenVector::Flatten(*(ins[i]));
+      result.device(place) = result + in;
+    }
+  }
+};
+
+template 
+class SumGradKernel : public framework::OpKernel {
+ public:
+  void Compute(const framework::ExecutionContext& context) const override {
+    auto* input = context.Input(framework::GradVarName("Out"));
+    auto outs = context.MultiOutput(framework::GradVarName("X"));
+    for (auto out : outs) {
+      out->mutable_data(context.GetPlace());
+    }
+
+    auto place = context.GetEigenDevice();
+    auto in = EigenVector::Flatten(*input);
+    for (auto out : outs) {
+      auto result = EigenVector::Flatten(*out);
+      result.device(place) = in;
+    }
+  }
+};
+
+}  // namespace operators
+}  // namespace paddle
diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc
index 109e62e8739f4b0cb0e2ba7d3f7cf2a2f5cbb9b7..be68c0930c849f969e58d6c786842acb99806eeb 100644
--- a/paddle/pybind/pybind.cc
+++ b/paddle/pybind/pybind.cc
@@ -52,6 +52,7 @@ USE_OP(pad);
 USE_CPU_ONLY_OP(scatter);
 USE_OP(top_k);
 USE_OP(squared_l2_distance);
+USE_OP(sum);
 
 namespace paddle {
 namespace framework {
@@ -217,7 +218,10 @@ All parameter, weight, gradient are variables in Paddle.
                -> std::map> {
                  return op.Outputs();
                })
+      .def("output_vars",
+           [](const OperatorBase &op) { return op.OutputVars(true); })
       .def("inputs", [](const OperatorBase &op) { return op.Inputs(); })
+      .def("input_vars", [](const OperatorBase &op) { return op.InputVars(); })
       .def("__str__", &OperatorBase::DebugString)
       .def("no_intermediate_outputs",
            [](const OperatorBase &op) { return op.OutputVars(false); })
diff --git a/proto/ModelConfig.proto b/proto/ModelConfig.proto
index 7d7fc23a4691646dfce4c162a445864c748501d9..ebf0911d6ea0b39d51447859ae2aef485b50b0e6 100644
--- a/proto/ModelConfig.proto
+++ b/proto/ModelConfig.proto
@@ -271,6 +271,7 @@ message ImageConfig {
   // The size of input feature map.
   required uint32 img_size = 8;
   optional uint32 img_size_y = 9;
+  optional uint32 img_size_z = 10 [ default = 1 ];
 }
 
 message PriorBoxConfig {
@@ -519,6 +520,7 @@ message LayerConfig {
   // for HuberRegressionLoss
   optional double delta = 57 [ default = 1.0 ];
 
+  // for 3D data
   optional uint64 depth = 58 [ default = 1 ];
 
   // for switch order layer
diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py
index 11dc84ae20679bb73735f9119739fca5ea7fa673..7e9112b43bf851575a3a798886d8b1b17e7c2017 100644
--- a/python/paddle/trainer/config_parser.py
+++ b/python/paddle/trainer/config_parser.py
@@ -1332,6 +1332,12 @@ def parse_image(image, input_layer_name, image_conf):
         get_img_size(input_layer_name, image_conf.channels)
 
 
+def parse_image3d(image, input_layer_name, image_conf):
+    image_conf.channels = image.channels
+    image_conf.img_size, image_conf.img_size_y, image_conf.img_size_z = \
+        get_img3d_size(input_layer_name, image_conf.channels)
+
+
 def parse_norm(norm, input_layer_name, norm_conf):
     norm_conf.norm_type = norm.norm_type
     config_assert(
@@ -2365,9 +2371,11 @@ class BatchNormLayer(LayerBase):
                  name,
                  inputs,
                  bias=True,
+                 img3D=False,
                  use_global_stats=True,
                  moving_average_fraction=0.9,
                  batch_norm_type=None,
+                 mean_var_names=None,
                  **xargs):
         if inputs is None:
             inputs = []
@@ -2409,24 +2417,69 @@ class BatchNormLayer(LayerBase):
 
         input_layer = self.get_input_layer(0)
         image_conf = self.config.inputs[0].image_conf
-        parse_image(self.inputs[0].image, input_layer.name, image_conf)
-
-        # Only pass the width and height of input to batch_norm layer
-        # when either of it is non-zero.
-        if input_layer.width != 0 or input_layer.height != 0:
-            self.set_cnn_layer(name, image_conf.img_size_y, image_conf.img_size,
-                               image_conf.channels, False)
+        if img3D:
+            parse_image3d(self.inputs[0].image, input_layer.name, image_conf)
+            # Only pass the width and height of input to batch_norm layer
+            # when either of it is non-zero.
+            if input_layer.width != 0 or input_layer.height != 0:
+                self.set_cnn_layer(
+                    input_layer_name=name,
+                    depth=image_conf.img_size_z,
+                    height=image_conf.img_size_y,
+                    width=image_conf.img_size,
+                    channels=image_conf.channels,
+                    is_print=True)
+            else:
+                self.set_layer_size(input_layer.size)
         else:
-            self.set_layer_size(input_layer.size)
+            parse_image(self.inputs[0].image, input_layer.name, image_conf)
+            # Only pass the width and height of input to batch_norm layer
+            # when either of it is non-zero.
+            if input_layer.width != 0 or input_layer.height != 0:
+                self.set_cnn_layer(
+                    input_layer_name=name,
+                    height=image_conf.img_size_y,
+                    width=image_conf.img_size,
+                    channels=image_conf.channels,
+                    is_print=True)
+            else:
+                self.set_layer_size(input_layer.size)
 
         psize = self.calc_parameter_size(image_conf)
         dims = [1, psize]
+        if mean_var_names is not None:
+            assert len(mean_var_names) == 2
+            self.inputs[1].parameter_name = mean_var_names[0]
+            self.inputs[2].parameter_name = mean_var_names[1]
+
         self.create_input_parameter(0, psize)
         self.create_input_parameter(1, psize, dims)
         self.create_input_parameter(2, psize, dims)
 
         self.create_bias_parameter(bias, psize)
 
+    def set_cnn_layer(self,
+                      input_layer_name,
+                      depth=None,
+                      height=None,
+                      width=None,
+                      channels=None,
+                      is_print=True):
+        depthIsNone = False
+        if depth is None:
+            depth = 1
+            depthIsNone = True
+        size = depth * height * width * channels
+        self.set_layer_size(size)
+        self.set_layer_height_width(height, width)
+        self.set_layer_depth(depth)
+        if is_print and depthIsNone:
+            print("output for %s: c = %d, h = %d, w = %d, size = %d" %
+                  (input_layer_name, channels, height, width, size))
+        elif is_print:
+            print("output for %s: c = %d, d = %d, h = %d, w = %d, size = %d" %
+                  (input_layer_name, channels, depth, height, width, size))
+
     def calc_parameter_size(self, image_conf):
         return image_conf.channels
 
@@ -2688,9 +2741,20 @@ class AddToLayer(LayerBase):
         super(AddToLayer, self).__init__(
             name, 'addto', 0, inputs=inputs, **xargs)
         config_assert(len(inputs) > 0, 'inputs cannot be empty for AddToLayer')
-        for input_index in xrange(len(self.inputs)):
-            input_layer = self.get_input_layer(input_index)
-            self.set_layer_size(input_layer.size)
+
+        if len(self.inputs) > 1:
+            for input_index in xrange(len(self.inputs)):
+                assert self.get_input_layer(0).height == self.get_input_layer(
+                    input_index).height
+                assert self.get_input_layer(0).width == self.get_input_layer(
+                    input_index).width
+                assert self.get_input_layer(0).depth == self.get_input_layer(
+                    input_index).depth
+
+        self.set_layer_size(self.get_input_layer(0).size)
+        self.set_layer_height_width(self.get_input_layer(0).height, \
+                                        self.get_input_layer(0).width)
+        self.set_layer_depth(self.get_input_layer(0).depth)
         self.create_bias_parameter(bias, self.config.size)
 
 
@@ -3370,11 +3434,20 @@ class ConcatenateLayer(LayerBase):
             name, 'concat', 0, inputs=inputs, **xargs)
         size = 0
         for input_index in xrange(len(self.inputs)):
+            assert self.get_input_layer(0).height == self.get_input_layer(
+                input_index).height
+            assert self.get_input_layer(0).width == self.get_input_layer(
+                input_index).width
+            assert self.get_input_layer(0).depth == self.get_input_layer(
+                input_index).depth
             input_layer = self.get_input_layer(input_index)
             input = self.inputs[input_index]
             if self.config.size == 0:
                 size += input_layer.size
 
+        self.set_layer_height_width(self.get_input_layer(0).height, \
+                                    self.get_input_layer(0).width)
+        self.set_layer_depth(self.get_input_layer(0).depth)
         self.set_layer_size(size)
 
 
diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py
index cba45bd3afa178ab4dd3a50f0947b144e7466e53..dc68c213da66ac680e6b14266cb5038a5ba73ec2 100644
--- a/python/paddle/trainer_config_helpers/layers.py
+++ b/python/paddle/trainer_config_helpers/layers.py
@@ -354,6 +354,10 @@ class LayerOutput(object):
     def height(self):
         return cp.g_layer_map[self.full_name].height
 
+    @property
+    def depth(self):
+        return cp.g_layer_map[self.full_name].depth
+
     def set_input(self, input):
         """
         Set the input for a memory layer. Can only be used for memory layer
@@ -943,7 +947,7 @@ def data_layer(name, size, depth=None, height=None, width=None,
     if height is not None and width is not None:
         num_filters = size / (width * height * depth)
         assert num_filters * width * height * depth == size, \
-                "size=%s width=%s height=%s depth=%s"  % (size, width, height, depth)
+                "size=%s width=%s height=%s depth=%s" % (size, width, height, depth)
 
     return LayerOutput(name, LayerType.DATA, size=size, num_filters=num_filters)
 
@@ -2953,13 +2957,15 @@ def img_cmrnorm_layer(input,
 def batch_norm_layer(input,
                      act=None,
                      name=None,
+                     img3D=False,
                      num_channels=None,
                      bias_attr=None,
                      param_attr=None,
                      layer_attr=None,
                      batch_norm_type=None,
                      moving_average_fraction=0.9,
-                     use_global_stats=None):
+                     use_global_stats=None,
+                     mean_var_names=None):
     """
     Batch Normalization Layer. The notation of this layer as follow.
 
@@ -3026,6 +3032,8 @@ def batch_norm_layer(input,
                                    :math:`runningMean = newMean*(1-factor)
                                    + runningMean*factor`
     :type moving_average_fraction: float.
+    :param mean_var_names: [mean name, variance name]
+    :type mean_var_names: string list
     :return: LayerOutput object.
     :rtype: LayerOutput
     """
@@ -3039,6 +3047,7 @@ def batch_norm_layer(input,
            (batch_norm_type == "cudnn_batch_norm")
     l = Layer(
         name=name,
+        img3D=img3D,
         inputs=Input(
             input.name, image=Image(channels=num_channels), **param_attr.attr),
         active_type=act.name,
@@ -3047,6 +3056,7 @@ def batch_norm_layer(input,
         bias=ParamAttr.to_bias(bias_attr),
         moving_average_fraction=moving_average_fraction,
         use_global_stats=use_global_stats,
+        mean_var_names=mean_var_names,
         **ExtraLayerAttribute.to_kwargs(layer_attr))
 
     return LayerOutput(
@@ -6410,7 +6420,7 @@ def gated_unit_layer(input,
 @wrap_name_default('switch_order')
 def switch_order_layer(input,
                        name=None,
-                       reshape=None,
+                       reshape_axis=None,
                        act=None,
                        layer_attr=None):
     """
@@ -6421,8 +6431,9 @@ def switch_order_layer(input,
     The example usage is:
 
     .. code-block:: python
+       reshape_axis = 3
+       switch = switch_order(input=layer, name='switch', reshape_axis=reshape_axis)
        reshape = {'height':[ 0, 1, 2], 'width':[3]}
-       switch = switch_order(input=layer, name='switch', reshape=reshape)
 
     :param input: The input layer.
     :type input: LayerOutput
@@ -6434,6 +6445,11 @@ def switch_order_layer(input,
     :rtype: LayerOutput
     """
     assert isinstance(input, LayerOutput)
+    assert reshape_axis != None and (reshape_axis > 0 and reshape_axis < 4)
+    height = [ele for ele in xrange(reshape_axis)]
+    width = [ele for ele in range(reshape_axis, 4)]
+    reshape = {'height': height, 'width': width}
+
     l = Layer(
         name=name,
         inputs=input.name,
diff --git a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh
index df872a90ff388f0d96cef44763dbd076bc768ab9..8a204a96f3ef57673cef65306d0bf8e8c3409751 100755
--- a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh
+++ b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh
@@ -10,6 +10,6 @@ test_prelu_layer test_row_conv test_detection_output_layer test_multibox_loss_la
 test_recursive_topology test_gated_unit_layer test_clip_layer test_row_l2_norm_layer
 test_kmax_seq_socre_layer test_sub_nested_seq_select_layer test_scale_shift_layer
 test_seq_slice_layer test_cross_entropy_over_beam test_pooling3D_layer
-test_conv3d_layer test_deconv3d_layer)
+test_conv3d_layer test_deconv3d_layer test_BatchNorm3D)
 
 export whole_configs=(test_split_datasource)
diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/img_layers.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/img_layers.protostr
index 1a577b8d9b1e1915236ba6afcfa97040d70c707a..5ddf6052df021b055390a42c25ce6c0d650e4aee 100644
--- a/python/paddle/trainer_config_helpers/tests/configs/protostr/img_layers.protostr
+++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/img_layers.protostr
@@ -62,6 +62,7 @@ layers {
   moving_average_fraction: 0.9
   height: 227
   width: 227
+  depth: 1
 }
 layers {
   name: "__crmnorm_0__"
diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/img_trans_layers.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/img_trans_layers.protostr
index 2818389b16cca75f5030b75fc4de8c89c06c5e02..c0252b945b4c7fd6b4dad8770e3e1dccb88df28a 100644
--- a/python/paddle/trainer_config_helpers/tests/configs/protostr/img_trans_layers.protostr
+++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/img_trans_layers.protostr
@@ -62,6 +62,7 @@ layers {
   moving_average_fraction: 0.9
   height: 256
   width: 256
+  depth: 1
 }
 layers {
   name: "__crmnorm_0__"
diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_BatchNorm3D.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_BatchNorm3D.protostr
new file mode 100644
index 0000000000000000000000000000000000000000..832ed24a31dd2bedba9a4fce77d7a088d1796fdb
--- /dev/null
+++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_BatchNorm3D.protostr
@@ -0,0 +1,92 @@
+type: "nn"
+layers {
+  name: "data3D"
+  type: "data"
+  size: 360
+  active_type: ""
+  height: 6
+  width: 20
+  depth: 3
+}
+layers {
+  name: "__batch_norm_0__"
+  type: "batch_norm"
+  size: 360
+  active_type: "relu"
+  inputs {
+    input_layer_name: "data3D"
+    input_parameter_name: "___batch_norm_0__.w0"
+    image_conf {
+      channels: 1
+      img_size: 20
+      img_size_y: 6
+      img_size_z: 3
+    }
+  }
+  inputs {
+    input_layer_name: "data3D"
+    input_parameter_name: "___batch_norm_0__.w1"
+  }
+  inputs {
+    input_layer_name: "data3D"
+    input_parameter_name: "___batch_norm_0__.w2"
+  }
+  bias_parameter_name: "___batch_norm_0__.wbias"
+  moving_average_fraction: 0.9
+  height: 6
+  width: 20
+  depth: 3
+}
+parameters {
+  name: "___batch_norm_0__.w0"
+  size: 1
+  initial_mean: 1.0
+  initial_std: 0.0
+  initial_strategy: 0
+  initial_smart: false
+}
+parameters {
+  name: "___batch_norm_0__.w1"
+  size: 1
+  initial_mean: 0.0
+  initial_std: 0.0
+  dims: 1
+  dims: 1
+  initial_strategy: 0
+  initial_smart: false
+  is_static: true
+  is_shared: true
+}
+parameters {
+  name: "___batch_norm_0__.w2"
+  size: 1
+  initial_mean: 0.0
+  initial_std: 0.0
+  dims: 1
+  dims: 1
+  initial_strategy: 0
+  initial_smart: false
+  is_static: true
+  is_shared: true
+}
+parameters {
+  name: "___batch_norm_0__.wbias"
+  size: 1
+  initial_mean: 0.0
+  initial_std: 0.0
+  dims: 1
+  dims: 1
+  initial_strategy: 0
+  initial_smart: false
+}
+input_layer_names: "data3D"
+output_layer_names: "__batch_norm_0__"
+sub_models {
+  name: "root"
+  layer_names: "data3D"
+  layer_names: "__batch_norm_0__"
+  input_layer_names: "data3D"
+  output_layer_names: "__batch_norm_0__"
+  is_recurrent_layer_group: false
+}
+
diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_bi_grumemory.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_bi_grumemory.protostr
index b110e91498ce7d112987714bd769868179141c54..8a1399efad0ff339e35f69400ac654a4787a6018 100644
--- a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_bi_grumemory.protostr
+++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_bi_grumemory.protostr
@@ -74,6 +74,9 @@ layers {
   inputs {
     input_layer_name: "__bidirectional_gru_0___bw"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 parameters {
   name: "___bidirectional_gru_0___fw_transform.w0"
diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_recursive_topology.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_recursive_topology.protostr
index 8133aa9c8d3e7c6843d1b27b70e87d394a1e0e47..046037936a6d85f54095c65f206e468aa69065d7 100644
--- a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_recursive_topology.protostr
+++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_recursive_topology.protostr
@@ -16,6 +16,9 @@ layers {
   inputs {
     input_layer_name: "data"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_1__"
@@ -28,6 +31,9 @@ layers {
   inputs {
     input_layer_name: "__addto_0__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_2__"
@@ -40,6 +46,9 @@ layers {
   inputs {
     input_layer_name: "__addto_1__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_3__"
@@ -52,6 +61,9 @@ layers {
   inputs {
     input_layer_name: "__addto_2__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_4__"
@@ -64,6 +76,9 @@ layers {
   inputs {
     input_layer_name: "__addto_3__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_5__"
@@ -76,6 +91,9 @@ layers {
   inputs {
     input_layer_name: "__addto_4__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_6__"
@@ -88,6 +106,9 @@ layers {
   inputs {
     input_layer_name: "__addto_5__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_7__"
@@ -100,6 +121,9 @@ layers {
   inputs {
     input_layer_name: "__addto_6__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_8__"
@@ -112,6 +136,9 @@ layers {
   inputs {
     input_layer_name: "__addto_7__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_9__"
@@ -124,6 +151,9 @@ layers {
   inputs {
     input_layer_name: "__addto_8__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_10__"
@@ -136,6 +166,9 @@ layers {
   inputs {
     input_layer_name: "__addto_9__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_11__"
@@ -148,6 +181,9 @@ layers {
   inputs {
     input_layer_name: "__addto_10__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_12__"
@@ -160,6 +196,9 @@ layers {
   inputs {
     input_layer_name: "__addto_11__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_13__"
@@ -172,6 +211,9 @@ layers {
   inputs {
     input_layer_name: "__addto_12__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_14__"
@@ -184,6 +226,9 @@ layers {
   inputs {
     input_layer_name: "__addto_13__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_15__"
@@ -196,6 +241,9 @@ layers {
   inputs {
     input_layer_name: "__addto_14__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_16__"
@@ -208,6 +256,9 @@ layers {
   inputs {
     input_layer_name: "__addto_15__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_17__"
@@ -220,6 +271,9 @@ layers {
   inputs {
     input_layer_name: "__addto_16__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_18__"
@@ -232,6 +286,9 @@ layers {
   inputs {
     input_layer_name: "__addto_17__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_19__"
@@ -244,6 +301,9 @@ layers {
   inputs {
     input_layer_name: "__addto_18__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_20__"
@@ -256,6 +316,9 @@ layers {
   inputs {
     input_layer_name: "__addto_19__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_21__"
@@ -268,6 +331,9 @@ layers {
   inputs {
     input_layer_name: "__addto_20__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_22__"
@@ -280,6 +346,9 @@ layers {
   inputs {
     input_layer_name: "__addto_21__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_23__"
@@ -292,6 +361,9 @@ layers {
   inputs {
     input_layer_name: "__addto_22__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_24__"
@@ -304,6 +376,9 @@ layers {
   inputs {
     input_layer_name: "__addto_23__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_25__"
@@ -316,6 +391,9 @@ layers {
   inputs {
     input_layer_name: "__addto_24__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_26__"
@@ -328,6 +406,9 @@ layers {
   inputs {
     input_layer_name: "__addto_25__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_27__"
@@ -340,6 +421,9 @@ layers {
   inputs {
     input_layer_name: "__addto_26__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_28__"
@@ -352,6 +436,9 @@ layers {
   inputs {
     input_layer_name: "__addto_27__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_29__"
@@ -364,6 +451,9 @@ layers {
   inputs {
     input_layer_name: "__addto_28__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_30__"
@@ -376,6 +466,9 @@ layers {
   inputs {
     input_layer_name: "__addto_29__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__addto_31__"
@@ -388,6 +481,9 @@ layers {
   inputs {
     input_layer_name: "__addto_30__"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__fc_layer_0__"
diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/util_layers.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/util_layers.protostr
index d0ad388165007b8f96f059e5b003c52f756383e5..7a2f3eab38808a031c27cf7ab9d6273952e389eb 100644
--- a/python/paddle/trainer_config_helpers/tests/configs/protostr/util_layers.protostr
+++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/util_layers.protostr
@@ -22,6 +22,9 @@ layers {
   inputs {
     input_layer_name: "b"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__concat_0__"
@@ -34,6 +37,9 @@ layers {
   inputs {
     input_layer_name: "b"
   }
+  height: 0
+  width: 0
+  depth: 1
 }
 layers {
   name: "__concat_1__"
diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_BatchNorm3D.py b/python/paddle/trainer_config_helpers/tests/configs/test_BatchNorm3D.py
new file mode 100644
index 0000000000000000000000000000000000000000..a991b22252ba10eed895efd931108c2d8b0e52f1
--- /dev/null
+++ b/python/paddle/trainer_config_helpers/tests/configs/test_BatchNorm3D.py
@@ -0,0 +1,11 @@
+from paddle.trainer_config_helpers import *
+
+settings(batch_size=1000, learning_rate=1e-4)
+
+#data = data_layer(name='data', size=180, width=30, height=6)
+#batchNorm = batch_norm_layer(data, num_channels=1)
+#outputs(batchNorm)
+
+data3D = data_layer(name='data3D', size=120 * 3, width=20, height=6, depth=3)
+batchNorm3D = batch_norm_layer(data3D, num_channels=1, img3D=True)
+outputs(batchNorm3D)
diff --git a/python/paddle/v2/framework/op.py b/python/paddle/v2/framework/op.py
index 78c64e261b7eea07c902743ff86a552c4b7dc355..8109227828bafda85d3a556dda928acd7c1fc94c 100644
--- a/python/paddle/v2/framework/op.py
+++ b/python/paddle/v2/framework/op.py
@@ -142,8 +142,8 @@ def create_op_creation_method(op_proto):
     return OpInfo(
         method=__impl__,
         name=op_proto.type,
-        inputs=[var.name for var in op_proto.inputs],
-        outputs=[var.name for var in op_proto.outputs],
+        inputs=[(var.name, var.duplicable) for var in op_proto.inputs],
+        outputs=[(var.name, var.duplicable) for var in op_proto.outputs],
         attrs=[attr.name for attr in op_proto.attrs])
 
 
@@ -180,9 +180,15 @@ class OperatorFactory(object):
         return self.op_methods.get(t)
 
     def get_op_input_names(self, type):
+        return map(lambda x: x[0], self.get_op_info(type).inputs)
+
+    def get_op_inputs(self, type):
         return self.get_op_info(type).inputs
 
     def get_op_output_names(self, type):
+        return map(lambda x: x[0], self.get_op_info(type).outputs)
+
+    def get_op_outputs(self, type):
         return self.get_op_info(type).outputs
 
     def get_op_attr_names(self, type):
diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt
index ef910f939be0b9d3cb5e6d49e69a00daa191b1c6..2117fdf0d58520a008d2bd01d56d96dd248be025 100644
--- a/python/paddle/v2/framework/tests/CMakeLists.txt
+++ b/python/paddle/v2/framework/tests/CMakeLists.txt
@@ -33,5 +33,6 @@ py_test(test_sgd_op SRCS test_sgd_op.py)
 py_test(test_gradient_checker SRCS test_gradient_checker.py)
 py_test(test_lookup_table SRCS test_lookup_table.py)
 py_test(test_scale_and_identity_op SRCS test_scale_and_identity_op.py)
+py_test(test_sum_op SRCS test_sum_op.py)
 py_test(mnist SRCS mnist.py)
 py_test(test_squared_l2_distance_op SRCS test_squared_l2_distance_op.py)
diff --git a/python/paddle/v2/framework/tests/op_test.py b/python/paddle/v2/framework/tests/op_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a6a5dca4c4ddc1399d80e491e4072f24707c01e
--- /dev/null
+++ b/python/paddle/v2/framework/tests/op_test.py
@@ -0,0 +1,275 @@
+import unittest
+import numpy as np
+import itertools
+import paddle.v2.framework.core as core
+from paddle.v2.framework.op import Operator
+
+
+def grad_var_name(var_name):
+    return var_name + "@GRAD"
+
+
+def create_op(scope, op_type, inputs, outputs, attrs=None):
+    kwargs = dict()
+
+    for in_name, in_dup in Operator.get_op_inputs(op_type):
+        if in_name in inputs:
+            kwargs[in_name] = []
+            if in_dup:
+                sub_in = inputs[in_name]
+                for sub_in_name in sub_in:
+                    var = scope.new_var(sub_in_name)
+                    kwargs[in_name].append(sub_in_name)
+            else:
+                var = scope.new_var(in_name)
+                kwargs[in_name].append(in_name)
+
+    for out_name, out_dup in Operator.get_op_outputs(op_type):
+        if out_name in outputs:
+            kwargs[out_name] = []
+            if out_dup:
+                sub_in = outputs[out_name]
+                for sun_in_name in sub_in:
+                    var = scope.new_var(sun_in_name)
+                    kwargs[out_name].append(sun_in_name)
+            else:
+                var = scope.new_var(out_name)
+                kwargs[out_name].append(out_name)
+
+    for attr_name in Operator.get_op_attr_names(op_type):
+        kwargs[attr_name] = attrs[attr_name]
+    return Operator(op_type, **kwargs)
+
+
+def set_input(scope, op, inputs, place):
+    for in_name, in_dup in Operator.get_op_inputs(op.type()):
+        if in_name in inputs:
+            if in_dup:
+                sub_in = inputs[in_name]
+                for sub_in_name in sub_in:
+                    var = scope.find_var(sub_in_name)
+                    tensor = var.get_tensor()
+                    arr = sub_in[sub_in_name]
+                    tensor.set_dims(arr.shape)
+                    tensor.set(arr, place)
+            else:
+                var = scope.find_var(in_name)
+                tensor = var.get_tensor()
+                arr = inputs[in_name]
+                tensor.set_dims(arr.shape)
+                tensor.set(arr, place)
+
+
+def set_output_grad(scope, op, outputs, place):
+    for out_name, out_dup in Operator.get_op_outputs(op.type()):
+        if out_name in outputs:
+            if out_dup:
+                sub_out = outputs[out_name]
+                for sub_out_name in sub_out:
+                    out_tensor = scope.find_var(sub_out_name).get_tensor()
+                    grad_tensor = scope.new_var(grad_var_name(
+                        sub_out_name)).get_tensor()
+                    grad_tensor.set_dims(out_tensor.shape())
+                    data = np.ones(out_tensor.shape(), dtype=np.float32)
+                    grad_tensor.set(data, place)
+            else:
+                out_tensor = scope.find_var(out_name).get_tensor()
+                grad_tensor = scope.new_var(grad_var_name(out_name)).get_tensor(
+                )
+                grad_tensor.set_dims(out_tensor.shape())
+                data = np.ones(out_tensor.shape(), dtype=np.float32)
+                grad_tensor.set(data, place)
+
+
+def get_numeric_gradient(scope,
+                         op,
+                         inputs,
+                         input_to_check,
+                         output_name,
+                         delta=0.005,
+                         in_place=False):
+
+    set_input(scope, op, inputs, core.CPUPlace())
+    op.infer_shape(scope)
+
+    tensor_to_check = scope.find_var(input_to_check).get_tensor()
+
+    def product(dim):
+        return reduce(lambda a, b: a * b, dim, 1)
+
+    ctx = core.DeviceContext.create(core.CPUPlace())
+
+    def get_output():
+        op.run(scope, ctx)
+        return np.array(scope.find_var(output_name).get_tensor()).sum()
+
+    tensor_to_check = scope.find_var(input_to_check).get_tensor()
+    tensor_size = product(tensor_to_check.get_dims())
+    gradient_flat = np.zeros(shape=(tensor_size, ), dtype='float32')
+    # we only compute gradient of one element each time.
+    # we use a for loop to compute the gradient of every element.
+    for i in xrange(tensor_size):
+        if in_place:
+            set_input(op, inputs, core.CPUPlace())
+
+        # get one input element throw it's index i.
+        origin = tensor_to_check.get_float_element(i)
+        # add delta to it, run op and then get the sum of the result tensor.
+        x_pos = origin + delta
+        tensor_to_check.set_float_element(i, x_pos)
+        y_pos = get_output()
+
+        if in_place:
+            set_input(op, inputs, core.CPUPlace())
+
+        x_neg = origin - delta
+        tensor_to_check.set_float_element(i, x_neg)
+        y_neg = get_output()
+
+        tensor_to_check.set_float_element(i, origin)
+        gradient_flat[i] = (y_pos - y_neg) / delta / 2
+
+    return gradient_flat.reshape(tensor_to_check.get_dims())
+
+
+def get_backward_op(scope, op, no_grad_set):
+    backward_op = core.Operator.backward(op, no_grad_set)
+    for input in backward_op.input_vars():
+        var = scope.new_var(input)
+        var.get_tensor()
+    for output in backward_op.output_vars():
+        var = scope.new_var(output)
+        var.get_tensor()
+    return backward_op
+
+
+def get_gradient(scope, op, inputs, outputs, grad_name, place,
+                 no_grad_set=None):
+    ctx = core.DeviceContext.create(place)
+
+    set_input(scope, op, inputs, place)
+
+    op.infer_shape(scope)
+    op.run(scope, ctx)
+
+    if no_grad_set is None:
+        no_grad_set = set()
+
+    backward_op = get_backward_op(scope, op, no_grad_set)
+    set_output_grad(scope, op, outputs, place)
+
+    backward_op.infer_shape(scope)
+    backward_op.run(scope, ctx)
+
+    out = np.array(scope.find_var(grad_name).get_tensor())
+    return out
+
+
+class OpTest(unittest.TestCase):
+    def check_output_with_place(self, place):
+        self.scope = core.Scope()
+        self.op = create_op(self.scope, self.op_type, self.inputs, self.outputs)
+        if isinstance(place, core.GPUPlace) and not self.op.support_gpu():
+            return
+        set_input(self.scope, self.op, self.inputs, place)
+        self.op.infer_shape(self.scope)
+        ctx = core.DeviceContext.create(place)
+        self.op.run(self.scope, ctx)
+
+        for out_name, out_dup in Operator.get_op_outputs(self.op.type()):
+            if out_dup:
+                sub_out = self.outputs[out_name]
+                for sub_out_name in sub_out:
+                    actual = np.array(
+                        self.scope.find_var(sub_out_name).get_tensor())
+                    expect = sub_out[sub_out_name]
+                    self.assertTrue(
+                        np.allclose(
+                            actual, expect, atol=1e-05),
+                        "output name: " + out_name + "has diff")
+            else:
+                actual = np.array(self.scope.find_var(out_name).get_tensor())
+                expect = self.outputs[out_name]
+                self.assertTrue(
+                    np.allclose(
+                        actual, expect, atol=1e-05),
+                    "output name: " + out_name + "has diff")
+
+    def check_output(self):
+        places = [core.CPUPlace()]
+        if core.is_compile_gpu():
+            places.append(core.GPUPlace(0))
+        for place in places:
+            self.check_output_with_place(place)
+
+    def __assert_is_close(self, numeric_grads, analytic_grads, names,
+                          max_relative_error, msg_prefix):
+
+        for a, b, name in itertools.izip(numeric_grads, analytic_grads, names):
+            abs_a = np.abs(a)
+            abs_a[abs_a < 1e-3] = 1
+
+            diff_mat = np.abs(a - b) / abs_a
+            max_diff = np.max(diff_mat)
+
+            def err_msg():
+                offset = np.argmax(diff_mat > max_relative_error)
+                return "%s Variable %s max gradient diff %f over limit %f, the first " \
+                  "error element is %d" % (
+                   msg_prefix, name, max_diff, max_relative_error, offset)
+
+            self.assertLessEqual(max_diff, max_relative_error, err_msg())
+
+    def check_grad(self,
+                   inputs_to_check,
+                   output_name,
+                   no_grad_set=None,
+                   in_place=False,
+                   max_relative_error=0.005):
+        self.scope = core.Scope()
+        self.op = create_op(self.scope, self.op_type, self.inputs, self.outputs)
+        if no_grad_set is None:
+            no_grad_set = set()
+
+        numeric_grads = [
+            get_numeric_gradient(
+                self.scope,
+                self.op,
+                self.inputs,
+                input_to_check,
+                output_name,
+                in_place=in_place) for input_to_check in inputs_to_check
+        ]
+        grad_names = [
+            grad_var_name(input_to_check) for input_to_check in inputs_to_check
+        ]
+
+        cpu_place = core.CPUPlace()
+        cpu_analytic_grads = [
+            get_gradient(self.scope, self.op, self.inputs, self.outputs,
+                         grad_name, cpu_place, no_grad_set)
+            for grad_name in grad_names
+        ]
+
+        self.__assert_is_close(numeric_grads, cpu_analytic_grads, grad_names,
+                               max_relative_error,
+                               "Gradient Check On %s" % str(cpu_place))
+
+        if core.is_compile_gpu() and self.op.support_gpu():
+            gpu_place = core.GPUPlace(0)
+            gpu_analytic_grads = [
+                get_gradient(self.scope, self.op, self.inputs, self.outputs,
+                             grad_name, gpu_place, no_grad_set)
+                for grad_name in grad_names
+            ]
+
+            self.__assert_is_close(numeric_grads, gpu_analytic_grads,
+                                   grad_names, max_relative_error,
+                                   "Gradient Check On %s" % str(gpu_place))
+
+            for c_grad, g_grad, name in itertools.izip(
+                    cpu_analytic_grads, gpu_analytic_grads, grad_names):
+                self.assertTrue(
+                    np.allclose(
+                        c_grad, g_grad, atol=1e-4),
+                    "output name: " + name + " has diff")
diff --git a/python/paddle/v2/framework/tests/test_cross_entropy_op.py b/python/paddle/v2/framework/tests/test_cross_entropy_op.py
index d4277f2a42ce2e66e37405ccd3b2ee444d403d1a..fb6a440e23c26d1766bdf1fc5f24217afe1150f8 100644
--- a/python/paddle/v2/framework/tests/test_cross_entropy_op.py
+++ b/python/paddle/v2/framework/tests/test_cross_entropy_op.py
@@ -1,36 +1,27 @@
 import unittest
 import numpy
-from op_test_util import OpTestMeta
-from gradient_checker import GradientChecker, create_op
+from op_test import OpTest
 
 
-class TestCrossEntropy(unittest.TestCase):
-    __metaclass__ = OpTestMeta
-
+class TestCrossEntropy(OpTest):
     def setUp(self):
-        self.type = "onehot_cross_entropy"
+        self.op_type = "onehot_cross_entropy"
         batch_size = 30
         class_num = 10
-        X = numpy.random.random((batch_size, class_num)).astype("float32")
-        label = 5 * numpy.ones(batch_size).astype("int32")
+        X = numpy.random.uniform(0.1, 1.0,
+                                 [batch_size, class_num]).astype("float32")
+        label = (class_num / 2) * numpy.ones(batch_size).astype("int32")
         self.inputs = {'X': X, 'label': label}
         Y = []
         for i in range(0, batch_size):
             Y.append(-numpy.log(X[i][label[i]]))
         self.outputs = {'Y': numpy.array(Y).astype("float32")}
 
+    def test_check_output(self):
+        self.check_output()
 
-class CrossEntropyGradOpTest(GradientChecker):
     def test_check_grad(self):
-        op = create_op("onehot_cross_entropy")
-        batch_size = 30
-        class_num = 10
-        inputs = {
-            "X": numpy.random.uniform(
-                0.1, 1.0, [batch_size, class_num]).astype("float32"),
-            "label": (class_num / 2) * numpy.ones(batch_size).astype("int32")
-        }
-        self.check_grad(op, inputs, set("X"), "Y")
+        self.check_grad(["X"], "Y")
 
 
 if __name__ == "__main__":
diff --git a/python/paddle/v2/framework/tests/test_lookup_table.py b/python/paddle/v2/framework/tests/test_lookup_table.py
index 19eb464baa555fb67a994f3cfb4d3ed628367c73..4b7ce92c0f0492a73c158378299933a0b329948b 100644
--- a/python/paddle/v2/framework/tests/test_lookup_table.py
+++ b/python/paddle/v2/framework/tests/test_lookup_table.py
@@ -4,7 +4,7 @@ from op_test_util import OpTestMeta
 from gradient_checker import GradientChecker, create_op
 
 
-class TestSigmoidOp(unittest.TestCase):
+class TestLookupTableOp(unittest.TestCase):
     __metaclass__ = OpTestMeta
 
     def setUp(self):
@@ -15,7 +15,7 @@ class TestSigmoidOp(unittest.TestCase):
         self.outputs = {'Out': table[ids]}
 
 
-class TestSigmoidGradOp(GradientChecker):
+class TestLookupTableGradOp(GradientChecker):
     def test_grad(self):
         op = create_op('lookup_table')
         table = np.random.random((17, 31)).astype('float32')
diff --git a/python/paddle/v2/framework/tests/test_sigmoid_op.py b/python/paddle/v2/framework/tests/test_sigmoid_op.py
index 273c2e5ab1a84d12621fe9568c4cf22073b6aed4..2316e49eff7bb1cdb53acb3889a6ef05060b59f3 100644
--- a/python/paddle/v2/framework/tests/test_sigmoid_op.py
+++ b/python/paddle/v2/framework/tests/test_sigmoid_op.py
@@ -1,27 +1,21 @@
 import unittest
 import numpy as np
-from op_test_util import OpTestMeta
-from gradient_checker import GradientChecker, create_op
+from op_test import OpTest
 
 
-class TestSigmoidOp(unittest.TestCase):
-    __metaclass__ = OpTestMeta
-
+class TestSigmoid(OpTest):
     def setUp(self):
-        self.type = "sigmoid"
-        self.inputs = {'X': np.random.random((15, 31)).astype("float32")}
+        self.op_type = "sigmoid"
+        self.inputs = {
+            'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32")
+        }
         self.outputs = {'Y': 1 / (1 + np.exp(-self.inputs['X']))}
 
+    def test_check_output(self):
+        self.check_output()
 
-class TestSigmoidGradOp(GradientChecker):
-    def test_grad(self):
-        op = create_op("sigmoid")
-        inputs = {"X": np.random.uniform(0.1, 1, [11, 17]).astype("float32")}
-        # compare gpu and cpu results for backward op.
-        # this test will be skiped if only compiling CPU version.
-        self.compare_grad(op, inputs)
-        # check gradients 
-        self.check_grad(op, inputs, set("X"), "Y", max_relative_error=0.007)
+    def test_check_grad(self):
+        self.check_grad(["X"], "Y", max_relative_error=0.007)
 
 
 if __name__ == '__main__':
diff --git a/python/paddle/v2/framework/tests/test_sum_op.py b/python/paddle/v2/framework/tests/test_sum_op.py
new file mode 100644
index 0000000000000000000000000000000000000000..66417d70e81186465e6f59a17fb62255afeddea5
--- /dev/null
+++ b/python/paddle/v2/framework/tests/test_sum_op.py
@@ -0,0 +1,24 @@
+import unittest
+import numpy as np
+from op_test import OpTest
+
+
+class TestSumOp(OpTest):
+    def setUp(self):
+        self.op_type = "sum"
+        x0 = np.random.random((3, 4)).astype('float32')
+        x1 = np.random.random((3, 4)).astype('float32')
+        x2 = np.random.random((3, 4)).astype('float32')
+        self.inputs = {"X": {"x0": x0, "x1": x1, "x2": x2}}
+        y = x0 + x1 + x2
+        self.outputs = {'Out': y}
+
+    def test_check_output(self):
+        self.check_output()
+
+    def test_check_grad(self):
+        self.check_grad(["x0"], "Out")
+
+
+if __name__ == '__main__':
+    unittest.main()