sequence_pooling.cu 14.2 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include <string>
Y
Yi Wang 已提交
16 17
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/sequence_pooling.h"
18
#include "paddle/fluid/platform/cuda_primitives.h"
P
peizhilin 已提交
19
#include "paddle/fluid/platform/macros.h"
20 21 22 23 24 25

namespace paddle {
namespace operators {
namespace math {

template <typename T>
D
dzhwinter 已提交
26
struct MaxPoolFunctor {
27 28 29
  HOSTDEVICE void operator()(const T* input, const T pad_value,
                             const size_t start, const size_t end,
                             const size_t item_dim, T* output, int* index) {
D
dzhwinter 已提交
30 31 32
    for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
      T max_val = static_cast<T>(-FLT_MAX);
      int max_index = -1;
33 34 35 36 37 38 39 40 41
      if (start == end) {
        output[tid] = pad_value;
        index[tid] = -1;
      } else {
        for (int i = start; i < end; ++i) {
          if (max_val < input[item_dim * i + tid]) {
            max_val = input[item_dim * i + tid];
            max_index = i;
          }
D
dzhwinter 已提交
42
        }
43 44
        output[tid] = max_val;
        index[tid] = max_index;
45 46 47
      }
    }
  }
D
dzhwinter 已提交
48
};
49 50

template <typename T>
D
dzhwinter 已提交
51
struct AvgPoolFunctor {
52 53 54
  HOSTDEVICE void operator()(const T* input, const T pad_value,
                             const size_t start, const size_t end,
                             const size_t item_dim, T* output, int* index) {
D
dzhwinter 已提交
55
    for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
56 57 58 59 60 61 62 63 64
      if (start == end) {
        output[tid] = pad_value;
      } else {
        T val = static_cast<T>(0);
        for (int i = start; i < end; ++i) {
          val += input[item_dim * i + tid];
        }
        // end, start is lod, so end - start != 0
        output[tid] = val / static_cast<T>(end - start);
D
dzhwinter 已提交
65
      }
66
    }
D
dzhwinter 已提交
67 68
  }
};
69

D
dzhwinter 已提交
70 71
template <typename T>
struct SumPoolFunctor {
72 73 74
  HOSTDEVICE void operator()(const T* input, const T pad_value,
                             const size_t start, const size_t end,
                             const size_t item_dim, T* output, int* index) {
D
dzhwinter 已提交
75
    for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
76 77 78 79 80 81 82 83
      if (start == end) {
        output[tid] = pad_value;
      } else {
        T val = static_cast<T>(0);
        for (int i = start; i < end; ++i) {
          val += input[item_dim * i + tid];
        }
        output[tid] = val;
D
dzhwinter 已提交
84 85 86 87
      }
    }
  }
};
88

D
dzhwinter 已提交
89 90
template <typename T>
struct SqrtPoolFunctor {
91 92 93
  HOSTDEVICE void operator()(const T* input, const T pad_value,
                             const size_t start, const size_t end,
                             const size_t item_dim, T* output, int* index) {
D
dzhwinter 已提交
94
    for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
95 96 97 98 99 100 101 102 103
      if (start == end) {
        output[tid] = pad_value;
      } else {
        T val = static_cast<T>(0);
        for (int i = start; i < end; ++i) {
          val += input[item_dim * i + tid];
        }
        // end, start is lod, so end - start != 0
        output[tid] = val / sqrt(end - start);
D
dzhwinter 已提交
104 105 106 107
      }
    }
  }
};
108

D
dzhwinter 已提交
109 110
template <typename T>
struct LastPoolFunctor {
111 112 113
  HOSTDEVICE void operator()(const T* input, const T pad_value,
                             const size_t start, const size_t end,
                             const size_t item_dim, T* output, int* index) {
D
dzhwinter 已提交
114
    for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
115 116 117 118 119
      if (start == end) {
        output[tid] = pad_value;
      } else {
        output[tid] = input[item_dim * (end - 1) + tid];
      }
D
dzhwinter 已提交
120
    }
121 122 123 124
  }
};

template <typename T>
D
dzhwinter 已提交
125
struct FirstPoolFunctor {
126 127 128
  HOSTDEVICE void operator()(const T* input, const T pad_value,
                             const size_t start, const size_t end,
                             const size_t item_dim, T* output, int* index) {
D
dzhwinter 已提交
129
    for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
130 131 132 133 134
      if (start == end) {
        output[tid] = pad_value;
      } else {
        output[tid] = input[item_dim * start + tid];
      }
D
dzhwinter 已提交
135
    }
136
  }
D
dzhwinter 已提交
137 138 139 140
};

template <typename T, typename Range_OP>
__global__ void sequence_pool_kernel(Range_OP op, const T* input,
141 142
                                     const T pad_value, const size_t* lod,
                                     const size_t lod_size,
D
dzhwinter 已提交
143 144 145 146 147 148 149 150 151 152
                                     const size_t item_dim, T* output,
                                     int* index) {
  int bid = blockIdx.x;
  if (bid >= lod_size - 1) return;
  size_t start = lod[bid];
  size_t end = lod[bid + 1];
  int* index_offset = nullptr;
  if (index != nullptr) {
    index_offset = &index[bid * item_dim];
  }
153 154
  op(input, pad_value, start, end, item_dim, &output[bid * item_dim],
     index_offset);
155 156 157
}

template <typename T>
D
dzhwinter 已提交
158
class SequencePoolFunctor<platform::CUDADeviceContext, T> {
159
 public:
Q
QI JUN 已提交
160
  void operator()(const platform::CUDADeviceContext& context,
161 162 163
                  const std::string pooltype, T pad_value,
                  const framework::LoDTensor& input, framework::Tensor* output,
                  bool is_test, framework::Tensor* index = nullptr) {
C
chengduoZH 已提交
164
    auto& lod = input.lod()[0];
D
dzhwinter 已提交
165 166 167 168 169 170
    const size_t item_dim = output->numel() / output->dims()[0];
    dim3 threads(1024, 1);
    dim3 grid(lod.size(), 1);
    if (pooltype == "MAX") {
      sequence_pool_kernel<
          T, MaxPoolFunctor<T>><<<grid, threads, 0, context.stream()>>>(
171
          MaxPoolFunctor<T>(), input.data<T>(), pad_value,
D
dzhwinter 已提交
172 173 174 175 176
          lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
          output->mutable_data<T>(context.GetPlace()), index->data<int>());
    } else if (pooltype == "AVERAGE") {
      sequence_pool_kernel<
          T, AvgPoolFunctor<T>><<<grid, threads, 0, context.stream()>>>(
177
          AvgPoolFunctor<T>(), input.data<T>(), pad_value,
D
dzhwinter 已提交
178 179 180 181 182
          lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
          output->mutable_data<T>(context.GetPlace()), nullptr);
    } else if (pooltype == "SUM") {
      sequence_pool_kernel<
          T, SumPoolFunctor<T>><<<grid, threads, 0, context.stream()>>>(
183
          SumPoolFunctor<T>(), input.data<T>(), pad_value,
D
dzhwinter 已提交
184 185 186 187 188
          lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
          output->mutable_data<T>(context.GetPlace()), nullptr);
    } else if (pooltype == "SQRT") {
      sequence_pool_kernel<
          T, SqrtPoolFunctor<T>><<<grid, threads, 0, context.stream()>>>(
189
          SqrtPoolFunctor<T>(), input.data<T>(), pad_value,
D
dzhwinter 已提交
190 191 192 193 194
          lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
          output->mutable_data<T>(context.GetPlace()), nullptr);
    } else if (pooltype == "LAST") {
      sequence_pool_kernel<
          T, LastPoolFunctor<T>><<<grid, threads, 0, context.stream()>>>(
195
          LastPoolFunctor<T>(), input.data<T>(), pad_value,
D
dzhwinter 已提交
196 197 198 199 200
          lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
          output->mutable_data<T>(context.GetPlace()), nullptr);
    } else if (pooltype == "FIRST") {
      sequence_pool_kernel<
          T, FirstPoolFunctor<T>><<<grid, threads, 0, context.stream()>>>(
201
          FirstPoolFunctor<T>(), input.data<T>(), pad_value,
D
dzhwinter 已提交
202 203 204 205
          lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
          output->mutable_data<T>(context.GetPlace()), nullptr);
    } else {
      PADDLE_THROW("unsupported pooling pooltype");
206
    }
D
dzhwinter 已提交
207 208
  }
};
209

D
dzhwinter 已提交
210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225
template <typename T>
struct MaxPoolGradFunctor {
  HOSTDEVICE void operator()(const T* out_grad, const size_t start,
                             const size_t end, const size_t item_dim,
                             T* in_grad, const int* index) {
    for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
      for (int i = start; i < end; ++i) {
        if (i == index[tid]) {
          in_grad[item_dim * i + tid] = out_grad[tid];
        } else {
          in_grad[item_dim * i + tid] = static_cast<T>(0);
        }
      }
    }
  }
};
226

D
dzhwinter 已提交
227 228 229 230 231 232 233 234 235 236 237 238
template <typename T>
struct AvgPoolGradFunctor {
  HOSTDEVICE void operator()(const T* out_grad, const size_t start,
                             const size_t end, const size_t item_dim,
                             T* in_grad, const int* index) {
    for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
      for (int i = start; i < end; ++i) {
        in_grad[item_dim * i + tid] = out_grad[tid] / (end - start);
      }
    }
  }
};
239

D
dzhwinter 已提交
240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325
template <typename T>
struct SumPoolGradFunctor {
  HOSTDEVICE void operator()(const T* out_grad, const size_t start,
                             const size_t end, const size_t item_dim,
                             T* in_grad, const int* index) {
    for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
      for (int i = start; i < end; ++i) {
        in_grad[item_dim * i + tid] = out_grad[tid];
      }
    }
  }
};

template <typename T>
struct SqrtPoolGradFunctor {
  HOSTDEVICE void operator()(const T* out_grad, const size_t start,
                             const size_t end, const size_t item_dim,
                             T* in_grad, const int* index) {
    for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
      for (int i = start; i < end; ++i) {
        in_grad[item_dim * i + tid] =
            out_grad[tid] / (sqrt(static_cast<T>(end - start)));
      }
    }
  }
};

template <typename T>
struct LastPoolGradFunctor {
  HOSTDEVICE void operator()(const T* out_grad, const size_t start,
                             const size_t end, const size_t item_dim,
                             T* in_grad, const int* index) {
    for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
      for (int i = start; i < end; ++i) {
        if (i == end - 1) {
          in_grad[item_dim * i + tid] = out_grad[tid];
        } else {
          in_grad[item_dim * i + tid] = static_cast<T>(0);
        }
      }
    }
  }
};

template <typename T>
struct FirstPoolGradFunctor {
  HOSTDEVICE void operator()(const T* out_grad, const size_t start,
                             const size_t end, const size_t item_dim,
                             T* in_grad, const int* index) {
    for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
      for (int i = start; i < end; ++i) {
        if (i == start) {
          in_grad[item_dim * i + tid] = out_grad[tid];
        } else {
          in_grad[item_dim * i + tid] = static_cast<T>(0);
        }
      }
    }
  }
};

template <typename T, typename Range_OP>
__global__ void sequence_pool_grad_kernel(Range_OP op, const T* out_grad,
                                          const size_t* lod,
                                          const size_t lod_size,
                                          const size_t item_dim, T* in_grad,
                                          const int* index) {
  int bid = blockIdx.x;
  if (bid >= lod_size - 1) return;
  size_t start = lod[bid];
  size_t end = lod[bid + 1];
  const int* index_offset = nullptr;
  if (index != nullptr) {
    index_offset = &index[bid * item_dim];
  }
  op(&out_grad[bid * item_dim], start, end, item_dim, in_grad, index_offset);
}

template <typename T>
class SequencePoolGradFunctor<platform::CUDADeviceContext, T> {
 public:
  void operator()(const platform::CUDADeviceContext& context,
                  const std::string pooltype, const framework::Tensor& out_grad,
                  framework::LoDTensor* in_grad,
                  /* max pool has index */
                  const framework::Tensor* index = nullptr) {
C
chengduoZH 已提交
326
    auto& lod = in_grad->lod()[0];
D
dzhwinter 已提交
327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369
    const size_t item_dim = in_grad->numel() / in_grad->dims()[0];
    dim3 threads(1024, 1);
    dim3 grid(lod.size(), 1);
    if (pooltype == "MAX") {
      sequence_pool_grad_kernel<
          T, MaxPoolGradFunctor<T>><<<grid, threads, 0, context.stream()>>>(
          MaxPoolGradFunctor<T>(), out_grad.data<T>(),
          lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
          in_grad->mutable_data<T>(context.GetPlace()), index->data<int>());
    } else if (pooltype == "AVERAGE") {
      sequence_pool_grad_kernel<
          T, AvgPoolGradFunctor<T>><<<grid, threads, 0, context.stream()>>>(
          AvgPoolGradFunctor<T>(), out_grad.data<T>(),
          lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
          in_grad->mutable_data<T>(context.GetPlace()), nullptr);
    } else if (pooltype == "SUM") {
      sequence_pool_grad_kernel<
          T, SumPoolGradFunctor<T>><<<grid, threads, 0, context.stream()>>>(
          SumPoolGradFunctor<T>(), out_grad.data<T>(),
          lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
          in_grad->mutable_data<T>(context.GetPlace()), nullptr);
    } else if (pooltype == "SQRT") {
      sequence_pool_grad_kernel<
          T, SqrtPoolGradFunctor<T>><<<grid, threads, 0, context.stream()>>>(
          SqrtPoolGradFunctor<T>(), out_grad.data<T>(),
          lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
          in_grad->mutable_data<T>(context.GetPlace()), nullptr);
    } else if (pooltype == "LAST") {
      sequence_pool_grad_kernel<
          T, LastPoolGradFunctor<T>><<<grid, threads, 0, context.stream()>>>(
          LastPoolGradFunctor<T>(), out_grad.data<T>(),
          lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
          in_grad->mutable_data<T>(context.GetPlace()), nullptr);
    } else if (pooltype == "FIRST") {
      sequence_pool_grad_kernel<
          T, FirstPoolGradFunctor<T>><<<grid, threads, 0, context.stream()>>>(
          FirstPoolGradFunctor<T>(), out_grad.data<T>(),
          lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
          in_grad->mutable_data<T>(context.GetPlace()), nullptr);

    } else {
      PADDLE_THROW("unsupported pooling pooltype");
    }
370 371 372
  }
};

D
dzhwinter 已提交
373 374 375 376 377
// sequence pooling
template class SequencePoolFunctor<platform::CUDADeviceContext, float>;
template class SequencePoolFunctor<platform::CUDADeviceContext, double>;
template class SequencePoolGradFunctor<platform::CUDADeviceContext, float>;
template class SequencePoolGradFunctor<platform::CUDADeviceContext, double>;
378 379 380 381

}  // namespace math
}  // namespace operators
}  // namespace paddle
反馈
建议
客服 返回
顶部