topk_compute_test.cc 3.8 KB
Newer Older
Y
Yan Chunwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <gtest/gtest.h>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
19
#include "lite/tests/utils/fill_data.h"
Y
Yan Chunwei 已提交
20 21 22

namespace paddle {
namespace lite {
23 24 25

template <typename T1, typename T2>
bool comp_func(std::pair<T1, T2> a, std::pair<T1, T2> b) {
Y
Yan Chunwei 已提交
26 27 28
  return (a.first > b.first);
}

29
template <typename T1, typename T2>
Y
Yan Chunwei 已提交
30 31 32
class TopkComputeTester : public arena::TestCase {
 protected:
  // common attributes for this op.
33 34 35 36 37
  std::string x_ = "x";
  std::string out_ = "out";
  std::string indices_ = "indices";
  DDim x_dims_{{3, 5, 4, 4}};
  int k_ = 1;
Y
Yan Chunwei 已提交
38 39 40 41

 public:
  TopkComputeTester(const Place& place,
                    const std::string& alias,
42 43 44
                    DDim x_dims,
                    int k = 1)
      : TestCase(place, alias), x_dims_(x_dims), k_(k) {}
Y
Yan Chunwei 已提交
45 46

  void RunBaseline(Scope* scope) override {
47 48 49 50
    auto* out_val = scope->NewTensor(out_);
    auto* out_ind = scope->NewTensor(indices_);
    DDim out_dims = x_dims_;
    out_dims[out_dims.size() - 1] = k_;
Y
Yan Chunwei 已提交
51 52
    out_val->Resize(out_dims);
    out_ind->Resize(out_dims);
C
chenjiaoAngel 已提交
53 54
    auto* out_val_data = out_val->template mutable_data<T1>();
    auto* out_ind_data = out_ind->template mutable_data<T2>();
Y
Yan Chunwei 已提交
55

56
    auto* x = scope->FindTensor(x_);
C
chenjiaoAngel 已提交
57
    const auto* x_data = x->template data<T1>();
58 59
    int m = out_dims.production() / k_;
    int n = x_dims_[x_dims_.size() - 1];
Y
Yan Chunwei 已提交
60 61

    for (int i = 0; i < m; i++) {
62 63 64 65
      const T1* in_tmp = x_data + i * n;
      T1* out_val_tmp = out_val_data + i * k_;
      T2* out_ind_tmp = out_ind_data + i * k_;
      std::vector<std::pair<T1, T2>> vec;
Y
Yan Chunwei 已提交
66
      for (int j = 0; j < n; j++) {
67
        vec.push_back(std::make_pair(in_tmp[j], static_cast<T2>(j)));
Y
Yan Chunwei 已提交
68
      }
69 70 71
      std::partial_sort(
          vec.begin(), vec.begin() + k_, vec.end(), comp_func<T1, T2>);
      for (int q = 0; q < k_; q++) {
Y
Yan Chunwei 已提交
72 73 74 75 76 77 78
        out_val_tmp[q] = vec[q].first;
        out_ind_tmp[q] = vec[q].second;
      }
    }
  }

  void PrepareOpDesc(cpp::OpDesc* op_desc) {
79 80 81 82 83
    op_desc->SetType("top_k");
    op_desc->SetInput("X", {x_});
    op_desc->SetOutput("Out", {out_});
    op_desc->SetOutput("Indices", {indices_});
    op_desc->SetAttr("k", k_);
Y
Yan Chunwei 已提交
84 85 86
  }

  void PrepareData() override {
87 88 89
    std::vector<T1> dx(x_dims_.production());
    fill_data_rand<T1>(dx.data(), -1, 1, x_dims_.production());
    SetCommonTensor(x_, x_dims_, dx.data());
Y
Yan Chunwei 已提交
90 91 92
  }
};

93 94 95 96 97
template <typename T1, typename T2>
void test_topk(Place place, float abs_error) {
  for (auto x_shape : std::vector<std::vector<int64_t>>{
           {2, 3, 4, 5}, {3, 4, 5}, {4, 5}, {5}}) {
    for (int k : {2, 5}) {
Y
Yan Chunwei 已提交
98
      std::unique_ptr<arena::TestCase> tester(
99 100
          new TopkComputeTester<T1, T2>(place, "def", DDim(x_shape), k));
      arena::Arena arena(std::move(tester), place, abs_error);
Y
Yan Chunwei 已提交
101 102 103 104 105 106
      arena.TestPrecision();
    }
  }
}

TEST(Topk, precision) {
107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
  Place place;
  float abs_error = 2e-5;
#if defined(LITE_WITH_NPU)
  place = TARGET(kNPU);
  abs_error = 1e-3;  // Using fp16 in NPU
#elif defined(LITE_WITH_ARM)
  place = TARGET(kARM);
#else
  return;
#endif

#if defined(LITE_WITH_NPU)
  test_topk<float, int>(place, abs_error);
#else
  test_topk<float, int64_t>(place, abs_error);
Y
Yan Chunwei 已提交
122 123 124 125 126
#endif
}

}  // namespace lite
}  // namespace paddle