write_to_array_compute_test.cc 2.8 KB
Newer Older
Y
Yan Chunwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <gtest/gtest.h>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
19
#include "lite/tests/utils/fill_data.h"
Y
Yan Chunwei 已提交
20 21 22 23 24 25 26

namespace paddle {
namespace lite {

class WriteToArrayComputeTester : public arena::TestCase {
 protected:
  // common attributes for this op.
27 28 29 30 31 32
  std::string x_ = "x";
  std::string idn_ = "i";
  std::string out_ = "out";
  DDim x_dims_{{3, 5, 4, 4}};
  int out_size_ = 0;
  int id_ = 0;
Y
Yan Chunwei 已提交
33 34 35 36

 public:
  WriteToArrayComputeTester(const Place& place,
                            const std::string& alias,
37 38 39 40
                            DDim x_dims,
                            int out_size = 0,
                            int id = 0)
      : TestCase(place, alias), x_dims_(x_dims), out_size_(out_size), id_(id) {}
Y
Yan Chunwei 已提交
41 42

  void RunBaseline(Scope* scope) override {
43 44
    auto out = scope->Var(out_)->GetMutable<std::vector<Tensor>>();
    auto x = scope->FindTensor(x_);
Y
Yan Chunwei 已提交
45

46 47
    if (out->size() < id_ + 1) {
      out->resize(id_ + 1);
Y
Yan Chunwei 已提交
48
    }
49 50 51
    out->at(id_).Resize(x->dims());
    auto out_data = out->at(id_).mutable_data<float>();
    memcpy(out_data, x->data<float>(), sizeof(float) * x->numel());
Y
Yan Chunwei 已提交
52 53 54 55
  }

  void PrepareOpDesc(cpp::OpDesc* op_desc) {
    op_desc->SetType("write_to_array");
56 57 58
    op_desc->SetInput("X", {x_});
    op_desc->SetInput("I", {idn_});
    op_desc->SetOutput("Out", {out_});
Y
Yan Chunwei 已提交
59 60 61
  }

  void PrepareData() override {
62 63 64
    std::vector<float> dx(x_dims_.production());
    fill_data_rand(dx.data(), -1.f, 1.f, x_dims_.production());
    SetCommonTensor(x_, x_dims_, dx.data());
Y
Yan Chunwei 已提交
65

66 67 68
    std::vector<int64_t> didn(1);
    didn[0] = id_;
    SetCommonTensor(idn_, DDim{{1}}, didn.data());
Y
Yan Chunwei 已提交
69 70
  }
};
71 72

void TestWriteToArray(Place place, float abs_error) {
Y
Yan Chunwei 已提交
73
  DDimLite dims{{3, 5, 4, 4}};
74 75 76 77 78 79 80
  for (int out_size : {0, 3}) {
    for (int id : {0, 1, 4}) {
      std::unique_ptr<arena::TestCase> tester(
          new WriteToArrayComputeTester(place, "def", dims, out_size, id));
      arena::Arena arena(std::move(tester), place, abs_error);
      arena.TestPrecision();
    }
Y
Yan Chunwei 已提交
81 82 83 84
  }
}

TEST(WriteToArray, precision) {
85 86
  Place place;
  float abs_error = 1e-5;
Y
Yan Chunwei 已提交
87
#ifdef LITE_WITH_ARM
88
  place = TARGET(kARM);
89 90
#else
  return;
Y
Yan Chunwei 已提交
91
#endif
92 93

  TestWriteToArray(place, abs_error);
Y
Yan Chunwei 已提交
94 95 96 97
}

}  // namespace lite
}  // namespace paddle