clip_op.h 6.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
W
wanghaoshuang 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
W
wanghaoshuang 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
W
wanghaoshuang 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
W
wanghaoshuang 已提交
14 15 16

#pragma once

Y
Yi Wang 已提交
17 18
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
S
sneaxiy 已提交
19
#include "paddle/fluid/operators/math/selected_rows_functor.h"
Y
Yi Wang 已提交
20
#include "paddle/fluid/platform/transform.h"
W
wanghaoshuang 已提交
21 22 23 24

namespace paddle {
namespace operators {

W
wanghaoshuang 已提交
25 26
using framework::Tensor;
using platform::Transform;
W
wanghaoshuang 已提交
27

28 29 30 31 32 33 34 35 36 37 38
#ifdef __NVCC__
template <typename T, typename UnaryOperation>
__global__ void ClipCudaKernel(const T* input, T* out, int num,
                               UnaryOperation op) {
  int idx = threadIdx.x + blockDim.x * blockIdx.x;
  if (idx < num) {
    out[idx] = op(input[idx]);
  }
}
#endif

W
wanghaoshuang 已提交
39 40 41 42 43
template <typename T>
class ClipFunctor {
 public:
  explicit ClipFunctor(const T min, const T max) : min_(min), max_(max) {}
  HOSTDEVICE T operator()(const T& x) const {
44
    return x < min_ ? min_ : x > max_ ? max_ : x;
W
wanghaoshuang 已提交
45 46 47 48 49 50 51 52 53 54 55 56
  }

 private:
  T min_;
  T max_;
};

template <typename T>
class ClipGradFunctor {
 public:
  explicit ClipGradFunctor(const T min, const T max) : min_(min), max_(max) {}
  HOSTDEVICE T operator()(const T& x, const T& y) const {
W
wanghaoshuang 已提交
57
    return (y > min_ && y < max_) ? x : 0;
W
wanghaoshuang 已提交
58
  }
W
wanghaoshuang 已提交
59

W
wanghaoshuang 已提交
60 61 62 63
 private:
  T min_;
  T max_;
};
64

Q
QI JUN 已提交
65
template <typename DeviceContext, typename T>
Y
Yu Yang 已提交
66
class ClipKernel : public framework::OpKernel<T> {
W
wanghaoshuang 已提交
67 68
 public:
  void Compute(const framework::ExecutionContext& context) const override {
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
    auto max = static_cast<T>(context.Attr<float>("max"));
    Tensor max_cpu;
    if (context.HasInput("Max")) {
      auto* max_t = context.Input<Tensor>("Max");
      auto* max_data = max_t->data<T>();
      if (platform::is_gpu_place(max_t->place())) {
        TensorCopySync(*max_t, platform::CPUPlace(), &max_cpu);
        max_data = max_cpu.data<T>();
      }
      max = max_data[0];
    }
    max = static_cast<T>(max);

    auto min = context.Attr<float>("min");
    Tensor min_cpu;
    if (context.HasInput("Min")) {
      auto* min_t = context.Input<Tensor>("Min");
      auto* min_data = min_t->data<T>();
      if (platform::is_gpu_place(min_t->place())) {
        TensorCopySync(*min_t, platform::CPUPlace(), &min_cpu);
        min_data = min_cpu.data<T>();
      }
      min = min_data[0];
    }
    min = static_cast<T>(min);
    PADDLE_ENFORCE_LT(min, max, platform::errors::InvalidArgument(
                                    "max should be greater than min. "
                                    "But received min = %f, max = %f",
                                    min, max));

S
sneaxiy 已提交
99 100 101 102 103 104 105
    auto* x_var = context.InputVar("X");
    if (x_var->IsType<framework::LoDTensor>()) {
      auto* x = context.Input<framework::LoDTensor>("X");
      auto* out = context.Output<framework::LoDTensor>("Out");
      T* out_data = out->mutable_data<T>(context.GetPlace());
      const T* x_data = x->data<T>();
      int64_t numel = x->numel();
106 107 108 109 110 111 112 113 114 115 116 117 118 119
      if (platform::is_gpu_place(context.GetPlace())) {
#ifdef __NVCC__
        int threads = 256;
        int blocks = (numel + threads - 1) / threads;
        ClipCudaKernel<T, ClipFunctor<T>><<<
            blocks, threads, 0,
            context.template device_context<platform::CUDADeviceContext>()
                .stream()>>>(x_data, out_data, numel, ClipFunctor<T>(min, max));
#endif
      } else {
        Transform<DeviceContext> trans;
        trans(context.template device_context<DeviceContext>(), x_data,
              x_data + numel, out_data, ClipFunctor<T>(min, max));
      }
S
sneaxiy 已提交
120 121 122
    } else if (x_var->IsType<framework::SelectedRows>()) {
      auto* x = context.Input<framework::SelectedRows>("X");
      auto* out = context.Output<framework::SelectedRows>("Out");
123 124 125
      PADDLE_ENFORCE_NE(x, out, platform::errors::InvalidArgument(
                                    "Inplace clip is not allowed "
                                    "when x is SelectedRows"));
S
sneaxiy 已提交
126 127 128 129 130 131 132 133 134 135 136
      math::scatter::MergeAdd<DeviceContext, T> merge_func;
      merge_func(context.template device_context<DeviceContext>(), *x, out);
      auto* out_tensor = out->mutable_value();
      auto* out_data = out_tensor->data<T>();
      int64_t numel = out_tensor->numel();
      Transform<DeviceContext> trans;
      trans(context.template device_context<DeviceContext>(), out_data,
            out_data + numel, out_data, ClipFunctor<T>(min, max));
    } else {
      PADDLE_THROW("ClipOp only supports LoDTensor and SelectedRows");
    }
W
wanghaoshuang 已提交
137 138 139
  }
};

Q
QI JUN 已提交
140
template <typename DeviceContext, typename T>
Y
Yu Yang 已提交
141
class ClipGradKernel : public framework::OpKernel<T> {
W
wanghaoshuang 已提交
142 143
 public:
  void Compute(const framework::ExecutionContext& context) const override {
144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
    auto max = static_cast<T>(context.Attr<float>("max"));
    Tensor max_cpu;
    if (context.HasInput("Max")) {
      auto* max_t = context.Input<Tensor>("Max");
      auto* max_data = max_t->data<T>();
      if (platform::is_gpu_place(max_t->place())) {
        TensorCopySync(*max_t, platform::CPUPlace(), &max_cpu);
        max_data = max_cpu.data<T>();
      }
      max = max_data[0];
    }
    max = static_cast<T>(max);

    auto min = context.Attr<float>("min");
    Tensor min_cpu;
    if (context.HasInput("Min")) {
      auto* min_t = context.Input<Tensor>("Min");
      auto* min_data = min_t->data<T>();
      if (platform::is_gpu_place(min_t->place())) {
        TensorCopySync(*min_t, platform::CPUPlace(), &min_cpu);
        min_data = min_cpu.data<T>();
      }
      min = min_data[0];
    }
    min = static_cast<T>(min);

S
sneaxiy 已提交
170 171 172 173
    auto* d_out =
        context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
    auto* d_x =
        context.Output<framework::LoDTensor>(framework::GradVarName("X"));
W
wanghaoshuang 已提交
174
    if (d_x != nullptr) {
S
sneaxiy 已提交
175
      auto* x = context.Input<framework::LoDTensor>("X");
W
wanghaoshuang 已提交
176
      int64_t numel = d_out->numel();
W
wanghaoshuang 已提交
177
      auto* d_x_data = d_x->mutable_data<T>(context.GetPlace());
W
wanghaoshuang 已提交
178 179
      const T* d_out_data = d_out->data<T>();
      const T* x_data = x->data<T>();
Q
QI JUN 已提交
180 181 182
      Transform<DeviceContext> trans;
      trans(context.template device_context<DeviceContext>(), d_out_data,
            d_out_data + numel, x_data, d_x_data, ClipGradFunctor<T>(min, max));
W
wanghaoshuang 已提交
183 184 185 186 187 188
    }
  }
};

}  // namespace operators
}  // namespace paddle