未验证 提交 1ae03f55 编写于 作者: Z zhupengyang 提交者: GitHub

arena/framework support tensor_arra, add write_to_array and read_from_array uts (#3083)

上级 0cc72c42
......@@ -74,20 +74,50 @@ void TestCase::PrepareInputsForInstruction() {
const auto* param_type = ParamTypeRegistry::Global().RetrieveInArgument(
place_, kernel_key, arg);
const auto* inst_type = Type::GetTensorTy(TARGET(kHost));
const Type* inst_type = nullptr;
if (param_type->type->IsTensor()) {
inst_type = Type::GetTensorTy(TARGET(kHost));
} else if (param_type->type->IsTensorList()) {
inst_type = Type::GetTensorListTy(TARGET(kHost));
} else {
LOG(FATAL) << "unsupported param_type";
}
CHECK(scope_->FindVar(var));
const auto* shared_tensor = scope_->FindTensor(var);
if (!TargetCompatibleTo(*inst_type, *param_type->type)) {
/// Create a tensor in the instruction's scope, alloc memory and then
/// copy data there.
auto* target_tensor = inst_scope_->NewTensor(var);
CHECK(!shared_tensor->dims().empty()) << "shared_tensor is empty yet";
target_tensor->Resize(shared_tensor->dims());
TargetCopy(param_type->type->target(),
target_tensor->mutable_data(param_type->type->target(),
shared_tensor->memory_size()),
shared_tensor->raw_data(),
shared_tensor->memory_size());
/// Create a tensor or tensor_array in the instruction's scope,
/// alloc memory and then copy data there.
if (param_type->type->IsTensor()) {
const auto* shared_tensor = scope_->FindTensor(var);
auto* target_tensor = inst_scope_->NewTensor(var);
CHECK(!shared_tensor->dims().empty()) << "shared_tensor is empty yet";
target_tensor->Resize(shared_tensor->dims());
TargetCopy(param_type->type->target(),
target_tensor->mutable_data(param_type->type->target(),
shared_tensor->memory_size()),
shared_tensor->raw_data(),
shared_tensor->memory_size());
} else if (param_type->type->IsTensorList()) {
const auto* shared_tensor_array =
scope_->FindVar(var)->GetMutable<std::vector<Tensor>>();
auto* target_tensor_array =
inst_scope_->Var(var)->GetMutable<std::vector<Tensor>>();
CHECK(!shared_tensor_array->empty())
<< "shared_tensor_array is empty yet";
target_tensor_array->resize(shared_tensor_array->size());
for (int i = 0; i < shared_tensor_array->size(); i++) {
target_tensor_array->at(i).Resize(
shared_tensor_array->at(i).dims());
TargetCopy(param_type->type->target(),
target_tensor_array->at(i).mutable_data(
param_type->type->target(),
shared_tensor_array->at(i).memory_size()),
shared_tensor_array->at(i).raw_data(),
shared_tensor_array->at(i).memory_size());
}
} else {
LOG(FATAL) << "not support";
}
}
}
}
......
......@@ -66,9 +66,17 @@ class TestCase {
/// output.
virtual void RunBaseline(Scope* scope) = 0;
/// Check the precision of the output tensors. It will compare the same tensor
/// in two scopes, one of the instruction execution, and the other for the
/// baseline.
// checkout the precision of the two tensors. b_tensor is from the baseline
template <typename T>
bool CheckTensorPrecision(const Tensor* a_tensor,
const Tensor* b_tensor,
float abs_error);
/// Check the precision of the output variables. It will compare the same
/// tensor
/// (or all tensors of tensor_array) in two scopes, one of the instruction
/// execution,
/// and the other for the baseline.
template <typename T>
bool CheckPrecision(const std::string& var_name, float abs_error);
......@@ -120,6 +128,34 @@ class TestCase {
tensor->set_persistable(is_persistable);
}
/// Prepare a tensor_array in host. The tensors will be created in scope_.
/// Need to specify the targets other than X86 or ARM.
template <typename T>
void SetCommonTensorList(const std::string& var_name,
const std::vector<DDim>& array_tensor_dims,
const std::vector<std::vector<T>>& datas,
const std::vector<LoD>& lods = {}) {
CHECK_EQ(array_tensor_dims.size(), datas.size());
if (!lods.empty()) {
CHECK_EQ(array_tensor_dims.size(), lods.size());
}
auto* tensor_array =
scope_->Var(var_name)->GetMutable<std::vector<Tensor>>();
for (int i = 0; i < array_tensor_dims.size(); i++) {
Tensor tmp;
tmp.Resize(array_tensor_dims[i]);
auto* tmp_data = tmp.mutable_data<T>();
memcpy(tmp_data,
datas[i].data(),
array_tensor_dims[i].production() * sizeof(T));
if (!lods.empty()) {
tmp.set_lod(lods[i]);
}
tensor_array->push_back(tmp);
}
}
// Prepare for the operator.
virtual void PrepareOpDesc(cpp::OpDesc* op_desc) = 0;
......@@ -263,9 +299,9 @@ class Arena {
};
template <typename T>
bool TestCase::CheckPrecision(const std::string& var_name, float abs_error) {
auto a_tensor = inst_scope_->FindTensor(var_name);
auto b_tensor = base_scope_->FindTensor(var_name);
bool TestCase::CheckTensorPrecision(const Tensor* a_tensor,
const Tensor* b_tensor,
float abs_error) {
CHECK(a_tensor);
CHECK(b_tensor);
......@@ -305,6 +341,34 @@ bool TestCase::CheckPrecision(const std::string& var_name, float abs_error) {
return success;
}
template <typename T>
bool TestCase::CheckPrecision(const std::string& var_name, float abs_error) {
bool success = true;
if (inst_scope_->FindVar(var_name)->IsType<Tensor>()) {
auto a_tensor = inst_scope_->FindTensor(var_name);
auto b_tensor = base_scope_->FindTensor(var_name);
success = success && CheckTensorPrecision<T>(a_tensor, b_tensor, abs_error);
} else if (inst_scope_->FindVar(var_name)->IsType<std::vector<Tensor>>()) {
auto a_tensor_array =
inst_scope_->FindVar(var_name)->GetMutable<std::vector<Tensor>>();
auto b_tensor_array =
base_scope_->FindVar(var_name)->GetMutable<std::vector<Tensor>>();
CHECK_EQ(a_tensor_array->size(), b_tensor_array->size());
for (int i = 0; i < a_tensor_array->size(); i++) {
Tensor* a_tensor = &(a_tensor_array->at(i));
Tensor* b_tensor = &(b_tensor_array->at(i));
if (a_tensor->dims().size() == 0 && b_tensor->dims().size() == 0) {
continue;
}
success =
success && CheckTensorPrecision<T>(a_tensor, b_tensor, abs_error);
}
} else {
LOG(FATAL) << "unsupported var type";
}
return success;
}
} // namespace arena
} // namespace lite
} // namespace paddle
......@@ -82,6 +82,7 @@ void TensorLite::CopyDataFrom(const TensorLite &other) {
target_ = other.target_;
lod_ = other.lod_;
memory_size_ = other.memory_size_;
precision_ = other.precision_;
buffer_->CopyDataFrom(*other.buffer_, memory_size_);
}
......
......@@ -177,8 +177,9 @@ static bool TargetCompatibleTo(const Type& a, const Type& b) {
return x == TARGET(kHost) || x == TARGET(kX86) || x == TARGET(kARM);
};
if (a.IsVoid() || b.IsVoid()) return true;
if (a.IsTensor() || b.IsTensor()) {
if (a.IsTensor() && b.IsTensor()) {
if (a.IsTensor() || b.IsTensor() || a.IsTensorList() || b.IsTensorList()) {
if ((a.IsTensor() && b.IsTensor()) ||
(a.IsTensorList() && b.IsTensorList())) {
return is_host(a.target()) ? is_host(b.target())
: a.target() == b.target();
}
......
......@@ -20,23 +20,17 @@ namespace lite {
namespace kernels {
namespace arm {
void ReadFromArrayCompute::PrepareForRun() {}
void ReadFromArrayCompute::Run() {
auto& ctx = this->ctx_->template As<ARMContext>();
auto& param = this->Param<operators::ReadFromArrayParam>();
auto& param = this->Param<param_t>();
int in_num = param.X->size();
CHECK_EQ(param.I->numel(), 1) << "I should have only one element";
int id = param.I->data<float>()[0];
int id = param.I->data<int64_t>()[0];
int in_num = param.X->size();
CHECK_LE(id, in_num) << "id is not valid";
int input_size = (*param.X)[id].numel();
param.Out->Resize((*param.X)[id].dims());
param.Out->CopyDataFrom((*param.X)[id]);
auto out_lod = param.Out->mutable_lod();
*out_lod = (*param.X)[id].lod();
}
} // namespace arm
......@@ -46,11 +40,11 @@ void ReadFromArrayCompute::Run() {
REGISTER_LITE_KERNEL(read_from_array,
kARM,
kFloat,
kAny,
kNCHW,
paddle::lite::kernels::arm::ReadFromArrayCompute,
def)
.BindInput("X", {LiteType::GetTensorListTy(TARGET(kARM))})
.BindInput("I", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("X", {LiteType::GetTensorListTy(TARGET(kARM), PRECISION(kAny))})
.BindInput("I", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.Finalize();
......@@ -23,13 +23,10 @@ namespace lite {
namespace kernels {
namespace arm {
class ReadFromArrayCompute
: public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
class ReadFromArrayCompute : public KernelLite<TARGET(kARM), PRECISION(kAny)> {
public:
using param_t = operators::ReadFromArrayParam;
void PrepareForRun() override;
void Run() override;
~ReadFromArrayCompute() {}
......
......@@ -24,33 +24,12 @@ void WriteToArrayCompute::Run() {
auto& ctx = this->ctx_->template As<ARMContext>();
auto& param = this->template Param<operators::WriteToArrayParam>();
CHECK_EQ(param.I->numel(), 1) << "input2 should have only one element";
auto precision_type = param.X->precision();
#define SOLVE_TYPE(type__, T) \
case type__: { \
const auto* x_data = param.X->data<T>(); \
int id = param.I->data<int64_t>()[0]; \
if (id >= param.Out->size()) { \
for (int i = param.Out->size(); i < id + 1; i++) { \
lite::Tensor tmp; \
param.Out->push_back(tmp); \
} \
} \
(*param.Out)[id].Resize(param.X->dims()); \
auto out_lod = (*param.Out)[id].mutable_lod(); \
*out_lod = param.X->lod(); \
auto* o_data = (*param.Out)[id].mutable_data<T>(TARGET(kHost)); \
int input_size = param.X->numel(); \
memcpy(o_data, x_data, sizeof(T) * input_size); \
} break;
switch (precision_type) {
SOLVE_TYPE(PRECISION(kFloat), float);
SOLVE_TYPE(PRECISION(kInt64), int64_t);
default:
LOG(FATAL) << "Unsupported precision type.";
int id = param.I->data<int64_t>()[0];
if (param.Out->size() < id + 1) {
param.Out->resize(id + 1);
}
#undef SOLVE_TYPE
param.Out->at(id).CopyDataFrom(*param.X);
}
} // namespace arm
......@@ -65,6 +44,7 @@ REGISTER_LITE_KERNEL(write_to_array,
paddle::lite::kernels::arm::WriteToArrayCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.BindInput("I", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.BindOutput("Out", {LiteType::GetTensorListTy(TARGET(kARM))})
.BindInput("I", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindOutput("Out",
{LiteType::GetTensorListTy(TARGET(kARM), PRECISION(kAny))})
.Finalize();
......@@ -25,8 +25,6 @@ namespace arm {
class WriteToArrayCompute : public KernelLite<TARGET(kARM), PRECISION(kAny)> {
public:
using param_t = operators::WriteToArrayParam;
void Run() override;
~WriteToArrayCompute() {}
......
......@@ -736,8 +736,8 @@ struct WriteToArrayParam {
};
struct ReadFromArrayParam {
std::vector<lite::Tensor>* X{};
lite::Tensor* I{};
const std::vector<lite::Tensor>* X{};
const lite::Tensor* I{};
lite::Tensor* Out{};
};
......
......@@ -19,11 +19,17 @@ namespace paddle {
namespace lite {
namespace operators {
bool ReadFromArrayOp::CheckShape() const { return true; }
bool ReadFromArrayOp::CheckShape() const {
CHECK(param_.X);
CHECK(param_.I);
CHECK(param_.Out);
return true;
}
bool ReadFromArrayOp::InferShape() const {
auto in_dims = (*param_.X)[0].dims();
param_.Out->Resize(in_dims);
int id = param_.I->data<int64_t>()[0];
auto out_dims = (*param_.X)[id].dims();
param_.Out->Resize(out_dims);
return true;
}
......@@ -32,11 +38,9 @@ bool ReadFromArrayOp::AttachImpl(const cpp::OpDesc &opdesc,
auto in = opdesc.Input("X").front();
param_.X = scope->FindVar(in)->GetMutable<std::vector<lite::Tensor>>();
param_.I =
scope->FindVar(opdesc.Input("I").front())->GetMutable<lite::Tensor>();
param_.I = scope->FindTensor(opdesc.Input("I").front());
param_.Out =
scope->FindVar(opdesc.Output("Out").front())->GetMutable<lite::Tensor>();
param_.Out = scope->FindMutableTensor(opdesc.Output("Out").front());
return true;
}
......
......@@ -19,25 +19,30 @@ namespace paddle {
namespace lite {
namespace operators {
bool WriteToArrayOp::CheckShape() const { return true; }
bool WriteToArrayOp::CheckShape() const {
CHECK(param_.X);
CHECK(param_.I);
CHECK(param_.Out);
return true;
}
bool WriteToArrayOp::InferShape() const {
auto in_dims = param_.X->dims();
for (auto out : *param_.Out) {
out.Resize(in_dims);
int id = param_.I->data<int64_t>()[0];
if (param_.Out->size() < id + 1) {
param_.Out->resize(id + 1);
}
return true;
}
bool WriteToArrayOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
auto inputs = opdesc.Input("X").front();
param_.X = scope->FindVar(inputs)->GetMutable<lite::Tensor>();
param_.X = scope->FindTensor(inputs);
auto id = opdesc.Input("I").front();
param_.I = scope->FindVar(id)->GetMutable<lite::Tensor>();
param_.I = scope->FindTensor(id);
auto out = opdesc.Output("Out").front();
param_.Out = scope->FindVar(out)->GetMutable<std::vector<lite::Tensor>>();
param_.Out = scope->FindVar(out)->GetMutable<std::vector<Tensor>>();
return true;
}
......
......@@ -23,8 +23,8 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA AND NOT LITE_WITH_BM) AND (LITE_
#lite_cc_test(test_kernel_logical_xor_compute SRCS logical_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_topk_compute SRCS topk_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_increment_compute SRCS increment_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_write_to_array_compute SRCS write_to_array_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_read_from_array_compute SRCS read_from_array_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_write_to_array_compute SRCS write_to_array_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_read_from_array_compute SRCS read_from_array_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_concat_compute SRCS concat_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_transpose_compute SRCS transpose_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_reshape_compute SRCS reshape_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
......
......@@ -13,11 +13,10 @@
// limitations under the License.
#include <gtest/gtest.h>
#include <stdio.h>
#include <stdlib.h>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
#include "lite/tests/utils/fill_data.h"
namespace paddle {
namespace lite {
......@@ -25,80 +24,78 @@ namespace lite {
class ReadFromArrayComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std::string input_0 = "in_0";
std::string input_1 = "in_1";
std::string input_2 = "in_2";
std::string input_i = "i";
std::string output = "out";
DDim dims_{{3, 5, 4, 4}};
int i_;
std::string x_ = "x";
std::string idn_ = "i";
std::string out_ = "out";
DDim tar_dims_{{3, 5, 4, 4}};
int x_size_ = 1;
int id_ = 0;
public:
ReadFromArrayComputeTester(const Place& place,
const std::string& alias,
const int i,
DDim dims)
: TestCase(place, alias), i_(i), dims_(dims) {}
DDim tar_dims,
int x_size = 1,
int id = 0)
: TestCase(place, alias), tar_dims_(tar_dims), x_size_(x_size), id_(id) {}
void RunBaseline(Scope* scope) override {
auto* out = scope->NewTensor(output);
CHECK(out);
auto* in_0 = scope->FindTensor(input_0);
auto* in_1 = scope->FindTensor(input_1);
auto* in_2 = scope->FindTensor(input_2);
auto* id_tensor = scope->FindTensor(input_i);
std::vector<const TensorLite*> in_vec = {in_0, in_1, in_2};
int cur_in_num = in_vec.size();
auto x = scope->FindVar(x_)->GetMutable<std::vector<Tensor>>();
auto idn = scope->FindTensor(idn_);
auto out = scope->NewTensor(out_);
int id = id_tensor->data<int>()[0];
out->Resize(dims_);
const auto* in_data = in_vec[id]->data<float>();
auto* o_data = out->mutable_data<float>();
int n = in_vec[id]->numel();
memcpy(o_data, in_data, sizeof(float) * n);
int id = idn->data<int64_t>()[0];
out->CopyDataFrom(x->at(id));
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("read_from_array");
op_desc->SetInput("X", {input_0, input_1, input_2});
op_desc->SetInput("I", {input_i});
op_desc->SetOutput("Out", {output});
op_desc->SetInput("X", {x_});
op_desc->SetInput("I", {idn_});
op_desc->SetOutput("Out", {out_});
}
void PrepareData() override {
std::vector<std::string> in_vec = {input_0, input_1, input_2};
for (auto in : in_vec) {
std::vector<float> data(dims_.production());
for (int i = 0; i < dims_.production(); i++) {
data[i] = std::rand() * 1.0f / RAND_MAX;
}
SetCommonTensor(in, dims_, data.data());
std::vector<DDim> x_dims(x_size_);
std::vector<std::vector<float>> x_data(x_size_);
for (int i = 0; i < x_size_; i++) {
x_dims[i] = tar_dims_;
x_data[i].resize(x_dims[i].production());
fill_data_rand(x_data[i].data(), -1.f, 1.f, x_dims[i].production());
}
SetCommonTensorList(x_, x_dims, x_data);
DDimLite dims_i{{1}};
int a = 1;
SetCommonTensor(input_i, dims_i, &a);
std::vector<int64_t> didn(1);
didn[0] = id_;
SetCommonTensor(idn_, DDim{{1}}, didn.data());
SetPrecisionType(out_, PRECISION(kFloat));
}
};
void test_read_from_array(Place place) {
void TestReadFromArray(Place place, float abs_error) {
DDimLite dims{{3, 5, 4, 4}};
for (int i : {1, 2}) {
std::unique_ptr<arena::TestCase> tester(
new ReadFromArrayComputeTester(place, "def", i, dims));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
for (int x_size : {1, 3}) {
for (int id : {0, 2}) {
if (x_size < id + 1) continue;
std::unique_ptr<arena::TestCase> tester(
new ReadFromArrayComputeTester(place, "def", dims, x_size, id));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
}
}
TEST(ReadFromArray, precision) {
// #ifdef LITE_WITH_X86
// Place place(TARGET(kX86));
// #endif
Place place;
float abs_error = 1e-5;
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
test_read_from_array(place);
place = {TARGET(kARM), PRECISION(kAny)};
#else
return;
#endif
TestReadFromArray(place, abs_error);
}
} // namespace lite
......
......@@ -13,11 +13,10 @@
// limitations under the License.
#include <gtest/gtest.h>
#include <stdio.h>
#include <stdlib.h>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
#include "lite/tests/utils/fill_data.h"
namespace paddle {
namespace lite {
......@@ -25,91 +24,75 @@ namespace lite {
class WriteToArrayComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std::string input_0 = "x";
std::string input_1 = "i";
std::string output_0 = "out0";
std::string output_1 = "out1";
std::string output_2 = "out2";
DDim dims_{{3, 5, 4, 4}};
int i_;
std::string x_ = "x";
std::string idn_ = "i";
std::string out_ = "out";
DDim x_dims_{{3, 5, 4, 4}};
int out_size_ = 0;
int id_ = 0;
public:
WriteToArrayComputeTester(const Place& place,
const std::string& alias,
const int i,
DDim dims)
: TestCase(place, alias), i_(i), dims_(dims) {}
DDim x_dims,
int out_size = 0,
int id = 0)
: TestCase(place, alias), x_dims_(x_dims), out_size_(out_size), id_(id) {}
void RunBaseline(Scope* scope) override {
auto* out_0 = scope->NewTensor(output_0);
auto* out_1 = scope->NewTensor(output_1);
auto* out_2 = scope->NewTensor(output_2);
CHECK(out_0);
CHECK(out_1);
CHECK(out_2);
std::vector<TensorLite*> out_vec = {out_0, out_1, out_2};
auto out = scope->Var(out_)->GetMutable<std::vector<Tensor>>();
auto x = scope->FindTensor(x_);
auto* x = scope->FindTensor(input_0);
const auto* x_data = x->data<float>();
auto* id = scope->FindTensor(input_1);
const auto* id_data = id->data<float>();
int n = x->numel();
int cur_out_num = out_vec.size();
for (int i = cur_out_num; i < id_data[0] + 1; i++) {
char buffer[30];
snprintf(buffer, sizeof(buffer), "out%d", i);
auto out = scope->NewTensor(buffer);
out_vec.push_back(out);
if (out->size() < id_ + 1) {
out->resize(id_ + 1);
}
out_vec[id_data[0]]->Resize(dims_);
auto* out_data = out_vec[id_data[0]]->mutable_data<float>();
memcpy(out_data, x_data, sizeof(float) * n);
out->at(id_).Resize(x->dims());
auto out_data = out->at(id_).mutable_data<float>();
memcpy(out_data, x->data<float>(), sizeof(float) * x->numel());
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("write_to_array");
op_desc->SetInput("X", {input_0});
op_desc->SetInput("I", {input_1});
op_desc->SetOutput("Out", {output_0, output_1, output_2});
op_desc->SetInput("X", {x_});
op_desc->SetInput("I", {idn_});
op_desc->SetOutput("Out", {out_});
}
void PrepareData() override {
std::vector<float> data(dims_.production());
std::vector<float> dx(x_dims_.production());
fill_data_rand(dx.data(), -1.f, 1.f, x_dims_.production());
SetCommonTensor(x_, x_dims_, dx.data());
for (int i = 0; i < dims_.production(); i++) {
data[i] = i * 1.1;
}
SetCommonTensor(input_0, dims_, data.data());
std::vector<int> data_1(1);
data_1[0] = i_;
DDimLite dims_2{{1}};
SetCommonTensor(input_1, dims_2, data_1.data());
std::vector<int64_t> didn(1);
didn[0] = id_;
SetCommonTensor(idn_, DDim{{1}}, didn.data());
SetCommonTensor(output_0, dims_2, data_1.data());
SetCommonTensor(output_1, dims_2, data_1.data());
SetCommonTensor(output_2, dims_2, data_1.data());
SetPrecisionType(out_, PRECISION(kFloat));
}
};
void test_write_to_array(Place place) {
void TestWriteToArray(Place place, float abs_error) {
DDimLite dims{{3, 5, 4, 4}};
for (int i : {1, 4}) {
std::unique_ptr<arena::TestCase> tester(
new WriteToArrayComputeTester(place, "def", i, dims));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
for (int out_size : {0, 3}) {
for (int id : {0, 1, 4}) {
std::unique_ptr<arena::TestCase> tester(
new WriteToArrayComputeTester(place, "def", dims, out_size, id));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
}
}
TEST(WriteToArray, precision) {
// #ifdef LITE_WITH_X86
// Place place(TARGET(kX86));
// #endif
Place place;
float abs_error = 1e-5;
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
test_write_to_array(place);
place = {TARGET(kARM), PRECISION(kAny)};
#else
return;
#endif
TestWriteToArray(place, abs_error);
}
} // namespace lite
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册