diff --git a/paddle/function/CMakeLists.txt b/paddle/function/CMakeLists.txt
index 1b3068b8ffcad511f69666c77f10baf5506c3413..9b2779b42cad324253dadf27dbff20fd8e8c8e16 100644
--- a/paddle/function/CMakeLists.txt
+++ b/paddle/function/CMakeLists.txt
@@ -45,7 +45,7 @@ if(WITH_GPU)
     add_simple_unittest(BlockExpandOpTest)
     add_simple_unittest(CropOpTest)
     add_simple_unittest(SwitchOpTest)
-    add_simple_unittest(MulValueOpTest)
+    add_simple_unittest(ScaleSubRegionOpTest)
 endif()
 
 add_simple_unittest(Im2ColTest)
diff --git a/paddle/function/ScaleSubRegionOp.cpp b/paddle/function/ScaleSubRegionOp.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..a080505d7df83a6c0a9d88fbcb7863fc0e1f7b21
--- /dev/null
+++ b/paddle/function/ScaleSubRegionOp.cpp
@@ -0,0 +1,155 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "ScaleSubRegionOp.h"
+#include "paddle/function/TensorShape.h"
+
+namespace paddle {
+
+template <>
+void ScaleSubRegion<DEVICE_TYPE_CPU>(real* outputs,
+                                     const real* inputs,
+                                     const real* indices,
+                                     const TensorShape shape,
+                                     const FuncConfig& conf) {
+  real value = conf.get<real>("value");
+
+  int number = shape[0];
+  int channel = shape[1];
+  int height = shape[2];
+  int width = shape[3];
+
+  memcpy(outputs, inputs, number * channel * height * width * sizeof(real));
+
+  for (int n = 0; n < number; ++n) {
+    // indices start from 1
+    int offset = n * 6;
+    for (int c = indices[offset] - 1; c < indices[offset + 1]; ++c) {
+      for (int h = indices[offset + 2] - 1; h < indices[offset + 3]; ++h) {
+        for (int w = indices[offset + 4] - 1; w < indices[offset + 5]; ++w) {
+          int idx = ((n * channel + c) * height + h) * width + w;
+          outputs[idx] *= value;
+        }
+      }
+    }
+  }
+}
+
+template <>
+void ScaleSubRegionGrad<DEVICE_TYPE_CPU>(const real* inGrad,
+                                         real* outGrad,
+                                         const real* indices,
+                                         const TensorShape shape,
+                                         const FuncConfig& conf) {
+  real value = conf.get<real>("value");
+
+  int number = shape[0];
+  int channel = shape[1];
+  int height = shape[2];
+  int width = shape[3];
+
+  for (int n = 0; n < number; ++n) {
+    for (int c = 0; c < channel; ++c) {
+      for (int h = 0; h < height; ++h) {
+        for (int w = 0; w < width; ++w) {
+          int idx = ((n * channel + c) * height + h) * width + w;
+          int offset = n * 6;
+          if (c >= (indices[offset] - 1) && c <= (indices[offset + 1] - 1) &&
+              h >= (indices[offset + 2] - 1) &&
+              h <= (indices[offset + 3] - 1) &&
+              w >= (indices[offset + 4] - 1) &&
+              w <= (indices[offset + 5] - 1)) {
+            outGrad[idx] += inGrad[idx] * value;
+          } else {
+            outGrad[idx] += inGrad[idx];
+          }
+        }
+      }
+    }
+  }
+}
+
+/**
+ * \brief For each instance, ScaleSubRegion can be used to multiply a value to
+ *        a specified sub continuous region. By providing start index and end
+ *        index for C/H/W, you can specify the location and shape of the region.
+ *
+ * Argument in this Function:
+ * \param inputs    A 4-D tensor with shape [N, C, H, W], only one input.
+ * \param indices   A 2-D tensor with shape [N, 6], indicates the sub region.
+ * \param outputs   A 4-D tensor with same shape as inputs, output value.
+ */
+template <DeviceType Device>
+class ScaleSubRegionFunc : public FunctionBase {
+public:
+  void init(const FuncConfig& config) override { conf_ = config; }
+
+  void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
+    CHECK_EQ(2UL, inputs.size());
+    CHECK_EQ(1UL, outputs.size());
+    CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO);
+
+    TensorShape shape = inputs[0].shape();
+
+    ScaleSubRegion<Device>(outputs[0].data<real>(),
+                           inputs[0].data<real>(),
+                           inputs[1].data<real>(),
+                           shape,
+                           conf_);
+  }
+
+private:
+  FuncConfig conf_;
+};
+
+/**
+ * \brief The backward propagation of ScaleSubRegion Function.
+ *
+ * Argument in this Function:
+ * \param inputs  A 4-D tensor with shape [N, C, H, W], output gradient.
+ * \param indices A 2-D tensor with shape [N, 6], indicates the sub region.
+ * \param outputs A 4-D tensor with shape [N, C, H, W], gradient of input value.
+ */
+
+template <DeviceType Device>
+class ScaleSubRegionGradFunc : public FunctionBase {
+public:
+  void init(const FuncConfig& config) override { conf_ = config; }
+
+  void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
+    CHECK_EQ(2UL, inputs.size());
+    CHECK_EQ(1UL, outputs.size());
+    CHECK_EQ(outputs[0].getArgType(), ADD_TO);
+
+    TensorShape shape = inputs[0].shape();
+
+    ScaleSubRegionGrad<Device>(inputs[0].data<real>(),
+                               outputs[0].data<real>(),
+                               inputs[1].data<real>(),
+                               shape,
+                               conf_);
+  }
+
+private:
+  FuncConfig conf_;
+};
+
+REGISTER_TYPED_FUNC(ScaleSubRegion, CPU, ScaleSubRegionFunc);
+REGISTER_TYPED_FUNC(ScaleSubRegionGrad, CPU, ScaleSubRegionGradFunc);
+#ifdef PADDLE_WITH_CUDA
+REGISTER_TYPED_FUNC(ScaleSubRegion, GPU, ScaleSubRegionFunc);
+REGISTER_TYPED_FUNC(ScaleSubRegionGrad, GPU, ScaleSubRegionGradFunc);
+#endif
+
+}  // namespace paddle
diff --git a/paddle/function/ScaleSubRegionOp.h b/paddle/function/ScaleSubRegionOp.h
new file mode 100644
index 0000000000000000000000000000000000000000..0480c8577f3fbf3bc9e94b635df96a31b103e9e3
--- /dev/null
+++ b/paddle/function/ScaleSubRegionOp.h
@@ -0,0 +1,55 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+
+#include "Function.h"
+
+namespace paddle {
+
+/**
+ * \brief Function to multiply a value to values in specified sub continuous
+ *        region. Indices must be provided to indcate the location and shape of
+ *        the region and the multiplied value is passed by configure variable.
+ *
+ *
+ * \param[out] outputs  Output value.
+ * \param[in]  inputs   Input data which contains NCHW information.
+ * \param[in]  indices  Indices data to indcate the sub region.
+ * \param[in]  shape    Tensor shape of input value.
+ * \param[in]  conf     Configure variable which contains the multiplied value.
+ */
+template <DeviceType Device>
+void ScaleSubRegion(real* outputs,
+                    const real* inputs,
+                    const real* indices,
+                    const TensorShape shape,
+                    const FuncConfig& conf);
+
+/**
+ * \brief Backward propagation function of ScaleSubRegion.
+ *
+ * \param[out] inGrad   Gradients of previous layer.
+ * \param[in]  outGrad  Output gradient.
+ * \param[in]  indices  Indices data.
+ * \param[in]  shape    The Shape of input tensor.
+ * \param[in]  conf     Configure variable.
+ */
+template <DeviceType Device>
+void ScaleSubRegionGrad(const real* inGrad,
+                        real* outGrad,
+                        const real* indices,
+                        const TensorShape shape,
+                        const FuncConfig& conf);
+}  // namespace paddle
diff --git a/paddle/function/ScaleSubRegionOpGpu.cu b/paddle/function/ScaleSubRegionOpGpu.cu
new file mode 100644
index 0000000000000000000000000000000000000000..8aae2e44c3fdc8b516e66ecfd2e04f466a17dde9
--- /dev/null
+++ b/paddle/function/ScaleSubRegionOpGpu.cu
@@ -0,0 +1,116 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "ScaleSubRegionOp.h"
+#include "hl_base.h"
+
+namespace paddle {
+
+__global__ void KeScaleSubRegion(real* outputs,
+                                 const real* inputs,
+                                 const real* indices,
+                                 real value,
+                                 int channel,
+                                 int height,
+                                 int width,
+                                 int nthreads) {
+  const int idx = threadIdx.x + blockIdx.x * blockDim.x;
+  if (idx < nthreads) {
+    const int w = idx % width;
+    const int h = (idx / width) % height;
+    const int c = (idx / width / height) % channel;
+    const int n = idx / width / height / channel;
+
+    const int offset = n * 6;
+    if (c >= (indices[offset] - 1) && c <= (indices[offset + 1] - 1) &&
+        h >= (indices[offset + 2] - 1) && h <= (indices[offset + 3] - 1) &&
+        w >= (indices[offset + 4] - 1) && w <= (indices[offset + 5] - 1)) {
+      outputs[idx] = inputs[idx] * value;
+    } else {
+      outputs[idx] = inputs[idx];
+    }
+  }
+}
+
+template <>
+void ScaleSubRegion<DEVICE_TYPE_GPU>(real* outputs,
+                                     const real* inputs,
+                                     const real* indices,
+                                     const TensorShape shape,
+                                     const FuncConfig& conf) {
+  real value = conf.get<real>("value");
+
+  int number = shape[0];
+  int channel = shape[1];
+  int height = shape[2];
+  int width = shape[3];
+
+  size_t nth = number * channel * height * width;
+  int blockSize = 1024;
+  int gridSize = (nth + blockSize - 1) / blockSize;
+
+  KeScaleSubRegion<<<gridSize, blockSize, 0, STREAM_DEFAULT>>>(
+      outputs, inputs, indices, value, channel, height, width, nth);
+  CHECK_SYNC("ScaleSubRegion");
+}
+
+__global__ void KeScaleSubRegionDiff(const real* inGrad,
+                                     real* outGrad,
+                                     const real* indices,
+                                     real value,
+                                     int channel,
+                                     int height,
+                                     int width,
+                                     int nthreads) {
+  const int idx = threadIdx.x + blockIdx.x * blockDim.x;
+  if (idx < nthreads) {
+    const int w = idx % width;
+    const int h = (idx / width) % height;
+    const int c = (idx / width / height) % channel;
+    const int n = idx / width / height / channel;
+
+    const int offset = n * 6;
+    if (c >= (indices[offset] - 1) && c <= (indices[offset + 1] - 1) &&
+        h >= (indices[offset + 2] - 1) && h <= (indices[offset + 3] - 1) &&
+        w >= (indices[offset + 4] - 1) && w <= (indices[offset + 5] - 1)) {
+      outGrad[idx] += inGrad[idx] * value;
+    } else {
+      outGrad[idx] += inGrad[idx];
+    }
+  }
+}
+
+template <>
+void ScaleSubRegionGrad<DEVICE_TYPE_GPU>(const real* inGrad,
+                                         real* outGrad,
+                                         const real* indices,
+                                         const TensorShape shape,
+                                         const FuncConfig& conf) {
+  real value = conf.get<real>("value");
+
+  int number = shape[0];
+  int channel = shape[1];
+  int height = shape[2];
+  int width = shape[3];
+
+  size_t nth = number * channel * height * width;
+  int blockSize = 1024;
+  int gridSize = (nth + blockSize - 1) / blockSize;
+
+  KeScaleSubRegionDiff<<<gridSize, blockSize, 0, STREAM_DEFAULT>>>(
+      inGrad, outGrad, indices, value, channel, height, width, nth);
+  CHECK_SYNC("ScaleSubRegionGrad");
+}
+
+}  // namespace paddle
diff --git a/paddle/function/ScaleSubRegionOpTest.cpp b/paddle/function/ScaleSubRegionOpTest.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..2cbbf9d4b33b447021ae1701eac9a9a8c155ae07
--- /dev/null
+++ b/paddle/function/ScaleSubRegionOpTest.cpp
@@ -0,0 +1,72 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include <gtest/gtest.h>
+#include "FunctionTest.h"
+
+namespace paddle {
+
+TEST(ScaleSubRegion, real) {
+  for (size_t numSamples : {5, 32}) {
+    for (size_t channels : {5, 5, 32}) {
+      for (size_t imgSizeH : {5, 33, 100}) {
+        for (size_t imgSizeW : {5, 32, 96}) {
+          for (real value : {-0.5, 0.0, 0.5}) {
+            for (bool firstHalf : {false, true}) {
+              VLOG(3) << " numSamples=" << numSamples
+                      << " channels=" << channels << " imgSizeH=" << imgSizeH
+                      << " imgSizeW=" << imgSizeW;
+
+              for (bool testGrad : {false, true}) {
+                CpuGpuFuncCompare compare(
+                    testGrad ? "ScaleSubRegionGrad" : "ScaleSubRegion",
+                    FuncConfig().set<real>("value", value));
+
+                TensorShape shape{numSamples, channels, imgSizeH, imgSizeW};
+                TensorShape indicesShape{numSamples, 6};
+
+                compare.addInputs(BufferArg(VALUE_TYPE_FLOAT, shape));
+                compare.addInputs(BufferArg(VALUE_TYPE_FLOAT, indicesShape));
+
+                compare.registerInitCallback([=](BufferArg& arg, size_t index) {
+                  if (index == 1) {
+                    real* data = (real*)arg.data();
+
+                    for (size_t i = 0; i < numSamples; ++i) {
+                      size_t offset = i * 6;
+                      data[offset] = firstHalf ? 1 : channels / 2;
+                      data[offset + 1] = firstHalf ? channels / 2 : channels;
+                      data[offset + 2] = firstHalf ? 1 : imgSizeH / 2;
+                      data[offset + 3] = firstHalf ? imgSizeH / 2 : imgSizeH;
+                      data[offset + 4] = firstHalf ? 1 : imgSizeW / 2;
+                      data[offset + 5] = firstHalf ? imgSizeW / 2 : imgSizeW;
+                    }
+                  }
+                });
+
+                compare.addOutputs(
+                    BufferArg(
+                        VALUE_TYPE_FLOAT, shape, testGrad ? ADD_TO : ASSIGN_TO),
+                    testGrad ? ADD_TO : ASSIGN_TO);
+                compare.run();
+              }
+            }
+          }
+        }
+      }
+    }
+  }
+}
+
+}  // namespace paddle
diff --git a/paddle/gserver/layers/ScaleSubRegionLayer.cpp b/paddle/gserver/layers/ScaleSubRegionLayer.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..b18bc0c1b9065e09324d8ab4ed165679f6196d00
--- /dev/null
+++ b/paddle/gserver/layers/ScaleSubRegionLayer.cpp
@@ -0,0 +1,78 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "ScaleSubRegionLayer.h"
+#include "paddle/utils/Stat.h"
+namespace paddle {
+
+REGISTER_LAYER(scale_sub_region, ScaleSubRegionLayer);
+
+bool ScaleSubRegionLayer::init(const LayerMap& layerMap,
+                               const ParameterMap& parameterMap) {
+  Layer::init(layerMap, parameterMap);
+  CHECK_EQ(static_cast<int>(inputLayers_.size()), 2);
+  auto& conf = config_.inputs(0).scale_sub_region_conf();
+  value_ = conf.value();
+
+  createFunction(forward_, "ScaleSubRegion", FuncConfig().set("value", value_));
+  createFunction(
+      backward_, "ScaleSubRegionGrad", FuncConfig().set("value", value_));
+
+  return true;
+}
+
+void ScaleSubRegionLayer::forward(PassType passType) {
+  Layer::forward(passType);
+  auto in0 = getInput(0);
+  imgH_ = in0.getFrameHeight();
+  imgW_ = in0.getFrameWidth();
+  if (imgH_ == 0 || imgW_ == 0) {
+    auto& conf = config_.inputs(0).scale_sub_region_conf();
+    imgH_ = conf.image_conf().img_size_y();
+    imgW_ = conf.image_conf().img_size();
+  }
+  MatrixPtr imgV = in0.value;
+  size_t batchSize = imgV->getHeight();
+  size_t spatialSize = imgH_ * imgW_;
+  channelsNum_ = imgV->getWidth() / spatialSize;
+  shape_ = TensorShape({batchSize, channelsNum_, imgH_, imgW_});
+
+  resetOutput(batchSize, imgV->getWidth());
+  auto out = getOutput();
+  out.setFrameHeight(imgH_);
+  out.setFrameWidth(imgW_);
+
+  MatrixPtr indicesV = getInputValue(1);
+  indicesShape_ = TensorShape({batchSize, 6});
+
+  REGISTER_TIMER_INFO("ScaleSubRegionForward", getName().c_str());
+  BufferArgs inArgs;
+  BufferArgs outArgs;
+  inArgs.addArg(*imgV, shape_);
+  inArgs.addArg(*indicesV, indicesShape_);
+  outArgs.addArg(*out.value, shape_, ASSIGN_TO);
+  forward_[0]->calc(inArgs, outArgs);
+}
+
+void ScaleSubRegionLayer::backward(const UpdateCallback& callback) {
+  REGISTER_TIMER_INFO("ScaleSubRegionBackward", getName().c_str());
+  BufferArgs inArgs;
+  BufferArgs outArgs;
+  inArgs.addArg(*getOutputGrad(), shape_);
+  inArgs.addArg(*getInputValue(1), indicesShape_);
+  outArgs.addArg(*getInputGrad(0), shape_, ADD_TO);
+  backward_[0]->calc(inArgs, outArgs);
+}
+
+}  // namespace paddle
diff --git a/paddle/gserver/layers/ScaleSubRegionLayer.h b/paddle/gserver/layers/ScaleSubRegionLayer.h
new file mode 100644
index 0000000000000000000000000000000000000000..a27c56de93bb6fdde0f95cd4c5abe5dfabe4e858
--- /dev/null
+++ b/paddle/gserver/layers/ScaleSubRegionLayer.h
@@ -0,0 +1,52 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+
+#include "Layer.h"
+
+namespace paddle {
+
+/**
+ * \brief  For each instance, this layer can be used to multiply a value to a
+ *         specified sub continuous region. By providing start index and end
+ *         index for C/H/W, you can specify the location and shape of the
+ *         region.
+ *
+ *         input_0: Input value.
+ *         input_1: Indices value to specify the location an shape of the
+ *                  region.
+ */
+class ScaleSubRegionLayer : public Layer {
+public:
+  explicit ScaleSubRegionLayer(const LayerConfig& config) : Layer(config) {}
+
+  ~ScaleSubRegionLayer() {}
+
+  bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
+
+  void forward(PassType passType);
+
+  void backward(const UpdateCallback& callback = nullptr);
+
+protected:
+  TensorShape shape_;
+  TensorShape indicesShape_;
+  size_t imgH_;
+  size_t imgW_;
+  size_t channelsNum_;
+  real value_;
+};
+
+}  // namespace paddle
diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp
index 89da15839e7f7ffda058017ce69e414fca2a0741..3f7d8810513ed4b765219be722cf8ae9adc7909f 100644
--- a/paddle/gserver/tests/test_LayerGrad.cpp
+++ b/paddle/gserver/tests/test_LayerGrad.cpp
@@ -2358,11 +2358,11 @@ TEST(Layer, ScaleShiftLayer) {
   }
 }
 
-TEST(Layer, MulValueLayer) {
+TEST(Layer, ScaleSubRegionLayer) {
   const size_t batchSize = 64;
   const size_t size = 4096;
   TestConfig config;
-  config.layerConfig.set_type("mul_value");
+  config.layerConfig.set_type("scale_sub_region");
   config.inputDefs.push_back({INPUT_DATA, "input", size, 0});
   MatrixPtr indicesV = Matrix::create(batchSize, 6, false, false);
   auto* data = indicesV->getData();
@@ -2376,16 +2376,17 @@ TEST(Layer, MulValueLayer) {
   }
   config.inputDefs.push_back({INPUT_SELF_DEFINE_DATA, "indices", indicesV, {}});
   LayerInputConfig* input = config.layerConfig.add_inputs();
-  MulValueConfig* mulValueConf = input->mutable_mul_value_conf();
-  ImageConfig* imgConf = mulValueConf->mutable_image_conf();
+  ScaleSubRegionConfig* scaleSubRegionConf =
+      input->mutable_scale_sub_region_conf();
+  ImageConfig* imgConf = scaleSubRegionConf->mutable_image_conf();
   imgConf->set_img_size(32);
   imgConf->set_img_size_y(32);
   imgConf->set_channels(4);
-  mulValueConf->set_value(1.0);
+  scaleSubRegionConf->set_value(2.0);
   config.layerConfig.add_inputs();
 
   for (auto useGpu : {false, true}) {
-    testLayerGrad(config, "mul_value", batchSize, false, useGpu, false);
+    testLayerGrad(config, "scale_sub_region", batchSize, false, useGpu, false);
   }
 }
 
diff --git a/proto/ModelConfig.proto b/proto/ModelConfig.proto
index 0fecad3f7d28b5ad5e8d02f619bd00e0b977fdf8..2d7ff1df98a9a448b447890537f20dd416a9ae9d 100644
--- a/proto/ModelConfig.proto
+++ b/proto/ModelConfig.proto
@@ -321,7 +321,7 @@ message ClipConfig {
   required double max = 2;
 }
 
-message MulValueConfig {
+message ScaleSubRegionConfig {
   required ImageConfig image_conf = 1;
   required float value = 2;
 }
@@ -347,7 +347,7 @@ message LayerInputConfig {
   optional MultiBoxLossConfig multibox_loss_conf = 16;
   optional DetectionOutputConfig detection_output_conf = 17;
   optional ClipConfig clip_conf = 18;
-  optional MulValueConfig mul_value_conf = 19;
+  optional ScaleSubRegionConfig scale_sub_region_conf = 19;
 }
 
 message LayerConfig {
diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py
index 222e195efe25bdbaede5dd0d3632802c6fb1f88c..9e2c6f59bd0af1627c79a8a29bd1515ae5c9c6b5 100644
--- a/python/paddle/trainer/config_parser.py
+++ b/python/paddle/trainer/config_parser.py
@@ -3801,21 +3801,23 @@ class SwitchOrderLayer(LayerBase):
         self.config.reshape_conf.width_axis.extend(reshape['width'])
 
 
-@config_layer('mul_value')
-class MulValueLayer(LayerBase):
+@config_layer('scale_sub_region')
+class ScaleSubRegionLayer(LayerBase):
     def __init__(self, name, inputs, value, **xargs):
-        super(MulValueLayer, self).__init__(
-            name, 'mul_value', 0, inputs=inputs, **xargs)
-        mul_value_conf = self.config.inputs[0].mul_value_conf
-        mul_value_conf.value = value
+        super(ScaleSubRegionLayer, self).__init__(
+            name, 'scale_sub_region', 0, inputs=inputs, **xargs)
+        scale_sub_region_conf = self.config.inputs[0].scale_sub_region_conf
+        scale_sub_region_conf.value = value
 
         # get channel, width and height from input_0 layer
         input_layer = self.get_input_layer(0)
-        image_conf = mul_value_conf.image_conf
+        image_conf = scale_sub_region_conf.image_conf
         image_conf.img_size = input_layer.width
         image_conf.img_size_y = input_layer.height
         image_conf.channels = input_layer.size / (input_layer.width *
                                                   input_layer.height)
+        self.set_cnn_layer(name, image_conf.img_size_y, image_conf.img_size,
+                           image_conf.channels)
 
 
 # Deprecated, use a new layer specific class instead
diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py
index e6901de14b9d2e14998847c48a60ef84193db145..f6527267f9b7c3417f79adb4d0e604e0e0f00d22 100644
--- a/python/paddle/trainer_config_helpers/layers.py
+++ b/python/paddle/trainer_config_helpers/layers.py
@@ -144,7 +144,7 @@ __all__ = [
     'img_conv3d_layer',
     'resize_layer',
     'sub_seq_layer',
-    'mul_value_layer',
+    'scale_sub_region_layer',
 ]
 
 
@@ -256,7 +256,7 @@ class LayerType(object):
     RESIZE = 'resize'
     SUB_SEQ_LAYER = 'subseq'
 
-    MUL_VALUE_LAYER = 'mul_value'
+    SCALE_SUB_REGION_LAYER = 'scale_sub_region'
 
     @staticmethod
     def is_layer_type(type_name):
@@ -7042,19 +7042,21 @@ def sub_seq_layer(input, offsets, sizes, act=None, bias_attr=None, name=None):
         size=input.size)
 
 
-@wrap_name_default('mul_value')
-def mul_value_layer(input, indices, value, name=None):
+@wrap_name_default('scale_sub_region')
+def scale_sub_region_layer(input, indices, value, name=None):
     """
-    Given an image or feature map with CHW information, mul_value_layer can be
-    used to multiply a real value to values of a sub continuous region. You can
-    provide start and end indices of CHW for each instance. Please notice that
-    all start indices are counting from 1. The shape of indices should be
-    [batch_size, 6] and the layout for each row is [C_Start, C_End, H_Start,
-    H_End, W_Start, W_End].
+    Given an image or feature map with CHW information, scale_sub_region_layer
+    can be used to multiply a real value to values of a sub continuous region.
+    You can provide start and end indices of CHW for each instance.
+    Please notice that all start indices are counting from 1.
+    The shape of indices should be [batch_size, 6] and the layout for each row
+    is [C_Start, C_End, H_Start, H_End, W_Start, W_End].
 
     .. code-block:: python
 
-        mul_value = mul_value_layer(input=input, indices=indices, value=value)
+        scale_sub_region = scale_sub_region_layer(input=input,
+                                                  indices=indices,
+                                                  value=value)
 
     :param name: The name of this layer. It is optional.
     :type name: basestring
@@ -7070,7 +7072,8 @@ def mul_value_layer(input, indices, value, name=None):
     """
 
     assert isinstance(input, LayerOutput), (
-        'The first input of mul_value_layer, must be a PaddlePaddle layer.')
+        'The first input of scale_sub_region_layer, '
+        'must be a PaddlePaddle layer.')
     assert isinstance(indices, LayerOutput), (
         'The start and end indices for CHW, must be a PaddlePaddle layer.')
     assert isinstance(value, float), (
@@ -7078,12 +7081,13 @@ def mul_value_layer(input, indices, value, name=None):
 
     Layer(
         name=name,
-        type=LayerType.MUL_VALUE_LAYER,
+        type=LayerType.SCALE_SUB_REGION_LAYER,
         inputs=[input.name, indices.name],
         value=value)
 
     return LayerOutput(
         name,
-        LayerType.MUL_VALUE_LAYER,
+        LayerType.SCALE_SUB_REGION_LAYER,
         parents=[input, indices],
+        num_filters=input.num_filters,
         size=input.size)
diff --git a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh
index 4c00400ddae5ac9183f856edd9baa7842956fd12..42aaed7a6469342086b8273eb5b80eaea905f851 100755
--- a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh
+++ b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh
@@ -10,6 +10,6 @@ test_prelu_layer test_row_conv test_detection_output_layer test_multibox_loss_la
 test_recursive_topology test_gated_unit_layer test_clip_layer test_row_l2_norm_layer
 test_kmax_seq_socre_layer test_sub_nested_seq_select_layer test_scale_shift_layer
 test_seq_slice_layer test_cross_entropy_over_beam test_pooling3D_layer
-test_conv3d_layer test_deconv3d_layer test_BatchNorm3D test_resize_layer test_mul_value_layer)
+test_conv3d_layer test_deconv3d_layer test_BatchNorm3D test_resize_layer test_scale_sub_region_layer)
 
 export whole_configs=(test_split_datasource)
diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_scale_sub_region_layer.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_scale_sub_region_layer.protostr
new file mode 100644
index 0000000000000000000000000000000000000000..d20133a10ec605654bd3744297673068a77020b8
--- /dev/null
+++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_scale_sub_region_layer.protostr
@@ -0,0 +1,51 @@
+type: "nn"
+layers {
+  name: "data"
+  type: "data"
+  size: 2016
+  active_type: ""
+  height: 48
+  width: 42
+}
+layers {
+  name: "indices"
+  type: "data"
+  size: 6
+  active_type: ""
+}
+layers {
+  name: "__scale_sub_region_0__"
+  type: "scale_sub_region"
+  size: 2016
+  active_type: ""
+  inputs {
+    input_layer_name: "data"
+    scale_sub_region_conf {
+      image_conf {
+        channels: 1
+        img_size: 42
+        img_size_y: 48
+      }
+      value: 0.0
+    }
+  }
+  inputs {
+    input_layer_name: "indices"
+  }
+  height: 48
+  width: 42
+}
+input_layer_names: "data"
+input_layer_names: "indices"
+output_layer_names: "__scale_sub_region_0__"
+sub_models {
+  name: "root"
+  layer_names: "data"
+  layer_names: "indices"
+  layer_names: "__scale_sub_region_0__"
+  input_layer_names: "data"
+  input_layer_names: "indices"
+  output_layer_names: "__scale_sub_region_0__"
+  is_recurrent_layer_group: false
+}
+
diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_scale_sub_region_layer.py b/python/paddle/trainer_config_helpers/tests/configs/test_scale_sub_region_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d4bf28bf1eaf58e1fd0eb62fd10efe998587edd
--- /dev/null
+++ b/python/paddle/trainer_config_helpers/tests/configs/test_scale_sub_region_layer.py
@@ -0,0 +1,11 @@
+from paddle.trainer_config_helpers import *
+
+settings(batch_size=1000, learning_rate=1e-5)
+
+data = data_layer(name='data', size=2016, height=48, width=42)
+indices = data_layer(name='indices', size=6)
+
+scale_sub_region = scale_sub_region_layer(
+    input=data, indices=indices, value=0.0)
+
+outputs(scale_sub_region)