未验证 提交 69f9cea0 编写于 作者: H hong19860320 提交者: GitHub

[XPU] Fix the deps of unit tests (#3807)

上级 db7639ca
......@@ -39,7 +39,7 @@ else()
endif()
find_library(XPU_SDK_XPU_RT_FILE NAMES xpurt
PATHS ${XPU_SDK_ROOT}/XTDK/shlib
PATHS ${XPU_SDK_ROOT}/XTDK/runtime/shlib
NO_DEFAULT_PATH)
if(NOT XPU_SDK_XPU_RT_FILE)
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "lite/core/arena/framework.h"
#include <set>
#include "lite/core/context.h"
#include "lite/operators/subgraph_op.h"
......@@ -22,7 +23,14 @@ namespace arena {
void TestCase::CreateInstruction() {
std::shared_ptr<lite::OpLite> op = nullptr;
if (place_.target == TARGET(kNPU) || place_.target == TARGET(kXPU)) {
static const std::set<TargetType> subgraph_op_supported_targets(
{TARGET(kNPU), TARGET(kXPU)});
bool enable_subgraph_op = subgraph_op_supported_targets.find(place_.target) !=
subgraph_op_supported_targets.end();
#if defined(LITE_WITH_XPU) && !defined(LITE_WITH_XTCL)
enable_subgraph_op = false; // Use XPU kernel directly if XTCL is disabled.
#endif
if (enable_subgraph_op) {
// Create a new block desc to wrap the original op desc
int sub_block_idx = 0;
auto sub_block_desc = new cpp::BlockDesc();
......@@ -91,7 +99,8 @@ void TestCase::PrepareInputsForInstruction() {
/// alloc memory and then copy data there.
if (param_type->type->IsTensor()) {
const auto* shared_tensor = scope_->FindTensor(var);
auto* target_tensor = inst_scope_->NewTensor(var);
auto* target_tensor =
inst_scope_->LocalVar(var)->GetMutable<Tensor>();
CHECK(!shared_tensor->dims().empty()) << "shared_tensor is empty yet";
target_tensor->Resize(shared_tensor->dims());
TargetCopy(param_type->type->target(),
......@@ -103,7 +112,7 @@ void TestCase::PrepareInputsForInstruction() {
const auto* shared_tensor_array =
scope_->FindVar(var)->GetMutable<std::vector<Tensor>>();
auto* target_tensor_array =
inst_scope_->Var(var)->GetMutable<std::vector<Tensor>>();
inst_scope_->LocalVar(var)->GetMutable<std::vector<Tensor>>();
CHECK(!shared_tensor_array->empty())
<< "shared_tensor_array is empty yet";
target_tensor_array->resize(shared_tensor_array->size());
......@@ -142,12 +151,23 @@ bool TestCase::CheckTensorPrecision(const Tensor* a_tensor,
b_tensor->target() == TARGET(kARM));
const T* a_data{};
Tensor a_host_tensor;
a_host_tensor.Resize(a_tensor->dims());
switch (a_tensor->target()) {
case TARGET(kX86):
case TARGET(kHost):
case TARGET(kARM):
a_data = static_cast<const T*>(a_tensor->raw_data());
break;
#ifdef LITE_WITH_XPU
case TARGET(kXPU):
CopySync<TARGET(kXPU)>(a_host_tensor.mutable_data<T>(),
a_tensor->raw_data(),
sizeof(T) * a_tensor->dims().production(),
IoDirection::DtoH);
a_data = a_host_tensor.data<T>();
break;
#endif
default:
// Before compare, need to copy data from `target` device to host.
......
......@@ -140,6 +140,11 @@ void TargetCopy(TargetType target, void* dst, const void* src, size_t size) {
dst, src, size, IoDirection::HtoD);
break;
#endif
#ifdef LITE_WITH_XPU
case TargetType::kXPU:
TargetWrapperXPU::MemcpySync(dst, src, size, IoDirection::HtoD);
break;
#endif
#ifdef LITE_WITH_OPENCL
case TargetType::kOpenCL:
TargetWrapperCL::MemcpySync(dst, src, size, IoDirection::DtoD);
......
......@@ -97,6 +97,11 @@ void CopySync(void* dst, const void* src, size_t size, IoDirection dir) {
case TARGET(kBM):
TargetWrapper<TARGET(kBM)>::MemcpySync(dst, src, size, dir);
break;
#endif
#ifdef LITE_WITH_XPU
case TARGET(kXPU):
TargetWrapperXPU::MemcpySync(dst, src, size, dir);
break;
#endif
default:
LOG(FATAL)
......
......@@ -3,17 +3,19 @@ if(LITE_WITH_XPU)
DEPS mir_passes lite_api_test_helper paddle_api_full paddle_api_light gflags utils
${ops} ${host_kernels} ${x86_kernels} ${xpu_kernels}
ARGS --model_dir=${LITE_MODEL_DIR}/resnet50)
add_dependencies(test_resnet50_lite_xpu extern_lite_download_resnet50_tar_gz)
lite_cc_test(test_ernie_lite_xpu SRCS test_ernie_lite_xpu.cc
DEPS mir_passes lite_api_test_helper paddle_api_full paddle_api_light gflags utils
${ops} ${host_kernels} ${x86_kernels} ${xpu_kernels}
ARGS --model_dir=${LITE_MODEL_DIR}/ernie)
add_dependencies(test_ernie_lite_xpu extern_lite_download_ernie_tar_gz)
lite_cc_test(test_bert_lite_xpu SRCS test_bert_lite_xpu.cc
DEPS mir_passes lite_api_test_helper paddle_api_full paddle_api_light gflags utils
${ops} ${host_kernels} ${x86_kernels} ${xpu_kernels}
ARGS --model_dir=${LITE_MODEL_DIR}/bert)
add_dependencies(test_bert_lite_xpu extern_lite_download_bert_tar_gz)
if(WITH_TESTING)
add_dependencies(test_resnet50_lite_xpu extern_lite_download_resnet50_tar_gz)
add_dependencies(test_ernie_lite_xpu extern_lite_download_ernie_tar_gz)
add_dependencies(test_bert_lite_xpu extern_lite_download_bert_tar_gz)
endif()
endif()
if(LITE_WITH_RKNPU)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册