segment_pool_op.h 6.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/segment_pooling.h"
#include "paddle/fluid/platform/macros.h"
21 22
#include "paddle/phi/common/place.h"
#include "paddle/phi/kernels/funcs/math_function.h"
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

template <typename DeviceContext, typename T, typename IndexT>
void SegmentKernelLaunchHelper(const framework::ExecutionContext& context) {
  auto* input = context.Input<Tensor>("X");
  auto* segment = context.Input<Tensor>("SegmentIds");
  auto* output = context.Output<Tensor>("Out");
  std::string pooltype = context.Attr<std::string>("pooltype");
  Tensor* summed_ids = nullptr;

  int64_t num_indices = segment->numel();
  PADDLE_ENFORCE_EQ(
      num_indices, input->dims()[0],
      platform::errors::InvalidArgument(
          "Segment_ids should be the same size as dimension 0 of input X."));
  PADDLE_ENFORCE_EQ(num_indices, segment->dims()[0],
                    platform::errors::InvalidArgument(
                        "Segment_ids should be 1-D tensor, or it's other "
                        "dimension size is 1. Segment_ids's shape is: [%s].",
                        segment->dims()));

  if (input->numel() == 0 || segment->numel() == 0) {
    return;
  }

52
  bool cpu_place = context.GetPlace().GetType() == phi::AllocationType::CPU;
53 54 55 56 57 58 59 60 61 62
  if (cpu_place) {
    auto dims = input->dims();
    auto* segment_ids = segment->data<IndexT>();
    dims[0] = static_cast<int64_t>(segment_ids[segment->numel() - 1] + 1);
    PADDLE_ENFORCE_GT(
        dims[0], 0,
        platform::errors::InvalidArgument(
            "Segment ids must be >= 0, but got last id %d", dims[0]));
    output->Resize({dims});
    output->mutable_data<T>(context.GetPlace());
63
    phi::funcs::SetConstant<DeviceContext, T> set_zero;
64 65 66
    auto& dev_ctx = context.template device_context<DeviceContext>();
    set_zero(dev_ctx, output, static_cast<T>(0));
  }
67
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
68 69
  if (!cpu_place) {
    Tensor length;
70
    length.mutable_data<IndexT>(phi::make_ddim({1}), platform::CPUPlace());
71 72 73
    IndexT* length_data = length.data<IndexT>();
    const IndexT* segment_ids = segment->data<IndexT>();

74
#ifdef PADDLE_WITH_HIP
75
    PADDLE_ENFORCE_GPU_SUCCESS(
76 77 78
        hipMemcpy(length_data, segment_ids + num_indices - 1, sizeof(IndexT),
                  hipMemcpyDeviceToHost));
#else
79
    PADDLE_ENFORCE_GPU_SUCCESS(
80 81
        cudaMemcpy(length_data, segment_ids + num_indices - 1, sizeof(IndexT),
                   cudaMemcpyDeviceToHost));
82
#endif
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99

    IndexT length_host = length_data[0];
    length_host++;
    PADDLE_ENFORCE_GT(
        length_host, 0,
        platform::errors::InvalidArgument(
            "Segment ids must be >= 0, but got last id %d", length_data[0]));
    auto dims = input->dims();
    dims[0] = static_cast<int64_t>(length_host);
    output->Resize({dims});
    output->mutable_data<T>(context.GetPlace());
    T init_value = 0;
    if (pooltype == "MAX") {
      init_value = static_cast<T>(-FLT_MAX);
    } else if (pooltype == "MIN") {
      init_value = static_cast<T>(FLT_MAX);
    }
100
    phi::funcs::SetConstant<DeviceContext, T> setconst;
101 102 103 104 105 106 107 108 109 110 111
    auto& dev_ctx = context.template device_context<DeviceContext>();
    setconst(dev_ctx, output, static_cast<T>(init_value));
    // the gpu kernel of mean pool record the counts of segment_ids
    if (pooltype == "MEAN") {
      summed_ids = context.Output<Tensor>("SummedIds");
      summed_ids->Resize({dims[0], 1});
      summed_ids->mutable_data<T>(context.GetPlace());
      setconst(dev_ctx, summed_ids, static_cast<T>(1e-12));
    }
  }
#endif
112 113 114 115 116 117 118 119 120 121 122 123

  SegmentPoolFunctor<DeviceContext, T, IndexT> pool;

  pool(context.template device_context<DeviceContext>(), *input, *segment,
       output, summed_ids, pooltype);
}

template <typename DeviceContext, typename T>
class SegmentPoolKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    auto* segment = context.Input<Tensor>("SegmentIds");
124
    auto index_type = framework::TransToProtoVarType(segment->dtype());
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
    if (index_type == framework::proto::VarType::INT32) {
      SegmentKernelLaunchHelper<DeviceContext, T, int>(context);
    } else if (index_type == framework::proto::VarType::INT64) {
      SegmentKernelLaunchHelper<DeviceContext, T, int64_t>(context);
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Unsupported index type, Expected int, int64, but got %s.",
          index_type));
    }
  }
};

template <typename DeviceContext, typename T>
class SegmentPoolGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    auto* input = context.Input<Tensor>("X");
    auto* output = context.Input<Tensor>("Out");
    auto* segment = context.Input<Tensor>("SegmentIds");
    auto* out_g = context.Input<Tensor>(framework::GradVarName("Out"));
    auto* in_g = context.Output<Tensor>(framework::GradVarName("X"));
    std::string pooltype = context.Attr<std::string>("pooltype");

    const Tensor* summed_ids = nullptr;
    if (pooltype == "MEAN") {
      summed_ids = context.Input<Tensor>("SummedIds");
    }

    in_g->mutable_data<T>(context.GetPlace());
154
    phi::funcs::SetConstant<DeviceContext, T> set_zero;
155 156 157
    auto& dev_ctx = context.template device_context<DeviceContext>();
    set_zero(dev_ctx, in_g, static_cast<T>(0));

158
    auto index_type = framework::TransToProtoVarType(segment->dtype());
159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
    if (index_type == framework::proto::VarType::INT32) {
      SegmentPoolGradFunctor<DeviceContext, T, int> pool;
      pool(context.template device_context<DeviceContext>(), *input, *output,
           *out_g, *segment, in_g, summed_ids, pooltype);
    } else if (index_type == framework::proto::VarType::INT64) {
      SegmentPoolGradFunctor<DeviceContext, T, int64_t> pool;
      pool(context.template device_context<DeviceContext>(), *input, *output,
           *out_g, *segment, in_g, summed_ids, pooltype);
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Unsupported index type, Expected int, int64, but got %s.",
          index_type));
    }
  }
};

}  // namespace operators
}  // namespace paddle