未验证 提交 1ae03f55 编写于 作者: Z zhupengyang 提交者: GitHub

arena/framework support tensor_arra, add write_to_array and read_from_array uts (#3083)

上级 0cc72c42
...@@ -74,12 +74,21 @@ void TestCase::PrepareInputsForInstruction() { ...@@ -74,12 +74,21 @@ void TestCase::PrepareInputsForInstruction() {
const auto* param_type = ParamTypeRegistry::Global().RetrieveInArgument( const auto* param_type = ParamTypeRegistry::Global().RetrieveInArgument(
place_, kernel_key, arg); place_, kernel_key, arg);
const auto* inst_type = Type::GetTensorTy(TARGET(kHost)); const Type* inst_type = nullptr;
if (param_type->type->IsTensor()) {
inst_type = Type::GetTensorTy(TARGET(kHost));
} else if (param_type->type->IsTensorList()) {
inst_type = Type::GetTensorListTy(TARGET(kHost));
} else {
LOG(FATAL) << "unsupported param_type";
}
CHECK(scope_->FindVar(var)); CHECK(scope_->FindVar(var));
const auto* shared_tensor = scope_->FindTensor(var);
if (!TargetCompatibleTo(*inst_type, *param_type->type)) { if (!TargetCompatibleTo(*inst_type, *param_type->type)) {
/// Create a tensor in the instruction's scope, alloc memory and then /// Create a tensor or tensor_array in the instruction's scope,
/// copy data there. /// alloc memory and then copy data there.
if (param_type->type->IsTensor()) {
const auto* shared_tensor = scope_->FindTensor(var);
auto* target_tensor = inst_scope_->NewTensor(var); auto* target_tensor = inst_scope_->NewTensor(var);
CHECK(!shared_tensor->dims().empty()) << "shared_tensor is empty yet"; CHECK(!shared_tensor->dims().empty()) << "shared_tensor is empty yet";
target_tensor->Resize(shared_tensor->dims()); target_tensor->Resize(shared_tensor->dims());
...@@ -88,6 +97,27 @@ void TestCase::PrepareInputsForInstruction() { ...@@ -88,6 +97,27 @@ void TestCase::PrepareInputsForInstruction() {
shared_tensor->memory_size()), shared_tensor->memory_size()),
shared_tensor->raw_data(), shared_tensor->raw_data(),
shared_tensor->memory_size()); shared_tensor->memory_size());
} else if (param_type->type->IsTensorList()) {
const auto* shared_tensor_array =
scope_->FindVar(var)->GetMutable<std::vector<Tensor>>();
auto* target_tensor_array =
inst_scope_->Var(var)->GetMutable<std::vector<Tensor>>();
CHECK(!shared_tensor_array->empty())
<< "shared_tensor_array is empty yet";
target_tensor_array->resize(shared_tensor_array->size());
for (int i = 0; i < shared_tensor_array->size(); i++) {
target_tensor_array->at(i).Resize(
shared_tensor_array->at(i).dims());
TargetCopy(param_type->type->target(),
target_tensor_array->at(i).mutable_data(
param_type->type->target(),
shared_tensor_array->at(i).memory_size()),
shared_tensor_array->at(i).raw_data(),
shared_tensor_array->at(i).memory_size());
}
} else {
LOG(FATAL) << "not support";
}
} }
} }
} }
......
...@@ -66,9 +66,17 @@ class TestCase { ...@@ -66,9 +66,17 @@ class TestCase {
/// output. /// output.
virtual void RunBaseline(Scope* scope) = 0; virtual void RunBaseline(Scope* scope) = 0;
/// Check the precision of the output tensors. It will compare the same tensor // checkout the precision of the two tensors. b_tensor is from the baseline
/// in two scopes, one of the instruction execution, and the other for the template <typename T>
/// baseline. bool CheckTensorPrecision(const Tensor* a_tensor,
const Tensor* b_tensor,
float abs_error);
/// Check the precision of the output variables. It will compare the same
/// tensor
/// (or all tensors of tensor_array) in two scopes, one of the instruction
/// execution,
/// and the other for the baseline.
template <typename T> template <typename T>
bool CheckPrecision(const std::string& var_name, float abs_error); bool CheckPrecision(const std::string& var_name, float abs_error);
...@@ -120,6 +128,34 @@ class TestCase { ...@@ -120,6 +128,34 @@ class TestCase {
tensor->set_persistable(is_persistable); tensor->set_persistable(is_persistable);
} }
/// Prepare a tensor_array in host. The tensors will be created in scope_.
/// Need to specify the targets other than X86 or ARM.
template <typename T>
void SetCommonTensorList(const std::string& var_name,
const std::vector<DDim>& array_tensor_dims,
const std::vector<std::vector<T>>& datas,
const std::vector<LoD>& lods = {}) {
CHECK_EQ(array_tensor_dims.size(), datas.size());
if (!lods.empty()) {
CHECK_EQ(array_tensor_dims.size(), lods.size());
}
auto* tensor_array =
scope_->Var(var_name)->GetMutable<std::vector<Tensor>>();
for (int i = 0; i < array_tensor_dims.size(); i++) {
Tensor tmp;
tmp.Resize(array_tensor_dims[i]);
auto* tmp_data = tmp.mutable_data<T>();
memcpy(tmp_data,
datas[i].data(),
array_tensor_dims[i].production() * sizeof(T));
if (!lods.empty()) {
tmp.set_lod(lods[i]);
}
tensor_array->push_back(tmp);
}
}
// Prepare for the operator. // Prepare for the operator.
virtual void PrepareOpDesc(cpp::OpDesc* op_desc) = 0; virtual void PrepareOpDesc(cpp::OpDesc* op_desc) = 0;
...@@ -263,9 +299,9 @@ class Arena { ...@@ -263,9 +299,9 @@ class Arena {
}; };
template <typename T> template <typename T>
bool TestCase::CheckPrecision(const std::string& var_name, float abs_error) { bool TestCase::CheckTensorPrecision(const Tensor* a_tensor,
auto a_tensor = inst_scope_->FindTensor(var_name); const Tensor* b_tensor,
auto b_tensor = base_scope_->FindTensor(var_name); float abs_error) {
CHECK(a_tensor); CHECK(a_tensor);
CHECK(b_tensor); CHECK(b_tensor);
...@@ -305,6 +341,34 @@ bool TestCase::CheckPrecision(const std::string& var_name, float abs_error) { ...@@ -305,6 +341,34 @@ bool TestCase::CheckPrecision(const std::string& var_name, float abs_error) {
return success; return success;
} }
template <typename T>
bool TestCase::CheckPrecision(const std::string& var_name, float abs_error) {
bool success = true;
if (inst_scope_->FindVar(var_name)->IsType<Tensor>()) {
auto a_tensor = inst_scope_->FindTensor(var_name);
auto b_tensor = base_scope_->FindTensor(var_name);
success = success && CheckTensorPrecision<T>(a_tensor, b_tensor, abs_error);
} else if (inst_scope_->FindVar(var_name)->IsType<std::vector<Tensor>>()) {
auto a_tensor_array =
inst_scope_->FindVar(var_name)->GetMutable<std::vector<Tensor>>();
auto b_tensor_array =
base_scope_->FindVar(var_name)->GetMutable<std::vector<Tensor>>();
CHECK_EQ(a_tensor_array->size(), b_tensor_array->size());
for (int i = 0; i < a_tensor_array->size(); i++) {
Tensor* a_tensor = &(a_tensor_array->at(i));
Tensor* b_tensor = &(b_tensor_array->at(i));
if (a_tensor->dims().size() == 0 && b_tensor->dims().size() == 0) {
continue;
}
success =
success && CheckTensorPrecision<T>(a_tensor, b_tensor, abs_error);
}
} else {
LOG(FATAL) << "unsupported var type";
}
return success;
}
} // namespace arena } // namespace arena
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -82,6 +82,7 @@ void TensorLite::CopyDataFrom(const TensorLite &other) { ...@@ -82,6 +82,7 @@ void TensorLite::CopyDataFrom(const TensorLite &other) {
target_ = other.target_; target_ = other.target_;
lod_ = other.lod_; lod_ = other.lod_;
memory_size_ = other.memory_size_; memory_size_ = other.memory_size_;
precision_ = other.precision_;
buffer_->CopyDataFrom(*other.buffer_, memory_size_); buffer_->CopyDataFrom(*other.buffer_, memory_size_);
} }
......
...@@ -177,8 +177,9 @@ static bool TargetCompatibleTo(const Type& a, const Type& b) { ...@@ -177,8 +177,9 @@ static bool TargetCompatibleTo(const Type& a, const Type& b) {
return x == TARGET(kHost) || x == TARGET(kX86) || x == TARGET(kARM); return x == TARGET(kHost) || x == TARGET(kX86) || x == TARGET(kARM);
}; };
if (a.IsVoid() || b.IsVoid()) return true; if (a.IsVoid() || b.IsVoid()) return true;
if (a.IsTensor() || b.IsTensor()) { if (a.IsTensor() || b.IsTensor() || a.IsTensorList() || b.IsTensorList()) {
if (a.IsTensor() && b.IsTensor()) { if ((a.IsTensor() && b.IsTensor()) ||
(a.IsTensorList() && b.IsTensorList())) {
return is_host(a.target()) ? is_host(b.target()) return is_host(a.target()) ? is_host(b.target())
: a.target() == b.target(); : a.target() == b.target();
} }
......
...@@ -20,23 +20,17 @@ namespace lite { ...@@ -20,23 +20,17 @@ namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace arm {
void ReadFromArrayCompute::PrepareForRun() {}
void ReadFromArrayCompute::Run() { void ReadFromArrayCompute::Run() {
auto& ctx = this->ctx_->template As<ARMContext>(); auto& ctx = this->ctx_->template As<ARMContext>();
auto& param = this->Param<operators::ReadFromArrayParam>(); auto& param = this->Param<param_t>();
int in_num = param.X->size();
CHECK_EQ(param.I->numel(), 1) << "I should have only one element"; CHECK_EQ(param.I->numel(), 1) << "I should have only one element";
int id = param.I->data<float>()[0]; int id = param.I->data<int64_t>()[0];
int in_num = param.X->size();
CHECK_LE(id, in_num) << "id is not valid"; CHECK_LE(id, in_num) << "id is not valid";
int input_size = (*param.X)[id].numel();
param.Out->Resize((*param.X)[id].dims()); param.Out->Resize((*param.X)[id].dims());
param.Out->CopyDataFrom((*param.X)[id]); param.Out->CopyDataFrom((*param.X)[id]);
auto out_lod = param.Out->mutable_lod();
*out_lod = (*param.X)[id].lod();
} }
} // namespace arm } // namespace arm
...@@ -46,11 +40,11 @@ void ReadFromArrayCompute::Run() { ...@@ -46,11 +40,11 @@ void ReadFromArrayCompute::Run() {
REGISTER_LITE_KERNEL(read_from_array, REGISTER_LITE_KERNEL(read_from_array,
kARM, kARM,
kFloat, kAny,
kNCHW, kNCHW,
paddle::lite::kernels::arm::ReadFromArrayCompute, paddle::lite::kernels::arm::ReadFromArrayCompute,
def) def)
.BindInput("X", {LiteType::GetTensorListTy(TARGET(kARM))}) .BindInput("X", {LiteType::GetTensorListTy(TARGET(kARM), PRECISION(kAny))})
.BindInput("I", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("I", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.Finalize(); .Finalize();
...@@ -23,13 +23,10 @@ namespace lite { ...@@ -23,13 +23,10 @@ namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace arm {
class ReadFromArrayCompute class ReadFromArrayCompute : public KernelLite<TARGET(kARM), PRECISION(kAny)> {
: public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public: public:
using param_t = operators::ReadFromArrayParam; using param_t = operators::ReadFromArrayParam;
void PrepareForRun() override;
void Run() override; void Run() override;
~ReadFromArrayCompute() {} ~ReadFromArrayCompute() {}
......
...@@ -24,33 +24,12 @@ void WriteToArrayCompute::Run() { ...@@ -24,33 +24,12 @@ void WriteToArrayCompute::Run() {
auto& ctx = this->ctx_->template As<ARMContext>(); auto& ctx = this->ctx_->template As<ARMContext>();
auto& param = this->template Param<operators::WriteToArrayParam>(); auto& param = this->template Param<operators::WriteToArrayParam>();
CHECK_EQ(param.I->numel(), 1) << "input2 should have only one element"; CHECK_EQ(param.I->numel(), 1) << "input2 should have only one element";
auto precision_type = param.X->precision();
#define SOLVE_TYPE(type__, T) \ int id = param.I->data<int64_t>()[0];
case type__: { \ if (param.Out->size() < id + 1) {
const auto* x_data = param.X->data<T>(); \ param.Out->resize(id + 1);
int id = param.I->data<int64_t>()[0]; \
if (id >= param.Out->size()) { \
for (int i = param.Out->size(); i < id + 1; i++) { \
lite::Tensor tmp; \
param.Out->push_back(tmp); \
} \
} \
(*param.Out)[id].Resize(param.X->dims()); \
auto out_lod = (*param.Out)[id].mutable_lod(); \
*out_lod = param.X->lod(); \
auto* o_data = (*param.Out)[id].mutable_data<T>(TARGET(kHost)); \
int input_size = param.X->numel(); \
memcpy(o_data, x_data, sizeof(T) * input_size); \
} break;
switch (precision_type) {
SOLVE_TYPE(PRECISION(kFloat), float);
SOLVE_TYPE(PRECISION(kInt64), int64_t);
default:
LOG(FATAL) << "Unsupported precision type.";
} }
#undef SOLVE_TYPE param.Out->at(id).CopyDataFrom(*param.X);
} }
} // namespace arm } // namespace arm
...@@ -65,6 +44,7 @@ REGISTER_LITE_KERNEL(write_to_array, ...@@ -65,6 +44,7 @@ REGISTER_LITE_KERNEL(write_to_array,
paddle::lite::kernels::arm::WriteToArrayCompute, paddle::lite::kernels::arm::WriteToArrayCompute,
def) def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.BindInput("I", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) .BindInput("I", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindOutput("Out", {LiteType::GetTensorListTy(TARGET(kARM))}) .BindOutput("Out",
{LiteType::GetTensorListTy(TARGET(kARM), PRECISION(kAny))})
.Finalize(); .Finalize();
...@@ -25,8 +25,6 @@ namespace arm { ...@@ -25,8 +25,6 @@ namespace arm {
class WriteToArrayCompute : public KernelLite<TARGET(kARM), PRECISION(kAny)> { class WriteToArrayCompute : public KernelLite<TARGET(kARM), PRECISION(kAny)> {
public: public:
using param_t = operators::WriteToArrayParam;
void Run() override; void Run() override;
~WriteToArrayCompute() {} ~WriteToArrayCompute() {}
......
...@@ -736,8 +736,8 @@ struct WriteToArrayParam { ...@@ -736,8 +736,8 @@ struct WriteToArrayParam {
}; };
struct ReadFromArrayParam { struct ReadFromArrayParam {
std::vector<lite::Tensor>* X{}; const std::vector<lite::Tensor>* X{};
lite::Tensor* I{}; const lite::Tensor* I{};
lite::Tensor* Out{}; lite::Tensor* Out{};
}; };
......
...@@ -19,11 +19,17 @@ namespace paddle { ...@@ -19,11 +19,17 @@ namespace paddle {
namespace lite { namespace lite {
namespace operators { namespace operators {
bool ReadFromArrayOp::CheckShape() const { return true; } bool ReadFromArrayOp::CheckShape() const {
CHECK(param_.X);
CHECK(param_.I);
CHECK(param_.Out);
return true;
}
bool ReadFromArrayOp::InferShape() const { bool ReadFromArrayOp::InferShape() const {
auto in_dims = (*param_.X)[0].dims(); int id = param_.I->data<int64_t>()[0];
param_.Out->Resize(in_dims); auto out_dims = (*param_.X)[id].dims();
param_.Out->Resize(out_dims);
return true; return true;
} }
...@@ -32,11 +38,9 @@ bool ReadFromArrayOp::AttachImpl(const cpp::OpDesc &opdesc, ...@@ -32,11 +38,9 @@ bool ReadFromArrayOp::AttachImpl(const cpp::OpDesc &opdesc,
auto in = opdesc.Input("X").front(); auto in = opdesc.Input("X").front();
param_.X = scope->FindVar(in)->GetMutable<std::vector<lite::Tensor>>(); param_.X = scope->FindVar(in)->GetMutable<std::vector<lite::Tensor>>();
param_.I = param_.I = scope->FindTensor(opdesc.Input("I").front());
scope->FindVar(opdesc.Input("I").front())->GetMutable<lite::Tensor>();
param_.Out = param_.Out = scope->FindMutableTensor(opdesc.Output("Out").front());
scope->FindVar(opdesc.Output("Out").front())->GetMutable<lite::Tensor>();
return true; return true;
} }
......
...@@ -19,25 +19,30 @@ namespace paddle { ...@@ -19,25 +19,30 @@ namespace paddle {
namespace lite { namespace lite {
namespace operators { namespace operators {
bool WriteToArrayOp::CheckShape() const { return true; } bool WriteToArrayOp::CheckShape() const {
CHECK(param_.X);
CHECK(param_.I);
CHECK(param_.Out);
return true;
}
bool WriteToArrayOp::InferShape() const { bool WriteToArrayOp::InferShape() const {
auto in_dims = param_.X->dims(); int id = param_.I->data<int64_t>()[0];
for (auto out : *param_.Out) { if (param_.Out->size() < id + 1) {
out.Resize(in_dims); param_.Out->resize(id + 1);
} }
return true; return true;
} }
bool WriteToArrayOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { bool WriteToArrayOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
auto inputs = opdesc.Input("X").front(); auto inputs = opdesc.Input("X").front();
param_.X = scope->FindVar(inputs)->GetMutable<lite::Tensor>(); param_.X = scope->FindTensor(inputs);
auto id = opdesc.Input("I").front(); auto id = opdesc.Input("I").front();
param_.I = scope->FindVar(id)->GetMutable<lite::Tensor>(); param_.I = scope->FindTensor(id);
auto out = opdesc.Output("Out").front(); auto out = opdesc.Output("Out").front();
param_.Out = scope->FindVar(out)->GetMutable<std::vector<lite::Tensor>>(); param_.Out = scope->FindVar(out)->GetMutable<std::vector<Tensor>>();
return true; return true;
} }
......
...@@ -23,8 +23,8 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA AND NOT LITE_WITH_BM) AND (LITE_ ...@@ -23,8 +23,8 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA AND NOT LITE_WITH_BM) AND (LITE_
#lite_cc_test(test_kernel_logical_xor_compute SRCS logical_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) #lite_cc_test(test_kernel_logical_xor_compute SRCS logical_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_topk_compute SRCS topk_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) #lite_cc_test(test_kernel_topk_compute SRCS topk_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_increment_compute SRCS increment_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) #lite_cc_test(test_kernel_increment_compute SRCS increment_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_write_to_array_compute SRCS write_to_array_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_write_to_array_compute SRCS write_to_array_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_read_from_array_compute SRCS read_from_array_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_read_from_array_compute SRCS read_from_array_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_concat_compute SRCS concat_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_concat_compute SRCS concat_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_transpose_compute SRCS transpose_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_transpose_compute SRCS transpose_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_reshape_compute SRCS reshape_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_reshape_compute SRCS reshape_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
......
...@@ -13,11 +13,10 @@ ...@@ -13,11 +13,10 @@
// limitations under the License. // limitations under the License.
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <stdio.h>
#include <stdlib.h>
#include "lite/api/paddle_use_kernels.h" #include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h" #include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h" #include "lite/core/arena/framework.h"
#include "lite/tests/utils/fill_data.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -25,80 +24,78 @@ namespace lite { ...@@ -25,80 +24,78 @@ namespace lite {
class ReadFromArrayComputeTester : public arena::TestCase { class ReadFromArrayComputeTester : public arena::TestCase {
protected: protected:
// common attributes for this op. // common attributes for this op.
std::string input_0 = "in_0"; std::string x_ = "x";
std::string input_1 = "in_1"; std::string idn_ = "i";
std::string input_2 = "in_2"; std::string out_ = "out";
std::string input_i = "i"; DDim tar_dims_{{3, 5, 4, 4}};
std::string output = "out"; int x_size_ = 1;
DDim dims_{{3, 5, 4, 4}}; int id_ = 0;
int i_;
public: public:
ReadFromArrayComputeTester(const Place& place, ReadFromArrayComputeTester(const Place& place,
const std::string& alias, const std::string& alias,
const int i, DDim tar_dims,
DDim dims) int x_size = 1,
: TestCase(place, alias), i_(i), dims_(dims) {} int id = 0)
: TestCase(place, alias), tar_dims_(tar_dims), x_size_(x_size), id_(id) {}
void RunBaseline(Scope* scope) override { void RunBaseline(Scope* scope) override {
auto* out = scope->NewTensor(output); auto x = scope->FindVar(x_)->GetMutable<std::vector<Tensor>>();
CHECK(out); auto idn = scope->FindTensor(idn_);
auto* in_0 = scope->FindTensor(input_0); auto out = scope->NewTensor(out_);
auto* in_1 = scope->FindTensor(input_1);
auto* in_2 = scope->FindTensor(input_2);
auto* id_tensor = scope->FindTensor(input_i);
std::vector<const TensorLite*> in_vec = {in_0, in_1, in_2};
int cur_in_num = in_vec.size();
int id = id_tensor->data<int>()[0]; int id = idn->data<int64_t>()[0];
out->Resize(dims_); out->CopyDataFrom(x->at(id));
const auto* in_data = in_vec[id]->data<float>();
auto* o_data = out->mutable_data<float>();
int n = in_vec[id]->numel();
memcpy(o_data, in_data, sizeof(float) * n);
} }
void PrepareOpDesc(cpp::OpDesc* op_desc) { void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("read_from_array"); op_desc->SetType("read_from_array");
op_desc->SetInput("X", {input_0, input_1, input_2}); op_desc->SetInput("X", {x_});
op_desc->SetInput("I", {input_i}); op_desc->SetInput("I", {idn_});
op_desc->SetOutput("Out", {output}); op_desc->SetOutput("Out", {out_});
} }
void PrepareData() override { void PrepareData() override {
std::vector<std::string> in_vec = {input_0, input_1, input_2}; std::vector<DDim> x_dims(x_size_);
for (auto in : in_vec) { std::vector<std::vector<float>> x_data(x_size_);
std::vector<float> data(dims_.production()); for (int i = 0; i < x_size_; i++) {
for (int i = 0; i < dims_.production(); i++) { x_dims[i] = tar_dims_;
data[i] = std::rand() * 1.0f / RAND_MAX; x_data[i].resize(x_dims[i].production());
} fill_data_rand(x_data[i].data(), -1.f, 1.f, x_dims[i].production());
SetCommonTensor(in, dims_, data.data());
} }
SetCommonTensorList(x_, x_dims, x_data);
std::vector<int64_t> didn(1);
didn[0] = id_;
SetCommonTensor(idn_, DDim{{1}}, didn.data());
DDimLite dims_i{{1}}; SetPrecisionType(out_, PRECISION(kFloat));
int a = 1;
SetCommonTensor(input_i, dims_i, &a);
} }
}; };
void test_read_from_array(Place place) { void TestReadFromArray(Place place, float abs_error) {
DDimLite dims{{3, 5, 4, 4}}; DDimLite dims{{3, 5, 4, 4}};
for (int i : {1, 2}) { for (int x_size : {1, 3}) {
for (int id : {0, 2}) {
if (x_size < id + 1) continue;
std::unique_ptr<arena::TestCase> tester( std::unique_ptr<arena::TestCase> tester(
new ReadFromArrayComputeTester(place, "def", i, dims)); new ReadFromArrayComputeTester(place, "def", dims, x_size, id));
arena::Arena arena(std::move(tester), place, 2e-5); arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision(); arena.TestPrecision();
} }
}
} }
TEST(ReadFromArray, precision) { TEST(ReadFromArray, precision) {
// #ifdef LITE_WITH_X86 Place place;
// Place place(TARGET(kX86)); float abs_error = 1e-5;
// #endif
#ifdef LITE_WITH_ARM #ifdef LITE_WITH_ARM
Place place(TARGET(kARM)); place = {TARGET(kARM), PRECISION(kAny)};
test_read_from_array(place); #else
return;
#endif #endif
TestReadFromArray(place, abs_error);
} }
} // namespace lite } // namespace lite
......
...@@ -13,11 +13,10 @@ ...@@ -13,11 +13,10 @@
// limitations under the License. // limitations under the License.
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <stdio.h>
#include <stdlib.h>
#include "lite/api/paddle_use_kernels.h" #include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h" #include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h" #include "lite/core/arena/framework.h"
#include "lite/tests/utils/fill_data.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -25,91 +24,75 @@ namespace lite { ...@@ -25,91 +24,75 @@ namespace lite {
class WriteToArrayComputeTester : public arena::TestCase { class WriteToArrayComputeTester : public arena::TestCase {
protected: protected:
// common attributes for this op. // common attributes for this op.
std::string input_0 = "x"; std::string x_ = "x";
std::string input_1 = "i"; std::string idn_ = "i";
std::string output_0 = "out0"; std::string out_ = "out";
std::string output_1 = "out1"; DDim x_dims_{{3, 5, 4, 4}};
std::string output_2 = "out2"; int out_size_ = 0;
DDim dims_{{3, 5, 4, 4}}; int id_ = 0;
int i_;
public: public:
WriteToArrayComputeTester(const Place& place, WriteToArrayComputeTester(const Place& place,
const std::string& alias, const std::string& alias,
const int i, DDim x_dims,
DDim dims) int out_size = 0,
: TestCase(place, alias), i_(i), dims_(dims) {} int id = 0)
: TestCase(place, alias), x_dims_(x_dims), out_size_(out_size), id_(id) {}
void RunBaseline(Scope* scope) override { void RunBaseline(Scope* scope) override {
auto* out_0 = scope->NewTensor(output_0); auto out = scope->Var(out_)->GetMutable<std::vector<Tensor>>();
auto* out_1 = scope->NewTensor(output_1); auto x = scope->FindTensor(x_);
auto* out_2 = scope->NewTensor(output_2);
CHECK(out_0);
CHECK(out_1);
CHECK(out_2);
std::vector<TensorLite*> out_vec = {out_0, out_1, out_2};
auto* x = scope->FindTensor(input_0); if (out->size() < id_ + 1) {
const auto* x_data = x->data<float>(); out->resize(id_ + 1);
auto* id = scope->FindTensor(input_1);
const auto* id_data = id->data<float>();
int n = x->numel();
int cur_out_num = out_vec.size();
for (int i = cur_out_num; i < id_data[0] + 1; i++) {
char buffer[30];
snprintf(buffer, sizeof(buffer), "out%d", i);
auto out = scope->NewTensor(buffer);
out_vec.push_back(out);
} }
out_vec[id_data[0]]->Resize(dims_); out->at(id_).Resize(x->dims());
auto* out_data = out_vec[id_data[0]]->mutable_data<float>(); auto out_data = out->at(id_).mutable_data<float>();
memcpy(out_data, x_data, sizeof(float) * n); memcpy(out_data, x->data<float>(), sizeof(float) * x->numel());
} }
void PrepareOpDesc(cpp::OpDesc* op_desc) { void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("write_to_array"); op_desc->SetType("write_to_array");
op_desc->SetInput("X", {input_0}); op_desc->SetInput("X", {x_});
op_desc->SetInput("I", {input_1}); op_desc->SetInput("I", {idn_});
op_desc->SetOutput("Out", {output_0, output_1, output_2}); op_desc->SetOutput("Out", {out_});
} }
void PrepareData() override { void PrepareData() override {
std::vector<float> data(dims_.production()); std::vector<float> dx(x_dims_.production());
fill_data_rand(dx.data(), -1.f, 1.f, x_dims_.production());
SetCommonTensor(x_, x_dims_, dx.data());
for (int i = 0; i < dims_.production(); i++) { std::vector<int64_t> didn(1);
data[i] = i * 1.1; didn[0] = id_;
} SetCommonTensor(idn_, DDim{{1}}, didn.data());
SetCommonTensor(input_0, dims_, data.data());
std::vector<int> data_1(1);
data_1[0] = i_;
DDimLite dims_2{{1}};
SetCommonTensor(input_1, dims_2, data_1.data());
SetCommonTensor(output_0, dims_2, data_1.data()); SetPrecisionType(out_, PRECISION(kFloat));
SetCommonTensor(output_1, dims_2, data_1.data());
SetCommonTensor(output_2, dims_2, data_1.data());
} }
}; };
void test_write_to_array(Place place) {
void TestWriteToArray(Place place, float abs_error) {
DDimLite dims{{3, 5, 4, 4}}; DDimLite dims{{3, 5, 4, 4}};
for (int i : {1, 4}) { for (int out_size : {0, 3}) {
for (int id : {0, 1, 4}) {
std::unique_ptr<arena::TestCase> tester( std::unique_ptr<arena::TestCase> tester(
new WriteToArrayComputeTester(place, "def", i, dims)); new WriteToArrayComputeTester(place, "def", dims, out_size, id));
arena::Arena arena(std::move(tester), place, 2e-5); arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision(); arena.TestPrecision();
} }
}
} }
TEST(WriteToArray, precision) { TEST(WriteToArray, precision) {
// #ifdef LITE_WITH_X86 Place place;
// Place place(TARGET(kX86)); float abs_error = 1e-5;
// #endif
#ifdef LITE_WITH_ARM #ifdef LITE_WITH_ARM
Place place(TARGET(kARM)); place = {TARGET(kARM), PRECISION(kAny)};
test_write_to_array(place); #else
return;
#endif #endif
TestWriteToArray(place, abs_error);
} }
} // namespace lite } // namespace lite
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册