未验证 提交 06d058fe 编写于 作者: H hong19860320 提交者: GitHub

[LITE][XPU] initial support for XPU (#2202)

* Initial support for XPU
* Fix compiling errors of XPU
* Move XPU op kernel bridges from backends to kernels to fix deps order
* Change the namespace and directory of XPU bridges
* Add XPU SDK
* Fix header files and namespace of XPU SDK
* Add unit tests for relu and conv2d ops
* Restore the modification of paddle_api_test
* Supports simple model which contains only a relu layer
* Add compiling scripts for XPU
* Fix compiling errors of XPU
* Add comments for XPU LoadModel and BuildModel
上级 b987ee38
......@@ -59,6 +59,7 @@ lite_option(LITE_WITH_CUDA "Enable CUDA in lite mode" OFF)
lite_option(LITE_WITH_X86 "Enable X86 in lite mode" ON)
lite_option(LITE_WITH_ARM "Enable ARM in lite mode" OFF)
lite_option(LITE_WITH_NPU "Enable NPU in lite mode" OFF)
lite_option(LITE_WITH_XPU "Enable XPU in lite mode" OFF)
lite_option(LITE_WITH_OPENMP "Enable OpenMP in lite framework" ON)
lite_option(LITE_WITH_OPENCL "Enable OpenCL support in lite" OFF)
lite_option(LITE_WITH_FPGA "Enable FPGA support in lite" OFF)
......@@ -184,6 +185,10 @@ if(LITE_WITH_CUDA)
include(cuda)
endif()
if(LITE_WITH_XPU)
include(xpu)
endif()
include(generic) # simplify cmake module
include(ccache) # set ccache for compilation
include(util) # set unittest and link libs
......
......@@ -127,6 +127,10 @@ if (LITE_WITH_NPU)
add_definitions("-DLITE_WITH_NPU")
endif()
if (LITE_WITH_XPU)
add_definitions("-DLITE_WITH_XPU")
endif()
if (LITE_WITH_OPENCL)
add_definitions("-DLITE_WITH_OPENCL")
endif()
......
......@@ -22,7 +22,7 @@ endfunction()
function (lite_deps TARGET)
set(options "")
set(oneValueArgs "")
set(multiValueArgs DEPS X86_DEPS CUDA_DEPS ARM_DEPS PROFILE_DEPS LIGHT_DEPS HVY_DEPS CL_DEPS FPGA_DEPS NPU_DEPS ARGS)
set(multiValueArgs DEPS X86_DEPS CUDA_DEPS ARM_DEPS PROFILE_DEPS LIGHT_DEPS HVY_DEPS CL_DEPS FPGA_DEPS NPU_DEPS XPU_DEPS ARGS)
cmake_parse_arguments(lite_deps "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
set(deps ${lite_deps_DEPS})
......@@ -83,6 +83,12 @@ function (lite_deps TARGET)
endforeach(var)
endif()
if (LITE_WITH_XPU)
foreach(var ${lite_deps_XPU_DEPS})
set(deps ${deps} ${var})
endforeach(var)
endif()
set(${TARGET} ${deps} PARENT_SCOPE)
endfunction()
......@@ -107,7 +113,7 @@ file(WRITE ${offline_lib_registry_file} "") # clean
function(lite_cc_library TARGET)
set(options SHARED shared STATIC static MODULE module)
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS X86_DEPS CUDA_DEPS CL_DEPS NPU_DEPS ARM_DEPS FPGA_DEPS PROFILE_DEPS LIGHT_DEPS
set(multiValueArgs SRCS DEPS X86_DEPS CUDA_DEPS CL_DEPS NPU_DEPS XPU_DEPS ARM_DEPS FPGA_DEPS PROFILE_DEPS LIGHT_DEPS
HVY_DEPS EXCLUDE_COMPILE_DEPS ARGS)
cmake_parse_arguments(args "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
......@@ -118,6 +124,7 @@ function(lite_cc_library TARGET)
CUDA_DEPS ${args_CUDA_DEPS}
CL_DEPS ${args_CL_DEPS}
NPU_DEPS ${args_NPU_DEPS}
XPU_DEPS ${args_XPU_DEPS}
ARM_DEPS ${args_ARM_DEPS}
FPGA_DEPS ${args_FPGA_DEPS}
PROFILE_DEPS ${args_PROFILE_DEPS}
......@@ -236,6 +243,7 @@ set(arm_kernels CACHE INTERNAL "arm kernels")
set(x86_kernels CACHE INTERNAL "x86 kernels")
set(fpga_kernels CACHE INTERNAL "fpga kernels")
set(npu_kernels CACHE INTERNAL "npu kernels")
set(xpu_kernels CACHE INTERNAL "xpu kernels")
set(opencl_kernels CACHE INTERNAL "opencl kernels")
set(host_kernels CACHE INTERNAL "host kernels")
......@@ -305,6 +313,12 @@ function(add_kernel TARGET device level)
endif()
set(npu_kernels "${npu_kernels};${TARGET}" CACHE INTERNAL "")
endif()
if ("${device}" STREQUAL "XPU")
if (NOT LITE_WITH_XPU)
return()
endif()
set(xpu_kernels "${xpu_kernels};${TARGET}" CACHE INTERNAL "")
endif()
if ("${device}" STREQUAL "FPGA")
if (NOT LITE_WITH_FPGA)
return()
......@@ -338,6 +352,7 @@ function(add_kernel TARGET device level)
lite_cc_library(${TARGET} SRCS ${args_SRCS}
DEPS ${args_DEPS}
X86_DEPS ${args_X86_DEPS}
XPU_DEPS ${args_XPU_DEPS}
CUDA_DEPS ${args_CUDA_DEPS}
CL_DEPS ${args_CL_DEPS}
ARM_DEPS ${args_ARM_DEPS}
......@@ -386,6 +401,7 @@ function(add_operator TARGET level)
lite_cc_library(${TARGET} SRCS ${args_SRCS}
DEPS ${args_DEPS}
X86_DEPS ${args_X86_DEPS}
XPU_DEPS ${args_XPU_DEPS}
CUDA_DEPS ${args_CUDA_DEPS}
CL_DEPS ${args_CL_DEPS}
ARM_DEPS ${args_ARM_DEPS}
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
if(NOT LITE_WITH_XPU)
return()
endif()
if(NOT DEFINED XPU_SDK_ROOT)
set(XPU_SDK_ROOT $ENV{XPU_SDK_ROOT})
if(NOT XPU_SDK_ROOT)
message(FATAL_ERROR "Must set XPU_SDK_ROOT or env XPU_SDK_ROOT when LITE_WITH_XPU=ON")
endif()
endif()
message(STATUS "XPU_SDK_ROOT: ${XPU_SDK_ROOT}")
find_path(XPU_SDK_INC NAMES xtcl.h
PATHS ${XPU_SDK_ROOT}/XTCL/include/xtcl NO_DEFAULT_PATH)
if(NOT XPU_SDK_INC)
message(FATAL_ERROR "Can not find xtcl.h in ${XPU_SDK_ROOT}/include")
endif()
include_directories("${XPU_SDK_ROOT}/XTCL/include")
include_directories("${XPU_SDK_ROOT}/XTDK/include")
find_library(XPU_SDK_XTCL_FILE NAMES xtcl
PATHS ${XPU_SDK_ROOT}/XTCL/so)
if(NOT XPU_SDK_XTCL_FILE)
message(FATAL_ERROR "Can not find XPU XTCL Library in ${XPU_SDK_ROOT}")
else()
message(STATUS "Found XPU XTCL Library: ${XPU_SDK_XTCL_FILE}")
add_library(xpu_sdk_xtcl SHARED IMPORTED GLOBAL)
set_property(TARGET xpu_sdk_xtcl PROPERTY IMPORTED_LOCATION ${XPU_SDK_XTCL_FILE})
endif()
find_library(XPU_SDK_TVM_FILE NAMES tvm
PATHS ${XPU_SDK_ROOT}/XTCL/so)
if(NOT XPU_SDK_TVM_FILE)
message(FATAL_ERROR "Can not find XPU TVM Library in ${XPU_SDK_ROOT}")
else()
message(STATUS "Found XPU TVM Library: ${XPU_SDK_TVM_FILE}")
add_library(xpu_sdk_tvm SHARED IMPORTED GLOBAL)
set_property(TARGET xpu_sdk_tvm PROPERTY IMPORTED_LOCATION ${XPU_SDK_TVM_FILE})
endif()
find_library(XPU_SDK_XPU_API_FILE NAMES xpuapi
PATHS ${XPU_SDK_ROOT}/XTDK/shlib)
if(NOT XPU_SDK_XPU_API_FILE)
message(FATAL_ERROR "Can not find XPU API Library in ${XPU_SDK_ROOT}")
else()
message(STATUS "Found XPU API Library: ${XPU_SDK_XPU_API_FILE}")
add_library(xpu_sdk_xpu_api SHARED IMPORTED GLOBAL)
set_property(TARGET xpu_sdk_xpu_api PROPERTY IMPORTED_LOCATION ${XPU_SDK_XPU_API_FILE})
endif()
find_library(XPU_SDK_XPU_RT_FILE NAMES xpurt
PATHS ${XPU_SDK_ROOT}/XTDK/shlib)
if(NOT XPU_SDK_XPU_RT_FILE)
message(FATAL_ERROR "Can not find XPU RT Library in ${XPU_SDK_ROOT}")
else()
message(STATUS "Found XPU RT Library: ${XPU_SDK_XPU_RT_FILE}")
add_library(xpu_sdk_xpu_rt SHARED IMPORTED GLOBAL)
set_property(TARGET xpu_sdk_xpu_rt PROPERTY IMPORTED_LOCATION ${XPU_SDK_XPU_RT_FILE})
endif()
find_library(XPU_SDK_XPU_JITC_FILE NAMES xpujitc
PATHS ${XPU_SDK_ROOT}/XTDK/shlib)
if(NOT XPU_SDK_XPU_JITC_FILE)
message(FATAL_ERROR "Can not find XPU JITC Library in ${XPU_SDK_ROOT}")
else()
message(STATUS "Found XPU JITC Library: ${XPU_SDK_XPU_JITC_FILE}")
add_library(xpu_sdk_xpu_jitc SHARED IMPORTED GLOBAL)
set_property(TARGET xpu_sdk_xpu_jitc PROPERTY IMPORTED_LOCATION ${XPU_SDK_XPU_JITC_FILE})
endif()
find_library(XPU_SDK_LLVM_FILE NAMES LLVM-8
PATHS ${XPU_SDK_ROOT}/XTDK/shlib)
if(NOT XPU_SDK_LLVM_FILE)
message(FATAL_ERROR "Can not find LLVM Library in ${XPU_SDK_ROOT}")
else()
message(STATUS "Found XPU LLVM Library: ${XPU_SDK_LLVM_FILE}")
add_library(xpu_sdk_llvm SHARED IMPORTED GLOBAL)
set_property(TARGET xpu_sdk_llvm PROPERTY IMPORTED_LOCATION ${XPU_SDK_LLVM_FILE})
endif()
set(xpu_runtime_libs xpu_sdk_xtcl xpu_sdk_tvm xpu_sdk_xpu_api xpu_sdk_xpu_rt xpu_sdk_xpu_jitc xpu_sdk_llvm CACHE INTERNAL "xpu runtime libs")
set(xpu_builder_libs xpu_sdk_xtcl xpu_sdk_tvm xpu_sdk_xpu_api xpu_sdk_xpu_rt xpu_sdk_xpu_jitc xpu_sdk_llvm CACHE INTERNAL "xpu builder libs")
......@@ -6,6 +6,7 @@ message(STATUS "LITE_WITH_CUDA:\t${LITE_WITH_CUDA}")
message(STATUS "LITE_WITH_X86:\t${LITE_WITH_X86}")
message(STATUS "LITE_WITH_ARM:\t${LITE_WITH_ARM}")
message(STATUS "LITE_WITH_NPU:\t${LITE_WITH_NPU}")
message(STATUS "LITE_WITH_XPU:\t${LITE_WITH_XPU}")
message(STATUS "LITE_WITH_FPGA:\t${LITE_WITH_FPGA}")
message(STATUS "LITE_WITH_PROFILE:\t${LITE_WITH_PROFILE}")
......
......@@ -40,7 +40,8 @@ if (WITH_TESTING)
DEPS scope optimizer target_wrapper_host model_parser program
${ops} ${host_kernels}
CUDA_DEPS ${cuda_kernels}
X86_DEPS ${x86_kernels})
X86_DEPS ${x86_kernels}
XPU_DEPS ${xpu_kernels})
endif()
if(LITE_WITH_FPGA)
set(light_api_deps ${light_api_deps} ${fpga_deps})
......@@ -52,6 +53,7 @@ message(STATUS "get X86 kernels ${x86_kernels}")
message(STATUS "get Host kernels ${host_kernels}")
message(STATUS "get ARM kernels ${arm_kernels}")
message(STATUS "get NPU kernels ${npu_kernels}")
message(STATUS "get XPU kernels ${xpu_kernels}")
message(STATUS "get FPGA kernels ${fpga_kernels}")
# for full api
......@@ -64,6 +66,7 @@ if (NOT LITE_ON_TINY_PUBLISH)
X86_DEPS ${x86_kernels}
ARM_DEPS ${arm_kernels}
NPU_DEPS ${npu_kernels} ${npu_bridges} npu_pass
XPU_DEPS ${xpu_kernels} ${xpu_bridges} xpu_pass
CL_DEPS ${opencl_kenrels}
FPGA_DEPS ${fpga_kenrels})
endif()
......@@ -83,6 +86,7 @@ lite_cc_library(light_api SRCS light_api.cc
X86_DEPS ${x86_kernels}
ARM_DEPS ${arm_kernels}
NPU_DEPS ${npu_kernels}
XPU_DEPS ${xpu_kernels}
CL_DEPS ${opencl_kenrels}
FPGA_DEPS ${fpga_kenrels})
......@@ -97,6 +101,7 @@ if(WITH_TESTING)
X86_DEPS ${x86_kernels}
ARM_DEPS ${arm_kernels}
NPU_DEPS ${npu_kernels}
XPU_DEPS ${xpu_kernels}
CL_DEPS ${opencl_kernels}
FPGA_DEPS ${fpga_kernels}
EXCLUDE_COMPILE_DEPS "ON"
......@@ -224,6 +229,7 @@ lite_cc_test(test_apis SRCS apis_test.cc
DEPS cxx_api light_api ${ops} paddle_api_light
CL_DEPS ${opencl_kernels}
X86_DEPS ${x86_kernels}
XPU_DEPS ${xpu_kernels}
FPGA_DEPS ${fpga_kernels}
ARGS --model_dir=${LITE_MODEL_DIR}/lite_naive_model
--optimized_model=${LITE_MODEL_DIR}/lite_naive_model_opt SERIAL)
......@@ -251,6 +257,7 @@ lite_cc_test(test_paddle_api SRCS paddle_api_test.cc DEPS paddle_api_full paddle
${ops}
ARM_DEPS ${arm_kernels}
NPU_DEPS ${npu_kernels}
XPU_DEPS ${xpu_kernels}
CL_DEPS ${opencl_kernels}
X86_DEPS ${x86_kernels}
FPGA_DEPS ${fpga_kernels}
......@@ -265,6 +272,7 @@ if(NOT IOS)
${ops} ${host_kernels}
ARM_DEPS ${arm_kernels}
NPU_DEPS ${npu_kernels}
XPU_DEPS ${xpu_kernels}
CL_DEPS ${opencl_kernels}
FPGA_DEPS ${fpga_kernels}
X86_DEPS ${x86_kernels})
......@@ -272,6 +280,7 @@ if(NOT IOS)
${ops} ${host_kernels}
ARM_DEPS ${arm_kernels}
NPU_DEPS ${npu_kernels}
XPU_DEPS ${xpu_kernels}
CL_DEPS ${opencl_kernels}
FPGA_DEPS ${fpga_kernels}
X86_DEPS ${x86_kernels})
......
......@@ -46,8 +46,16 @@ std::string Place::DebugString() const {
}
const std::string& TargetToStr(TargetType target) {
static const std::string target2string[] = {
"unk", "host", "x86", "cuda", "arm", "opencl", "any", "fpga", "npu"};
static const std::string target2string[] = {"unk",
"host",
"x86",
"cuda",
"arm",
"opencl",
"any",
"fpga",
"npu",
"xpu"};
auto x = static_cast<int>(target);
CHECK_LT(x, static_cast<int>(TARGET(NUM)));
return target2string[x];
......@@ -84,7 +92,8 @@ const std::string& TargetRepr(TargetType target) {
"kOpenCL",
"kAny",
"kFPGA",
"kNPU"};
"kNPU",
"kXPU"};
auto x = static_cast<int>(target);
CHECK_LT(x, static_cast<int>(TARGET(NUM)));
return target2string[x];
......
......@@ -50,8 +50,9 @@ enum class TargetType : int {
kOpenCL = 5,
kFPGA = 7,
kNPU = 8,
kXPU = 9,
kAny = 6, // any target
NUM = 9, // number of fields.
NUM = 10, // number of fields.
};
enum class PrecisionType : int {
kUnk = 0,
......
......@@ -5,3 +5,4 @@ add_subdirectory(cuda)
add_subdirectory(fpga)
add_subdirectory(host)
add_subdirectory(npu)
add_subdirectory(xpu)
if(NOT LITE_WITH_XPU)
return()
endif()
lite_cc_library(xpu_runtime SRCS runtime.cc DEPS ${xpu_runtime_libs})
lite_cc_library(xpu_builder SRCS builder.cc DEPS ${xpu_builder_libs} xpu_runtime tensor op scope)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/xpu/builder.h"
#include <mutex> // NOLINT
#include <utility>
#include "lite/backends/xpu/runtime.h"
namespace paddle {
namespace lite {
namespace xpu {
bool HasInputArg(const OpInfo* op_info,
const Scope* scope,
const std::string& argname) {
auto iarg_names = op_info->input_argnames();
if (std::find(iarg_names.begin(), iarg_names.end(), argname) !=
iarg_names.end()) {
auto inputs = op_info->Input(argname);
if (inputs.empty()) {
return false;
}
auto var_name = inputs.front();
auto var = scope->FindVar(var_name);
return var != nullptr;
} else {
return false;
}
}
std::string UniqueName(const std::string& prefix) {
static std::mutex counter_mtx;
static std::unordered_map<std::string, int> counter_map;
std::unique_lock<std::mutex> counter_lck(counter_mtx);
int counter = 1;
auto it = counter_map.find(prefix);
if (it == counter_map.end()) {
counter_map[prefix] = counter;
} else {
counter = ++(it->second);
}
return prefix + "_" + std::to_string(counter);
}
xtcl::DataType CvtPrecisionType(PrecisionType in_type) {
xtcl::DataType out_type = ::xtcl::Float(32);
switch (in_type) {
case PRECISION(kFloat):
out_type = ::xtcl::Float(32);
break;
case PRECISION(kInt8):
out_type = ::xtcl::Int(8);
break;
case PRECISION(kInt32):
out_type = ::xtcl::Int(32);
break;
default:
LOG(FATAL) << "Can not convert precision type(" << PrecisionToStr(in_type)
<< ") from Lite to XPU";
break;
}
return out_type;
}
DLDataType CvtDataType(PrecisionType in_type) {
DLDataType out_type = {kDLFloat, 32, 1};
switch (in_type) {
case PRECISION(kFloat):
out_type = {kDLFloat, 32, 1};
break;
case PRECISION(kInt8):
out_type = {kDLInt, 8, 1};
break;
case PRECISION(kInt32):
out_type = {kDLInt, 32, 1};
break;
default:
LOG(FATAL) << "Can not convert data type(" << PrecisionToStr(in_type)
<< ") from Lite to XPU";
break;
}
return out_type;
}
xtcl::Array<xtcl::xIndexExpr> CvtShape(const std::vector<int>& in_shape) {
xtcl::Array<xtcl::xIndexExpr> out_shape;
for (auto dim : in_shape) {
out_shape.push_back(dim);
}
return out_shape;
}
xtcl::Array<xtcl::xIndexExpr> CvtShape(const std::vector<int64_t>& in_shape) {
return CvtShape(std::vector<int>(in_shape.begin(), in_shape.end()));
}
xtcl::Array<xtcl::xIndexExpr> CvtShape(const DDim& in_dims) {
return CvtShape(in_dims.Vectorize());
}
std::shared_ptr<xtcl::xNDArray> CvtTensor(lite::Tensor* in_tensor,
std::vector<int64_t> out_shape,
PrecisionType in_ptype,
DataLayoutType in_ltype) {
uint8_t* in_data = nullptr;
auto in_size = in_tensor->dims().production();
auto in_shape = in_tensor->dims().Vectorize();
if (out_shape.empty()) {
out_shape = in_shape;
}
int in_bytes;
if (in_ptype == PRECISION(kFloat)) {
in_data = reinterpret_cast<uint8_t*>(in_tensor->mutable_data<float>());
in_bytes = in_size * sizeof(float);
} else if (in_ptype == PRECISION(kInt32)) {
in_data = reinterpret_cast<uint8_t*>(in_tensor->mutable_data<int32_t>());
in_bytes = in_size * sizeof(int32_t);
} else if (in_ptype == PRECISION(kInt8)) {
in_data = reinterpret_cast<uint8_t*>(in_tensor->mutable_data<int8_t>());
in_bytes = in_size * sizeof(int8_t);
} else {
LOG(FATAL) << "Unknow precision type " << PrecisionToStr(in_ptype);
}
auto out_tensor = std::make_shared<xtcl::xNDArray>(
xtcl::xNDArray::Empty(out_shape, CvtDataType(in_ptype), {kDLCPU, 0}));
auto out_data =
reinterpret_cast<uint8_t*>(out_tensor->ToDLPack()->dl_tensor.data);
std::memcpy(out_data, in_data, in_bytes);
return out_tensor;
}
// Build the XPU subgraph to the XPU model, store the model data into the
// weight tensor of the graph op, and the model data will be loaded again
// by the graph computing kernel when the graph op is executed for inference.
// Due to the lack of XPU APIs for building and outputing the model data,
// the compiled XPU runtime object will be managed by the global variable
// 'DeviceInfo' and the key name for finding the runtime object will be
// stored in the weight tensor of graph op.
// TODO(hong19860320) Compile the XPU subgraph and output the compiled model
// data to the weight tensor of graph op.
bool BuildModel(
std::shared_ptr<xtcl::network::xNetworkBuilder> builder,
std::shared_ptr<xtcl::network::xTensorCompiler::ParamNDArrayMap> params,
std::vector<std::shared_ptr<xtcl::xExpr>>* outputs,
lite::Tensor* model) {
LOG(INFO) << "[XPU] Build Model.";
CHECK(builder != nullptr);
CHECK(outputs != nullptr);
CHECK_GT(outputs->size(), 0);
CHECK(model != nullptr);
// build graph and fill all of constant params
xtcl::xNetwork network = builder->FinalizeNetwork(*((*outputs)[0]));
auto target = xtcl::Target::Create("llvm");
auto compiler = xtcl::network::xTensorCompiler(network, target);
compiler.SetParams(*params); // set the data of constant tensors
compiler.Build();
// create and register runtime
auto runtime = std::make_shared<xtcl::network::xRuntimeInstance>(
compiler.CreateRuntimeInstance());
if (runtime == nullptr) {
LOG(WARNING) << "[XPU] Build Model failed!";
return false;
}
std::string name = UniqueName("xpu");
LOG(INFO) << "[XPU] Model Name: " << name;
DeviceInfo::Global().Insert(name, runtime);
model->Resize({static_cast<int64_t>(name.length() + 1)});
memcpy(model->mutable_data<int8_t>(),
reinterpret_cast<const int8_t*>(name.c_str()),
name.length() + 1);
return true;
}
} // namespace xpu
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <xtcl/xtcl.h>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/target_wrapper.h"
#include "lite/core/tensor.h"
namespace paddle {
namespace lite {
namespace xpu {
bool HasInputArg(const OpInfo* op_info,
const Scope* scope,
const std::string& argname);
std::string UniqueName(const std::string& prefix);
xtcl::DataType CvtPrecisionType(PrecisionType in_type);
DLDataType CvtDataType(PrecisionType in_type);
xtcl::Array<xtcl::xIndexExpr> CvtShape(const std::vector<int>& in_shape);
xtcl::Array<xtcl::xIndexExpr> CvtShape(const std::vector<int64_t>& in_shape);
xtcl::Array<xtcl::xIndexExpr> CvtShape(const DDim& in_dims);
std::shared_ptr<xtcl::xNDArray> CvtTensor(
Tensor* in_tensor,
std::vector<int64_t> out_shape = {},
PrecisionType in_ptype = PRECISION(kFloat),
DataLayoutType in_ltype = DATALAYOUT(kNCHW));
bool BuildModel(
std::shared_ptr<xtcl::network::xNetworkBuilder> builder,
std::shared_ptr<xtcl::network::xTensorCompiler::ParamNDArrayMap> params,
std::vector<std::shared_ptr<xtcl::xExpr>>* outputs,
lite::Tensor* model);
} // namespace xpu
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/xpu/runtime.h"
#include <vector>
#include "lite/utils/cp_logging.h"
namespace paddle {
namespace lite {
namespace xpu {
// Extract the model data and recover the XPU model for inference, the function
// is called by the graph computing kernel when the graph op is executed.
// Due to the lack of XPU APIs for loading and recovering the XPU model from
// memory, the key name is obtained from the weight tensor of graph op, to get
// the runtime object for inference from the global variable 'DeviceInfo'.
// TODO(hong19860320) Recover the XPU model from the weight tensor of graph op.
bool LoadModel(const lite::Tensor &model,
std::shared_ptr<xtcl::network::xRuntimeInstance> *runtime) {
LOG(INFO) << "[XPU] Load Model.";
CHECK_GT(model.dims().production(), 0);
std::string name(reinterpret_cast<const char *>(model.data<int8_t>()));
LOG(INFO) << "[XPU] Model Name: " << name;
CHECK(runtime != nullptr);
*runtime = DeviceInfo::Global().Find(name);
if (*runtime == nullptr) {
LOG(WARNING) << "[XPU] Load Model failed!";
return false;
}
return true;
}
} // namespace xpu
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <xtcl/xtcl.h>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include "lite/core/tensor.h"
namespace paddle {
namespace lite {
namespace xpu {
class DeviceInfo {
public:
static DeviceInfo& Global() {
static DeviceInfo x;
return x;
}
DeviceInfo() {}
void Insert(const std::string& name,
std::shared_ptr<xtcl::network::xRuntimeInstance> runtime) {
if (runtimes_.find(name) != runtimes_.end()) {
LOG(WARNING) << "[XPU] Model " << name << " already exists.";
return;
}
runtimes_.emplace(std::make_pair(name, runtime));
}
void Clear() { runtimes_.clear(); }
std::shared_ptr<xtcl::network::xRuntimeInstance> Find(
const std::string& name) const {
if (runtimes_.find(name) != runtimes_.end()) {
return runtimes_.at(name);
} else {
return nullptr;
}
}
private:
int device_id_{0};
std::string device_name_{"default"};
std::unordered_map<std::string,
std::shared_ptr<xtcl::network::xRuntimeInstance>>
runtimes_;
};
bool LoadModel(const lite::Tensor& model,
std::shared_ptr<xtcl::network::xRuntimeInstance>* runtime);
} // namespace xpu
} // namespace lite
} // namespace paddle
......@@ -35,7 +35,7 @@ lite_cc_library(device_info SRCS device_info.cc DEPS tensor)
if (LITE_WITH_ARM)
lite_cc_library(context SRCS context.cc DEPS tensor any device_info CL_DEPS cl_context gflags NPU_DEPS npu_runtime)
else()
lite_cc_library(context SRCS context.cc DEPS tensor any device_info eigen3 CL_DEPS cl_context gflags)
lite_cc_library(context SRCS context.cc DEPS tensor any device_info eigen3 CL_DEPS cl_context gflags XPU_DEPS xpu_runtime)
endif()
#-------------------------------------------- GET CODE META INFO ------------------------------------------
......
......@@ -5,6 +5,6 @@ endif()
lite_cc_library(arena_framework SRCS framework.cc DEPS program gtest)
if(NOT LITE_WITH_OPENCL AND (LITE_WITH_X86 OR LITE_WITH_ARM))
if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_XPU) AND (LITE_WITH_X86 OR LITE_WITH_ARM))
lite_cc_test(test_arena_framework SRCS framework_test.cc DEPS arena_framework ${x86_kernels} ${fpga_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
endif()
......@@ -28,6 +28,9 @@
#ifdef LITE_WITH_NPU
#include "lite/backends/npu/runtime.h"
#endif
#ifdef LITE_WITH_XPU
#include "lite/backends/xpu/runtime.h"
#endif
#include <map>
#include <memory>
......@@ -55,6 +58,7 @@ using X86Context = Context<TargetType::kX86>;
using CUDAContext = Context<TargetType::kCUDA>;
using ARMContext = Context<TargetType::kARM>;
using NPUContext = Context<TargetType::kNPU>;
using XPUContext = Context<TargetType::kXPU>;
using OpenCLContext = Context<TargetType::kOpenCL>;
using FPGAContext = Context<TargetType::kFPGA>;
......@@ -84,6 +88,20 @@ class Context<TargetType::kNPU> {
};
#endif
#ifdef LITE_WITH_XPU
template <>
class Context<TargetType::kXPU> {
public:
Context() {}
explicit Context(const NPUContext& ctx);
// NOTE: InitOnce should only be used by ContextScheduler
void InitOnce() {}
void CopySharedTo(XPUContext* ctx) {}
std::string name() const { return "XPUContext"; }
};
#endif
#ifdef LITE_WITH_ARM
template <>
class Context<TargetType::kARM> {
......@@ -340,6 +358,12 @@ class ContextScheduler {
&ctx->As<NPUContext>());
break;
#endif
#ifdef LITE_WITH_XPU
case TARGET(kXPU):
kernel_contexts_[TargetType::kXPU].As<XPUContext>().CopySharedTo(
&ctx->As<XPUContext>());
break;
#endif
#ifdef LITE_WITH_OPENCL
case TARGET(kOpenCL):
kernel_contexts_[TargetType::kOpenCL].As<OpenCLContext>().CopySharedTo(
......@@ -386,6 +410,9 @@ class ContextScheduler {
#endif
#ifdef LITE_WITH_NPU
InitContext<TargetType::kNPU, NPUContext>();
#endif
#ifdef LITE_WITH_XPU
InitContext<TargetType::kXPU, XPUContext>();
#endif
}
......
......@@ -53,6 +53,7 @@ void ExpandPlaces(std::set<Place>* places, const Place& place) {
TARGET(kARM),
TARGET(kOpenCL),
TARGET(kNPU),
TARGET(kXPU),
TARGET(kFPGA)});
static const Types<PrecisionType> precision_set(
{PRECISION(kFloat), PRECISION(kInt8), PRECISION(kFP16), PRECISION(kAny)});
......
......@@ -30,5 +30,21 @@ if(LITE_WITH_NPU)
endif()
endif()
if(LITE_WITH_XPU)
lite_cc_library(xpu_pass SRCS generate_xpu_program_pass.cc
DEPS mir_pass types context ${mir_fusers} ${xpu_bridges} ${xpu_builder_libs} graph_op subgraph_pass)
list(APPEND subgraph_passes xpu_pass)
lite_cc_test(test_xpu_pass SRCS generate_xpu_program_pass_test.cc
DEPS xpu_pass mir_passes paddle_api_full gflags
ARGS --model_dir=${LITE_MODEL_DIR}/mobilenet_v1
--optimized_model=${LITE_MODEL_DIR}/lite_npu_model_opt SERIAL)
if (WITH_TESTING)
add_dependencies(test_xpu_pass extern_lite_download_mobilenet_v1_tar_gz)
add_dependencies(test_subgraph_pass extern_lite_download_mobilenet_v2_relu_tar_gz)
set(LINK_FLAGS "-Wl,--version-script ${PADDLE_SOURCE_DIR}/lite/core/lite.map")
set_target_properties(test_xpu_pass PROPERTIES LINK_FLAGS "${LINK_FLAGS}")
endif()
endif()
set(subgraph_passes ${subgraph_passes} CACHE INTERNAL "subgraph_passes")
message(STATUS "----> subgraph_passes: ${subgraph_passes}")
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/subgraph/generate_xpu_program_pass.h"
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "lite/core/mir/graph_visualize_pass.h"
#include "lite/core/mir/pass_registry.h"
#include "lite/core/mir/pattern_matcher.h"
#include "lite/backends/xpu/builder.h"
#include "lite/kernels/xpu/bridges/paddle_use_xpu_bridges.h"
#include "lite/kernels/xpu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace mir {
namespace subgraph {
std::shared_ptr<xtcl::xExpr> GenerateXPUProgramPass::CvtVarNode(
lite::kernels::xpu::bridges::graph_ctx_type* graph_ctx,
lite::mir::Node* var_node,
const Scope* scope) {
CHECK(var_node->IsArg());
const auto& arg = var_node->AsArg();
auto var_name = arg.name;
VLOG(4) << "[XPU] Convert var node " << var_name;
auto* var = scope->FindVar(var_name);
CHECK(var);
auto* tensor = var->GetMutable<lite::Tensor>();
CHECK(tensor);
auto dims = tensor->dims();
auto cvted_var_node =
std::make_shared<xtcl::xExpr>(graph_ctx->builder->CreateTensor(
var_name, lite::xpu::CvtShape(dims), ::xtcl::Float(32)));
if (arg.is_weight) {
auto cvted_var_tensor = lite::xpu::CvtTensor(tensor);
graph_ctx->params->emplace(std::make_pair(var_name, *cvted_var_tensor));
}
return cvted_var_node;
}
void GenerateXPUProgramPass::CvtAllOpNodes(
const std::vector<Node*>& op_nodes,
lite::kernels::xpu::bridges::graph_ctx_type* graph_ctx,
lite::kernels::xpu::bridges::node_map_type* cvted_var_nodes) {
const auto& bridges = lite::kernels::xpu::bridges::Factory::Instance();
const auto& supported_lists = bridges.AllFunctions();
// return record all converted vars
// op node's inputs must be found in converted_vars
for (auto& node : op_nodes) {
lite::kernels::xpu::bridges::node_map_type input_nodes;
auto& stmt = node->AsStmt();
for (auto& var_node : node->inlinks) {
auto& arg = var_node->AsArg();
// weight should be handled in the converter, so skip here
if (arg.is_weight) {
continue;
}
auto var_name = arg.name;
if (!cvted_var_nodes->count(var_name)) {
cvted_var_nodes->insert(std::make_pair(
var_name, CvtVarNode(graph_ctx, var_node, stmt.op()->scope())));
}
input_nodes.insert(*cvted_var_nodes->find(var_name));
}
auto output_nodes =
supported_lists.at(stmt.op_type())(stmt.op(), graph_ctx, input_nodes);
cvted_var_nodes->insert(output_nodes.begin(), output_nodes.end());
}
}
std::string GenerateXPUProgramPass::BuildXPUGraph(
const std::unordered_set<Node*>& op_nodes,
const std::unordered_set<Node*>& in_data_vars,
const std::unordered_set<Node*>& out_data_vars,
int sub_id) {
auto ordered_op_nodes = GetTopologicalOrder(op_nodes);
lite::kernels::xpu::bridges::graph_ctx_type graph_ctx;
graph_ctx.builder = std::make_shared<xtcl::network::xNetworkBuilder>();
graph_ctx.params =
std::make_shared<xtcl::network::xTensorCompiler::ParamNDArrayMap>();
lite::kernels::xpu::bridges::node_map_type cvted_var_nodes;
CvtAllOpNodes(ordered_op_nodes, &graph_ctx, &cvted_var_nodes);
std::string weight_var_name = "graph" + std::to_string(sub_id) + "_weights";
auto any_op = (*op_nodes.begin())->AsStmt().op();
auto weight = any_op->scope()->Var(weight_var_name)->GetMutable<Tensor>();
weight->set_persistable(true);
weight->set_precision(PRECISION(kInt8));
// Compiling graph to XPU model and store mode data into weight tensor with
// persistable=true, Sothat the model parser can recognize it and save it to
// param files
std::vector<std::shared_ptr<xtcl::xExpr>> ordered_cvted_var_nodes;
for (auto out_data_var : out_data_vars) {
auto var_name = out_data_var->AsArg().name;
ordered_cvted_var_nodes.push_back(cvted_var_nodes[var_name]);
}
if (!lite::xpu::BuildModel(graph_ctx.builder,
graph_ctx.params,
&ordered_cvted_var_nodes,
weight)) {
LOG(WARNING) << "[XPU] Build XPU graph failed (subgraph=" << sub_id << ")";
throw std::runtime_error("[XPU] Build XPU graph failed.");
}
LOG(INFO) << "[XPU] Build XPU graph success (subgraph=" << sub_id << ")";
return weight_var_name;
}
void GenerateXPUProgramPass::GenXPUSubgraph(
const std::unique_ptr<SSAGraph>& graph,
const std::unordered_set<Node*>& op_nodes,
int sub_id) {
std::unordered_set<Node*> in_data_vars;
std::unordered_set<Node*> in_wgt_vars;
std::unordered_set<Node*> out_data_vars;
std::unordered_set<Node*> out_unused_vars;
FindInputOutputVars(
op_nodes, &in_data_vars, &in_wgt_vars, &out_data_vars, &out_unused_vars);
auto weight_var_name =
BuildXPUGraph(op_nodes, in_data_vars, out_data_vars, sub_id);
auto any_op = (*op_nodes.begin())->AsStmt().op();
InsertNewNode(graph,
weight_var_name,
any_op->scope(),
any_op->valid_places(),
in_data_vars,
in_wgt_vars,
out_data_vars,
out_unused_vars);
auto nodes2rm = GetNode2rm(
op_nodes, {in_data_vars, in_wgt_vars, out_data_vars, out_unused_vars});
GraphSafeRemoveNodes(graph.get(), nodes2rm);
}
void GenerateXPUProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
LOG(INFO) << "[XPU] Before XPU Pass \n" << Visualize(graph.get());
const auto& bridges = lite::kernels::xpu::bridges::Factory::Instance();
const auto& op_map = bridges.AllFunctions();
std::vector<std::string> supported_op_types;
for (auto& i : op_map) {
LOG(INFO) << "[XPU] Supported type: " << i.first;
supported_op_types.push_back(i.first);
}
try {
int num_subgraph = FuseSubgraph(graph, supported_op_types);
InferOnce(graph);
auto op_nodes_all = ClassifySubgraph(graph);
CHECK_EQ(op_nodes_all.size(), num_subgraph);
int id = 1;
for (auto& op_nodes : op_nodes_all) {
LOG(INFO) << "[XPU] Converting Subgraph " << id;
GenXPUSubgraph(graph, op_nodes.second, id);
LOG(INFO) << "[XPU] After XPU Pass Subgraph " << id << "\n"
<< Visualize(graph.get());
id++;
}
} catch (...) {
LOG(WARNING) << "[XPU] Build XPU graph failed.";
throw std::runtime_error("[XPU] Build XPU graph failed.");
}
for (auto& item : graph->StmtTopologicalOrder()) {
if (item->IsStmt()) {
auto& stmt = item->AsStmt();
LOG(INFO) << stmt;
insts_.emplace_back(stmt.op(), std::move(stmt.kernels().front()));
}
}
}
std::unique_ptr<RuntimeProgram> GenerateXPUProgramPass::GenProgram() {
LOG(INFO) << "[XPU] program insts.size=" << insts_.size();
std::unique_ptr<RuntimeProgram> program(
new RuntimeProgram(std::move(insts_)));
return program;
}
} // namespace subgraph
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(generate_xpu_program_pass,
paddle::lite::mir::subgraph::GenerateXPUProgramPass)
.BindTargets({TARGET(kXPU)});
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "lite/backends/xpu/builder.h"
#include "lite/core/mir/pass.h"
#include "lite/core/mir/subgraph/subgraph_program_pass.h"
#include "lite/kernels/xpu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace mir {
namespace subgraph {
class GenerateXPUProgramPass : public SubgraphProgramPass {
public:
using key2nodes_t = std::map<std::string, Node*>;
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
std::unique_ptr<RuntimeProgram> GenProgram();
protected:
// nodes2cvt: op nodes to convert
// return cvted_vars: converted var nodes
void CvtAllOpNodes(
const std::vector<Node*>& op_nodes,
lite::kernels::xpu::bridges::graph_ctx_type* graph_ctx,
lite::kernels::xpu::bridges::node_map_type* cvted_var_nodes);
std::shared_ptr<xtcl::xExpr> CvtVarNode(
lite::kernels::xpu::bridges::graph_ctx_type* graph_ctx,
lite::mir::Node* var_node,
const Scope* scope);
std::string BuildXPUGraph(const std::unordered_set<Node*>& op_nodes,
const std::unordered_set<Node*>& in_data_vars,
const std::unordered_set<Node*>& out_data_vars,
int sub_id);
void GenXPUSubgraph(const std::unique_ptr<SSAGraph>& graph,
const std::unordered_set<Node*>& op_nodes,
int sub_id);
private:
std::vector<Instruction> insts_;
};
} // namespace subgraph
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <cmath>
#include "lite/api/paddle_api.h"
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/api/paddle_use_passes.h"
#include "lite/api/test_helper.h"
#include "lite/utils/cp_logging.h"
DEFINE_string(model_file, "", "model file path of combined protobuf model");
DEFINE_string(params_file, "", "params file path of combined protobuf model");
DEFINE_string(optimized_model_dir, "", "path of optimized naive buffer model");
DEFINE_string(input_tensor_shape, "1,3,224,224", "shapes of input tensors");
DEFINE_int32(output_tensor_num, 1, "number of output tensors");
namespace paddle {
namespace lite {
std::vector<std::vector<int64_t>> ParseShape(std::string txt) {
std::vector<std::vector<int64_t>> shape;
while (!txt.empty()) {
size_t idx = txt.find_first_of(":");
std::string dims = txt.substr(0, idx);
std::vector<int64_t> s;
while (!dims.empty()) {
size_t idx = dims.find_first_of(",");
int d = atoi(dims.substr(0, idx).c_str());
VLOG(3) << d;
s.push_back(d);
if (idx == std::string::npos) {
break;
} else {
dims = dims.substr(idx + 1);
}
}
shape.push_back(s);
if (idx == std::string::npos) {
break;
} else {
txt = txt.substr(idx + 1);
}
}
return shape;
}
int64_t ShapeProduction(std::vector<int64_t> shape) {
int64_t s = 1;
for (int64_t dim : shape) {
s *= dim;
}
return s;
}
void FillInputTensor(
const std::shared_ptr<lite_api::PaddlePredictor>& predictor,
const std::vector<std::vector<int64_t>>& input_tensor_shape,
const float value) {
for (int i = 0; i < input_tensor_shape.size(); i++) {
auto input_tensor = predictor->GetInput(i);
input_tensor->Resize(input_tensor_shape[i]);
auto input_tensor_data = input_tensor->mutable_data<float>();
auto input_tensor_size = ShapeProduction(input_tensor->shape());
for (int j = 0; j < input_tensor_size; j++) {
input_tensor_data[j] = value;
}
}
}
void CompareOutputTensor(
const std::shared_ptr<lite_api::PaddlePredictor>& tar_predictor,
const std::shared_ptr<lite_api::PaddlePredictor>& ref_predictor,
const int output_tensor_num) {
for (int i = 0; i < output_tensor_num; i++) {
auto tar_output_tensor = tar_predictor->GetOutput(i);
auto ref_output_tensor = ref_predictor->GetOutput(i);
auto tar_output_tensor_data = tar_output_tensor->data<float>();
auto ref_output_tensor_data = ref_output_tensor->data<float>();
auto tar_output_tensor_size = ShapeProduction(tar_output_tensor->shape());
auto ref_output_tensor_size = ShapeProduction(ref_output_tensor->shape());
EXPECT_EQ(tar_output_tensor_size, ref_output_tensor_size);
for (size_t j = 0; j < ref_output_tensor_size; j++) {
auto diff =
std::fabs(tar_output_tensor_data[j] - ref_output_tensor_data[j]) /
(std::fabs(ref_output_tensor_data[j]) + 1e-6);
VLOG(3) << diff;
EXPECT_LT(diff, 0.1);
}
}
}
std::shared_ptr<lite_api::PaddlePredictor> TestModel(
const std::string& model_dir,
const std::string& model_file,
const std::string& params_file,
const std::vector<lite_api::Place>& valid_places,
const std::vector<std::vector<int64_t>>& input_tensor_shape,
const std::string& optimized_model_dir) {
// generate optimized model
lite_api::CxxConfig cxx_config;
cxx_config.set_model_dir(model_dir);
cxx_config.set_model_file(model_file);
cxx_config.set_param_file(params_file);
cxx_config.set_valid_places(valid_places);
auto predictor = lite_api::CreatePaddlePredictor(cxx_config);
FillInputTensor(predictor, input_tensor_shape, -1);
predictor->SaveOptimizedModel(optimized_model_dir,
lite_api::LiteModelType::kNaiveBuffer);
#if 0 // TODO(hong19860320) supports light api for XPU
// load optimized model
lite_api::MobileConfig mobile_config;
mobile_config.set_model_dir(optimized_model_dir);
mobile_config.set_power_mode(lite_api::PowerMode::LITE_POWER_HIGH);
mobile_config.set_threads(1);
predictor = lite_api::CreatePaddlePredictor(mobile_config);
FillInputTensor(predictor, input_tensor_shape, 1);
#endif
// run optimized model
for (int i = 0; i < FLAGS_warmup; i++) {
predictor->Run();
}
for (int i = 0; i < FLAGS_repeats; i++) {
auto start = GetCurrentUS();
predictor->Run();
LOG(INFO) << i << ", " << GetCurrentUS() - start << "us";
}
return predictor;
}
TEST(XPUSubgraph, compare) {
// parsing input tensor shape, supported formats: "1,3,224,224"
// "1,3,224,224:1,80"
std::vector<std::vector<int64_t>> input_tensor_shape =
ParseShape(FLAGS_input_tensor_shape);
// generate and run optimized CPU model
LOG(INFO) << " ================ CPU ================== ";
auto cpu_predictor =
TestModel(FLAGS_model_dir,
FLAGS_model_file,
FLAGS_params_file,
{lite_api::Place{TARGET(kX86), PRECISION(kFloat)}},
input_tensor_shape,
FLAGS_optimized_model_dir + "/CPU");
// generate and run optimized XPU model
LOG(INFO) << " ================ XPU ================== ";
auto xpu_predictor =
TestModel(FLAGS_model_dir,
FLAGS_model_file,
FLAGS_params_file,
{lite_api::Place{TARGET(kXPU), PRECISION(kFloat)},
lite_api::Place{TARGET(kX86), PRECISION(kFloat)}},
input_tensor_shape,
FLAGS_optimized_model_dir + "/XPU");
// verify results
CompareOutputTensor(xpu_predictor, cpu_predictor, FLAGS_output_tensor_num);
}
} // namespace lite
} // namespace paddle
......@@ -207,8 +207,26 @@ void SubgraphProgramPass::InferOnce(const std::unique_ptr<SSAGraph>& graph) {
if (!item->IsStmt()) continue;
auto& stmt = item->AsStmt();
auto& op = stmt.op();
auto scope = op->scope();
std::string op_type = op->op_info()->Type();
if (op_type == "feed" || op_type == "fetch") continue;
// check the dimension of input variables in the scope, must not be empty !
if (op_type == "feed") {
auto input_var_names = op->op_info()->output_names();
CHECK_GE(input_var_names.size(), 1);
for (auto input_var_name : input_var_names) {
auto input_var = scope->FindVar(input_var_name);
CHECK(input_var) << "No input variable '" << input_var_name
<< "' found in scope " << scope;
auto input = input_var->GetMutable<lite::Tensor>();
CHECK(!input->dims().empty()) << "The dimension of input variable '"
<< input_var_name
<< "' can not be empty.";
}
continue;
}
if (op_type == "fetch") {
continue;
}
op->CheckShape();
op->InferShape();
// TOOD(xxx): remove Launch() at last
......
......@@ -46,6 +46,9 @@ TEST(SubgraphTest, models) {
#endif
#ifdef LITE_WITH_NPU
Place{TARGET(kNPU), PRECISION(kFloat)},
#endif
#ifdef LITE_WITH_XPU
Place{TARGET(kXPU), PRECISION(kFloat)},
#endif
});
lite::Program program(program_desc, scope, valid_places);
......
......@@ -78,6 +78,9 @@ std::list<std::unique_ptr<KernelBase>> KernelRegistry::Create(
case TARGET(kNPU): {
CREATE_KERNEL(kNPU);
} break;
case TARGET(kXPU): {
CREATE_KERNEL(kXPU);
} break;
case TARGET(kFPGA): {
CREATE_KERNEL(kFPGA);
} break;
......@@ -142,6 +145,11 @@ KernelRegistry::KernelRegistry()
INIT_FOR(kNPU, kAny, kNCHW);
INIT_FOR(kNPU, kAny, kAny);
INIT_FOR(kXPU, kFloat, kNCHW);
INIT_FOR(kXPU, kInt8, kNCHW);
INIT_FOR(kXPU, kAny, kNCHW);
INIT_FOR(kXPU, kAny, kAny);
INIT_FOR(kFPGA, kFP16, kNHWC);
INIT_FOR(kFPGA, kFP16, kAny);
INIT_FOR(kFPGA, kFloat, kNHWC);
......
......@@ -178,6 +178,16 @@ class KernelRegistry final {
PRECISION(kInt8),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kXPU),
PRECISION(kAny),
DATALAYOUT(kAny)> *, //
KernelRegistryForTarget<TARGET(kXPU),
PRECISION(kFloat),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kXPU),
PRECISION(kInt8),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kFPGA),
PRECISION(kFloat),
DATALAYOUT(kNCHW)> *, //
......
......@@ -28,6 +28,9 @@
#ifdef LITE_WITH_NPU
#include "lite/core/mir/subgraph/generate_npu_program_pass.h"
#endif
#ifdef LITE_WITH_XPU
#include "lite/core/mir/subgraph/generate_xpu_program_pass.h"
#endif
namespace paddle {
namespace lite {
......@@ -106,7 +109,8 @@ class Optimizer {
"runtime_context_assign_pass",
"argument_type_display_pass", //
#if !defined(LITE_WITH_OPENCL) && !defined(LITE_WITH_NPU)
#if !defined(LITE_WITH_OPENCL) && !defined(LITE_WITH_NPU) && \
!defined(LITE_WITH_XPU)
// TODO(ysh329): cause CL_INVALID_MEM_OBJECT when setArg in kernel
"memory_optimize_pass",
#endif
......@@ -121,14 +125,27 @@ class Optimizer {
// Generate a new program based on the mir graph.
std::unique_ptr<RuntimeProgram> GenRuntimeProgram() {
#if defined(LITE_WITH_NPU) || defined(LITE_WITH_XPU)
auto target_place = Place{
#ifdef LITE_WITH_NPU
if (std::find(valid_places_.begin(),
valid_places_.end(),
Place{TARGET(kNPU), PRECISION(kFloat)}) !=
TARGET(kNPU),
#endif
#ifdef LITE_WITH_XPU
TARGET(kXPU),
#endif
PRECISION(kFloat)};
if (std::find(valid_places_.begin(), valid_places_.end(), target_place) !=
valid_places_.end()) {
#ifdef LITE_WITH_NPU
auto pass = mir::PassManager::Global()
.LookUp<mir::subgraph::GenerateNPUProgramPass>(
"generate_npu_program_pass");
#endif
#ifdef LITE_WITH_XPU
auto pass = mir::PassManager::Global()
.LookUp<mir::subgraph::GenerateXPUProgramPass>(
"generate_xpu_program_pass");
#endif
try {
pass->Apply(graph_);
auto program = pass->GenProgram();
......@@ -136,7 +153,8 @@ class Optimizer {
program->set_exec_scope(exec_scope_);
return program;
} catch (...) {
LOG(WARNING) << "Build NPU graph failed";
LOG(WARNING) << "Build " << TargetToStr(target_place.target)
<< " program failed!";
}
}
#endif
......
......@@ -15,6 +15,7 @@ lite_cc_test(test_gen_code SRCS gen_code_test.cc
X86_DEPS ${x86_kernels}
ARM_DEPS ${arm_kernels}
NPU_DEPS ${npu_kernels}
XPU_DEPS ${xpu_kernels}
CL_DEPS ${opencl_kernels}
FPGA_DEPS ${fpga_kernels}
EXCLUDE_COMPILE_DEPS "ON"
......@@ -42,6 +43,7 @@ lite_cc_test(test_generated_code SRCS generated_code_test.cc DEPS __generated_co
X86_DEPS ${x86_kernels}
ARM_DEPS ${arm_kernels}
NPU_DEPS ${npu_kernels}
XPU_DEPS ${xpu_kernels}
CL_DEPS ${opencl_kernels}
FPGA_DEPS ${fpga_kernels}
EXCLUDE_COMPILE_DEPS "ON"
......
......@@ -9,3 +9,4 @@ add_subdirectory(x86)
add_subdirectory(opencl)
add_subdirectory(fpga)
add_subdirectory(npu)
add_subdirectory(xpu)
if(NOT LITE_WITH_XPU)
return ()
endif()
add_kernel(graph_compute_xpu XPU basic SRCS graph_compute.cc DEPS ${lite_kernel_deps} xpu_runtime)
# lite_cc_test(test_graph_compute_xpu SRCS graph_compute_test.cc DEPS graph_compute_xpu)
add_subdirectory(bridges)
lite_cc_library(xpu_bridge_registry SRCS registry.cc)
set(xpu_bridge_deps xpu_bridge_registry xpu_builder op)
lite_cc_library(xpu_bridge_act_op SRCS act_op.cc DEPS ${xpu_bridge_deps})
lite_cc_library(xpu_bridge_conv_op SRCS conv_op.cc DEPS ${xpu_bridge_deps})
set(xpu_bridges
xpu_bridge_registry
xpu_bridge_act_op
xpu_bridge_conv_op
CACHE INTERNAL "xpu_bridges")
set(xpu_bridge_test_deps ${xpu_bridges} ${xpu_kernels} ${ops})
lite_cc_test(test_xpu_bridge_act_op SRCS act_op_test.cc test_helper.cc DEPS ${xpu_bridge_test_deps})
lite_cc_test(test_xpu_bridge_conv_op SRCS conv_op_test.cc test_helper.cc DEPS ${xpu_bridge_test_deps})
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/xpu/builder.h"
#include "lite/kernels/xpu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
namespace bridges {
node_map_type ActConverter(const std::shared_ptr<lite::OpLite> op,
graph_ctx_type* graph_ctx,
const node_map_type& input_nodes) {
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = lite::xpu::UniqueName(op_type);
LOG(INFO) << "[XPU] Converting " + op_type + "...";
// check context
CHECK(graph_ctx != nullptr);
CHECK(graph_ctx->builder != nullptr);
CHECK(graph_ctx->params != nullptr);
// create act node and set params from op
auto x_var_name = op_info->Input("X").front();
CHECK(input_nodes.count(x_var_name));
std::shared_ptr<xtcl::xExpr> act_node = nullptr;
if (op_type == "relu") {
act_node = std::make_shared<xtcl::xExpr>(
graph_ctx->builder->CreateRelu(*input_nodes.at(x_var_name)));
} else {
// TODO(hong19860320) supports more activation ops
LOG(FATAL) << "[XPU] Unsupported activation type " << op_type;
}
graph_ctx->builder->SetLayer(unique_op_type);
// output converted nodes
node_map_type output_nodes;
output_nodes[op_info->Output("Out").front()] = act_node;
return output_nodes;
}
} // namespace bridges
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_XPU_BRIDGE(relu, paddle::lite::kernels::xpu::bridges::ActConverter);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <random>
#include "lite/core/op_registry.h"
#include "lite/kernels/xpu/bridges/registry.h"
#include "lite/kernels/xpu/bridges/test_helper.h"
#include "lite/operators/activation_ops.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
namespace bridges {
void relu_ref(const std::shared_ptr<operators::ActivationOp> op) {
Scope* scope = op->scope();
const OpInfo* op_info = op->op_info();
auto x = scope->FindVar(op_info->Input("X").front())->GetMutable<Tensor>();
auto out =
scope->FindVar(op_info->Output("Out").front())->GetMutable<Tensor>();
auto x_data = x->data<float>();
auto out_data = out->mutable_data<float>();
DDim x_dims = x->dims();
DDim out_dims = out->dims();
CHECK_EQ(x_dims.production(), out_dims.production());
for (int i = 0; i < out_dims.production(); i++) {
out_data[i] = std::max(0.f, x_data[i]);
}
}
void test_relu(int bs, int ic, int ih, int iw) {
// prepare input&output variables
Scope scope;
std::string x_var_name("x");
std::string out_var_name("out");
std::string out_ref_var_name("out_ref");
auto* x = scope.Var(x_var_name)->GetMutable<Tensor>();
auto* out = scope.Var(out_var_name)->GetMutable<Tensor>();
auto* out_ref = scope.Var(out_ref_var_name)->GetMutable<Tensor>();
x->Resize({bs, ic, ih, iw});
// initialize input&output data
FillTensor<float, int>(x);
// initialize op desc
cpp::OpDesc opdesc;
opdesc.SetType("relu");
opdesc.SetInput("X", {x_var_name});
opdesc.SetOutput("Out", {out_var_name});
// create and convert op to XPU model, and run it on XPU
auto op = CreateOp<operators::ActivationOp>(opdesc, &scope);
LauchOp(op, {x_var_name}, {out_var_name});
out_ref->CopyDataFrom(*out);
// execute reference implementation and save to output tensor
relu_ref(op);
// compare results
auto* out_data = out->mutable_data<float>();
auto* out_ref_data = out_ref->mutable_data<float>();
for (int i = 0; i < out->dims().production(); i++) {
VLOG(5) << i;
EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-5);
}
}
TEST(NPUBridges, relu) {
for (auto bs : {1, 3}) {
for (auto ic : {3, 4}) {
for (auto ih : {2, 5}) {
for (auto iw : {5, 9}) {
VLOG(3) << "bs: " << bs << " ic: " << ic << " ih: " << ih
<< " iw: " << iw;
test_relu(bs, ic, ih, iw);
}
}
}
}
}
} // namespace bridges
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_OP(relu);
USE_XPU_BRIDGE(relu);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/xpu/builder.h"
#include "lite/kernels/xpu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
namespace bridges {
node_map_type ConvConverter(const std::shared_ptr<lite::OpLite> op,
graph_ctx_type* graph_ctx,
const node_map_type& input_nodes) {
auto scope = op->scope();
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = lite::xpu::UniqueName(op_type);
LOG(INFO) << "[XPU] Converting " << op_type << "... ";
// get input, filter and op attributes
auto input_var_name = op_info->Input("Input").front();
auto input = scope->FindVar(input_var_name)->GetMutable<lite::Tensor>();
auto input_dims = input->dims();
auto filter_var_name = op_info->Input("Filter").front();
auto filter = scope->FindVar(filter_var_name)->GetMutable<lite::Tensor>();
auto filter_dims = filter->dims();
auto bs = input_dims[0];
auto oc = filter_dims[0];
CHECK_EQ(input_dims.size(), 4);
CHECK_EQ(filter_dims.size(), 4);
auto strides = op_info->GetAttr<std::vector<int>>("strides");
auto paddings = op_info->GetAttr<std::vector<int>>("paddings");
auto groups = op_info->GetAttr<int>("groups");
auto dilations = op_info->GetAttr<std::vector<int>>("dilations");
auto fuse_relu = op_info->GetAttr<bool>("fuse_relu");
CHECK_EQ(strides.size(), 2);
CHECK_EQ(paddings.size(), 2);
CHECK_EQ(dilations.size(), 2);
std::vector<int64_t> output_shape({bs, oc});
for (size_t i = 0; i < 2; i++) {
const int dkernel = dilations[i] * (filter_dims[2 + i] - 1) + 1;
output_shape.push_back(
(input_dims[i + 2] + 2 * paddings[i] - dkernel) / strides[i] + 1);
}
DDim output_dims(output_shape);
// check context
CHECK(graph_ctx != nullptr);
CHECK(graph_ctx->builder != nullptr);
CHECK(graph_ctx->params != nullptr);
// create filter node
CHECK(!input_nodes.count(filter_var_name));
auto filter_const_node = std::make_shared<xtcl::xExpr>(
graph_ctx->builder->CreateTensor(filter_var_name,
lite::xpu::CvtShape(filter_dims),
::xtcl::Float(32)));
auto filter_const_tensor = lite::xpu::CvtTensor(filter);
graph_ctx->params->emplace(
std::make_pair(filter_var_name, *filter_const_tensor));
// create conv node and set input, filter, bias nodes and attributes
auto conv_attrs = xtcl::make_node<xtcl::network::Conv2DAttrs>();
conv_attrs->strides = std::move(lite::xpu::CvtShape(strides));
conv_attrs->padding = std::move(lite::xpu::CvtShape(paddings));
conv_attrs->dilation = std::move(lite::xpu::CvtShape(dilations));
conv_attrs->groups = groups;
// conv_attrs->channels = nullptr;
conv_attrs->kernel_size = std::move(xtcl::Array<xtcl::xIndexExpr>(nullptr));
conv_attrs->data_layout = "NCHW";
conv_attrs->kernel_layout = "OIHW";
conv_attrs->out_layout = "";
// conv_attrs->out_dtype = "";
CHECK(input_nodes.count(input_var_name));
auto conv_node =
std::make_shared<xtcl::xExpr>(graph_ctx->builder->CreateConv2D(
*input_nodes.at(input_var_name), *filter_const_node, conv_attrs));
graph_ctx->builder->SetLayer(unique_op_type);
// create bias node if has bias
// supports the bias nodes with the following dimensions
// 0: {oc}
// 1: {1, oc, oh, ow}
// 2: {n, oc, oh, ow}
if (lite::xpu::HasInputArg(op_info, scope, "Bias")) {
auto bias_var_name = op_info->Input("Bias").front();
auto* bias = scope->FindVar(bias_var_name)->GetMutable<lite::Tensor>();
auto bias_dims = bias->dims();
auto bias_data_size = bias_dims.production();
auto output_data_size = output_dims.production();
std::vector<int64_t> bias_shape;
bool is_channel_bias = false;
if (bias_data_size == oc) {
// 0: {oc}
bias_shape = {oc};
is_channel_bias = true;
} else if (bias_data_size == output_data_size / bs) {
// 1: {1, oc, oh, ow}
bias_shape = {1, output_dims[1], output_dims[2], output_dims[3]};
} else if (bias_data_size == output_data_size) {
// 2: {n, oc, oh, ow}
bias_shape = output_dims.Vectorize();
} else {
LOG(ERROR) << "bias dimension " << bias_dims
<< " isn't supported in conv2d Op when output dimension is "
<< output_dims;
}
std::shared_ptr<xtcl::xExpr> bias_node = nullptr;
if (input_nodes.count(bias_var_name)) {
// bias node from input node
bias_node = input_nodes.at(bias_var_name);
} else {
// bias node with const tensor
auto bias_const_node = std::make_shared<xtcl::xExpr>(
graph_ctx->builder->CreateTensor(bias_var_name,
lite::xpu::CvtShape(bias_shape),
::xtcl::Float(32)));
auto bias_const_tensor = lite::xpu::CvtTensor(bias, bias_shape);
graph_ctx->params->emplace(
std::make_pair(bias_var_name, *bias_const_tensor));
bias_node = bias_const_node;
}
std::shared_ptr<xtcl::xExpr> add_node = nullptr;
if (is_channel_bias) {
add_node = std::make_shared<xtcl::xExpr>(
graph_ctx->builder->CreateBiasAdd(*conv_node, *bias_node, 1));
} else {
add_node = std::make_shared<xtcl::xExpr>(
graph_ctx->builder->CreateBinaryOp("add", *conv_node, *bias_node));
}
graph_ctx->builder->SetLayer(unique_op_type + "/add");
conv_node = add_node;
}
// output converted nodes
node_map_type output_nodes;
if (fuse_relu) {
// append relu node if fuse_relu is true
auto relu_node = std::make_shared<xtcl::xExpr>(
graph_ctx->builder->CreateRelu(*conv_node));
graph_ctx->builder->SetLayer(unique_op_type + "/relu");
output_nodes[op_info->Output("Output").front()] = relu_node;
} else {
output_nodes[op_info->Output("Output").front()] = conv_node;
}
return output_nodes;
}
} // namespace bridges
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_XPU_BRIDGE(conv2d, paddle::lite::kernels::xpu::bridges::ConvConverter);
REGISTER_XPU_BRIDGE(depthwise_conv2d,
paddle::lite::kernels::xpu::bridges::ConvConverter);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/conv_op.h"
#include <gtest/gtest.h>
#include <random>
#include "lite/core/op_registry.h"
#include "lite/kernels/xpu/bridges/registry.h"
#include "lite/kernels/xpu/bridges/test_helper.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
namespace bridges {
void conv_ref(const std::shared_ptr<operators::ConvOpLite> op) {
Scope* scope = op->scope();
const OpInfo* op_info = op->op_info();
auto input =
scope->FindVar(op_info->Input("Input").front())->GetMutable<Tensor>();
auto filter =
scope->FindVar(op_info->Input("Filter").front())->GetMutable<Tensor>();
auto output =
scope->FindVar(op_info->Output("Output").front())->GetMutable<Tensor>();
std::vector<int32_t> strides =
op_info->GetAttr<std::vector<int32_t>>("strides");
std::vector<int32_t> paddings =
op_info->GetAttr<std::vector<int32_t>>("paddings");
int32_t groups = op_info->GetAttr<int32_t>("groups");
std::vector<int32_t> dilations =
op_info->GetAttr<std::vector<int32_t>>("dilations");
bool fuse_relu = op_info->GetAttr<bool>("fuse_relu");
auto input_dims = input->dims();
auto filter_dims = filter->dims();
auto output_dims = output->dims();
auto input_data = input->mutable_data<float>();
auto filter_data = filter->mutable_data<float>();
auto output_data = output->mutable_data<float>();
int kernel_w = filter_dims[3];
int kernel_h = filter_dims[2];
int stride_w = strides[1];
int stride_h = strides[0];
int dila_w = dilations[1];
int dila_h = dilations[0];
int pad_w = paddings[1];
int pad_h = paddings[0];
int batch_size = input_dims[0];
int in_ch_size = input_dims[1];
int in_h = input_dims[2];
int in_w = input_dims[3];
int out_ch_size = output_dims[1];
int out_h = output_dims[2];
int out_w = output_dims[3];
int out_c_group = out_ch_size / groups;
int in_c_group = in_ch_size / groups;
Tensor* bias = nullptr;
float* bias_data = nullptr;
bool is_channel_bias = false;
if (op_info->HasInput("Bias")) {
auto bias_var_names = op_info->Input("Bias");
if (bias_var_names.size() > 0) {
auto bias_var_name = bias_var_names.front();
bias = scope->FindVar(bias_var_name)->GetMutable<lite::Tensor>();
auto bias_dims = bias->dims();
is_channel_bias = bias_dims.production() == out_ch_size;
bias_data = bias->mutable_data<float>();
}
}
for (int n = 0; n < batch_size; ++n) {
for (int g = 0; g < groups; ++g) {
for (int oc = 0; oc < out_c_group; ++oc) {
for (int oh = 0; oh < out_h; ++oh) {
for (int ow = 0; ow < out_w; ++ow) {
int out_idx = n * groups * out_c_group * out_h * out_w +
g * out_c_group * out_h * out_w + oc * out_h * out_w +
oh * out_w + ow;
float out_value =
bias_data != nullptr
? (is_channel_bias ? bias_data[g * out_c_group + oc]
: bias_data[out_idx])
: 0;
// + out_value *= beta;
for (int ic = 0; ic < in_c_group; ++ic) {
for (int kh = 0; kh < kernel_h; ++kh) {
for (int kw = 0; kw < kernel_w; ++kw) {
int iw = ow * stride_w - pad_w + kw * (dila_w);
int ih = oh * stride_h - pad_h + kh * (dila_h);
if (iw < 0 || iw >= in_w) continue;
if (ih < 0 || ih >= in_h) continue;
int in_idx = n * in_ch_size * in_h * in_w +
g * in_c_group * in_h * in_w + ic * in_h * in_w +
ih * in_w + iw;
int filter_idx =
g * out_c_group * in_c_group * kernel_h * kernel_w +
oc * in_c_group * kernel_h * kernel_w +
ic * kernel_h * kernel_w + kh * kernel_w + kw;
out_value += input_data[in_idx] * filter_data[filter_idx];
}
}
}
if (fuse_relu) {
out_value = out_value > 0 ? out_value : 0;
}
output_data[out_idx] = out_value;
}
}
}
}
}
}
void test_conv(int bs,
int ic,
int oc,
int ih,
int iw,
bool has_bias,
bool is_channel_bias,
bool fuse_relu,
bool depthwise,
int dilation,
int stride,
int padding,
int kernel) {
// prepare input&output variables
Scope scope;
std::string input_var_name("input");
std::string filter_var_name("filter");
std::string bias_var_name("bias");
std::string output_var_name("output");
std::string output_ref_var_name("output_ref");
auto* input = scope.Var(input_var_name)->GetMutable<Tensor>();
auto* filter = scope.Var(filter_var_name)->GetMutable<Tensor>();
auto* bias = scope.Var(bias_var_name)->GetMutable<Tensor>();
auto* output = scope.Var(output_var_name)->GetMutable<Tensor>();
auto* output_ref = scope.Var(output_ref_var_name)->GetMutable<Tensor>();
// get group size and input&filter shape
int groups = 1;
if (depthwise) { // depthwise convolution ?
groups = oc = ic;
}
std::vector<int64_t> input_shape = {bs, ic, ih, iw};
std::vector<int64_t> filter_shape = {oc, ic / groups, kernel, kernel};
std::vector<int64_t> output_shape({bs, oc});
for (size_t i = 0; i < 2; i++) {
const int dkernel = dilation * (kernel - 1) + 1;
int output_size = (input_shape[i + 2] + 2 * padding - dkernel) / stride + 1;
output_shape.push_back(output_size);
}
input->Resize(input_shape);
filter->Resize(filter_shape);
// initialize input&output data
FillTensor<float, int>(input);
FillTensor<float, int>(filter);
// initialize op desc
cpp::OpDesc opdesc;
opdesc.SetType(depthwise ? "depthwise_conv2d" : "conv2d");
opdesc.SetInput("Input", {input_var_name});
opdesc.SetInput("Filter", {filter_var_name});
opdesc.SetOutput("Output", {output_var_name});
opdesc.SetAttr("dilations", std::vector<int32_t>({dilation, dilation}));
opdesc.SetAttr("strides", std::vector<int32_t>({stride, stride}));
opdesc.SetAttr("paddings", std::vector<int32_t>({padding, padding}));
opdesc.SetAttr("groups", groups);
opdesc.SetAttr("fuse_relu", static_cast<bool>(fuse_relu));
if (has_bias) {
if (is_channel_bias) {
bias->Resize({1, oc, 1, 1});
} else {
bias->Resize({1, output_shape[1], output_shape[2], output_shape[3]});
}
FillTensor<float, int>(bias);
opdesc.SetInput("Bias", {bias_var_name});
}
// create and convert op to NPU model, then run it on NPU
auto op = CreateOp<operators::ConvOpLite>(opdesc, &scope);
LauchOp(op, {input_var_name}, {output_var_name});
output_ref->CopyDataFrom(*output);
// execute reference implementation and save to output tensor('out')
conv_ref(op);
// compare results
auto* output_data = output->mutable_data<float>();
auto* output_ref_data = output_ref->mutable_data<float>();
for (int i = 0; i < output->dims().production(); i++) {
VLOG(5) << i;
EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5);
}
}
TEST(NPUBridges, conv) {
#if 0
for (auto bs : {1, 2}) {
for (auto ic : {3, 6}) {
for (auto oc : {6, 9}) {
for (auto ih : {14, 28}) {
for (auto iw : {14, 28}) {
for (auto has_bias : {false, true}) {
for (auto is_channel_bias : {false, true}) {
for (auto fuse_relu : {false, true}) {
for (auto depthwise : {false, true}) {
for (auto dilation : {1, 2}) {
for (auto stride : {1, 2}) {
for (auto kernel : {1, 3, 5}) {
std::vector<int> paddings = {kernel / 2};
if (kernel / 2 != 0) {
paddings.push_back(0);
}
for (auto padding : paddings) {
VLOG(3) << "bs: " << bs << " ic: " << ic
<< " oc: " << oc << " ih: " << ih
<< " iw: " << iw
<< " has_bias: " << has_bias
<< " is_channel_bias: " << is_channel_bias
<< " fuse_relu: " << fuse_relu
<< " depthwise: " << depthwise
<< " dilation: " << dilation
<< " stride: " << stride
<< " padding: " << padding
<< " kernel: " << kernel;
test_conv(bs,
ic,
oc,
ih,
iw,
has_bias,
is_channel_bias,
fuse_relu,
depthwise,
dilation,
stride,
padding,
kernel);
}
}
}
}
}
}
}
}
}
}
}
}
}
#else
test_conv(1, 1, 1, 4, 4, false, false, false, false, 1, 1, 1, 3);
test_conv(1, 1, 1, 4, 4, true, true, false, false, 1, 1, 1, 3);
test_conv(1, 1, 1, 4, 4, true, false, false, false, 1, 1, 1, 3);
#endif
}
} // namespace bridges
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_OP(conv2d);
USE_XPU_BRIDGE(conv2d);
USE_LITE_OP(depthwise_conv2d);
USE_XPU_BRIDGE(depthwise_conv2d);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/kernels/xpu/bridges/registry.h"
USE_XPU_BRIDGE(relu);
USE_XPU_BRIDGE(conv2d);
USE_XPU_BRIDGE(depthwise_conv2d);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/xpu/bridges/registry.h"
#include <utility>
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
namespace bridges {
Factory& Factory::Instance() {
static Factory g_xpu_bridge;
return g_xpu_bridge;
}
bool Factory::HasType(const std::string& op_type) const {
return map_.count(op_type);
}
void Factory::Insert(const std::string& op_type, const func_type& func_name) {
map_.insert(std::make_pair(op_type, func_name));
}
} // namespace bridges
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <xtcl/xtcl.h>
#include <functional>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/utils/macros.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
namespace bridges {
// xpu network builder and constant tensors
class graph_ctx_type {
public:
std::shared_ptr<xtcl::network::xNetworkBuilder> builder;
std::shared_ptr<xtcl::network::xTensorCompiler::ParamNDArrayMap> params;
};
// var_name, xpu node pointer
using node_map_type =
std::unordered_map<std::string, std::shared_ptr<xtcl::xExpr>>;
using func_type = std::function<node_map_type(
const std::shared_ptr<OpLite>, graph_ctx_type*, const node_map_type&)>;
using cvt_map_type = std::unordered_map<std::string, func_type>;
class Factory {
public:
static Factory& Instance();
const cvt_map_type& AllFunctions() const { return map_; }
bool HasType(const std::string& op_type) const;
void Insert(const std::string& op_type, const func_type& func_name);
Factory() = default;
private:
cvt_map_type map_;
DISALLOW_COPY_AND_ASSIGN(Factory);
};
} // namespace bridges
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
// some platform-independent defintion
#if defined(_WIN32)
#define UNUSED
#define __builtin_expect(EXP, C) (EXP)
#else
#define UNUSED __attribute__((unused))
#endif
#define STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE(uniq_name, msg) \
struct __test_global_namespace_##uniq_name##__ {}; \
static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \
__test_global_namespace_##uniq_name##__>::value, \
msg)
#define REGISTER_XPU_BRIDGE(op_type, cvt_func_name) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
__reg_xpu_bridge_##op_type##__, \
"REGISTER_XPU_BRIDGE must be called in global namespace only once!"); \
int __reg_xpu_bridge_##op_type##_Insert() { \
paddle::lite::kernels::xpu::bridges::Factory::Instance().Insert( \
#op_type, cvt_func_name); \
return 0; \
}
#define USE_XPU_BRIDGE(op_type) \
extern int __reg_xpu_bridge_##op_type##_Insert(); \
static int __reg_xpu_bridge_##op_type##_Insert_return UNUSED = \
__reg_xpu_bridge_##op_type##_Insert();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/xpu/bridges/test_helper.h"
#include <utility>
#include "lite/backends/xpu/builder.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/xpu/bridges/registry.h"
#include "lite/operators/graph_op.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
namespace bridges {
void LauchOp(const std::shared_ptr<lite::OpLite> op,
const std::vector<std::string>& input_var_names,
const std::vector<std::string>& output_var_names) {
auto scope = op->scope();
auto op_type = op->op_info()->Type();
// convert lite op to XPU op
const auto& bridges = lite::kernels::xpu::bridges::Factory::Instance();
const auto& supported_lists = bridges.AllFunctions();
CHECK(bridges.HasType(op_type));
graph_ctx_type graph_ctx;
graph_ctx.builder = std::make_shared<xtcl::network::xNetworkBuilder>();
graph_ctx.params =
std::make_shared<xtcl::network::xTensorCompiler::ParamNDArrayMap>();
node_map_type input_nodes;
for (auto input_var_name : input_var_names) {
auto input = scope->FindVar(input_var_name)->GetMutable<lite::Tensor>();
auto input_node = std::make_shared<xtcl::xExpr>(
graph_ctx.builder->CreateTensor(input_var_name,
lite::xpu::CvtShape(input->dims()),
::xtcl::Float(32)));
input_nodes[input_var_name] = input_node;
}
auto output_nodes = supported_lists.at(op_type)(op, &graph_ctx, input_nodes);
CHECK_GT(output_nodes.size(), 0);
// build network graph and output model data
std::vector<std::shared_ptr<xtcl::xExpr>> ordered_output_nodes;
for (auto output_var_name : output_var_names) {
ordered_output_nodes.push_back(output_nodes.at(output_var_name));
}
std::string weight_var_name = "weight";
auto weight = scope->Var(weight_var_name)->GetMutable<Tensor>();
weight->set_persistable(true);
weight->set_precision(PRECISION(kInt8));
CHECK(lite::xpu::BuildModel(
graph_ctx.builder, graph_ctx.params, &ordered_output_nodes, weight));
CHECK_GT(weight->numel(), 0);
CHECK(weight->data<uint8_t>() != nullptr);
// create graph op and set inputs and outputs
cpp::OpDesc graph_op_desc;
graph_op_desc.SetType("graph_op");
graph_op_desc.SetInput("Inputs", input_var_names);
graph_op_desc.SetInput("Weight", {weight_var_name});
graph_op_desc.SetOutput("Outputs", output_var_names);
auto graph_op =
std::make_shared<operators::GraphOpLite>(graph_op_desc.Type());
graph_op->SetValidPlaces({Place{TARGET(kXPU), PRECISION(kFloat)}});
CHECK(graph_op->Attach(graph_op_desc, scope));
CHECK(graph_op->CheckShape());
CHECK(graph_op->InferShape());
// create graph op kernel and set XPU context
auto graph_kernels =
graph_op->CreateKernels({Place{TARGET(kXPU), PRECISION(kFloat)}});
CHECK(!graph_kernels.empty());
auto graph_kernel =
std::move(graph_kernels.front()); // use the first kernel by default
auto graph_device = ContextScheduler::Global().NewContext(TARGET(kXPU));
graph_kernel->SetContext(std::move(graph_device));
// perform graph op kernel and store to output variables
graph_kernel->Launch();
lite::xpu::DeviceInfo::Global().Clear();
}
} // namespace bridges
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_OP(graph_op);
USE_LITE_KERNEL(graph_op, kXPU, kFloat, kNCHW, def);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <random>
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
namespace bridges {
template <typename T>
std::shared_ptr<T> CreateOp(const cpp::OpDesc& opdesc, lite::Scope* scope) {
auto op = std::make_shared<T>(opdesc.Type());
op->SetValidPlaces({Place{TARGET(kHost), PRECISION(kFloat)},
Place{TARGET(kARM), PRECISION(kFloat)},
Place{TARGET(kXPU), PRECISION(kFloat)}});
CHECK(op->Attach(opdesc, scope));
CHECK(op->CheckShape());
CHECK(op->InferShape());
return op;
}
// T is the target data type
// R is the range data type, e.g. int, half
template <typename T, typename R = float>
void FillTensor(Tensor* x,
T lower = static_cast<T>(-2),
T upper = static_cast<T>(2)) {
static unsigned int seed = 100;
std::mt19937 rng(seed++);
std::uniform_real_distribution<double> uniform_dist(0, 1);
T* x_data = x->mutable_data<T>();
for (int i = 0; i < x->dims().production(); ++i) {
auto r = uniform_dist(rng) * (upper - lower) + lower;
x_data[i] = static_cast<T>(static_cast<R>(r));
}
}
void LauchOp(const std::shared_ptr<lite::OpLite> op,
const std::vector<std::string>& input_var_names,
const std::vector<std::string>& output_var_names);
} // namespace bridges
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/xpu/graph_compute.h"
#include <sys/time.h>
#include <time.h>
#include <string>
#include <vector>
#include "lite/backends/xpu/runtime.h"
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
void GraphCompute::PrepareForRun() {
// auto& ctx = this->ctx_->template As<XPUContext>();
auto& param = this->Param<param_t>();
CHECK(param.weight);
CHECK(lite::xpu::LoadModel(*param.weight, &runtime_));
CHECK(runtime_ != nullptr);
}
void GraphCompute::Run() {
auto& param = this->Param<param_t>();
auto GetCurrentUS = []() -> double {
struct timeval time;
gettimeofday(&time, NULL);
return 1e+6 * time.tv_sec + time.tv_usec;
};
auto start_time = GetCurrentUS();
for (int i = 0; i < param.inputs.size(); i++) {
auto input_var_name = param.inputs[i].first;
auto input_tensor = param.inputs[i].second;
LOG(INFO) << "input dims[" << i << ":" << input_var_name
<< "]: " << input_tensor->dims();
auto input_tensor_data = input_tensor->data<float>();
for (int j = 0; j < input_tensor->dims().production(); j++) {
VLOG(3) << input_tensor_data[j];
}
auto input_ndarray = xtcl::xNDArray::Empty(
input_tensor->dims().Vectorize(), {kDLFloat, 32, 1}, {kDLCPU, 0});
auto input_ndarray_data =
static_cast<float*>(input_ndarray.ToDLPack()->dl_tensor.data);
std::memcpy(input_ndarray_data,
input_tensor_data,
sizeof(float) * input_tensor->dims().production());
runtime_->SetInputZeroCopy(input_var_name,
&input_ndarray.ToDLPack()->dl_tensor);
}
runtime_->Run();
for (int i = 0; i < param.outputs.size(); i++) {
auto output_ndarray = runtime_->GetOutput(i);
auto output_var_name = param.outputs[i].first;
auto output_tensor = param.outputs[i].second;
output_tensor->Resize(output_ndarray.Shape());
LOG(INFO) << "output dims[" << i << ":" << output_var_name
<< "]: " << output_tensor->dims();
auto output_ndarray_data =
static_cast<float*>(output_ndarray.ToDLPack()->dl_tensor.data);
auto output_tensor_data = output_tensor->mutable_data<float>();
std::memcpy(output_tensor_data,
output_ndarray_data,
sizeof(float) * output_tensor->dims().production());
for (int j = 0; j < output_tensor->dims().production(); j++) {
VLOG(3) << output_tensor_data[j];
}
}
LOG(INFO) << "[XPU] Process cost " << GetCurrentUS() - start_time << " us";
}
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(graph_op,
kXPU,
kFloat,
kNCHW,
paddle::lite::kernels::xpu::GraphCompute,
def)
.BindInput("Inputs", {LiteType::GetTensorTy(TARGET(kHost))})
.BindInput("Weight", {LiteType::GetTensorTy(TARGET(kHost))})
.BindOutput("Outputs", {LiteType::GetTensorTy(TARGET(kHost))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <xtcl/xtcl.h>
#include <memory>
#include <string>
#include <vector>
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
#include "lite/core/types.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
class GraphCompute : public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
public:
using param_t = operators::GraphParam;
void PrepareForRun() override;
void Run() override;
virtual ~GraphCompute() = default;
private:
std::shared_ptr<xtcl::network::xRuntimeInstance> runtime_{nullptr};
};
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "lite/operators/graph_op.h"
#include <utility>
#include "lite/core/op_registry.h"
namespace paddle {
......@@ -34,7 +35,8 @@ bool GraphOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
for (auto var : inputs) {
CHECK(scope->FindVar(var));
param_.inputs.push_back(scope->FindVar(var)->GetMutable<lite::Tensor>());
param_.inputs.push_back(
std::make_pair(var, scope->FindVar(var)->GetMutable<lite::Tensor>()));
}
param_.weight = scope->FindVar(weight.front())->GetMutable<lite::Tensor>();
......@@ -42,7 +44,8 @@ bool GraphOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
for (auto var : outputs) {
CHECK(scope->FindVar(var));
param_.outputs.push_back(scope->FindVar(var)->GetMutable<lite::Tensor>());
param_.outputs.push_back(
std::make_pair(var, scope->FindVar(var)->GetMutable<lite::Tensor>()));
}
return true;
......
......@@ -14,6 +14,7 @@
#pragma once
#include <string>
#include <utility>
#include <vector>
#include "lite/api/paddle_place.h"
#include "lite/core/scope.h"
......@@ -69,9 +70,9 @@ struct CalibParam {
};
struct GraphParam {
std::vector<const lite::Tensor*> inputs{};
std::vector<std::pair<std::string, const lite::Tensor*>> inputs{};
lite::Tensor* weight{};
std::vector<lite::Tensor*> outputs{};
std::vector<std::pair<std::string, lite::Tensor*>> outputs{};
};
/// -------------------------- NN operators ------------------------------------
......
if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA) AND (LITE_WITH_X86 OR LITE_WITH_ARM))
if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA AND NOT LITE_WITH_XPU) AND (LITE_WITH_X86 OR LITE_WITH_ARM))
lite_cc_test(test_kernel_scale_compute SRCS scale_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_power_compute SRCS power_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_shuffle_channel_compute SRCS shuffle_channel_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
......
#!/bin/bash
set -ex
# global variables with default value
XPU_SDK_ROOT="$(pwd)/../XPU_SDK" # XPU SDK
TARGET_NAME="lite_compile_deps" # default target
BUILD_EXTRA=ON # ON(with sequence ops)/OFF
WITH_TESTING=ON # ON/OFF
function print_usage {
echo -e "\nUSAGE:"
echo
echo "----------------------------------------"
echo -e "--xpu_sdk_root=<xpu sdk directory>"
echo -e "--target_name=<target name>"
echo "----------------------------------------"
echo
}
# readonly variables with default value
readonly CMAKE_COMMON_OPTIONS="-DWITH_LITE=ON \
-DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK=OFF \
-DWITH_PYTHON=OFF \
-DLITE_WITH_ARM=OFF"
readonly NUM_CORES_FOR_COMPILE=${LITE_BUILD_THREADS:-1}
readonly THIRDPARTY_TAR=https://paddle-inference-dist.bj.bcebos.com/PaddleLite/third-party-05b862.tar.gz
readonly workspace=$(pwd)
function prepare_thirdparty {
if [ ! -d $workspace/third-party -o -f $workspace/third-party-05b862.tar.gz ]; then
rm -rf $workspace/third-party
if [ ! -f $workspace/third-party-05b862.tar.gz ]; then
wget $THIRDPARTY_TAR
fi
tar xzf third-party-05b862.tar.gz
else
git submodule update --init --recursive
fi
}
# for code gen, a source file is generated after a test, but is dependended by some targets in cmake.
# here we fake an empty file to make cmake works.
function prepare_workspace {
# in build directory
# 1. Prepare gen_code file
GEN_CODE_PATH_PREFIX=lite/gen_code
mkdir -p ./${GEN_CODE_PATH_PREFIX}
touch ./${GEN_CODE_PATH_PREFIX}/__generated_code__.cc
# 2.Prepare debug tool
DEBUG_TOOL_PATH_PREFIX=lite/tools/debug
mkdir -p ./${DEBUG_TOOL_PATH_PREFIX}
cp ../${DEBUG_TOOL_PATH_PREFIX}/analysis_tool.py ./${DEBUG_TOOL_PATH_PREFIX}/
# clone submodule
# git submodule update --init --recursive
prepare_thirdparty
}
function build_xpu {
build_dir=${workspace}/build.lite.xpu
mkdir -p $build_dir
cd $build_dir
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:$PWD/third_party/install/mklml/lib"
prepare_workspace
cmake .. \
${CMAKE_COMMON_OPTIONS} \
-DWITH_GPU=OFF \
-DWITH_MKLDNN=OFF \
-DLITE_WITH_X86=ON \
-DWITH_MKL=ON \
-DLITE_BUILD_EXTRA=ON \
-DLITE_WITH_XPU=ON \
-DWITH_TESTING=${WITH_TESTING} \
-DXPU_SDK_ROOT=${XPU_SDK_ROOT}
make $TARGET_NAME -j$NUM_CORES_FOR_COMPILE
cd -
echo "Done"
}
function main {
# Parse command line.
for i in "$@"; do
case $i in
--target_name=*)
TARGET_NAME="${i#*=}"
shift
;;
--build_extra=*)
BUILD_EXTRA="${i#*=}"
shift
;;
--xpu_sdk_root=*)
XPU_SDK_ROOT="${i#*=}"
shift
;;
build)
build_xpu
shift
;;
*)
# unknown option
print_usage
exit 1
;;
esac
done
}
main $@
......@@ -248,6 +248,63 @@ function build_test_train {
}
function cmake_xpu {
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:$PWD/third_party/install/mklml/lib"
prepare_workspace
cmake .. \
${common_flags} \
-DWITH_GPU=OFF \
-DWITH_MKLDNN=OFF \
-DLITE_WITH_X86=ON \
-DWITH_MKL=ON \
-DLITE_BUILD_EXTRA=ON \
-DLITE_WITH_XPU=ON \
-DXPU_SDK_ROOT="$(pwd)/../../XPU_SDK"
}
function build_xpu {
make lite_compile_deps -j$NUM_CORES_FOR_COMPILE
}
# It will eagerly test all lite related unittests.
function test_xpu {
# Due to the missing of xpu kernels, we skip the following tests temporarily.
# TODO(xxx) clear the skip list latter
local skip_list=("test_paddle_api" "test_cxx_api" "test_googlenet"
"test_mobilenetv1_lite_x86" "test_mobilenetv2_lite_x86"
"test_inceptionv4_lite_x86" "test_light_api"
"test_apis" "test_model_bin"
)
local to_skip=0
for _test in $(cat $TESTS_FILE); do
to_skip=0
for skip_name in ${skip_list[@]}; do
if [ $skip_name = $_test ]; then
echo "to skip " $skip_name
to_skip=1
fi
done
if [ $to_skip -eq 0 ]; then
ctest -R $_test -V
fi
done
}
# Build the code and run lite server tests. This is executed in the CI system.
function build_test_xpu {
cur_dir=$(pwd)
build_dir=$cur_dir/build.lite.xpu
mkdir -p $build_dir
cd $build_dir
cmake_xpu
build_xpu
test_xpu
}
# test_arm_android <some_test_name> <adb_port_number>
function test_arm_android {
local test_name=$1
......@@ -850,6 +907,10 @@ function main {
cmake_x86
shift
;;
cmake_xpu)
cmake_xpu
shift
;;
cmake_opencl)
cmake_opencl $ARM_OS $ARM_ABI $ARM_LANG
shift
......@@ -874,6 +935,10 @@ function main {
test_server
shift
;;
test_xpu)
test_xpu
shift
;;
test_arm)
test_arm $ARM_OS $ARM_ABI $ARM_LANG $ARM_PORT
shift
......@@ -890,6 +955,10 @@ function main {
build_test_server
shift
;;
build_test_xpu)
build_test_xpu
shift
;;
build_test_train)
build_test_train
shift
......
......@@ -13,6 +13,7 @@ if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK OR LITE_ON_MODEL_OPTIMIZE_TOOL)
X86_DEPS ${x86_kernels}
ARM_DEPS ${arm_kernels}
NPU_DEPS ${npu_kernels}
XPU_DEPS ${xpu_kernels}
FPGA_DEPS ${fpga_kernels}
CL_DEPS ${opencl_kernels})
endif()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册