sequence_pooling.cc 17.7 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

A
Abhinav Arora 已提交
15
#include <string>
M
minqiyang 已提交
16

T
tensor-tang 已提交
17
#include "paddle/fluid/operators/jit/kernels.h"
M
minqiyang 已提交
18
#include "paddle/fluid/operators/math/sequence_pooling.h"
19 20
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
21 22 23 24 25

namespace paddle {
namespace operators {
namespace math {

D
dzhwinter 已提交
26 27 28 29 30 31 32 33 34
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
template <typename T, int MajorType = Eigen::RowMajor,
          typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
          typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;

J
Jacek Czaja 已提交
35
template <typename T, bool is_test>
D
dzhwinter 已提交
36
class MaxSeqPoolFunctor {
37
 public:
Q
QI JUN 已提交
38
  void operator()(const platform::CPUDeviceContext& context,
39
                  const framework::LoDTensor& input, T pad_value,
40
                  framework::LoDTensor* output, framework::Tensor* index) {
41 42 43
    auto in_dims = input.dims();
    auto out_dims = output->dims();
    auto idx_dims = index->dims();
44
    PADDLE_ENFORCE_GT(in_dims.size(), 1,
45 46 47 48
                      platform::errors::InvalidArgument(
                          "The rank of input shall be greater than 1, but got "
                          "the rank is %ld. Please check the input value",
                          in_dims.size()));
49
    PADDLE_ENFORCE_GT(out_dims.size(), 1,
50 51 52 53
                      platform::errors::InvalidArgument(
                          "The rank of output shall be greater than 1, but got "
                          "the rank is %ld. Please check the input value",
                          out_dims.size()));
D
dangqingqing 已提交
54
    for (int64_t i = 1; i < in_dims.size(); ++i) {
55 56 57 58 59 60
      PADDLE_ENFORCE_EQ(
          in_dims[i], out_dims[i],
          platform::errors::InvalidArgument(
              "The dimension of input and output shall be same. Expected %ld "
              "== %ld, but got %ld != %ld. Please check the input value.",
              in_dims[i], out_dims[i], in_dims[i], out_dims[i]));
61
    }
62 63 64 65 66 67
    PADDLE_ENFORCE_EQ(
        idx_dims, out_dims,
        platform::errors::InvalidArgument(
            "The dimension of index and output shall be same. Expected %ld == "
            "%ld, but got %ld != %ld. Please check the input value.",
            idx_dims, out_dims, idx_dims, out_dims));
68

69 70
    auto lod_level = input.lod().size();
    auto starts = input.lod()[lod_level - 1];
71 72 73 74 75 76 77
    const T* in_data = input.data<T>();
    T* out_data = output->data<T>();
    int* max_index = index->data<int>();

    int64_t num_seq = out_dims[0];
    int64_t dim = output->numel() / num_seq;
    for (int64_t i = 0; i < num_seq; ++i) {
78 79 80 81 82 83 84
      if (starts[i] == starts[i + 1]) {
        for (int64_t k = 0; k < dim; ++k) {
          out_data[i * dim + k] = pad_value;
          max_index[i * dim + k] = -1;
        }
        continue;
      }
85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
      for (int64_t k = 0; k < dim; ++k) {
        out_data[i * dim + k] = in_data[starts[i] * dim + k];
        max_index[i * dim + k] = starts[i];
      }
      for (size_t j = starts[i] + 1; j < starts[i + 1]; ++j) {
        for (int64_t k = 0; k < dim; ++k) {
          if (in_data[j * dim + k] > out_data[i * dim + k]) {
            out_data[i * dim + k] = in_data[j * dim + k];
            max_index[i * dim + k] = j;
          }
        }
      }
    }
  }
};
J
Jacek Czaja 已提交
100 101 102 103 104 105
// Instantisation of Max Sequence Pooling for test phase eg. no need to fill
// index buffer
template <typename T>
class MaxSeqPoolFunctor<T, true> {
 public:
  void operator()(const platform::CPUDeviceContext& context,
106
                  const framework::LoDTensor& input, T pad_value,
107
                  framework::LoDTensor* output, framework::Tensor* index) {
J
Jacek Czaja 已提交
108 109
    auto in_dims = input.dims();
    auto out_dims = output->dims();
110
    PADDLE_ENFORCE_GT(in_dims.size(), 1,
111 112 113 114
                      platform::errors::InvalidArgument(
                          "The rank of input shall be greater than 1, but got "
                          "%ld <= 1. Please check the input value.",
                          in_dims.size()));
115
    PADDLE_ENFORCE_GT(out_dims.size(), 1,
116 117 118 119
                      platform::errors::InvalidArgument(
                          "The rank of output shall be greater than 1, but got "
                          "%ld <= 1. Please check the input value.",
                          out_dims.size()));
J
Jacek Czaja 已提交
120
    for (int64_t i = 1; i < in_dims.size(); ++i) {
121 122 123 124 125 126
      PADDLE_ENFORCE_EQ(
          in_dims[i], out_dims[i],
          platform::errors::InvalidArgument(
              "The dimension of input and output shall be same. Expected %ld "
              "== %ld, but got %ld != %ld. Please check the input value.",
              in_dims[i], out_dims[i], in_dims[i], out_dims[i]));
J
Jacek Czaja 已提交
127 128
    }

129 130
    auto lod_level = input.lod().size();
    auto starts = input.lod()[lod_level - 1];
J
Jacek Czaja 已提交
131 132
    const T* in_data = input.data<T>();
    T* out_data = output->data<T>();
133

J
Jacek Czaja 已提交
134 135 136
    int64_t num_seq = out_dims[0];
    int64_t dim = output->numel() / num_seq;
    for (int64_t i = 0; i < num_seq; ++i) {
137 138 139 140 141 142
      if (starts[i] == starts[i + 1]) {
        for (int64_t k = 0; k < dim; ++k) {
          out_data[i * dim + k] = pad_value;
        }
        continue;
      }
J
Jacek Czaja 已提交
143 144 145 146 147 148 149 150 151 152 153 154
      std::memcpy(&out_data[i * dim], &in_data[starts[i] * dim],
                  dim * sizeof(T));
      for (size_t j = starts[i] + 1; j < starts[i + 1]; ++j) {
        for (int64_t k = 0; k < dim; ++k) {
          if (in_data[j * dim + k] > out_data[i * dim + k]) {
            out_data[i * dim + k] = in_data[j * dim + k];
          }
        }
      }
    }
  }
};
155
template <typename T>
D
dzhwinter 已提交
156
class MaxSeqPoolGradFunctor {
157
 public:
Q
QI JUN 已提交
158
  void operator()(const platform::CPUDeviceContext& context,
159
                  const framework::LoDTensor& out_grad,
160 161 162 163 164
                  const framework::Tensor& index,
                  framework::LoDTensor* in_grad) {
    auto og_dims = out_grad.dims();
    auto ig_dims = in_grad->dims();
    auto idx_dims = index.dims();
165
    PADDLE_ENFORCE_GT(og_dims.size(), 1,
166 167 168 169
                      platform::errors::InvalidArgument(
                          "The rank of output@Grad shall be greater than 1, "
                          "but got %ld <= 1. Please check the input value.",
                          og_dims.size()));
170
    PADDLE_ENFORCE_GT(ig_dims.size(), 1,
171 172 173 174
                      platform::errors::InvalidArgument(
                          "The rank of input@Grad shall be greater than 1, but "
                          "got %ld <= 1. Please check the input value.",
                          ig_dims.size()));
D
dangqingqing 已提交
175
    for (int64_t i = 1; i < og_dims.size(); ++i) {
176 177 178 179 180 181
      PADDLE_ENFORCE_EQ(og_dims[i], ig_dims[i],
                        platform::errors::InvalidArgument(
                            "The dimension of input@Grad and output@Grad shall "
                            "be same. Expected %ld == %ld, but got %ld != %ld. "
                            "Please check the input value.",
                            og_dims[i], ig_dims[i], og_dims[i], ig_dims[i]));
182
    }
183 184 185 186 187 188
    PADDLE_ENFORCE_EQ(
        idx_dims, og_dims,
        platform::errors::InvalidArgument(
            "The dimension of index and output@Grad shall be same. Expected "
            "%ld == %ld, but got %ld != %ld. Please check the input value.",
            idx_dims, og_dims, idx_dims, og_dims));
189 190 191 192 193

    const T* og_data = out_grad.data<T>();
    const int* max_index = index.data<int>();
    T* ig_data = in_grad->data<T>();

194
    phi::funcs::SetConstant<platform::CPUDeviceContext, T> set_zero;
195 196 197
    set_zero(context, in_grad, static_cast<T>(0.0));
    int64_t num_seq = og_dims[0];
    int64_t dim = out_grad.numel() / num_seq;
D
dangqingqing 已提交
198 199
    for (int64_t i = 0; i < num_seq; ++i) {
      for (int64_t j = 0; j < dim; ++j) {
200
        int step_id = max_index[i * dim + j];
201
        if (step_id == -1) continue;
202 203 204 205 206 207
        ig_data[step_id * dim + j] = og_data[i * dim + j];
      }
    }
  }
};

208
template <typename T>
B
bingyanghuang 已提交
209
class LastSeqPoolFunctor {
210 211
 public:
  void operator()(const platform::CPUDeviceContext& context,
212
                  const framework::LoDTensor& input, T pad_value,
213
                  framework::LoDTensor* output) {
B
bingyanghuang 已提交
214 215 216
    // Create pointers to input and output data
    auto* in_data = input.data<T>();
    auto* out_data = output->data<T>();
B
bingyanghuang 已提交
217

B
bingyanghuang 已提交
218 219
    // Calculate the size of each item in sequence
    int64_t item_size = input.numel() / input.dims()[0];
220 221
    auto lod_level = input.lod().size();
    auto lod = input.lod()[lod_level - 1];
B
bingyanghuang 已提交
222
    int seq_num = static_cast<int>(lod.size()) - 1;
B
bingyanghuang 已提交
223 224 225
    for (int i = 0; i < seq_num; ++i) {
      // Calculate the length of each sequence
      int64_t seq_len = static_cast<int64_t>(lod[i + 1] - lod[i]);
226 227 228 229 230 231 232 233 234 235
      if (seq_len == 0) {
        for (int j = 0; j < item_size; ++j) {
          out_data[j] = pad_value;
        }
      } else {
        // Point to the begin of next sequence
        in_data += seq_len * item_size;
        // Copy the last item of sequence to output
        std::memcpy(out_data, (in_data - item_size), item_size * sizeof(T));
      }
B
bingyanghuang 已提交
236
      out_data += item_size;
B
bingyanghuang 已提交
237
    }
B
bingyanghuang 已提交
238 239 240 241 242 243 244
  }
};

template <typename T>
class FirstSeqPoolFunctor {
 public:
  void operator()(const platform::CPUDeviceContext& context,
245
                  const framework::LoDTensor& input, T pad_value,
246
                  framework::LoDTensor* output) {
B
bingyanghuang 已提交
247 248 249 250 251 252
    // Create pointers to input and output data
    auto* in_data = input.data<T>();
    auto* out_data = output->data<T>();

    // Calculate the size of each item in sequence
    int64_t item_size = input.numel() / input.dims()[0];
253 254
    auto lod_level = input.lod().size();
    auto lod = input.lod()[lod_level - 1];
B
bingyanghuang 已提交
255 256 257 258
    int seq_num = static_cast<int>(lod.size()) - 1;
    for (int i = 0; i < seq_num; ++i) {
      // Calculate the length of each sequence
      int64_t seq_len = static_cast<int64_t>(lod[i + 1] - lod[i]);
259 260 261 262 263 264 265 266 267 268
      if (seq_len == 0) {
        for (int j = 0; j < item_size; ++j) {
          out_data[j] = pad_value;
        }
      } else {
        // Copy the first item of sequence to output
        std::memcpy(out_data, in_data, item_size * sizeof(T));
        // Point to the next sequence
        in_data += seq_len * item_size;
      }
B
bingyanghuang 已提交
269
      out_data += item_size;
B
bingyanghuang 已提交
270
    }
B
bingyanghuang 已提交
271
  }
272 273
};

M
minqiyang 已提交
274 275 276 277
template <typename T>
class SumSeqPoolGradFunctor {
 public:
  void operator()(const platform::CPUDeviceContext& context,
278
                  const framework::LoDTensor& out_grad,
M
minqiyang 已提交
279
                  framework::LoDTensor* in_grad) {
280 281
    auto lod_level = in_grad->lod().size();
    auto lod = in_grad->lod()[lod_level - 1];
M
minqiyang 已提交
282 283
    int64_t out_w = out_grad.numel() / out_grad.dims()[0];
    int64_t in_w = in_grad->numel() / in_grad->dims()[0];
284 285 286 287 288 289
    PADDLE_ENFORCE_EQ(in_w, out_w,
                      platform::errors::InvalidArgument(
                          "The feature size of input@Grad and output@Grad "
                          "shall be same. Expected %ld == %ld, but got %ld != "
                          "%ld. Please check the input value.",
                          in_w, out_w, in_w, out_w));
M
minqiyang 已提交
290 291
    const T* out_g_data = out_grad.data<T>();
    T* in_g_data = in_grad->mutable_data<T>(context.GetPlace());
292
    auto blas = phi::funcs::GetBlas<platform::CPUDeviceContext, T>(context);
M
minqiyang 已提交
293 294
    for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
      int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]);
295
      if (h == 0) continue;
M
minqiyang 已提交
296 297 298 299 300 301 302 303 304 305
      int64_t in_offset = lod[i] * in_w;
      const T* out_pos = out_g_data + i * out_w;
      T* in_pos = in_g_data + in_offset;
      for (int r = 0; r != h; ++r) {
        blas.VCOPY(in_w, out_pos, in_pos + r * in_w);
      }
    }
  }
};

D
dzhwinter 已提交
306 307 308 309 310
template <typename T>
class SequencePoolFunctor<platform::CPUDeviceContext, T> {
 public:
  /* max pool has index output */
  void operator()(const platform::CPUDeviceContext& context,
311
                  const std::string pooltype, T pad_value,
312 313 314
                  const framework::LoDTensor& input,
                  framework::LoDTensor* output, bool is_test,
                  framework::Tensor* index = nullptr) {
D
dzhwinter 已提交
315
    if (pooltype == "MAX") {
J
Jacek Czaja 已提交
316 317
      if (is_test) {
        math::MaxSeqPoolFunctor<T, true> max_pool;
318
        max_pool(context, input, pad_value, output, index);
J
Jacek Czaja 已提交
319 320
      } else {
        math::MaxSeqPoolFunctor<T, false> max_pool;
321
        max_pool(context, input, pad_value, output, index);
J
Jacek Czaja 已提交
322
      }
D
dzhwinter 已提交
323 324
      return;
    }
B
bingyanghuang 已提交
325 326
    if (pooltype == "LAST") {
      math::LastSeqPoolFunctor<T> last_pool;
327
      last_pool(context, input, pad_value, output);
328 329
      return;
    }
B
bingyanghuang 已提交
330 331
    if (pooltype == "FIRST") {
      math::FirstSeqPoolFunctor<T> first_pool;
332
      first_pool(context, input, pad_value, output);
B
bingyanghuang 已提交
333 334
      return;
    }
335 336
    auto lod_level = input.lod().size();
    auto lod = input.lod()[lod_level - 1];
T
tensor-tang 已提交
337 338
    if (pooltype == "SUM") {
      auto place = context.GetPlace();
339 340
      PADDLE_ENFORCE_EQ(
          platform::is_cpu_place(place), true,
341 342
          platform::errors::InvalidArgument(
              "Sequence_pool should run on CPU Device when pooltype is SUM"));
T
tensor-tang 已提交
343 344
      const T* src = input.data<T>();
      T* dst = output->mutable_data<T>(place);
T
tensor-tang 已提交
345 346 347
      jit::seq_pool_attr_t attr(
          static_cast<int>(input.numel() / input.dims()[0]),
          jit::SeqPoolType::kSum);
348 349 350
      auto seqpool =
          jit::KernelFuncs<jit::SeqPoolTuple<T>, platform::CPUPlace>::Cache()
              .At(attr);
T
tensor-tang 已提交
351 352
      for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
        attr.h = static_cast<int>(lod[i + 1] - lod[i]);
353 354 355 356 357 358 359
        if (attr.h == 0) {
          for (int j = 0; j < attr.w; ++j) {
            dst[j] = pad_value;
          }
        } else {
          seqpool(src, dst, &attr);
        }
T
tensor-tang 已提交
360 361 362 363 364
        dst += attr.w;
        src += attr.h * attr.w;
      }
      return;
    }
D
dzhwinter 已提交
365 366
    auto& place = *context.eigen_device();
    for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
367 368 369 370 371 372 373 374
      Tensor out_t = output->Slice(i, i + 1);
      int64_t w = input.numel() / input.dims()[0];
      if (lod[i] == lod[i + 1]) {
        for (int j = 0; j < w; ++j) {
          out_t.data<T>()[j] = pad_value;
        }
        continue;
      }
D
dzhwinter 已提交
375 376 377
      Tensor in_t =
          input.Slice(static_cast<int>(lod[i]), static_cast<int>(lod[i + 1]));
      int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]);
378
      auto in_e = EigenMatrix<T>::From(in_t, phi::make_ddim({h, w}));
D
dzhwinter 已提交
379 380 381 382 383 384 385
      auto out_e = EigenVector<T>::Flatten(out_t);
      if (pooltype == "AVERAGE") {
        out_e.device(place) = in_e.mean(Eigen::array<int, 1>({{0}}));
      } else if (pooltype == "SQRT") {
        out_e.device(place) = in_e.sum(Eigen::array<int, 1>({{0}})) /
                              std::sqrt(static_cast<T>(h));
      } else {
386 387 388 389
        PADDLE_THROW(platform::errors::InvalidArgument(
            "unsupported pooling pooltype: %s. Only support \"AVERAGE\" and "
            "\"SQRT\"",
            pooltype));
D
dzhwinter 已提交
390 391 392 393 394 395 396 397 398
      }
    }
  }
};

template <typename T>
class SequencePoolGradFunctor<platform::CPUDeviceContext, T> {
 public:
  void operator()(const platform::CPUDeviceContext& context,
399 400
                  const std::string pooltype,
                  const framework::LoDTensor& out_grad,
D
dzhwinter 已提交
401 402 403 404 405 406 407 408 409 410 411
                  framework::LoDTensor* in_grad,
                  /* max pool has index */
                  const framework::Tensor* index = nullptr) {
    if (pooltype == "MAX") {
      math::MaxSeqPoolGradFunctor<T> max_pool_grad;
      max_pool_grad(context, out_grad, *index, in_grad);
      return;
    }

    if (pooltype == "LAST" || pooltype == "FIRST") {
      // set X@Grad be zero at first when pooltype is LAST/FIRST
412
      phi::funcs::SetConstant<platform::CPUDeviceContext, T> functor;
D
dzhwinter 已提交
413 414
      functor(context, in_grad, 0);
    }
M
minqiyang 已提交
415 416

    if (pooltype == "SUM") {
M
minqiyang 已提交
417 418
      math::SumSeqPoolGradFunctor<T> sum_pool_grad;
      sum_pool_grad(context, out_grad, in_grad);
M
minqiyang 已提交
419 420 421
      return;
    }

422 423
    auto lod_level = in_grad->lod().size();
    auto lod = in_grad->lod()[lod_level - 1];
D
dzhwinter 已提交
424 425
    auto& place = *context.eigen_device();
    for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
426
      if (lod[i] == lod[i + 1]) continue;
D
dzhwinter 已提交
427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446
      auto in_g_t = in_grad->Slice(static_cast<int>(lod[i]),
                                   static_cast<int>(lod[i + 1]));
      auto out_g_t = out_grad.Slice(i, i + 1);
      int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]);
      int64_t w = in_grad->numel() / in_grad->dims()[0];
      auto in_g_e = EigenMatrix<T>::From(in_g_t, {h, w});
      auto out_g_e = EigenMatrix<T>::From(out_g_t, {1, w});
      auto out_g_e_v = EigenVector<T>::Flatten(out_g_t);
      Eigen::DSizes<int, 2> bcast(h, 1);

      if (pooltype == "AVERAGE") {
        in_g_e.device(place) = (out_g_e / static_cast<T>(h)).broadcast(bcast);
      } else if (pooltype == "SQRT") {
        in_g_e.device(place) =
            (out_g_e / std::sqrt(static_cast<T>(h))).broadcast(bcast);
      } else if (pooltype == "LAST") {
        in_g_e.chip(h - 1, 0).device(place) = out_g_e_v;
      } else if (pooltype == "FIRST") {
        in_g_e.chip(0, 0).device(place) = out_g_e_v;
      } else {
447 448 449 450
        PADDLE_THROW(platform::errors::InvalidArgument(
            "unsupported pooling pooltype: %s. Only support \"AVERAGE\", "
            "\"SQRT\", \"LAST\" and \"FIRST\"",
            pooltype));
D
dzhwinter 已提交
451 452 453 454 455 456 457 458 459
      }
    }
  }
};

template class SequencePoolFunctor<platform::CPUDeviceContext, float>;
template class SequencePoolFunctor<platform::CPUDeviceContext, double>;
template class SequencePoolGradFunctor<platform::CPUDeviceContext, float>;
template class SequencePoolGradFunctor<platform::CPUDeviceContext, double>;
460 461 462 463

}  // namespace math
}  // namespace operators
}  // namespace paddle