未验证 提交 13072661 编写于 作者: H hong19860320 提交者: GitHub

[Core] Fix the test framework for XPU (#3841)

上级 a8d75d20
...@@ -55,7 +55,7 @@ void TestCase::CreateInstruction() { ...@@ -55,7 +55,7 @@ void TestCase::CreateInstruction() {
op = LiteOpRegistry::Global().Create(op_desc().Type()); op = LiteOpRegistry::Global().Create(op_desc().Type());
} }
CHECK(op) << "no op for " << op_desc().Type(); CHECK(op) << "no op for " << op_desc().Type();
op->Attach(*op_desc_, inst_scope_); op->Attach(*op_desc_, inst_scope_.get());
auto kernels = op->CreateKernels({place_}); auto kernels = op->CreateKernels({place_});
// filter out the target kernel // filter out the target kernel
CHECK(!kernels.empty()) << "No kernel found for place " CHECK(!kernels.empty()) << "No kernel found for place "
...@@ -80,54 +80,35 @@ void TestCase::CreateInstruction() { ...@@ -80,54 +80,35 @@ void TestCase::CreateInstruction() {
void TestCase::PrepareInputsForInstruction() { void TestCase::PrepareInputsForInstruction() {
for (auto& arg : op_desc().InputArgumentNames()) { for (auto& arg : op_desc().InputArgumentNames()) {
for (auto& var : op_desc().Input(arg)) { for (auto& var : op_desc().Input(arg)) {
std::string kernel_key = instruction_->kernel()->key_with_alias(); const auto* type = instruction_->kernel()->GetInputDeclType(arg);
const auto* param_type = ParamTypeRegistry::Global().RetrieveInArgument( CHECK(base_scope_->FindVar(var));
place_, kernel_key, arg); /// Create a tensor or tensor_array in the instruction's scope,
/// alloc memory and then copy data there.
const Type* inst_type = nullptr; if (type->IsTensor() &&
if (param_type->type->IsTensor()) { !TargetCompatibleTo(*Type::GetTensorTy(TARGET(kHost)), *type)) {
inst_type = Type::GetTensorTy(TARGET(kHost)); const auto* base_tensor = base_scope_->FindTensor(var);
} else if (param_type->type->IsTensorList()) { auto* inst_tensor = inst_scope_->FindMutableTensor(var);
inst_type = Type::GetTensorListTy(TARGET(kHost)); CHECK(!base_tensor->dims().empty())
} else { << "The dims of input tensor is empty yet";
LOG(FATAL) << "unsupported param_type"; TargetCopy(type->target(),
} inst_tensor->mutable_data(type->target(),
base_tensor->memory_size()),
CHECK(scope_->FindVar(var)); base_tensor->raw_data(),
if (!TargetCompatibleTo(*inst_type, *param_type->type)) { base_tensor->memory_size());
/// Create a tensor or tensor_array in the instruction's scope, } else if (type->IsTensorList() &&
/// alloc memory and then copy data there. !TargetCompatibleTo(*Type::GetTensorListTy(TARGET(kHost)),
if (param_type->type->IsTensor()) { *type)) {
const auto* shared_tensor = scope_->FindTensor(var); const auto* base_tensor_list = base_scope_->FindTensorList(var);
auto* target_tensor = auto* inst_tensor_list = inst_scope_->FindMutableTensorList(var);
inst_scope_->LocalVar(var)->GetMutable<Tensor>(); CHECK_EQ(base_tensor_list->size(), inst_tensor_list->size());
CHECK(!shared_tensor->dims().empty()) << "shared_tensor is empty yet"; for (size_t i = 0; i < base_tensor_list->size(); i++) {
target_tensor->Resize(shared_tensor->dims()); CHECK(!base_tensor_list->at(i).dims().empty())
TargetCopy(param_type->type->target(), << "The dims of input tensor[" << i << "] is empty yet";
target_tensor->mutable_data(param_type->type->target(), TargetCopy(type->target(),
shared_tensor->memory_size()), inst_tensor_list->at(i).mutable_data(
shared_tensor->raw_data(), type->target(), base_tensor_list->at(i).memory_size()),
shared_tensor->memory_size()); inst_tensor_list->at(i).raw_data(),
} else if (param_type->type->IsTensorList()) { inst_tensor_list->at(i).memory_size());
const auto* shared_tensor_array =
scope_->FindVar(var)->GetMutable<std::vector<Tensor>>();
auto* target_tensor_array =
inst_scope_->LocalVar(var)->GetMutable<std::vector<Tensor>>();
CHECK(!shared_tensor_array->empty())
<< "shared_tensor_array is empty yet";
target_tensor_array->resize(shared_tensor_array->size());
for (size_t i = 0; i < shared_tensor_array->size(); i++) {
target_tensor_array->at(i).Resize(
shared_tensor_array->at(i).dims());
TargetCopy(param_type->type->target(),
target_tensor_array->at(i).mutable_data(
param_type->type->target(),
shared_tensor_array->at(i).memory_size()),
shared_tensor_array->at(i).raw_data(),
shared_tensor_array->at(i).memory_size());
}
} else {
LOG(FATAL) << "not support";
} }
} }
} }
...@@ -135,37 +116,36 @@ void TestCase::PrepareInputsForInstruction() { ...@@ -135,37 +116,36 @@ void TestCase::PrepareInputsForInstruction() {
} }
template <typename T> template <typename T>
bool TestCase::CheckTensorPrecision(const Tensor* a_tensor, bool TestCase::CheckTensorPrecision(const Tensor* inst_tensor,
const Tensor* b_tensor, const Tensor* base_tensor,
float abs_error) { float abs_error) {
CHECK(a_tensor); CHECK(inst_tensor);
CHECK(b_tensor); CHECK(base_tensor);
CHECK(ShapeEquals(a_tensor->dims(), b_tensor->dims())); CHECK(ShapeEquals(inst_tensor->dims(), base_tensor->dims()));
CHECK(a_tensor->lod() == b_tensor->lod()) << "lod not match"; CHECK(inst_tensor->lod() == base_tensor->lod()) << "lod not match";
// The baseline should output in host devices. // The baseline should output in host devices.
CHECK(b_tensor->target() == TARGET(kHost) || CHECK(base_tensor->target() == TARGET(kHost) ||
b_tensor->target() == TARGET(kX86) || base_tensor->target() == TARGET(kX86) ||
b_tensor->target() == TARGET(kARM)); base_tensor->target() == TARGET(kARM));
const T* inst_data{};
const T* a_data{}; Tensor inst_host_tensor;
Tensor a_host_tensor; inst_host_tensor.Resize(inst_tensor->dims());
a_host_tensor.Resize(a_tensor->dims()); switch (inst_tensor->target()) {
switch (a_tensor->target()) {
case TARGET(kX86): case TARGET(kX86):
case TARGET(kHost): case TARGET(kHost):
case TARGET(kARM): case TARGET(kARM):
a_data = static_cast<const T*>(a_tensor->raw_data()); inst_data = static_cast<const T*>(inst_tensor->raw_data());
break; break;
#ifdef LITE_WITH_XPU #ifdef LITE_WITH_XPU
case TARGET(kXPU): case TARGET(kXPU):
CopySync<TARGET(kXPU)>(a_host_tensor.mutable_data<T>(), CopySync<TARGET(kXPU)>(inst_host_tensor.mutable_data<T>(),
a_tensor->raw_data(), inst_tensor->raw_data(),
sizeof(T) * a_tensor->dims().production(), sizeof(T) * inst_tensor->dims().production(),
IoDirection::DtoH); IoDirection::DtoH);
a_data = a_host_tensor.data<T>(); inst_data = inst_host_tensor.data<T>();
break; break;
#endif #endif
...@@ -174,50 +154,50 @@ bool TestCase::CheckTensorPrecision(const Tensor* a_tensor, ...@@ -174,50 +154,50 @@ bool TestCase::CheckTensorPrecision(const Tensor* a_tensor,
LOG(FATAL) << "Not supported"; LOG(FATAL) << "Not supported";
} }
CHECK(a_data); CHECK(inst_data);
const T* b_data = static_cast<const T*>(b_tensor->raw_data()); const T* base_data = static_cast<const T*>(base_tensor->raw_data());
bool success = true; bool success = true;
for (int i = 0; i < a_tensor->dims().production(); i++) { for (int i = 0; i < inst_tensor->dims().production(); i++) {
EXPECT_NEAR(a_data[i], b_data[i], abs_error); EXPECT_NEAR(inst_data[i], base_data[i], abs_error);
if (fabsf(a_data[i] - b_data[i]) > abs_error) { if (fabsf(inst_data[i] - base_data[i]) > abs_error) {
success = false; success = false;
} }
} }
return success; return success;
} }
bool TestCase::CheckPrecision(const Tensor* a_tensor, bool TestCase::CheckPrecision(const Tensor* inst_tensor,
const Tensor* b_tensor, const Tensor* base_tensor,
float abs_error, float abs_error,
PrecisionType precision_type) { PrecisionType precision_type) {
PrecisionType precision_type_t = precision_type; PrecisionType precision_type_t = precision_type;
if (precision_type == PRECISION(kAny)) { if (precision_type == PRECISION(kAny)) {
precision_type_t = b_tensor->precision(); precision_type_t = base_tensor->precision();
} }
CHECK(precision_type_t == b_tensor->precision()) CHECK(precision_type_t == base_tensor->precision())
<< "arg precision type and base tensor precision type are not matched! " << "arg precision type and base tensor precision type are not matched! "
"arg precision type is: " "arg precision type is: "
<< PrecisionToStr(precision_type) << ", base tensor precision type is: " << PrecisionToStr(precision_type) << ", base tensor precision type is: "
<< PrecisionToStr(b_tensor->precision()); << PrecisionToStr(base_tensor->precision());
CHECK(a_tensor->precision() == b_tensor->precision()) CHECK(inst_tensor->precision() == base_tensor->precision())
<< "real tensor precision type and base tensor precision type are not " << "real tensor precision type and base tensor precision type are not "
"matched! real tensor precision type is: " "matched! real tensor precision type is: "
<< PrecisionToStr(a_tensor->precision()) << PrecisionToStr(inst_tensor->precision())
<< ", base tensor precision type is: " << ", base tensor precision type is: "
<< PrecisionToStr(b_tensor->precision()); << PrecisionToStr(base_tensor->precision());
switch (precision_type_t) { switch (precision_type_t) {
case PRECISION(kFloat): case PRECISION(kFloat):
return CheckTensorPrecision<float>(a_tensor, b_tensor, abs_error); return CheckTensorPrecision<float>(inst_tensor, base_tensor, abs_error);
case PRECISION(kInt8): case PRECISION(kInt8):
return CheckTensorPrecision<int8_t>(a_tensor, b_tensor, abs_error); return CheckTensorPrecision<int8_t>(inst_tensor, base_tensor, abs_error);
case PRECISION(kInt32): case PRECISION(kInt32):
return CheckTensorPrecision<int32_t>(a_tensor, b_tensor, abs_error); return CheckTensorPrecision<int32_t>(inst_tensor, base_tensor, abs_error);
case PRECISION(kInt64): case PRECISION(kInt64):
return CheckTensorPrecision<int64_t>(a_tensor, b_tensor, abs_error); return CheckTensorPrecision<int64_t>(inst_tensor, base_tensor, abs_error);
case PRECISION(kBool): case PRECISION(kBool):
return CheckTensorPrecision<bool>(a_tensor, b_tensor, abs_error); return CheckTensorPrecision<bool>(inst_tensor, base_tensor, abs_error);
default: default:
LOG(FATAL) << "not support type: " << PrecisionToStr(precision_type); LOG(FATAL) << "not support type: " << PrecisionToStr(precision_type);
return false; return false;
...@@ -229,24 +209,24 @@ bool TestCase::CheckPrecision(const std::string& var_name, ...@@ -229,24 +209,24 @@ bool TestCase::CheckPrecision(const std::string& var_name,
PrecisionType precision_type) { PrecisionType precision_type) {
bool success = true; bool success = true;
if (inst_scope_->FindVar(var_name)->IsType<Tensor>()) { if (inst_scope_->FindVar(var_name)->IsType<Tensor>()) {
auto a_tensor = inst_scope_->FindTensor(var_name); auto inst_tensor = inst_scope_->FindTensor(var_name);
auto b_tensor = base_scope_->FindTensor(var_name); auto base_tensor = base_scope_->FindTensor(var_name);
success = success && success =
CheckPrecision(a_tensor, b_tensor, abs_error, precision_type); success &&
CheckPrecision(inst_tensor, base_tensor, abs_error, precision_type);
} else if (inst_scope_->FindVar(var_name)->IsType<std::vector<Tensor>>()) { } else if (inst_scope_->FindVar(var_name)->IsType<std::vector<Tensor>>()) {
auto a_tensor_array = auto inst_tensor_list = inst_scope_->FindMutableTensorList(var_name);
inst_scope_->FindVar(var_name)->GetMutable<std::vector<Tensor>>(); auto base_tensor_list = base_scope_->FindMutableTensorList(var_name);
auto b_tensor_array = CHECK_EQ(inst_tensor_list->size(), base_tensor_list->size());
base_scope_->FindVar(var_name)->GetMutable<std::vector<Tensor>>(); for (size_t i = 0; i < inst_tensor_list->size(); i++) {
CHECK_EQ(a_tensor_array->size(), b_tensor_array->size()); Tensor* inst_tensor = &(inst_tensor_list->at(i));
for (size_t i = 0; i < a_tensor_array->size(); i++) { Tensor* base_tensor = &(base_tensor_list->at(i));
Tensor* a_tensor = &(a_tensor_array->at(i)); if (inst_tensor->dims().size() == 0 && base_tensor->dims().size() == 0) {
Tensor* b_tensor = &(b_tensor_array->at(i));
if (a_tensor->dims().size() == 0 && b_tensor->dims().size() == 0) {
continue; continue;
} }
success = success && success =
CheckPrecision(a_tensor, b_tensor, abs_error, precision_type); success &&
CheckPrecision(inst_tensor, base_tensor, abs_error, precision_type);
} }
} else { } else {
LOG(FATAL) << "unsupported var type"; LOG(FATAL) << "unsupported var type";
......
...@@ -40,13 +40,15 @@ namespace arena { ...@@ -40,13 +40,15 @@ namespace arena {
class TestCase { class TestCase {
public: public:
explicit TestCase(const Place& place, const std::string& alias) explicit TestCase(const Place& place, const std::string& alias)
: place_(place), scope_(new Scope), alias_(alias) { : place_(place),
alias_(alias),
inst_scope_(new Scope),
base_scope_(new Scope) {
ctx_ = ContextScheduler::Global().NewContext(place_.target); ctx_ = ContextScheduler::Global().NewContext(place_.target);
} }
virtual ~TestCase(); virtual ~TestCase();
void Prepare() { void Prepare() {
PrepareScopes();
PrepareData(); PrepareData();
op_desc_.reset(new cpp::OpDesc); op_desc_.reset(new cpp::OpDesc);
PrepareOpDesc(op_desc_.get()); PrepareOpDesc(op_desc_.get());
...@@ -91,16 +93,15 @@ class TestCase { ...@@ -91,16 +93,15 @@ class TestCase {
// kernel registry. // kernel registry.
void CheckKernelConsistWithDefinition() {} void CheckKernelConsistWithDefinition() {}
Scope& scope() { return *scope_; } Scope* baseline_scope() { return base_scope_.get(); }
Scope* inst_scope() { return inst_scope_.get(); }
Scope* baseline_scope() { return base_scope_; }
Scope* inst_scope() { return inst_scope_; }
protected: protected:
// Prepare inputs in scope() for Tester. // Prepare inputs in scope() for Tester.
virtual void PrepareData() = 0; virtual void PrepareData() = 0;
/// Prepare a tensor in host. The tensors will be created in scope_. /// Prepare a tensor in host. The tensors will be created both in base_scope_
/// and inst_scope_.
/// Need to specify the targets other than X86 or ARM. /// Need to specify the targets other than X86 or ARM.
template <typename T> template <typename T>
void SetCommonTensor(const std::string& var_name, void SetCommonTensor(const std::string& var_name,
...@@ -108,42 +109,47 @@ class TestCase { ...@@ -108,42 +109,47 @@ class TestCase {
const T* data, const T* data,
const LoD& lod = {}, const LoD& lod = {},
bool is_persistable = false) { bool is_persistable = false) {
auto* tensor = scope_->NewTensor(var_name); // Create and fill a input tensor with the given data for baseline
tensor->Resize(ddim); auto* base_tensor = base_scope_->NewTensor(var_name);
auto* d = tensor->mutable_data<T>(); base_tensor->Resize(ddim);
memcpy(d, data, ddim.production() * sizeof(T)); memcpy(base_tensor->mutable_data<T>(), data, ddim.production() * sizeof(T));
// set lod // set lod
if (!lod.empty()) *tensor->mutable_lod() = lod; if (!lod.empty()) *base_tensor->mutable_lod() = lod;
// set persistable // set persistable
tensor->set_persistable(is_persistable); base_tensor->set_persistable(is_persistable);
// Create a copy for instruction
auto* inst_tensor = inst_scope_->NewTensor(var_name);
inst_tensor->CopyDataFrom(*base_tensor);
} }
/// Prepare a tensor_array in host. The tensors will be created in scope_. /// Prepare a tensor_array in host. The tensors will be created in scope_.
/// Need to specify the targets other than X86 or ARM. /// Need to specify the targets other than X86 or ARM.
template <typename T> template <typename T>
void SetCommonTensorList(const std::string& var_name, void SetCommonTensorList(const std::string& var_name,
const std::vector<DDim>& array_tensor_dims, const std::vector<DDim>& ddims,
const std::vector<std::vector<T>>& datas, const std::vector<std::vector<T>>& datas,
const std::vector<LoD>& lods = {}) { const std::vector<LoD>& lods = {}) {
CHECK_EQ(array_tensor_dims.size(), datas.size()); // Create a tensor array for baseline, and a copy for instruction
CHECK_EQ(ddims.size(), datas.size());
if (!lods.empty()) { if (!lods.empty()) {
CHECK_EQ(array_tensor_dims.size(), lods.size()); CHECK_EQ(ddims.size(), lods.size());
} }
auto* tensor_array = auto* base_tensor_list = base_scope_->NewTensorList(var_name);
scope_->Var(var_name)->GetMutable<std::vector<Tensor>>(); auto* inst_tensor_list = inst_scope_->NewTensorList(var_name);
for (int i = 0; i < array_tensor_dims.size(); i++) { for (int i = 0; i < ddims.size(); i++) {
Tensor tmp; Tensor item;
tmp.Resize(array_tensor_dims[i]); item.Resize(ddims[i]);
auto* tmp_data = tmp.mutable_data<T>(); memcpy(item.mutable_data<T>(),
memcpy(tmp_data,
datas[i].data(), datas[i].data(),
array_tensor_dims[i].production() * sizeof(T)); ddims[i].production() * sizeof(T));
if (!lods.empty()) { if (!lods.empty()) {
tmp.set_lod(lods[i]); item.set_lod(lods[i]);
} }
tensor_array->push_back(tmp); base_tensor_list->push_back(item);
inst_tensor_list->push_back(item);
} }
} }
...@@ -157,11 +163,6 @@ class TestCase { ...@@ -157,11 +163,6 @@ class TestCase {
std::unique_ptr<KernelContext> ctx_; std::unique_ptr<KernelContext> ctx_;
void CreateInstruction(); void CreateInstruction();
void PrepareScopes() {
inst_scope_ = &scope_->NewScope();
base_scope_ = &scope_->NewScope();
}
// Check shape // Check shape
// TODO(Superjomn) Move this method to utils or DDim? // TODO(Superjomn) Move this method to utils or DDim?
bool ShapeEquals(const DDim& a, const DDim& b) { bool ShapeEquals(const DDim& a, const DDim& b) {
...@@ -172,25 +173,23 @@ class TestCase { ...@@ -172,25 +173,23 @@ class TestCase {
return true; return true;
} }
/// Copy the input tensors to target devices needed by the instruction. // Copy the host tensors to the device tensors if needed by the instruction.
void PrepareInputsForInstruction(); void PrepareInputsForInstruction();
// Create output tensors and variables. // Create output tensors and variables.
void PrepareOutputsForInstruction() { void PrepareOutputsForInstruction() {
for (auto x : op_desc().output_vars()) { for (auto x : op_desc().output_vars()) {
inst_scope_->NewTensor(x); inst_scope_->Var(x);
base_scope_->NewTensor(x);
} }
} }
private: private:
Place place_; Place place_;
std::shared_ptr<Scope> scope_;
std::string alias_; std::string alias_;
// The workspace for the Instruction. // The workspace for the Instruction.
Scope* inst_scope_{}; std::shared_ptr<Scope> inst_scope_;
// The workspace for the baseline implementation. // The workspace for the baseline implementation.
Scope* base_scope_{}; std::shared_ptr<Scope> base_scope_;
std::unique_ptr<cpp::OpDesc> op_desc_; std::unique_ptr<cpp::OpDesc> op_desc_;
std::unique_ptr<Instruction> instruction_; std::unique_ptr<Instruction> instruction_;
}; };
......
...@@ -62,19 +62,36 @@ class Scope final { ...@@ -62,19 +62,36 @@ class Scope final {
// Create a Tensor variable. This will create a new Variable called `name`. // Create a Tensor variable. This will create a new Variable called `name`.
Tensor* NewTensor(const std::string& name) { Tensor* NewTensor(const std::string& name) {
auto* var = Var(name); auto* var = Var(name);
return var->GetMutable<TensorLite>(); return var->GetMutable<Tensor>();
} }
const Tensor* FindTensor(const std::string& name) { const Tensor* FindTensor(const std::string& name) {
auto* var = FindVar(name); auto* var = FindVar(name);
if (!var) return nullptr; if (!var) return nullptr;
return &var->Get<TensorLite>(); return &var->Get<Tensor>();
} }
Tensor* FindMutableTensor(const std::string& name) { Tensor* FindMutableTensor(const std::string& name) {
auto* var = FindVar(name); auto* var = FindVar(name);
if (!var) return nullptr; if (!var) return nullptr;
return var->GetMutable<TensorLite>(); return var->GetMutable<Tensor>();
}
std::vector<Tensor>* NewTensorList(const std::string& name) {
auto* var = Var(name);
return var->GetMutable<std::vector<Tensor>>();
}
const std::vector<Tensor>* FindTensorList(const std::string& name) {
auto* var = FindVar(name);
if (!var) return nullptr;
return &var->Get<std::vector<Tensor>>();
}
std::vector<Tensor>* FindMutableTensorList(const std::string& name) {
auto* var = FindVar(name);
if (!var) return nullptr;
return var->GetMutable<std::vector<Tensor>>();
} }
private: private:
......
...@@ -70,9 +70,7 @@ class BoxClipComputeTester : public arena::TestCase { ...@@ -70,9 +70,7 @@ class BoxClipComputeTester : public arena::TestCase {
float sign = i % 3 == 0 ? -1.0f : 1.0f; float sign = i % 3 == 0 ? -1.0f : 1.0f;
input_data[i] = sign * static_cast<float>((i * 7) % 20); input_data[i] = sign * static_cast<float>((i * 7) % 20);
} }
SetCommonTensor(input_, input_dims_, input_data.data()); SetCommonTensor(input_, input_dims_, input_data.data(), input_lod_);
auto input_tensor = baseline_scope()->FindMutableTensor(input_);
input_tensor->set_lod(input_lod_);
std::vector<float> im_info_data{10, 10, 1, 15, 15, 1}; std::vector<float> im_info_data{10, 10, 1, 15, 15, 1};
SetCommonTensor(im_info_, im_info_dim_, im_info_data.data()); SetCommonTensor(im_info_, im_info_dim_, im_info_data.data());
......
...@@ -106,13 +106,11 @@ class RoiAlignComputeTester : public arena::TestCase { ...@@ -106,13 +106,11 @@ class RoiAlignComputeTester : public arena::TestCase {
} }
LOG(INFO) << "Read rois data. " << datas[0] << " " << datas.back(); LOG(INFO) << "Read rois data. " << datas[0] << " " << datas.back();
reader.close(); reader.close();
SetCommonTensor(rois_, dims, datas.data());
auto rois_tensor = baseline_scope()->FindMutableTensor(rois_);
std::vector<uint64_t> lod0({0, 152, 304}); std::vector<uint64_t> lod0({0, 152, 304});
LoD lod; LoD lod;
lod.push_back(lod0); lod.push_back(lod0);
rois_tensor->set_lod(lod); SetCommonTensor(rois_, dims, datas.data(), lod);
} }
}; };
......
...@@ -202,20 +202,15 @@ class SliceComputeTester : public arena::TestCase { ...@@ -202,20 +202,15 @@ class SliceComputeTester : public arena::TestCase {
DDim({static_cast<int64_t>(ends_.size())}), DDim({static_cast<int64_t>(ends_.size())}),
ends_.data()); ends_.data());
} else if (use_tensor_list_) { } else if (use_tensor_list_) {
Scope& scope_ = this->scope();
for (int i = 0; i < starts_.size(); ++i) { for (int i = 0; i < starts_.size(); ++i) {
auto* tensor = scope_.NewTensor("starts_tensor_list_" + SetCommonTensor("starts_tensor_list_" + paddle::lite::to_string(i),
paddle::lite::to_string(i)); DDim({1}),
tensor->Resize(DDim({1})); &starts_[i]);
auto* d = tensor->mutable_data<int>();
d[0] = starts_[i];
} }
for (int i = 0; i < ends_.size(); ++i) { for (int i = 0; i < ends_.size(); ++i) {
auto* tensor = SetCommonTensor("ends_tensor_list_" + paddle::lite::to_string(i),
scope_.NewTensor("ends_tensor_list_" + paddle::lite::to_string(i)); DDim({1}),
tensor->Resize(DDim({1})); &ends_[i]);
auto* d = tensor->mutable_data<int>();
d[0] = ends_[i];
} }
} }
} }
......
...@@ -103,7 +103,7 @@ TEST(Softmax, precision) { ...@@ -103,7 +103,7 @@ TEST(Softmax, precision) {
#if defined(LITE_WITH_NPU) #if defined(LITE_WITH_NPU)
place = TARGET(kNPU); place = TARGET(kNPU);
abs_error = 4e-3; // Using fp16 in NPU abs_error = 4e-3; // Using fp16 in NPU
#elif defined(LITE_WITH_XPU) && defined(LITE_WITH_XTCL) #elif defined(LITE_WITH_XPU)
place = TARGET(kXPU); place = TARGET(kXPU);
#else #else
return; return;
...@@ -111,8 +111,12 @@ TEST(Softmax, precision) { ...@@ -111,8 +111,12 @@ TEST(Softmax, precision) {
for (auto x_dims : for (auto x_dims :
std::vector<std::vector<int64_t>>{{1, 2, 3, 4}, {2, 3, 4}, {3, 4}}) { std::vector<std::vector<int64_t>>{{1, 2, 3, 4}, {2, 3, 4}, {3, 4}}) {
for (auto axis : {-1, 0, 1, 2, 3}) { int ndims = x_dims.size();
if (axis >= x_dims.size()) continue; for (int axis = -1; axis < ndims; axis++) {
#if defined(LITE_WITH_XPU)
if (axis != -1 && axis != ndims - 1)
continue; // -1 and dims.size() - 1 are only supported by XPU
#endif
std::unique_ptr<arena::TestCase> tester( std::unique_ptr<arena::TestCase> tester(
new SoftmaxComputeTest(place, "def", DDim(x_dims), axis)); new SoftmaxComputeTest(place, "def", DDim(x_dims), axis));
arena::Arena arena(std::move(tester), place, abs_error); arena::Arena arena(std::move(tester), place, abs_error);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册