diff --git a/paddle/fluid/operators/bilateral_slice_op.cc b/paddle/fluid/operators/bilateral_slice_op.cc
new file mode 100644
index 0000000000000000000000000000000000000000..b742b4c0deea89dacd29a02588236b81ac13f6af
--- /dev/null
+++ b/paddle/fluid/operators/bilateral_slice_op.cc
@@ -0,0 +1,194 @@
+/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+   http://www.apache.org/licenses/LICENSE-2.0
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License. */
+
+#include "paddle/fluid/operators/bilateral_slice_op.h"
+#include <memory>
+#include <string>
+#include <vector>
+#include "paddle/fluid/framework/op_registry.h"
+
+namespace paddle {
+namespace operators {
+
+using framework::Tensor;
+using DataLayout = framework::DataLayout;
+
+class BilateralSliceOp : public framework::OperatorWithKernel {
+ public:
+  using framework::OperatorWithKernel::OperatorWithKernel;
+
+ protected:
+  void InferShape(framework::InferShapeContext* ctx) const override {
+    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "BilateralSlice");
+    OP_INOUT_CHECK(ctx->HasInput("Grid"), "Input", "Grid", "BilateralSlice");
+    OP_INOUT_CHECK(ctx->HasInput("Guide"), "Input", "Guide", "BilateralSlice");
+    OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Output", "BilateralSlice");
+
+    auto dim_x = ctx->GetInputDim("X");  // NCHW format
+    PADDLE_ENFORCE_EQ(
+        dim_x.size(), 4,
+        platform::errors::Unimplemented(
+            "Input(X) dimension must be 4, but got dimension = %d .",
+            dim_x.size()));
+
+    auto input_dims = ctx->GetInputDim("X");
+    auto grid_dims = ctx->GetInputDim("Grid");
+    auto guide_dims = ctx->GetInputDim("Guide");
+    bool has_offset = ctx->Attrs().Get<bool>("has_offset");
+    int64_t h = guide_dims[1];
+    int64_t w = guide_dims[2];
+    int64_t bs = grid_dims[0];
+    int64_t coeffs_chans = grid_dims[1];
+    int64_t input_chans = input_dims[1];
+
+    int64_t output_chans;
+    if (has_offset) {
+      PADDLE_ENFORCE_EQ((coeffs_chans % (input_chans + 1)), 0,
+                        platform::errors::InvalidArgument(
+                            "Slicing with affine offset, coefficients grid "
+                            "should have n_out*(n_in+1) channels, but got %d",
+                            coeffs_chans));
+      output_chans = coeffs_chans / (input_chans + 1);
+    } else {
+      PADDLE_ENFORCE_EQ((coeffs_chans % input_chans), 0,
+                        platform::errors::InvalidArgument(
+                            "Slicing without affine offset, coefficients grid "
+                            "should have n_out*n_in channels, but got %d .",
+                            coeffs_chans));
+      output_chans = coeffs_chans / input_chans;
+    }
+
+    std::vector<int64_t> output_dims;
+    output_dims.push_back(bs);
+    output_dims.push_back(output_chans);
+    output_dims.push_back(h);
+    output_dims.push_back(w);
+
+    ctx->SetOutputDim("Out", framework::make_ddim(output_dims));
+  }
+
+ protected:
+  framework::OpKernelType GetExpectedKernelType(
+      const framework::ExecutionContext& ctx) const override {
+    return framework::OpKernelType(
+        OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
+  }
+
+  framework::OpKernelType GetKernelTypeForVar(
+      const std::string& var_name, const Tensor& tensor,
+      const framework::OpKernelType& expected_kernel_type) const override {
+    return framework::OpKernelType(expected_kernel_type.data_type_,
+                                   tensor.place(), tensor.layout());
+  }
+};
+
+class BilateralSliceOpMaker : public framework::OpProtoAndCheckerMaker {
+ public:
+  void Make() override {
+    AddInput("X",
+             "The input tensor of bilateral_slice operator, "
+             "This is a 4-D tensor with shape of [N, C, H, W]");
+    AddInput("Grid",
+             "This is a 5-D tensor. "
+             "It should be [N, C, D, H, W].");
+    AddInput("Guide",
+             "This is a 3-D tensor "
+             "It should be [N, H, W].");
+    AddOutput("Out",
+              "The output tensor of bilateral slice operator, "
+              "This is a tensor in same rank with Input(X).");
+    AddAttr<bool>("has_offset", "an optional bool. Defaults to False. ")
+        .SetDefault(false);
+    AddComment(R"DOC(
+          This operator enhance input X according guide and grid
+          For details of bilateral slice, please refer to paper:
+          https://groups.csail.mit.edu/graphics/hdrnet/
+         )DOC");
+  }
+};
+
+class BilateralSliceOpGrad : public framework::OperatorWithKernel {
+ public:
+  using framework::OperatorWithKernel::OperatorWithKernel;
+
+ protected:
+  void InferShape(framework::InferShapeContext* ctx) const override {
+    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "BilateralSliceOpGrad");
+    OP_INOUT_CHECK(ctx->HasInput("Grid"), "Input", "Grid",
+                   "BilateralSliceOpGrad");
+    OP_INOUT_CHECK(ctx->HasInput("Guide"), "Input", "Guide",
+                   "BilateralSliceOpGrad");
+    OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", "Out",
+                   "BilateralSliceOpGrad");
+
+    auto dim_x = ctx->GetInputDim("X");
+    auto dim_grid = ctx->GetInputDim("Grid");
+    auto dim_guide = ctx->GetInputDim("Guide");
+    if (ctx->HasOutput(framework::GradVarName("X"))) {
+      ctx->SetOutputDim(framework::GradVarName("X"), dim_x);
+    }
+    if (ctx->HasOutput(framework::GradVarName("Grid"))) {
+      ctx->SetOutputDim(framework::GradVarName("Grid"), dim_grid);
+    }
+    if (ctx->HasOutput(framework::GradVarName("Guide"))) {
+      ctx->SetOutputDim(framework::GradVarName("Guide"), dim_guide);
+    }
+  }
+
+  framework::OpKernelType GetExpectedKernelType(
+      const framework::ExecutionContext& ctx) const override {
+    return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
+                                       ctx, framework::GradVarName("Out")),
+                                   ctx.GetPlace());
+  }
+};
+
+template <typename T>
+class BilateralSliceGradMaker : public framework::SingleGradOpMaker<T> {
+ public:
+  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
+
+ protected:
+  void Apply(GradOpPtr<T> op) const override {
+    op->SetType(this->ForwardOpType() + "_grad");
+    op->SetInput("X", this->Input("X"));
+    op->SetInput("Grid", this->Input("Grid"));
+    op->SetInput("Guide", this->Input("Guide"));
+
+    op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
+    op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
+    op->SetOutput(framework::GradVarName("Grid"), this->InputGrad("Grid"));
+    op->SetOutput(framework::GradVarName("Guide"), this->InputGrad("Guide"));
+    op->SetAttrMap(this->Attrs());
+  }
+};
+
+template <typename T>
+class BilateralSliceKernel : public framework::OpKernel<T> {
+ public:
+  void Compute(const framework::ExecutionContext& ctx) const override {
+    PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
+                      platform::errors::Unimplemented(
+                          "BilateralSlice only supports GPU now."));
+  }
+};
+
+}  // namespace operators
+}  // namespace paddle
+
+namespace ops = paddle::operators;
+REGISTER_OPERATOR(bilateral_slice, ops::BilateralSliceOp,
+                  ops::BilateralSliceOpMaker,
+                  ops::BilateralSliceGradMaker<paddle::framework::OpDesc>,
+                  ops::BilateralSliceGradMaker<paddle::imperative::OpBase>);
+REGISTER_OPERATOR(bilateral_slice_grad, ops::BilateralSliceOpGrad);
+REGISTER_OP_CPU_KERNEL(bilateral_slice, ops::BilateralSliceKernel<float>,
+                       ops::BilateralSliceKernel<double>);
diff --git a/paddle/fluid/operators/bilateral_slice_op.cu b/paddle/fluid/operators/bilateral_slice_op.cu
new file mode 100644
index 0000000000000000000000000000000000000000..e46950f61887dd64123135faec36ee0df11c0683
--- /dev/null
+++ b/paddle/fluid/operators/bilateral_slice_op.cu
@@ -0,0 +1,506 @@
+/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+   http://www.apache.org/licenses/LICENSE-2.0
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License. */
+
+#include <algorithm>
+#include <string>
+#include "paddle/fluid/operators/bilateral_slice_op.h"
+#include "paddle/fluid/platform/cuda_primitives.h"
+#include "paddle/fluid/platform/gpu_launch_config.h"
+
+namespace paddle {
+namespace operators {
+
+using framework::Tensor;
+using DataLayout = framework::DataLayout;
+
+template <typename T>
+__device__ T DiffAbs(T x) {
+  T eps = 1e-8;
+  return sqrt(x * x + eps);
+}
+
+template <typename T>
+__device__ T DdiffAbs(T x) {
+  T eps = 1e-8;
+  return x / sqrt(x * x + eps);
+}
+
+template <typename T>
+__device__ T WeightZ(T x) {
+  T abx = DiffAbs(x);
+  return max(1.0f - abx, 0.0f);
+}
+
+template <typename T>
+__device__ T DweightZ(T x) {
+  T abx = DiffAbs(x);
+  if (abx > 1.0f) {
+    return 0.0f;
+  } else {
+    return DdiffAbs(x);
+  }
+}
+
+template <typename T>
+__global__ void BilateralSliceCudaForwardKernel(
+    T* output, const T* bilateral_grid, const T* guide, const T* input,
+    GridSizes gsz, bool has_offset, int total_count, int output_chans) {
+  int h = gsz.h;
+  int w = gsz.w;
+  int gd = gsz.gd;
+  int gh = gsz.gh;
+  int gw = gsz.gw;
+  int input_chans = gsz.input_chans;
+  int coeff_stride = input_chans;
+  int grid_chans = input_chans * output_chans;
+
+  if (has_offset) {
+    grid_chans += output_chans;
+    coeff_stride += 1;
+  }
+
+  for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < total_count;
+       idx += blockDim.x * gridDim.x) {
+    int x = idx % w;
+    int y = (idx / w) % h;
+    int out_c = (idx / (h * w)) % output_chans;
+    int b = (idx / (output_chans * w * h));
+
+    T gx = (x + 0.5f) * gw / (1.0f * w);
+    T gy = (y + 0.5f) * gh / (1.0f * h);
+    T gz = guide[x + w * (y + h * b)] * gd;
+
+    int fx = static_cast<int>(floor(gx - 0.5f));
+    int fy = static_cast<int>(floor(gy - 0.5f));
+    int fz = static_cast<int>(floor(gz - 0.5f));
+
+    int sy = gw;
+    int sz = gw * gh;
+    int sc = gd * gw * gh;
+    int sb = grid_chans * gd * gw * gh;
+
+    T value = 0.0f;
+    for (int in_c = 0; in_c < coeff_stride; ++in_c) {
+      T coeff_sample = 0.0f;
+
+      for (int xx = fx; xx < fx + 2; ++xx) {
+        int x_ = max(min(xx, gw - 1), 0);
+        T wx = max(1.0f - abs(xx + 0.5 - gx), 0.0f);
+
+        for (int yy = fy; yy < fy + 2; ++yy) {
+          int y_ = max(min(yy, gh - 1), 0);
+          T wy = max(1.0f - abs(yy + 0.5 - gy), 0.0f);
+
+          for (int zz = fz; zz < fz + 2; ++zz) {
+            int z_ = max(min(zz, gd - 1), 0);
+            T wz = WeightZ(zz + 0.5 - gz);
+            int c_ = coeff_stride * out_c + in_c;
+            int grid_idx = x_ + sy * y_ + sz * z_ + sc * c_ + sb * b;
+
+            coeff_sample += bilateral_grid[grid_idx] * wx * wy * wz;
+          }
+        }
+      }
+      if (in_c < input_chans) {
+        int input_idx = x + w * (y + h * (in_c + input_chans * b));
+        value += coeff_sample * input[input_idx];
+      } else {
+        value += coeff_sample;
+      }
+    }
+
+    output[idx] = value;
+  }
+}
+
+template <typename T>
+class BilateralSliceOpCUDAKernel : public framework::OpKernel<T> {
+ public:
+  void Compute(const framework::ExecutionContext& ctx) const override {
+    auto* input = ctx.Input<Tensor>("X");
+    auto* grid = ctx.Input<Tensor>("Grid");
+    auto* guide = ctx.Input<Tensor>("Guide");
+    auto* output = ctx.Output<Tensor>("Out");
+
+    auto* output_data = output->mutable_data<T>(ctx.GetPlace());
+    auto* grid_data = grid->data<T>();
+    auto* guide_data = guide->data<T>();
+    auto* input_data = input->data<T>();
+
+    bool has_offset = ctx.Attr<bool>("has_offset");
+    auto input_dims = input->dims();
+    auto output_dims = output->dims();
+    auto grid_dims = grid->dims();
+
+    int batch_size = input_dims[0];
+    int h = input_dims[2];
+    int w = input_dims[3];
+    int input_chans = input_dims[1];
+    int coeff_stride = input_chans;
+    int grid_chans = input_chans * output_dims[1];
+
+    int64_t coeffs_chans = grid_dims[1];
+    int64_t gd = grid_dims[2];
+    int64_t gh = grid_dims[3];
+    int64_t gw = grid_dims[4];
+
+    GridSizes grid_sizes;
+    grid_sizes.h = h;
+    grid_sizes.w = w;
+    grid_sizes.bs = batch_size;
+    grid_sizes.coeffs_chans = coeffs_chans;
+    grid_sizes.gd = gd;
+    grid_sizes.gh = gh;
+    grid_sizes.gw = gw;
+    grid_sizes.input_chans = input_chans;
+
+    int total_count = batch_size * h * w * output_dims[1];
+
+    platform::GpuLaunchConfig config =
+        platform::getGpuLaunchConfig(total_count, ctx);
+
+    BilateralSliceCudaForwardKernel<T><<<config.blocks, config.threads, 0,
+                                         ctx.cuda_device_context().stream()>>>(
+        output_data, grid_data, guide_data, input_data, grid_sizes, has_offset,
+        total_count, output_dims[1]);
+  }
+};
+
+template <typename T>
+__global__ void BilateralSliceCudaGridGradKernel(
+    T* out_grid_grad, const T* upstream_grad, const T* guide, const T* input,
+    GridSizes gsz, bool has_offset, int grid_count, int output_chans) {
+  int h = gsz.h;
+  int w = gsz.w;
+  int gd = gsz.gd;
+  int gh = gsz.gh;
+  int gw = gsz.gw;
+  int input_chans = gsz.input_chans;
+  int grid_chans = input_chans * output_chans;
+  int coeff_stride = input_chans;
+
+  if (has_offset) {
+    grid_chans += output_chans;
+    coeff_stride += 1;
+  }
+
+  for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < grid_count;
+       idx += blockDim.x * gridDim.x) {
+    int gx = idx % gw;
+    int gy = (idx / gw) % gh;
+    int gz = (idx / (gh * gw)) % gd;
+    int c = (idx / (gd * gh * gw)) % grid_chans;
+    int b = (idx / (grid_chans * gd * gw * gh));
+
+    T scale_w = w * 1.0 / gw;
+    T scale_h = h * 1.0 / gh;
+
+    int left_x = static_cast<int>(floor(scale_w * (gx + 0.5 - 1)));
+    int right_x = static_cast<int>(ceil(scale_w * (gx + 0.5 + 1)));
+    int left_y = static_cast<int>(floor(scale_h * (gy + 0.5 - 1)));
+    int right_y = static_cast<int>(ceil(scale_h * (gy + 0.5 + 1)));
+
+    int sy = w;
+    int sc = w * h;
+    int sb = output_chans * w * h;
+
+    int isy = w;
+    int isc = h * w;
+    int isb = input_chans * h * w;
+
+    int out_c = c / coeff_stride;
+    int in_c = c % coeff_stride;
+
+    T value = 0.0f;
+    for (int x = left_x; x < right_x; ++x) {
+      int x_ = x;
+
+      if (x_ < 0) {
+        x_ = -x_ - 1;
+      }
+      if (x_ >= w) {
+        x_ = 2 * w - 1 - x_;
+      }
+
+      T gx2 = (x + 0.5f) / scale_w;
+      T wx = max(1.0f - abs(gx + 0.5 - gx2), 0.0f);
+
+      for (int y = left_y; y < right_y; ++y) {
+        int y_ = y;
+
+        if (y_ < 0) {
+          y_ = -y_ - 1;
+        }
+        if (y_ >= h) {
+          y_ = 2 * h - 1 - y_;
+        }
+
+        T gy2 = (y + 0.5f) / scale_h;
+        T wy = max(1.0f - abs(gy + 0.5 - gy2), 0.0f);
+
+        int guide_idx = x_ + w * y_ + h * w * b;
+        T gz2 = guide[guide_idx] * gd;
+        T wz = WeightZ(gz + 0.5f - gz2);
+        if (((gz == 0) && (gz2 < 0.5f)) ||
+            ((gz == (gd - 1)) && (gz2 > (gd - 0.5f)))) {
+          wz = 1.0f;
+        }
+
+        int back_idx = x_ + sy * y_ + sc * out_c + sb * b;
+        if (in_c < input_chans) {
+          int input_idx = x_ + isy * y_ + isc * in_c + isb * b;
+          value += wz * wx * wy * upstream_grad[back_idx] * input[input_idx];
+        } else {
+          value += wz * wx * wy * upstream_grad[back_idx];
+        }
+      }
+    }
+    out_grid_grad[idx] = value;
+  }
+}
+
+template <typename T>
+__global__ void BilateralSliceCudaGuideGradKernel(
+    T* out_guide_grad, const T* upstream_grad, const T* bilateral_grid,
+    const T* guide, const T* input, GridSizes gsz, bool has_offset,
+    int guide_count, int output_chans) {
+  int h = gsz.h;
+  int w = gsz.w;
+  int gd = gsz.gd;
+  int gh = gsz.gh;
+  int gw = gsz.gw;
+  int input_chans = gsz.input_chans;
+  int grid_chans = input_chans * output_chans;
+  int coeff_stride = input_chans;
+
+  if (has_offset) {
+    grid_chans += output_chans;
+    coeff_stride += 1;
+  }
+
+  for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < guide_count;
+       idx += blockDim.x * gridDim.x) {
+    int x = idx % w;
+    int y = (idx / w) % h;
+    int b = (idx / (w * h));
+
+    T gx = (x + 0.5f) * gw / (1.0f * w);
+    T gy = (y + 0.5f) * gh / (1.0f * h);
+    T gz = guide[x + w * (y + h * b)] * gd;
+
+    int fx = static_cast<int>(floor(gx - 0.5f));
+    int fy = static_cast<int>(floor(gy - 0.5f));
+    int fz = static_cast<int>(floor(gz - 0.5f));
+
+    int sy = gw;
+    int sz = gh * gw;
+    int sc = gd * gh * gw;
+    int sb = grid_chans * gd * gw * gh;
+
+    T out_sum = 0.0f;
+    for (int out_c = 0; out_c < output_chans; ++out_c) {
+      T in_sum = 0.0f;
+      for (int in_c = 0; in_c < coeff_stride; ++in_c) {
+        T grid_sum = 0.0f;
+        for (int xx = fx; xx < fx + 2; ++xx) {
+          int x_ = max(min(xx, gw - 1), 0);
+          T wx = max(1.0f - abs(xx + 0.5 - gx), 0.0f);
+
+          for (int yy = fy; yy < fy + 2; ++yy) {
+            int y_ = max(min(yy, gh - 1), 0);
+            T wy = max(1.0f - abs(yy + 0.5 - gy), 0.0f);
+
+            for (int zz = fz; zz < fz + 2; ++zz) {
+              int z_ = max(min(zz, gd - 1), 0);
+              T dwz = gd * DweightZ(zz + 0.5 - gz);
+
+              int c_ = coeff_stride * out_c + in_c;
+              int grid_idx = x_ + sy * y_ + sz * z_ + sc * c_ + sb * b;
+              grid_sum += bilateral_grid[grid_idx] * wx * wy * dwz;
+            }
+          }
+        }
+
+        if (in_c < input_chans) {
+          in_sum +=
+              grid_sum * input[x + w * (y + h * (in_c + input_chans * b))];
+        } else {
+          in_sum += grid_sum;
+        }
+      }
+
+      out_sum +=
+          in_sum * upstream_grad[x + w * (y + h * (out_c + output_chans * b))];
+    }
+
+    out_guide_grad[idx] = out_sum;
+  }
+}
+
+template <typename T>
+__global__ void BilateralSliceCudaInputGradKernel(
+    T* out_input_grad, const T* upstream_grad, const T* bilateral_grid,
+    const T* guide, GridSizes gsz, bool has_offset, int input_count,
+    int output_chans) {
+  int h = gsz.h;
+  int w = gsz.w;
+  int gd = gsz.gd;
+  int gh = gsz.gh;
+  int gw = gsz.gw;
+  int input_chans = gsz.input_chans;
+  int grid_chans = input_chans * output_chans;
+  int coeff_stride = input_chans;
+
+  if (has_offset) {
+    grid_chans += output_chans;
+    coeff_stride += 1;
+  }
+
+  for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < input_count;
+       idx += blockDim.x * gridDim.x) {
+    int x = idx % w;
+    int y = (idx / w) % h;
+    int in_c = (idx / (h * w)) % input_chans;
+    int b = (idx / (input_chans * w * h));
+
+    T gx = (x + 0.5f) * gw / (1.0f * w);
+    T gy = (y + 0.5f) * gh / (1.0f * h);
+    T gz = guide[x + w * (y + h * b)] * gd;
+
+    int fx = static_cast<int>(floor(gx - 0.5f));
+    int fy = static_cast<int>(floor(gy - 0.5f));
+    int fz = static_cast<int>(floor(gz - 0.5f));
+
+    int sy = gw;
+    int sz = gh * gw;
+    int sc = gd * gh * gw;
+    int sb = grid_chans * gd * gh * gw;
+
+    T value = 0.0f;
+    for (int out_c = 0; out_c < output_chans; ++out_c) {
+      T chan_val = 0.0f;
+
+      for (int xx = fx; xx < fx + 2; ++xx) {
+        int x_ = max(min(xx, gw - 1), 0);
+        T wx = max(1.0f - abs(xx + 0.5 - gx), 0.0f);
+
+        for (int yy = fy; yy < fy + 2; ++yy) {
+          int y_ = max(min(yy, gh - 1), 0);
+          T wy = max(1.0f - abs(yy + 0.5 - gy), 0.0f);
+
+          for (int zz = fz; zz < fz + 2; ++zz) {
+            int z_ = max(min(zz, gd - 1), 0);
+
+            T wz = WeightZ(zz + 0.5 - gz);
+
+            int c_ = coeff_stride * out_c + in_c;
+            int grid_idx = x_ + sy * y_ + sz * z_ + sc * c_ + sb * b;
+            chan_val += bilateral_grid[grid_idx] * wx * wy * wz;
+          }
+        }
+      }
+
+      value += chan_val *
+               upstream_grad[x + w * (y + h * (out_c + output_chans * b))];
+    }
+    out_input_grad[idx] = value;
+  }
+}
+
+template <typename T>
+class BilateralSliceGradOpCUDAKernel : public framework::OpKernel<T> {
+ public:
+  void Compute(const framework::ExecutionContext& ctx) const override {
+    auto* input = ctx.Input<Tensor>("X");
+    auto* guide = ctx.Input<Tensor>("Guide");
+    auto* grid = ctx.Input<Tensor>("Grid");
+    auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
+    auto* grid_grad = ctx.Output<Tensor>(framework::GradVarName("Grid"));
+    auto* guide_grad = ctx.Output<Tensor>(framework::GradVarName("Guide"));
+    auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
+
+    const T* input_data = input->data<T>();
+    const T* guide_data = guide->data<T>();
+    const T* grid_data = grid->data<T>();
+    const T* output_grad_data = output_grad->data<T>();
+
+    T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
+    T* guide_grad_data = guide_grad->mutable_data<T>(ctx.GetPlace());
+    T* grid_grad_data = grid_grad->mutable_data<T>(ctx.GetPlace());
+
+    bool has_offset = ctx.Attr<bool>("has_offset");
+
+    auto input_grad_dims = input_grad->dims();
+    auto grid_dims = grid_grad->dims();
+
+    int batch_size = input_grad_dims[0];
+    int h = input_grad_dims[2];
+    int w = input_grad_dims[3];
+    int input_chans = input_grad_dims[1];
+
+    int64_t coeffs_chans = grid_dims[1];
+    int64_t gd = grid_dims[2];
+    int64_t gh = grid_dims[3];
+    int64_t gw = grid_dims[4];
+
+    int output_chans = 0;
+    if (has_offset) {
+      output_chans = coeffs_chans / (input_chans + 1);
+    } else {
+      output_chans = coeffs_chans / input_chans;
+    }
+    int grid_count = batch_size * gh * gw * gd * coeffs_chans;
+    int guide_count = batch_size * h * w;
+    int input_count = batch_size * h * w * input_chans;
+
+    GridSizes grid_sizes;
+    grid_sizes.h = h;
+    grid_sizes.w = w;
+    grid_sizes.bs = batch_size;
+    grid_sizes.coeffs_chans = coeffs_chans;
+    grid_sizes.gd = gd;
+    grid_sizes.gh = gh;
+    grid_sizes.gw = gw;
+    grid_sizes.input_chans = input_chans;
+
+    platform::GpuLaunchConfig config =
+        platform::getGpuLaunchConfig(grid_count, ctx, 512);
+
+    BilateralSliceCudaGridGradKernel<T><<<config.blocks, config.threads, 0,
+                                          ctx.cuda_device_context().stream()>>>(
+        grid_grad_data, output_grad_data, guide_data, input_data, grid_sizes,
+        has_offset, grid_count, output_chans);
+
+    config = platform::getGpuLaunchConfig(guide_count, ctx, 512);
+
+    BilateralSliceCudaGuideGradKernel<T><<<
+        config.blocks, config.threads, 0, ctx.cuda_device_context().stream()>>>(
+        guide_grad_data, output_grad_data, grid_data, guide_data, input_data,
+        grid_sizes, has_offset, guide_count, output_chans);
+
+    config = platform::getGpuLaunchConfig(input_count, ctx, 512);
+
+    BilateralSliceCudaInputGradKernel<T><<<
+        config.blocks, config.threads, 0, ctx.cuda_device_context().stream()>>>(
+        input_grad_data, output_grad_data, grid_data, guide_data, grid_sizes,
+        has_offset, input_count, output_chans);
+  }
+};
+
+}  // namespace operators
+}  // namespace paddle
+
+namespace ops = paddle::operators;
+REGISTER_OP_CUDA_KERNEL(bilateral_slice, ops::BilateralSliceOpCUDAKernel<float>,
+                        ops::BilateralSliceOpCUDAKernel<double>);
+REGISTER_OP_CUDA_KERNEL(bilateral_slice_grad,
+                        ops::BilateralSliceGradOpCUDAKernel<float>,
+                        ops::BilateralSliceGradOpCUDAKernel<double>);
diff --git a/paddle/fluid/operators/bilateral_slice_op.h b/paddle/fluid/operators/bilateral_slice_op.h
new file mode 100644
index 0000000000000000000000000000000000000000..0903fe4c71d3d7123c6f340d9e83d526c72dfccb
--- /dev/null
+++ b/paddle/fluid/operators/bilateral_slice_op.h
@@ -0,0 +1,33 @@
+/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+   http://www.apache.org/licenses/LICENSE-2.0
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License. */
+#pragma once
+#include <algorithm>
+#include <string>
+#include <vector>
+#include "paddle/fluid/framework/op_registry.h"
+#include "paddle/fluid/platform/hostdevice.h"
+
+namespace paddle {
+namespace operators {
+
+struct GridSizes {
+  int64_t h;
+  int64_t w;
+  int64_t bs;
+  int64_t coeffs_chans;
+  int64_t gd;
+  int64_t gh;
+  int64_t gw;
+  int64_t input_chans;
+};
+
+}  // namespace operators
+}  // namespace paddle
diff --git a/paddle/fluid/platform/gpu_launch_config.h b/paddle/fluid/platform/gpu_launch_config.h
index d57478b89781ed073cef0fa73e201784f73dfc6b..fd6e80527caf6d79bf61aa6c2f03fa14724f4d42 100644
--- a/paddle/fluid/platform/gpu_launch_config.h
+++ b/paddle/fluid/platform/gpu_launch_config.h
@@ -31,9 +31,10 @@ struct GpuLaunchConfig {
 };
 
 inline GpuLaunchConfig getGpuLaunchConfig(
-    const int N, const framework::ExecutionContext& ctx) {
+    const int N, const framework::ExecutionContext& ctx,
+    int max_threads = 1024) {
   int threads =
-      std::min(1024, ctx.cuda_device_context().GetMaxThreadsPerBlock());
+      std::min(max_threads, ctx.cuda_device_context().GetMaxThreadsPerBlock());
   int physical_thread_count =
       std::min(ctx.cuda_device_context().GetMaxPhysicalThreadCount(), N);
   int blocks = std::min((physical_thread_count + threads - 1) / threads,
diff --git a/python/paddle/fluid/contrib/layers/nn.py b/python/paddle/fluid/contrib/layers/nn.py
index 273a669a1414e858920f6f5c2ad1fce8810eb829..50e6eaa80c135b24efa3844a6387278cc247af3a 100644
--- a/python/paddle/fluid/contrib/layers/nn.py
+++ b/python/paddle/fluid/contrib/layers/nn.py
@@ -35,7 +35,7 @@ __all__ = [
     'match_matrix_tensor', 'tree_conv', 'fused_embedding_seq_pool',
     'multiclass_nms2', 'search_pyramid_hash', 'shuffle_batch', 'partial_concat',
     'partial_sum', 'tdm_child', 'rank_attention', 'tdm_sampler', 'batch_fc',
-    '_pull_box_extended_sparse'
+    '_pull_box_extended_sparse', 'bilateral_slice'
 ]
 
 
@@ -1409,3 +1409,65 @@ def _pull_box_extended_sparse(input, size, extend_size=64, dtype='float32'):
     if len(outs) == 1:
         return outs[0], outs_extend[0]
     return outs, outs_extend
+
+
+def bilateral_slice(x, guide, grid, has_offset, name=None):
+    """
+    :alias_main: paddle.nn.functional.bilateral_slice
+	:alias: paddle.nn.functional.bilateral_slice,paddle.nn.functional.vision.bilateral_slice
+	:old_api: paddle.fluid.layers.bilateral_slice
+
+    This operation implements bilateral slicing on the input according to the guide map.
+    For more information of bilateral slicing, please refer to Deep Bilateral Learning for Real-Time Image Enhancement <https://groups.csail.mit.edu/graphics/hdrnet/data/hdrnet.pdf>_
+
+    Args:
+        x(Variable): The input tensor, which is a 4-D tensor with shape
+                     [N, C, H, W], N is the batch size, C is the channel
+                     number, H and W is the feature height and width.
+                     The data type is float32 and float64.
+        guide(Variable): Input grid tensor of shape [N, H, W]. The
+                        data type is float32 and float64.
+        grid(Variable): Input grid tensor of shape [N, C, D, H, W]. The
+                        data type is float32 and float64.
+        has_offset(bool): Whether to slice with affine offset.
+        name(str, optional): For detailed information, please refer
+                             to :ref:`api_guide_Name`. Usually name is no need to set and
+                             None by default.
+
+    Returns:
+        Variable: Output of shape [N, C, H, W]. The data type is same as input tensor.
+
+    Examples:
+
+        .. code-block:: python
+
+            import paddle.fluid as fluid
+
+            x = fluid.data(name='x', shape=[None, 3, 101, 60], dtype='float32')
+            guide = fluid.data(name='guide', shape=[None, 101, 60], dtype='float32')
+            grid = fluid.data(name='grid', shape=[None, 12, 8, 10, 6], dtype='float32')
+
+            # without offset
+            output = fluid.layers.bilateral_slice(x, guide, grid, has_offset=False)
+            
+            # has offset
+            output = fluid.layers.bilateral_slice(x, guide, grid, has_offset=True)
+
+    """
+    helper = LayerHelper("bilateral_slice", **locals())
+
+    check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'bilateral_slice')
+    check_variable_and_dtype(guide, 'guide', ['float32', 'float64'],
+                             'bilateral_slice')
+    check_variable_and_dtype(grid, 'grid', ['float32', 'float64'],
+                             'bilateral_slice')
+
+    out = helper.create_variable_for_type_inference(x.dtype)
+    inputs = {'X': x, 'Guide': guide, 'Grid': grid}
+
+    helper.append_op(
+        type='bilateral_slice',
+        inputs=inputs,
+        attrs={'has_offset': has_offset},
+        outputs={'Out': out})
+    return out
diff --git a/python/paddle/fluid/tests/unittests/test_bilateral_slice_op.py b/python/paddle/fluid/tests/unittests/test_bilateral_slice_op.py
new file mode 100644
index 0000000000000000000000000000000000000000..51e447dba725c03ad7eea5c94c2be70cc8ea9a7a
--- /dev/null
+++ b/python/paddle/fluid/tests/unittests/test_bilateral_slice_op.py
@@ -0,0 +1,194 @@
+#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+import numpy as np
+from op_test import OpTest
+import paddle
+import math
+
+
+class Gsz:
+    def __init__(self, h, w, gd, gh, gw, input_chans):
+        self.h = h
+        self.w = w
+        self.gd = gd
+        self.gh = gh
+        self.gw = gw
+        self.input_chans = input_chans
+
+
+def diff_abs(x):
+    eps = 1e-8
+    return math.sqrt(x * x + eps)
+
+
+def d_diff_abs(x):
+    eps = 1e-8
+    return x / math.sqrt(x * x + eps)
+
+
+def weight_z(x):
+    abx = diff_abs(x)
+    return max(1.0 - abx, 0.0)
+
+
+def d_weight_z(x):
+    abx = diff_abs(x)
+    if abx > 1.0:
+        return 0.0
+    else:
+        return d_diff_abs(x)
+
+
+def naive_bilateral_slice_forward(output, grid, guide, input, gsz, has_offset,
+                                  total_count, output_chans):
+    h = gsz.h
+    w = gsz.w
+    gd = gsz.gd
+    gh = gsz.gh
+    gw = gsz.gw
+    input_chans = gsz.input_chans
+    coeff_stride = input_chans
+    grid_chans = input_chans * output_chans
+
+    if has_offset:
+        grid_chans += output_chans
+        coeff_stride += 1
+
+    for idx in range(total_count):
+        x = idx % w
+        y = idx // w % h
+        out_c = (idx // (h * w)) % output_chans
+        b = (idx // (output_chans * w * h))
+
+        gx = (x + 0.5) * gw / (1.0 * w)
+        gy = (y + 0.5) * gh / (1.0 * h)
+        gz = guide[int(b), int(y), int(x)] * gd
+
+        fx = int(np.floor(gx - 0.5))
+        fy = int(np.floor(gy - 0.5))
+        fz = int(np.floor(gz - 0.5))
+
+        value = 0.0
+        for in_c in range(0, coeff_stride):
+            coeff_sample = 0.0
+
+            for xx in range(fx, fx + 2):
+                x_ = max(min(xx, gw - 1), 0)
+                wx = max(1.0 - abs(xx + 0.5 - gx), 0.0)
+
+                for yy in range(fy, fy + 2):
+                    y_ = max(min(yy, gh - 1), 0)
+                    wy = max(1.0 - abs(yy + 0.5 - gy), 0.0)
+
+                    for zz in range(fz, fz + 2):
+                        z_ = max(min(zz, gd - 1), 0)
+                        wz = weight_z(zz + 0.5 - gz)
+                        c_ = coeff_stride * out_c + in_c
+
+                        coeff_sample += grid[int(b), int(c_), int(z_), int(y_),
+                                             int(x_)] * wx * wy * wz
+
+            if in_c < input_chans:
+                value += coeff_sample * input[int(b), int(in_c), int(y), int(x)]
+            else:
+                value += coeff_sample
+        output[int(b), int(out_c), int(y), int(x)] = value
+
+
+def naive_bilateral_slice(x, guide, grid, has_offset):
+    bs = x.shape[0]
+    h = x.shape[2]
+    w = x.shape[3]
+    input_chans = x.shape[1]
+
+    coeffs_chans = grid.shape[1]
+    if has_offset:
+        output_chans = coeffs_chans // (input_chans + 1)
+    else:
+        output_chans = coeffs_chans // input_chans
+
+    output = np.zeros([bs, int(output_chans), h, w]).astype(x.dtype)
+
+    gd = grid.shape[2]
+    gh = grid.shape[3]
+    gw = grid.shape[4]
+
+    gsz = Gsz(h, w, gd, gh, gw, input_chans)
+    total_count = bs * h * w * output.shape[1]
+    naive_bilateral_slice_forward(output, grid, guide, x, gsz, has_offset,
+                                  total_count, output.shape[1])
+    return output
+
+
+@unittest.skipIf(not paddle.fluid.is_compiled_with_cuda(),
+                 'CPU testing is not supported')
+class TestBilateralSliceOp(OpTest):
+    def setUp(self):
+        self.initTestCase()
+        self.op_type = 'bilateral_slice'
+        batch_size = 3
+        h = 50
+        w = 30
+        c = 1
+        gh = 5
+        gw = 3
+        gd = 2
+        gc = 2
+        x = np.random.rand(batch_size, c, h, w).astype(self.data_type)
+        guide = np.random.rand(batch_size, h, w).astype(self.data_type)
+        grid = np.random.rand(batch_size, gc, gd, gh, gw).astype(self.data_type)
+        output_np = naive_bilateral_slice(x, guide, grid, self.has_offset)
+
+        self.inputs = {'X': x, 'Grid': grid, 'Guide': guide}
+        self.attrs = {'has_offset': self.has_offset, }
+        self.outputs = {'Out': output_np}
+
+    def test_check_output(self):
+        place = paddle.fluid.CUDAPlace(0)
+        self.check_output_with_place(place, atol=1e-5)
+        self.check_output
+
+    def test_check_grad(self):
+        place = paddle.fluid.CUDAPlace(0)
+        self.check_grad_with_place(place, ['X'], 'Out')
+
+    def initTestCase(self):
+        self.has_offset = False
+        self.data_type = 'float64'
+
+
+@unittest.skipIf(not paddle.fluid.is_compiled_with_cuda(),
+                 'CPU testing is not supported')
+class TestBilateralSliceOp1(TestBilateralSliceOp):
+    def initTestCase(self):
+        self.has_offset = True
+        self.data_type = 'float32'
+
+
+class TestBilateralSliceApi(TestBilateralSliceOp):
+    def test_api(self):
+        x = paddle.fluid.data(
+            name='x', shape=[None, 3, 25, 15], dtype='float32')
+        guide = paddle.fluid.data(
+            name='guide', shape=[None, 25, 15], dtype='float32')
+        grid = paddle.fluid.data(
+            name='grid', shape=[None, 12, 8, 5, 3], dtype='float32')
+        paddle.fluid.contrib.layers.bilateral_slice(x, guide, grid,
+                                                    self.has_offset)
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/python/paddle/fluid/tests/unittests/white_list/op_accuracy_white_list.py b/python/paddle/fluid/tests/unittests/white_list/op_accuracy_white_list.py
index db5ad92ff5ead4fbc609209692268bef254d8c27..4629089e39c9489725340df2172c53ed0661708f 100644
--- a/python/paddle/fluid/tests/unittests/white_list/op_accuracy_white_list.py
+++ b/python/paddle/fluid/tests/unittests/white_list/op_accuracy_white_list.py
@@ -74,7 +74,8 @@ NO_FP64_CHECK_GRAD_OP_LIST = [
     'transpose2', \
     'trilinear_interp', \
     'var_conv_2d', \
-    'warpctc'
+    'warpctc', \
+    'bilateral_slice'
 ]
 
 NO_FP16_CHECK_GRAD_OP_LIST = [
diff --git a/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py b/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py
index fd3d5f3104f8243fdcae312620742688eb79b854..ce6868b5c70ae1218df48f899f936f57f6734582 100644
--- a/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py
+++ b/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py
@@ -40,7 +40,8 @@ NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST = [
     'teacher_student_sigmoid_loss', \
     'unpool', \
     'yolov3_loss', \
-    'inverse'
+    'inverse', \
+    'bilateral_slice'
 ]
 
 NEED_FIX_FP64_CHECK_OUTPUT_THRESHOLD_OP_LIST = ['bilinear_interp']