sequence_pooling.cu 14.4 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include <algorithm>
16
#include <string>
Y
Yi Wang 已提交
17 18
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/sequence_pooling.h"
D
dzhwinter 已提交
19
#include "paddle/fluid/platform/cuda_primitives.h"
P
peizhilin 已提交
20
#include "paddle/fluid/platform/macros.h"
21 22 23 24 25 26

namespace paddle {
namespace operators {
namespace math {

template <typename T>
D
dzhwinter 已提交
27
struct MaxPoolFunctor {
28 29 30
  HOSTDEVICE void operator()(const T* input, const T pad_value,
                             const size_t start, const size_t end,
                             const size_t item_dim, T* output, int* index) {
D
dzhwinter 已提交
31 32 33
    for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
      T max_val = static_cast<T>(-FLT_MAX);
      int max_index = -1;
34 35 36 37 38 39 40 41 42
      if (start == end) {
        output[tid] = pad_value;
        index[tid] = -1;
      } else {
        for (int i = start; i < end; ++i) {
          if (max_val < input[item_dim * i + tid]) {
            max_val = input[item_dim * i + tid];
            max_index = i;
          }
D
dzhwinter 已提交
43
        }
44 45
        output[tid] = max_val;
        index[tid] = max_index;
46 47 48
      }
    }
  }
D
dzhwinter 已提交
49
};
50 51

template <typename T>
D
dzhwinter 已提交
52
struct AvgPoolFunctor {
53 54 55
  HOSTDEVICE void operator()(const T* input, const T pad_value,
                             const size_t start, const size_t end,
                             const size_t item_dim, T* output, int* index) {
D
dzhwinter 已提交
56
    for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
57 58 59 60 61 62 63 64 65
      if (start == end) {
        output[tid] = pad_value;
      } else {
        T val = static_cast<T>(0);
        for (int i = start; i < end; ++i) {
          val += input[item_dim * i + tid];
        }
        // end, start is lod, so end - start != 0
        output[tid] = val / static_cast<T>(end - start);
D
dzhwinter 已提交
66
      }
67
    }
D
dzhwinter 已提交
68 69
  }
};
70

D
dzhwinter 已提交
71 72
template <typename T>
struct SumPoolFunctor {
73 74 75
  HOSTDEVICE void operator()(const T* input, const T pad_value,
                             const size_t start, const size_t end,
                             const size_t item_dim, T* output, int* index) {
D
dzhwinter 已提交
76
    for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
77 78 79 80 81 82 83 84
      if (start == end) {
        output[tid] = pad_value;
      } else {
        T val = static_cast<T>(0);
        for (int i = start; i < end; ++i) {
          val += input[item_dim * i + tid];
        }
        output[tid] = val;
D
dzhwinter 已提交
85 86 87 88
      }
    }
  }
};
89

D
dzhwinter 已提交
90 91
template <typename T>
struct SqrtPoolFunctor {
92 93 94
  HOSTDEVICE void operator()(const T* input, const T pad_value,
                             const size_t start, const size_t end,
                             const size_t item_dim, T* output, int* index) {
D
dzhwinter 已提交
95
    for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
96 97 98 99 100 101 102 103 104
      if (start == end) {
        output[tid] = pad_value;
      } else {
        T val = static_cast<T>(0);
        for (int i = start; i < end; ++i) {
          val += input[item_dim * i + tid];
        }
        // end, start is lod, so end - start != 0
        output[tid] = val / sqrt(end - start);
D
dzhwinter 已提交
105 106 107 108
      }
    }
  }
};
109

D
dzhwinter 已提交
110 111
template <typename T>
struct LastPoolFunctor {
112 113 114
  HOSTDEVICE void operator()(const T* input, const T pad_value,
                             const size_t start, const size_t end,
                             const size_t item_dim, T* output, int* index) {
D
dzhwinter 已提交
115
    for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
116 117 118 119 120
      if (start == end) {
        output[tid] = pad_value;
      } else {
        output[tid] = input[item_dim * (end - 1) + tid];
      }
D
dzhwinter 已提交
121
    }
122 123 124 125
  }
};

template <typename T>
D
dzhwinter 已提交
126
struct FirstPoolFunctor {
127 128 129
  HOSTDEVICE void operator()(const T* input, const T pad_value,
                             const size_t start, const size_t end,
                             const size_t item_dim, T* output, int* index) {
D
dzhwinter 已提交
130
    for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
131 132 133 134 135
      if (start == end) {
        output[tid] = pad_value;
      } else {
        output[tid] = input[item_dim * start + tid];
      }
D
dzhwinter 已提交
136
    }
137
  }
D
dzhwinter 已提交
138 139 140 141
};

template <typename T, typename Range_OP>
__global__ void sequence_pool_kernel(Range_OP op, const T* input,
142 143
                                     const T pad_value, const size_t* lod,
                                     const size_t lod_size,
D
dzhwinter 已提交
144 145 146 147 148 149 150 151 152 153
                                     const size_t item_dim, T* output,
                                     int* index) {
  int bid = blockIdx.x;
  if (bid >= lod_size - 1) return;
  size_t start = lod[bid];
  size_t end = lod[bid + 1];
  int* index_offset = nullptr;
  if (index != nullptr) {
    index_offset = &index[bid * item_dim];
  }
154 155
  op(input, pad_value, start, end, item_dim, &output[bid * item_dim],
     index_offset);
156 157 158
}

template <typename T>
D
dzhwinter 已提交
159
class SequencePoolFunctor<platform::CUDADeviceContext, T> {
160
 public:
Q
QI JUN 已提交
161
  void operator()(const platform::CUDADeviceContext& context,
162
                  const std::string pooltype, T pad_value,
163 164 165 166 167
                  const framework::LoDTensor& input,
                  framework::LoDTensor* output, bool is_test,
                  framework::Tensor* index = nullptr) {
    auto lod_level = input.lod().size();
    auto& lod = input.lod()[lod_level - 1];
D
dzhwinter 已提交
168 169
    const size_t item_dim = output->numel() / output->dims()[0];
    dim3 threads(1024, 1);
170
    dim3 grid(std::max(lod.size() - 1, 1UL), 1);
D
dzhwinter 已提交
171 172 173
    if (pooltype == "MAX") {
      sequence_pool_kernel<
          T, MaxPoolFunctor<T>><<<grid, threads, 0, context.stream()>>>(
174
          MaxPoolFunctor<T>(), input.data<T>(), pad_value,
D
dzhwinter 已提交
175 176 177 178 179
          lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
          output->mutable_data<T>(context.GetPlace()), index->data<int>());
    } else if (pooltype == "AVERAGE") {
      sequence_pool_kernel<
          T, AvgPoolFunctor<T>><<<grid, threads, 0, context.stream()>>>(
180
          AvgPoolFunctor<T>(), input.data<T>(), pad_value,
D
dzhwinter 已提交
181 182 183 184 185
          lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
          output->mutable_data<T>(context.GetPlace()), nullptr);
    } else if (pooltype == "SUM") {
      sequence_pool_kernel<
          T, SumPoolFunctor<T>><<<grid, threads, 0, context.stream()>>>(
186
          SumPoolFunctor<T>(), input.data<T>(), pad_value,
D
dzhwinter 已提交
187 188 189 190 191
          lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
          output->mutable_data<T>(context.GetPlace()), nullptr);
    } else if (pooltype == "SQRT") {
      sequence_pool_kernel<
          T, SqrtPoolFunctor<T>><<<grid, threads, 0, context.stream()>>>(
192
          SqrtPoolFunctor<T>(), input.data<T>(), pad_value,
D
dzhwinter 已提交
193 194 195 196 197
          lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
          output->mutable_data<T>(context.GetPlace()), nullptr);
    } else if (pooltype == "LAST") {
      sequence_pool_kernel<
          T, LastPoolFunctor<T>><<<grid, threads, 0, context.stream()>>>(
198
          LastPoolFunctor<T>(), input.data<T>(), pad_value,
D
dzhwinter 已提交
199 200 201 202 203
          lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
          output->mutable_data<T>(context.GetPlace()), nullptr);
    } else if (pooltype == "FIRST") {
      sequence_pool_kernel<
          T, FirstPoolFunctor<T>><<<grid, threads, 0, context.stream()>>>(
204
          FirstPoolFunctor<T>(), input.data<T>(), pad_value,
D
dzhwinter 已提交
205 206 207 208
          lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
          output->mutable_data<T>(context.GetPlace()), nullptr);
    } else {
      PADDLE_THROW("unsupported pooling pooltype");
209
    }
D
dzhwinter 已提交
210 211
  }
};
212

D
dzhwinter 已提交
213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228
template <typename T>
struct MaxPoolGradFunctor {
  HOSTDEVICE void operator()(const T* out_grad, const size_t start,
                             const size_t end, const size_t item_dim,
                             T* in_grad, const int* index) {
    for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
      for (int i = start; i < end; ++i) {
        if (i == index[tid]) {
          in_grad[item_dim * i + tid] = out_grad[tid];
        } else {
          in_grad[item_dim * i + tid] = static_cast<T>(0);
        }
      }
    }
  }
};
229

D
dzhwinter 已提交
230 231 232 233 234 235 236 237 238 239 240 241
template <typename T>
struct AvgPoolGradFunctor {
  HOSTDEVICE void operator()(const T* out_grad, const size_t start,
                             const size_t end, const size_t item_dim,
                             T* in_grad, const int* index) {
    for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
      for (int i = start; i < end; ++i) {
        in_grad[item_dim * i + tid] = out_grad[tid] / (end - start);
      }
    }
  }
};
242

D
dzhwinter 已提交
243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324
template <typename T>
struct SumPoolGradFunctor {
  HOSTDEVICE void operator()(const T* out_grad, const size_t start,
                             const size_t end, const size_t item_dim,
                             T* in_grad, const int* index) {
    for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
      for (int i = start; i < end; ++i) {
        in_grad[item_dim * i + tid] = out_grad[tid];
      }
    }
  }
};

template <typename T>
struct SqrtPoolGradFunctor {
  HOSTDEVICE void operator()(const T* out_grad, const size_t start,
                             const size_t end, const size_t item_dim,
                             T* in_grad, const int* index) {
    for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
      for (int i = start; i < end; ++i) {
        in_grad[item_dim * i + tid] =
            out_grad[tid] / (sqrt(static_cast<T>(end - start)));
      }
    }
  }
};

template <typename T>
struct LastPoolGradFunctor {
  HOSTDEVICE void operator()(const T* out_grad, const size_t start,
                             const size_t end, const size_t item_dim,
                             T* in_grad, const int* index) {
    for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
      for (int i = start; i < end; ++i) {
        if (i == end - 1) {
          in_grad[item_dim * i + tid] = out_grad[tid];
        } else {
          in_grad[item_dim * i + tid] = static_cast<T>(0);
        }
      }
    }
  }
};

template <typename T>
struct FirstPoolGradFunctor {
  HOSTDEVICE void operator()(const T* out_grad, const size_t start,
                             const size_t end, const size_t item_dim,
                             T* in_grad, const int* index) {
    for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
      for (int i = start; i < end; ++i) {
        if (i == start) {
          in_grad[item_dim * i + tid] = out_grad[tid];
        } else {
          in_grad[item_dim * i + tid] = static_cast<T>(0);
        }
      }
    }
  }
};

template <typename T, typename Range_OP>
__global__ void sequence_pool_grad_kernel(Range_OP op, const T* out_grad,
                                          const size_t* lod,
                                          const size_t lod_size,
                                          const size_t item_dim, T* in_grad,
                                          const int* index) {
  int bid = blockIdx.x;
  if (bid >= lod_size - 1) return;
  size_t start = lod[bid];
  size_t end = lod[bid + 1];
  const int* index_offset = nullptr;
  if (index != nullptr) {
    index_offset = &index[bid * item_dim];
  }
  op(&out_grad[bid * item_dim], start, end, item_dim, in_grad, index_offset);
}

template <typename T>
class SequencePoolGradFunctor<platform::CUDADeviceContext, T> {
 public:
  void operator()(const platform::CUDADeviceContext& context,
325 326
                  const std::string pooltype,
                  const framework::LoDTensor& out_grad,
D
dzhwinter 已提交
327 328 329
                  framework::LoDTensor* in_grad,
                  /* max pool has index */
                  const framework::Tensor* index = nullptr) {
330 331
    auto lod_level = in_grad->lod().size();
    auto& lod = in_grad->lod()[lod_level - 1];
D
dzhwinter 已提交
332 333
    const size_t item_dim = in_grad->numel() / in_grad->dims()[0];
    dim3 threads(1024, 1);
334
    dim3 grid(std::max(lod.size() - 1, 1UL), 1);
D
dzhwinter 已提交
335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374
    if (pooltype == "MAX") {
      sequence_pool_grad_kernel<
          T, MaxPoolGradFunctor<T>><<<grid, threads, 0, context.stream()>>>(
          MaxPoolGradFunctor<T>(), out_grad.data<T>(),
          lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
          in_grad->mutable_data<T>(context.GetPlace()), index->data<int>());
    } else if (pooltype == "AVERAGE") {
      sequence_pool_grad_kernel<
          T, AvgPoolGradFunctor<T>><<<grid, threads, 0, context.stream()>>>(
          AvgPoolGradFunctor<T>(), out_grad.data<T>(),
          lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
          in_grad->mutable_data<T>(context.GetPlace()), nullptr);
    } else if (pooltype == "SUM") {
      sequence_pool_grad_kernel<
          T, SumPoolGradFunctor<T>><<<grid, threads, 0, context.stream()>>>(
          SumPoolGradFunctor<T>(), out_grad.data<T>(),
          lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
          in_grad->mutable_data<T>(context.GetPlace()), nullptr);
    } else if (pooltype == "SQRT") {
      sequence_pool_grad_kernel<
          T, SqrtPoolGradFunctor<T>><<<grid, threads, 0, context.stream()>>>(
          SqrtPoolGradFunctor<T>(), out_grad.data<T>(),
          lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
          in_grad->mutable_data<T>(context.GetPlace()), nullptr);
    } else if (pooltype == "LAST") {
      sequence_pool_grad_kernel<
          T, LastPoolGradFunctor<T>><<<grid, threads, 0, context.stream()>>>(
          LastPoolGradFunctor<T>(), out_grad.data<T>(),
          lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
          in_grad->mutable_data<T>(context.GetPlace()), nullptr);
    } else if (pooltype == "FIRST") {
      sequence_pool_grad_kernel<
          T, FirstPoolGradFunctor<T>><<<grid, threads, 0, context.stream()>>>(
          FirstPoolGradFunctor<T>(), out_grad.data<T>(),
          lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
          in_grad->mutable_data<T>(context.GetPlace()), nullptr);

    } else {
      PADDLE_THROW("unsupported pooling pooltype");
    }
375 376 377
  }
};

D
dzhwinter 已提交
378 379 380 381 382
// sequence pooling
template class SequencePoolFunctor<platform::CUDADeviceContext, float>;
template class SequencePoolFunctor<platform::CUDADeviceContext, double>;
template class SequencePoolGradFunctor<platform::CUDADeviceContext, float>;
template class SequencePoolGradFunctor<platform::CUDADeviceContext, double>;
383 384 385 386

}  // namespace math
}  // namespace operators
}  // namespace paddle