unique_op.cc 6.8 KB
Newer Older
Z
zhoukunsheng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/unique_op.h"
16

C
csy0225 已提交
17
#include <memory>
18

C
csy0225 已提交
19 20 21 22
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
Z
zhoukunsheng 已提交
23 24 25 26 27 28 29 30 31

namespace paddle {
namespace operators {

class UniqueOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
32 33
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "unique");
    OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "unique");
Z
Zhang Ting 已提交
34 35 36 37 38

    bool return_index = ctx->Attrs().Get<bool>("return_index");
    bool return_inverse = ctx->Attrs().Get<bool>("return_inverse");
    bool return_counts = ctx->Attrs().Get<bool>("return_counts");
    auto axis_vec = ctx->Attrs().Get<std::vector<int>>("axis");
C
csy0225 已提交
39 40 41 42 43 44 45 46 47 48 49
    auto data_type =
        static_cast<phi::DataType>(static_cast<framework::proto::VarType::Type>(
            ctx->Attrs().Get<int>("dtype")));

    // Construct MetaTensor for InferMeta Func
    using CompatMetaTensor = framework::CompatMetaTensor;
    CompatMetaTensor x(ctx->GetInputVarPtrs("X")[0], ctx->IsRuntime());
    CompatMetaTensor out(ctx->GetOutputVarPtrs("Out")[0], ctx->IsRuntime());
    std::unique_ptr<CompatMetaTensor> indices(nullptr);
    std::unique_ptr<CompatMetaTensor> index(nullptr);
    std::unique_ptr<CompatMetaTensor> counts(nullptr);
Z
Zhang Ting 已提交
50 51 52

    if (return_index) {
      OP_INOUT_CHECK(ctx->HasOutput("Indices"), "Output", "Indices", "unique");
C
csy0225 已提交
53 54 55
      indices =
          std::move(std::unique_ptr<CompatMetaTensor>(new CompatMetaTensor(
              ctx->GetOutputVarPtrs("Indices")[0], ctx->IsRuntime())));
Z
Zhang Ting 已提交
56 57 58
    }
    if (return_inverse) {
      OP_INOUT_CHECK(ctx->HasOutput("Index"), "Output", "Index", "unique");
C
csy0225 已提交
59 60
      index = std::move(std::unique_ptr<CompatMetaTensor>(new CompatMetaTensor(
          ctx->GetOutputVarPtrs("Index")[0], ctx->IsRuntime())));
Z
Zhang Ting 已提交
61 62 63
    }
    if (return_counts) {
      OP_INOUT_CHECK(ctx->HasOutput("Counts"), "Output", "Counts", "unique");
C
csy0225 已提交
64 65
      counts = std::move(std::unique_ptr<CompatMetaTensor>(new CompatMetaTensor(
          ctx->GetOutputVarPtrs("Counts")[0], ctx->IsRuntime())));
Z
Zhang Ting 已提交
66
    }
C
csy0225 已提交
67 68
    bool is_sorted = ctx->Attrs().Get<bool>("is_sorted");
    if (is_sorted) {
69 70 71 72 73 74 75 76 77 78
      phi::UniqueInferMeta(x,
                           return_index,
                           return_inverse,
                           return_counts,
                           axis_vec,
                           data_type,
                           &out,
                           indices.get(),
                           index.get(),
                           counts.get());
Z
Zhang Ting 已提交
79
    } else {
C
csy0225 已提交
80 81 82 83 84
      OP_INOUT_CHECK(ctx->HasOutput("Index"), "Output", "Index", "unique");
      if (index == nullptr) {
        index =
            std::move(std::unique_ptr<CompatMetaTensor>(new CompatMetaTensor(
                ctx->GetOutputVarPtrs("Index")[0], ctx->IsRuntime())));
Z
Zhang Ting 已提交
85
      }
86 87 88 89 90 91 92 93 94 95 96
      phi::UniqueRawInferMeta(x,
                              return_index,
                              return_inverse,
                              return_counts,
                              axis_vec,
                              data_type,
                              is_sorted,
                              &out,
                              indices.get(),
                              index.get(),
                              counts.get());
Z
Zhang Ting 已提交
97
    }
Z
zhoukunsheng 已提交
98
  }
99 100 101 102

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
103 104 105 106 107 108 109 110 111 112 113
    // Return CPUPlace when Attr("is_sorted") is false. Because it means
    // that fluid.layers.unique is called, but there is no cuda kernel.
    if (!ctx.Attr<bool>("is_sorted")) {
      return framework::OpKernelType(
          OperatorWithKernel::IndicateVarDataType(ctx, "X"),
          platform::CPUPlace());
    } else {
      // new version paddle.unique is called.
      return framework::OpKernelType(
          OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
    }
114
  }
Z
zhoukunsheng 已提交
115 116 117 118 119
};

class UniqueOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
Z
Zhang Ting 已提交
120 121
    AddInput("X",
             "Input tensor. It should be a 1-D tensor when Attr(is_sorted)"
122
             " is false or a N-D tensor when Attr(is_sorted) is true.");
Z
zhoukunsheng 已提交
123 124 125
    AddAttr<int>("dtype", "data type for output index");
    AddOutput("Out", "A unique subsequence for input tensor.");
    AddOutput("Index",
Z
Zhang Ting 已提交
126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
              "Equivalent to inverse in numpy.unique, "
              "the indices for where elements in the original input ended up "
              "in the returned unique tensor.");
    AddOutput(
        "Indices",
        "The indices of the input tensor that result in the unique tensor.")
        .AsDispensable();
    AddOutput("Counts", "The counts for each unique element.").AsDispensable();
    AddAttr<bool>("return_index",
                  "If True, also return the indices of the input"
                  " tensor that result in the unique Tensor.")
        .SetDefault(false);
    AddAttr<bool>(
        "return_inverse",
        "If True, also return the indices for where elements"
        " in the original input ended up in the returned unique tensor.")
        .SetDefault(false);
    AddAttr<bool>("return_counts",
                  "If True, also return the counts for each unique element.")
        .SetDefault(false);
    AddAttr<std::vector<int>>(
        "axis",
        "The axis to apply unique. If None, the input will be flattened.")
        .SetDefault({});
    AddAttr<bool>("is_sorted",
                  "If True, the unique elements of X are in ascending order."
                  "Otherwise, the unique elements are not sorted.")
        .SetDefault(false);
Z
zhoukunsheng 已提交
154
    AddComment(R"DOC(
Z
Zhang Ting 已提交
155
    1. Return a unique subsequence for 1-D input tensor, and an index tensor
156
    pointing to this unique subsequence when Attr(is_sorted) is false. This
Z
Zhang Ting 已提交
157
    means paddle.unique is called.
158

Z
Zhang Ting 已提交
159 160
    2. Returns the unique elements of X in ascending order when Attr(is_sorted)
    is true. This means fluid.layers.unique is called.
Z
zhoukunsheng 已提交
161 162 163 164 165 166 167
)DOC");
  }
};
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
C
csy0225 已提交
168

Z
zhoukunsheng 已提交
169
REGISTER_OP_WITHOUT_GRADIENT(unique, ops::UniqueOp, ops::UniqueOpMaker);