logsumexp_op.cc 6.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/operators/reduce_ops/logsumexp_op.h"
16
#include <algorithm>
17 18 19 20 21 22
#include <string>
#include <vector>

namespace paddle {
namespace operators {

23 24 25 26 27 28 29 30 31 32 33 34
class LogsumexpOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "logsumexp");
    OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "logsumexp");
    auto x_dims = ctx->GetInputDim("X");
    auto x_rank = x_dims.size();
    PADDLE_ENFORCE_LE(x_rank, 4,
                      platform::errors::InvalidArgument(
                          "The input tensor X's dimensions of logsumexp "
35
                          "should be less or equal than 4. But received X's "
36 37 38 39 40 41 42 43 44 45 46 47
                          "dimensions = %d, X's shape = [%s].",
                          x_rank, x_dims));
    auto axis = ctx->Attrs().Get<std::vector<int>>("axis");
    PADDLE_ENFORCE_GT(
        axis.size(), 0,
        platform::errors::InvalidArgument(
            "The size of axis of logsumexp "
            "should be greater than 0. But received the size of axis "
            "of logsumexp is %d.",
            axis.size()));

    for (size_t i = 0; i < axis.size(); i++) {
48 49 50 51 52 53 54 55 56 57 58 59
      PADDLE_ENFORCE_LT(axis[i], x_rank,
                        platform::errors::InvalidArgument(
                            "axis[%d] should be in the "
                            "range [-D, D), where D is the dimensions of X and "
                            "D is %d. But received axis[%d] = %d.",
                            i, x_rank, i, axis[i]));
      PADDLE_ENFORCE_GE(axis[i], -x_rank,
                        platform::errors::InvalidArgument(
                            "axis[%d] should be in the "
                            "range [-D, D), where D is the dimensions of X and "
                            "D is %d. But received axis[%d] = %d.",
                            i, x_rank, i, axis[i]));
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
      if (axis[i] < 0) {
        axis[i] += x_rank;
      }
    }

    bool keepdim = ctx->Attrs().Get<bool>("keepdim");
    bool reduce_all = ctx->Attrs().Get<bool>("reduce_all");
    auto dims_vector = vectorize(x_dims);
    if (reduce_all) {
      if (keepdim)
        ctx->SetOutputDim(
            "Out", framework::make_ddim(std::vector<int64_t>(x_rank, 1)));
      else
        ctx->SetOutputDim("Out", {1});
    } else {
      auto dims_vector = vectorize(x_dims);
      if (keepdim) {
        for (size_t i = 0; i < axis.size(); ++i) {
          dims_vector[axis[i]] = 1;
        }
      } else {
        const int kDelFlag = -1;
        for (size_t i = 0; i < axis.size(); ++i) {
          dims_vector[axis[i]] = kDelFlag;
        }
        dims_vector.erase(
            std::remove(dims_vector.begin(), dims_vector.end(), kDelFlag),
            dims_vector.end());
      }
      if (!keepdim && dims_vector.size() == 0) {
        dims_vector.push_back(1);
      }
      auto out_dims = framework::make_ddim(dims_vector);
      ctx->SetOutputDim("Out", out_dims);
      if (axis.size() > 0 && axis[0] != 0) {
        // Only pass LoD when not reducing on the first dim.
        ctx->ShareLoD("X", /*->*/ "Out");
      }
    }
  }
};

class LogsumexpOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X",
             "(Tensor) The input tensor. Tensors with rank at most 4 are "
             "supported.");
    AddOutput("Out", "(Tensor) The result tensor.");
    AddAttr<std::vector<int>>(
        "axis",
        "(list<int>, default {0}) The dimensions to reduce. "
        "Must be in the range [-rank(input), rank(input)). "
        "If `axis[i] < 0`, the axis[i] to reduce is `rank + axis[i]`. "
        "Note that reducing on the first dim will make the LoD info lost.")
        .SetDefault({0});
    AddAttr<bool>("keepdim",
                  "(bool, default false) "
                  "If true, retain the reduced dimension with length 1.")
        .SetDefault(false);
    AddAttr<bool>("reduce_all",
                  "(bool, default false) "
                  "If true, output a scalar reduced along all dimensions.")
        .SetDefault(false);
    AddComment(string::Sprintf(R"DOC(
logsumexp Operator.

This operator computes the logsumexp of input tensor along the given axis.
The result tensor has 1 fewer dimension than the input unless keep_dim is true.
If reduce_all is true, just reduce along all dimensions and output a scalar.

)DOC"));
  }
};

class LogsumexpGrapOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "logsumexp");
    OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "logsumexp");
    OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
                   "Out@GRAD", "logsumexp");
    ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
  }
146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165
};

template <typename T>
class LogsumexpGradOpMaker : public framework::SingleGradOpMaker<T> {
 public:
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

  void Apply(GradOpPtr<T> op) const override {
    op->SetType("logsumexp_grad");
    op->SetInput("X", this->Input("X"));
    op->SetInput("Out", this->Output("Out"));
    op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    op->SetAttrMap(this->Attrs());
    op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
  }
};

}  // namespace operators
}  // namespace paddle

166 167 168
namespace ops = paddle::operators;

REGISTER_OPERATOR(logsumexp, ops::LogsumexpOp, ops::LogsumexpOpMaker,
169 170
                  ops::LogsumexpGradOpMaker<paddle::framework::OpDesc>,
                  ops::LogsumexpGradOpMaker<paddle::imperative::OpBase>);
171
REGISTER_OPERATOR(logsumexp_grad, ops::LogsumexpGrapOp);
172 173

REGISTER_OP_CPU_KERNEL(
174 175 176 177 178 179
    logsumexp, ops::LogsumexpKernel<paddle::platform::CPUDeviceContext, float>,
    ops::LogsumexpKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
    logsumexp_grad,
    ops::LogsumexpGradKernel<paddle::platform::CPUDeviceContext, float>,
    ops::LogsumexpGradKernel<paddle::platform::CPUDeviceContext, double>);