segment_pooling.cc 6.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/phi/kernels/funcs/segment_pooling.h"
16

17 18
#include <string>

19 20
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
21

22 23 24 25
namespace phi {
namespace funcs {

using Tensor = DenseTensor;
26 27

template <typename T, typename IndexT>
28
class SegmentPoolFunctor<phi::CPUContext, T, IndexT> {
29
 public:
30 31 32 33
  void operator()(const phi::CPUContext& dev_ctx,
                  const DenseTensor& input,
                  const DenseTensor& segments,
                  DenseTensor* output,
34
                  DenseTensor* index UNUSED,
35 36 37 38 39
                  const std::string pooltype = "SUM") {
    const IndexT* segment_ids = segments.data<IndexT>();
    auto curent_id = segment_ids[0];
    int64_t last_idx = 0;
    int64_t w = input.numel() / input.dims()[0];
40
    auto& place = *dev_ctx.eigen_device();
41 42 43
    for (int64_t idx = 1; idx <= segments.numel(); ++idx) {
      if (idx < segments.numel()) {
        if (segment_ids[idx] == curent_id) continue;
44 45 46
        PADDLE_ENFORCE_GE(segment_ids[idx],
                          curent_id,
                          phi::errors::InvalidArgument(
47 48
                              "The segment ids should be sorted, but got "
                              "segment_ids[%d]:%d > segment_ids[%d]:%d.",
49 50 51 52
                              idx - 1,
                              curent_id,
                              idx,
                              segment_ids[idx]));
53 54 55 56 57 58
      }

      Tensor out_t = output->Slice(curent_id, curent_id + 1);
      Tensor in_t = input.Slice(last_idx, idx);

      int64_t h = idx - last_idx;
59 60
      auto in_e = EigenMatrix<T>::From(in_t, phi::make_ddim({h, w}));
      auto out_e = EigenVector<T>::Flatten(out_t);
61 62 63 64 65 66 67 68 69 70 71

      auto reduce_dim = Eigen::array<int, 1>({{0}});
      if (pooltype == "MEAN") {
        out_e.device(place) = in_e.mean(reduce_dim);
      } else if (pooltype == "SUM") {
        out_e.device(place) = in_e.sum(reduce_dim);
      } else if (pooltype == "MAX") {
        out_e.device(place) = in_e.maximum(reduce_dim);
      } else if (pooltype == "MIN") {
        out_e.device(place) = in_e.minimum(reduce_dim);
      } else {
72
        PADDLE_THROW(phi::errors::InvalidArgument(
73 74 75 76 77 78 79 80 81 82 83 84
            "Unsupported segment pooling type, only MEAN, SUM, MAX, MIN "
            "available, but got %s.",
            pooltype));
      }

      last_idx = idx;
      if (idx < segments.numel()) curent_id = segment_ids[idx];
    }
  }
};

template <typename T, typename IndexT>
85
class SegmentPoolGradFunctor<phi::CPUContext, T, IndexT> {
86
 public:
87 88 89 90 91 92
  void operator()(const phi::CPUContext& dev_ctx,
                  const DenseTensor& input,
                  const DenseTensor& output,
                  const DenseTensor& out_grad,
                  const DenseTensor& segments,
                  DenseTensor* in_grad,
93
                  const paddle::optional<DenseTensor>& index UNUSED,
94 95
                  const std::string pooltype = "SUM") {
    const IndexT* segment_ids = segments.data<IndexT>();
96
    auto& place = *dev_ctx.eigen_device();
97 98 99 100 101 102
    auto curent_id = segment_ids[0];
    int64_t last_idx = 0;
    int64_t w = in_grad->numel() / in_grad->dims()[0];
    for (int64_t idx = 1; idx <= segments.numel(); ++idx) {
      if (idx < segments.numel()) {
        if (segment_ids[idx] == curent_id) continue;
103 104 105
        PADDLE_ENFORCE_GE(segment_ids[idx],
                          curent_id,
                          phi::errors::InvalidArgument(
106 107
                              "The segment ids should be sorted, but got "
                              "segment_ids[%d]:%d > segment_ids[%d]:%d.",
108 109 110 111
                              idx - 1,
                              curent_id,
                              idx,
                              segment_ids[idx]));
112 113 114 115 116 117
      }

      Tensor out_g_t = out_grad.Slice(curent_id, curent_id + 1);
      Tensor in_g_t = in_grad->Slice(last_idx, idx);

      int64_t h = idx - last_idx;
118 119
      auto in_g_e = EigenMatrix<T>::From(in_g_t, {h, w});
      auto out_g_e = EigenMatrix<T>::From(out_g_t, {1, w});
120 121 122 123 124 125 126 127 128
      Eigen::DSizes<int, 2> bcast(h, 1);

      if (pooltype == "MEAN") {
        in_g_e.device(place) = (out_g_e / static_cast<T>(h)).broadcast(bcast);
      } else if (pooltype == "SUM") {
        in_g_e.device(place) = out_g_e.broadcast(bcast);
      } else if (pooltype == "MAX" || pooltype == "MIN") {
        Tensor out_t = output.Slice(curent_id, curent_id + 1);
        Tensor in_t = input.Slice(last_idx, idx);
129 130
        auto in_e = EigenMatrix<T>::From(in_t, {h, w});
        auto out_e = EigenMatrix<T>::From(out_t, {1, w});
131 132 133 134
        in_g_e.device(place) =
            (in_e == out_e.broadcast(bcast)).template cast<T>() *
            out_g_e.broadcast(bcast);
      } else {
135
        PADDLE_THROW(phi::errors::InvalidArgument(
136 137 138 139 140 141 142 143 144 145 146
            "Unsupported segment pooling type, only MEAN, SUM, MAX, MIN "
            "available, but got %s.",
            pooltype));
      }

      last_idx = idx;
      if (idx < segments.numel()) curent_id = segment_ids[idx];
    }
  }
};

147
using CPU = phi::CPUContext;
148
using float16 = phi::dtype::float16;
149 150 151 152
template class SegmentPoolFunctor<CPU, float, int>;
template class SegmentPoolFunctor<CPU, float, int64_t>;
template class SegmentPoolFunctor<CPU, double, int>;
template class SegmentPoolFunctor<CPU, double, int64_t>;
153 154 155 156
template class SegmentPoolFunctor<CPU, int, int>;
template class SegmentPoolFunctor<CPU, int, int64_t>;
template class SegmentPoolFunctor<CPU, int64_t, int>;
template class SegmentPoolFunctor<CPU, int64_t, int64_t>;
157 158
template class SegmentPoolFunctor<CPU, float16, int>;
template class SegmentPoolFunctor<CPU, float16, int64_t>;
159

160 161 162 163
template class SegmentPoolGradFunctor<CPU, float, int>;
template class SegmentPoolGradFunctor<CPU, float, int64_t>;
template class SegmentPoolGradFunctor<CPU, double, int>;
template class SegmentPoolGradFunctor<CPU, double, int64_t>;
164 165 166 167
template class SegmentPoolGradFunctor<CPU, int, int>;
template class SegmentPoolGradFunctor<CPU, int, int64_t>;
template class SegmentPoolGradFunctor<CPU, int64_t, int>;
template class SegmentPoolGradFunctor<CPU, int64_t, int64_t>;
168 169
template class SegmentPoolGradFunctor<CPU, float16, int>;
template class SegmentPoolGradFunctor<CPU, float16, int64_t>;
170

171 172
}  // namespace funcs
}  // namespace phi