lookup_table_compute_test.cc 4.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <gtest/gtest.h>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
#include "lite/tests/utils/fill_data.h"

namespace paddle {
namespace lite {

24
template <typename T>
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
class LookupTableComputeTest : public arena::TestCase {
 protected:
  // common attributes for this op.
  std::string op_type_ = "lookup_table";
  std::string ids_ = "ids";
  std::string w_ = "w";
  std::string out_ = "out";
  DDim ids_dims_{{2, 1}};
  DDim w_dims_{{8, 4}};
  int64_t padding_idx_ = -1;

 public:
  LookupTableComputeTest(const Place& place,
                         const std::string& alias,
                         const DDim& ids_dims,
                         const DDim& w_dims,
                         int64_t padding_idx)
      : TestCase(place, alias),
        ids_dims_(ids_dims),
        w_dims_(w_dims),
        padding_idx_(padding_idx) {}

  void RunBaseline(Scope* scope) override {
    auto ids = scope->FindTensor(ids_);
    auto w = scope->FindTensor(w_);
    auto ids_dims = ids->dims();
    auto w_dims = w->dims();

    auto out = scope->NewTensor(out_);
    CHECK(out);

    int ids_rank = ids_dims.size();
    CHECK_EQ(ids_dims[ids_rank - 1], 1);
    CHECK_EQ(w_dims.size(), 2);

    std::vector<int64_t> out_dims;
    for (int i = 0; i < ids_rank - 1; ++i) {
      out_dims.push_back(ids_dims[i]);
    }
    out_dims.push_back(w_dims[1]);
    out->Resize(out_dims);
    out->set_lod(ids->lod());

68
    auto ids_data = ids->template data<T>();
69
    auto ids_size = ids_dims.production();
70
    auto w_data = w->template data<float>();
71 72
    auto w_rows = w_dims[0];
    auto w_cols = w_dims[1];
73
    auto out_data = out->template mutable_data<float>();
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98

    for (int64_t i = 0; i < ids_size; i++) {
      auto id = ids_data[i];
      if (padding_idx_ != -1 && id == padding_idx_) {
        memset(out_data + i * w_cols, 0, w_cols * sizeof(float));
      } else {
        CHECK_LT(id, w_rows) << "lookup_table ids[i] expected < " << w_rows
                             << " but got " << id;
        CHECK_GE(id, 0) << "lookup_table ids[i] expected >= 0 but got " << id;
        memcpy(out_data + i * w_cols,
               w_data + id * w_cols,
               w_cols * sizeof(float));
      }
    }
  }

  void PrepareOpDesc(cpp::OpDesc* op_desc) {
    op_desc->SetType(op_type_);
    op_desc->SetInput("Ids", {ids_});
    op_desc->SetInput("W", {w_});
    op_desc->SetOutput("Out", {out_});
    op_desc->SetAttr<int64_t>("padding_idx", padding_idx_);
  }

  void PrepareData() override {
99 100
    std::vector<T> ids(ids_dims_.production());
    fill_data_rand<T>(ids.data(), 0, w_dims_[0] - 1, ids_dims_.production());
101 102 103 104 105 106 107 108 109 110 111

    std::vector<float> w(w_dims_.production());
    fill_data_rand(w.data(), -1.f, 1.f, w_dims_.production());

    SetCommonTensor(ids_, ids_dims_, ids.data());
    SetCommonTensor(w_, w_dims_, w.data());
  }
};

TEST(LookupTable, precision) {
  LOG(INFO) << "test lookup_table op";
112
  float abs_error = 1e-5;
113
  Place place;
114 115 116 117
#if defined(LITE_WITH_NPU)
  place = TARGET(kNPU);
  abs_error = 1e-2;
#elif defined(LITE_WITH_ARM)
118
  place = TARGET(kARM);
119
#elif defined(LITE_WITH_XPU)
120 121 122 123 124
  place = TARGET(kXPU);
#else
  return;
#endif

125 126 127 128 129 130
#if defined(LITE_WITH_NPU)
  using ID_T = int;
#else
  using ID_T = int64_t;
#endif

131 132 133 134
  for (auto ids_dims :
       std::vector<std::vector<int64_t>>{{5, 2, 3, 1}, {2, 3, 1}, {3, 1}}) {
    for (auto w_dims :
         std::vector<std::vector<int64_t>>{{4, 2}, {6, 8}, {12, 15}}) {
135
#if defined(LITE_WITH_XPU) && defined(LITE_WITH_NPU)
136
      for (auto padding_idx :
137
           std::vector<int64_t>{-1}) {  // Only -1 is supported by XPU or NPU
138 139 140
#else
      for (auto padding_idx : std::vector<int64_t>{-1, 0, w_dims[0] - 1}) {
#endif
141 142 143
        std::unique_ptr<arena::TestCase> tester(
            new LookupTableComputeTest<ID_T>(
                place, "def", DDim(ids_dims), DDim(w_dims), padding_idx));
144 145 146 147 148 149 150 151 152
        arena::Arena arena(std::move(tester), place, abs_error);
        arena.TestPrecision();
      }
    }
  }
}

}  // namespace lite
}  // namespace paddle