diff --git a/cmake/external/lite.cmake b/cmake/external/lite.cmake index 9d7c21108601491744e81b15b48ec0cd31d9bf1d..978b0427125be940852426380c9452dc87f97471 100644 --- a/cmake/external/lite.cmake +++ b/cmake/external/lite.cmake @@ -25,7 +25,7 @@ if (NOT LITE_SOURCE_DIR OR NOT LITE_BINARY_DIR) set(LITE_INSTALL_DIR ${THIRD_PARTY_PATH}/install/lite) if(NOT LITE_GIT_TAG) - set(LITE_GIT_TAG ab8af5c4b4dc5b40217633e0aa436315912d7b53) + set(LITE_GIT_TAG 42ab4d559f6659edfc35040fb30fdcec3dc3f8aa) endif() if(NOT CUDA_ARCH_NAME) @@ -83,7 +83,7 @@ message(STATUS "Paddle-lite SOURCE_DIR: ${LITE_SOURCE_DIR}") include_directories(${LITE_SOURCE_DIR}) include_directories(${LITE_BINARY_DIR}) -function(external_lite_static_libs alias path) +function(external_lite_libs alias path) add_library(${alias} SHARED IMPORTED GLOBAL) SET_PROPERTY(TARGET ${alias} PROPERTY IMPORTED_LOCATION ${path}) @@ -92,8 +92,16 @@ function(external_lite_static_libs alias path) endif() endfunction() -external_lite_static_libs(lite_full_static ${LITE_BINARY_DIR}/inference_lite_lib/cxx/lib/libpaddle_full_api_shared.so) +external_lite_libs(lite_full_static ${LITE_BINARY_DIR}/inference_lite_lib/cxx/lib/libpaddle_full_api_shared.so) set(LITE_SHARED_LIB ${LITE_BINARY_DIR}/inference_lite_lib/cxx/lib/libpaddle_full_api_shared.so) +if(XPU_SDK_ROOT) + include_directories("${XPU_SDK_ROOT}/XTDK/include") + include_directories("${XPU_SDK_ROOT}/XTCL/include") + add_definitions(-DPADDLE_WITH_XPU) + LINK_DIRECTORIES("${XPU_SDK_ROOT}/XTDK/shlib/") + LINK_DIRECTORIES("${XPU_SDK_ROOT}/XTDK/runtime/shlib/") +endif() + add_definitions(-DPADDLE_WITH_LITE) add_definitions(-DLITE_WITH_LOG) diff --git a/paddle/fluid/framework/tensor.cc b/paddle/fluid/framework/tensor.cc index df1e0fb6d5b48e0670b0bebb128578c467d19467..544c014eaf98a99b1737809f2cbad39b46fdb276 100644 --- a/paddle/fluid/framework/tensor.cc +++ b/paddle/fluid/framework/tensor.cc @@ -108,8 +108,15 @@ const DDim& Tensor::dims() const { return dims_; } int64_t Tensor::numel() const { return product(dims_); } void Tensor::ResetHolder(std::shared_ptr holder) { + PADDLE_ENFORCE_EQ( + offset_, 0, + platform::errors::Fatal( + "Only the offset is supported to zero when the holder is reset.")); if (holder_) { - PADDLE_ENFORCE_EQ(numel() * SizeOfType(type()), holder->size()); + PADDLE_ENFORCE_LE( + numel() * SizeOfType(type()) + offset_, holder->size(), + paddle::platform::errors::InvalidArgument( + "The size of Holder is not enough to store the Tensor.")); } holder_ = holder; } diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index 2fc7f81bf8a59ca6dba3db36dfe7a9c074f03f9b..27bae7a71ea192ac08e4e87cb7bcdb8b84e29dc8 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -200,6 +200,10 @@ struct Argument { DECL_ARGUMENT_FIELD(lite_ops_filter, LiteOpsFilter, std::vector); DECL_ARGUMENT_FIELD(lite_precision_mode, LitePrecisionMode, AnalysisConfig::Precision); + DECL_ARGUMENT_FIELD(lite_zero_copy, LiteZeroCopy, bool); + + DECL_ARGUMENT_FIELD(use_xpu, UseXpu, bool); + DECL_ARGUMENT_FIELD(xpu_l3_workspace_size, XpuL3WorkspaceSize, int); // Memory optimized related. DECL_ARGUMENT_FIELD(enable_memory_optim, EnableMemoryOptim, bool); diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc index 4a79a3cf3050380c920590355f10bb7a0d34f125..cd8d86d72938417112e17e86e5cc6dd12254a8d1 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.cc +++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc @@ -146,6 +146,10 @@ void IRPassManager::CreatePasses(Argument *argument, pass->Set("predictor_id", new int(argument->predictor_id())); pass->Set("enable_int8", new bool(enable_int8)); pass->Set("use_gpu", new bool(argument->use_gpu())); + pass->Set("zero_copy", new bool(argument->lite_zero_copy())); + pass->Set("use_xpu", new bool(argument->use_xpu())); + pass->Set("xpu_l3_workspace_size", + new int(argument->xpu_l3_workspace_size())); } disable_logs_ = argument->disable_logs(); if (pass_name == "fc_fuse_pass") { diff --git a/paddle/fluid/inference/analysis/ir_passes/lite_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/lite_subgraph_pass.cc index 91d0aec3f41fd90159958aa9035cfbf4d1c749fb..6b16a481ddedbad0956d1358de95842ea9a3a101 100644 --- a/paddle/fluid/inference/analysis/ir_passes/lite_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/lite_subgraph_pass.cc @@ -242,16 +242,33 @@ void LiteSubgraphPass::SetUpEngine( bool use_gpu = Get("use_gpu"); bool enable_int8 = Get("enable_int8"); - lite_api::TargetType target_type = use_gpu ? TARGET(kCUDA) : TARGET(kX86); + bool use_xpu = Get("use_xpu"); + int xpu_l3_workspace_size = Get("xpu_l3_workspace_size"); + + lite_api::TargetType target_type; + if (use_gpu) { + target_type = TARGET(kCUDA); + } else if (use_xpu) { + target_type = TARGET(kXPU); + } else { + target_type = TARGET(kX86); + } + paddle::lite_api::PrecisionType precision_type = - enable_int8 ? PRECISION(kInt8) : PRECISION(kInt64); + enable_int8 ? PRECISION(kInt8) : PRECISION(kFloat); + serialize_params(&config.param, scope, repetitive_params); config.model = program->Proto()->SerializeAsString(); config.valid_places = { + // Notice: The ordering here determines the device where the + // input tensor of the Lite engine is located, and then affects + // whether tensor sharing is feasible. paddle::lite::Place({target_type, precision_type}), + paddle::lite::Place({target_type, PRECISION(kInt64)}), paddle::lite::Place({target_type, PRECISION(kFloat)}), paddle::lite::Place({TARGET(kHost), PRECISION(kFloat)}), }; + config.xpu_l3_workspace_size = xpu_l3_workspace_size; if (dump_model) { lite::StrToBinaryFile("./model.bin", config.model); lite::StrToBinaryFile("./param.bin", config.param); @@ -283,6 +300,7 @@ void LiteSubgraphPass::BuildOperator( op_desc->SetAttr("engine_key", unique_key); op_desc->SetAttr("enable_int8", Get("enable_int8")); op_desc->SetAttr("use_gpu", Get("use_gpu")); + op_desc->SetAttr("zero_copy", Get("zero_copy")); } void LiteSubgraphPass::ApplyImpl(framework::ir::Graph* graph) const { diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index 994f7c95352631b657edc3709f8f141cd68b3660..61886c225e6548413e6e2eb0415f596d016a988f 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -88,6 +88,12 @@ void AnalysisConfig::DisableFCPadding() { Update(); } +void AnalysisConfig::EnableXpu(int l3_workspace_size) { + use_xpu_ = true; + xpu_l3_workspace_size_ = l3_workspace_size; + Update(); +} + AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { #define CP_MEMBER(member__) member__ = other.member__; @@ -132,6 +138,10 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { CP_MEMBER(lite_precision_mode_); CP_MEMBER(lite_passes_filter_); CP_MEMBER(lite_ops_filter_); + CP_MEMBER(lite_zero_copy_); + + CP_MEMBER(use_xpu_); + CP_MEMBER(xpu_l3_workspace_size_); // profile related. CP_MEMBER(with_profile_); @@ -344,6 +354,22 @@ void AnalysisConfig::Update() { } } + if (use_xpu_) { +#ifndef PADDLE_WITH_XPU + PADDLE_THROW(platform::errors::Unavailable( + "You tried to use an XPU device, but Paddle was not compiled " + "with XPU-runtime.")); +#endif + if (!use_lite_) { + LOG(WARNING) << "Because XPU currently only works in Paddle-Lite " + "subgraph mode, please make sure you have enabled it."; + } + PADDLE_ENFORCE_EQ(use_gpu_, false, + platform::errors::Unavailable( + "Currently, XPU and GPU cannot be enabled in the " + "same analysis configuration.")); + } + if (ir_debug_) { pass_builder()->TurnOnDebug(); } @@ -387,6 +413,8 @@ std::string AnalysisConfig::SerializeInfoCache() { ss << cpu_math_library_num_threads_; ss << use_lite_; + ss << use_xpu_; + ss << xpu_l3_workspace_size_; ss << thread_local_stream_; @@ -464,13 +492,14 @@ void AnalysisConfig::DisableGlogInfo() { } void AnalysisConfig::EnableLiteEngine( - AnalysisConfig::Precision precision_mode, + AnalysisConfig::Precision precision_mode, bool zero_copy, const std::vector &passes_filter, const std::vector &ops_filter) { use_lite_ = true; lite_precision_mode_ = precision_mode; lite_passes_filter_ = passes_filter; lite_ops_filter_ = ops_filter; + lite_zero_copy_ = zero_copy; Update(); } diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 195053814e6a00094dd7cf59d2c1331d96e3d634..a8c8058c6b714dcd6f283c35b50bef55446e62bb 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -465,6 +465,9 @@ void AnalysisPredictor::PrepareArgument() { argument_.SetLitePrecisionMode(config_.lite_precision_mode_); argument_.SetLitePassesFilter(config_.lite_passes_filter_); argument_.SetLiteOpsFilter(config_.lite_ops_filter_); + argument_.SetLiteZeroCopy(config_.lite_zero_copy_); + argument_.SetUseXpu(config_.use_xpu_); + argument_.SetXpuL3WorkspaceSize(config_.xpu_l3_workspace_size_); LOG(INFO) << "Lite subgraph engine is enabled"; } diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h index ec7b08b306707484619d126d4983633aeec9b601..6a31ff281c68e3675d35c14059a453455ef398df 100644 --- a/paddle/fluid/inference/api/paddle_analysis_config.h +++ b/paddle/fluid/inference/api/paddle_analysis_config.h @@ -176,6 +176,8 @@ struct PD_INFER_DECL AnalysisConfig { /// /// void DisableGpu(); + + void EnableXpu(int l3_workspace_size = 0xfffc00); /// /// \brief A boolean state telling whether the GPU is turned on. /// @@ -319,6 +321,7 @@ struct PD_INFER_DECL AnalysisConfig { /// void EnableLiteEngine( AnalysisConfig::Precision precision_mode = Precision::kFloat32, + bool zero_copy = false, const std::vector& passes_filter = {}, const std::vector& ops_filter = {}); @@ -579,8 +582,11 @@ struct PD_INFER_DECL AnalysisConfig { std::vector lite_passes_filter_; std::vector lite_ops_filter_; Precision lite_precision_mode_; + bool lite_zero_copy_; bool thread_local_stream_{false}; + bool use_xpu_{false}; + int xpu_l3_workspace_size_; // mkldnn related. int mkldnn_cache_capacity_{0}; diff --git a/paddle/fluid/inference/lite/CMakeLists.txt b/paddle/fluid/inference/lite/CMakeLists.txt index decf2d830fe0a77c69c80c071248b04c7fdb2f9d..fd513b59588f82716900d4d48e9aac036085baa9 100644 --- a/paddle/fluid/inference/lite/CMakeLists.txt +++ b/paddle/fluid/inference/lite/CMakeLists.txt @@ -1,6 +1,9 @@ +if(XPU_SDK_ROOT) + set(XPU_DEPS xpuapi xpurt) +endif() + cc_library(lite_op_teller SRCS op_teller.cc DEPS lite_full_static framework_proto device_context boost xxhash) -cc_library(lite_engine SRCS engine.cc DEPS lite_full_static framework_proto) +cc_library(lite_engine SRCS engine.cc DEPS lite_full_static framework_proto ${XPU_DEPS}) cc_library(lite_tensor_utils SRCS tensor_utils.cc DEPS memcpy lite_full_static framework_proto boost device_context) - cc_test(test_lite_engine SRCS test_engine.cc DEPS lite_engine protobuf framework_proto glog gtest analysis) cc_test(test_lite_tensor_utils SRCS test_tensor_utils.cc DEPS lite_engine lite_tensor_utils) diff --git a/paddle/fluid/inference/lite/engine.cc b/paddle/fluid/inference/lite/engine.cc index fb3b6e460d5bb23133de1d6a8a106530043cd99a..8e88c94493952ff257ef69bf73f8edebb6ba2eee 100644 --- a/paddle/fluid/inference/lite/engine.cc +++ b/paddle/fluid/inference/lite/engine.cc @@ -16,8 +16,11 @@ #define LITE_WITH_CUDA 1 #endif -#include "paddle/fluid/inference/lite/engine.h" +#ifdef PADDLE_WITH_XPU +#define LITE_WITH_XPU 1 +#endif +#include "paddle/fluid/inference/lite/engine.h" #include "lite/api/paddle_use_passes.h" namespace paddle { @@ -39,10 +42,17 @@ paddle::lite::Predictor* EngineManager::Get(const std::string& name) const { paddle::lite::Predictor* EngineManager::Create(const std::string& name, const EngineConfig& cfg) { - auto* p = new paddle::lite::Predictor(); + if (cfg.valid_places.front().target == TARGET(kCUDA)) { #ifdef PADDLE_WITH_CUDA - paddle::lite::Env::Init(); + paddle::lite::Env::Init(); #endif + } else if (cfg.valid_places.front().target == TARGET(kXPU)) { +#ifdef PADDLE_WITH_XPU + paddle::lite::TargetWrapper::workspace_l3_size_per_thread = + cfg.xpu_l3_workspace_size; +#endif + } + auto* p = new paddle::lite::Predictor(); p->Build("", cfg.model, cfg.param, cfg.valid_places, cfg.neglected_passes, cfg.model_type, cfg.model_from_memory); engines_[name].reset(p); diff --git a/paddle/fluid/inference/lite/engine.h b/paddle/fluid/inference/lite/engine.h index 5f11c51952bd3ce0bb0e09121dbd5e633c6fd3ae..345eb682e9fe81d4ec67a31082c1d347a694fd96 100644 --- a/paddle/fluid/inference/lite/engine.h +++ b/paddle/fluid/inference/lite/engine.h @@ -26,6 +26,7 @@ #include "lite/api/paddle_place.h" #include "lite/core/context.h" #include "lite/core/device_info.h" +#include "lite/core/memory.h" #include "lite/core/op_registry.h" #include "lite/core/tensor.h" #pragma GCC diagnostic pop @@ -42,6 +43,7 @@ struct EngineConfig { std::vector neglected_passes; lite_api::LiteModelType model_type{lite_api::LiteModelType::kProtobuf}; bool model_from_memory{true}; + size_t xpu_l3_workspace_size; }; class EngineManager { diff --git a/paddle/fluid/inference/lite/tensor_utils.cc b/paddle/fluid/inference/lite/tensor_utils.cc index 59087c6fec20360ef4a8f8a34aa810c3328d6e0d..d79a041ccf8a1611247b65b034c03940eabfcccd 100644 --- a/paddle/fluid/inference/lite/tensor_utils.cc +++ b/paddle/fluid/inference/lite/tensor_utils.cc @@ -14,8 +14,10 @@ #include "paddle/fluid/inference/lite/tensor_utils.h" #include +#include #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/inference/lite/engine.h" +#include "paddle/fluid/memory/allocation/allocator.h" namespace paddle { namespace inference { @@ -46,6 +48,9 @@ platform::Place GetNativePlace(const TargetType& type, int id = 0) { return platform::CPUPlace(); case TargetType::kCUDA: return platform::CUDAPlace(id); + case TargetType::kXPU: + LOG(ERROR) << "No corresponding device for XPU yet."; + return platform::Place(); default: PADDLE_THROW( platform::errors::Unavailable("Unsupported target type. Now only " @@ -191,6 +196,31 @@ void TensorCopyAsync(framework::LoDTensor* dst, const paddle::lite::Tensor& src, VLOG(3) << "[Lite memory size] Bytes = " << src.memory_size(); } +template <> +void TensorDataShare(paddle::lite::Tensor* dst, framework::LoDTensor* src) { + const size_t bytes = + static_cast(src->numel()) * framework::SizeOfType(src->type()); + auto buf = std::make_shared(paddle::lite::Buffer( + src->data(), GetLiteTargetType(src->place()), src->memory_size())); + dst->Resize(framework::vectorize(src->dims())); + dst->set_precision(GetLitePrecisionType(src->type())); + SetLoD(dst->mutable_lod(), src->lod()); + dst->ResetBuffer(buf, bytes); +} + +template <> +void TensorDataShare(framework::LoDTensor* dst, paddle::lite::Tensor* src) { + constexpr framework::proto::VarType::Type dtype = + framework::proto::VarType_Type_FP32; + void* src_raw_data = src->raw_data(); + std::shared_ptr holder( + new memory::allocation::Allocation(src_raw_data, src->memory_size(), + GetNativePlace(src->target()))); + dst->Resize(paddle::framework::make_ddim(src->dims().Vectorize())); + SetLoD(dst->mutable_lod(), src->lod()); + dst->ResetHolderWithType(holder, dtype); +} + } // namespace utils } // namespace lite } // namespace inference diff --git a/paddle/fluid/inference/lite/tensor_utils.h b/paddle/fluid/inference/lite/tensor_utils.h index 21c5e794d4195f8dcd040dbf2a59ed87d170cb6d..1b2923bc28033934f5304a48c6a90f158a81a12e 100644 --- a/paddle/fluid/inference/lite/tensor_utils.h +++ b/paddle/fluid/inference/lite/tensor_utils.h @@ -26,6 +26,21 @@ template void TensorCopyAsync(DstTensor* dst, const SrcTensor& src, const platform::DeviceContext& ctx); +template +void TensorDataShare(DstTensor* dst, SrcTensor* src); + +template +void TensorCopy(DstTensor* dst, SrcTensor* src, + const platform::DeviceContext& ctx, bool shared = true) { + if (shared) { + VLOG(3) << "TensorDataShare is running"; + TensorDataShare(dst, src); + } else { + VLOG(3) << "TensorCopyAsync is running"; + TensorCopyAsync(dst, *src, ctx); + } +} + } // namespace utils } // namespace lite } // namespace inference diff --git a/paddle/fluid/inference/lite/test_tensor_utils.cc b/paddle/fluid/inference/lite/test_tensor_utils.cc index eee00e9ba31a6de8688dfb27dd56031e9da4353f..eef7bfb68fe06537d09f3f3e7e5c35283d4739ef 100644 --- a/paddle/fluid/inference/lite/test_tensor_utils.cc +++ b/paddle/fluid/inference/lite/test_tensor_utils.cc @@ -77,7 +77,7 @@ void test_tensor_copy(const platform::DeviceContext& ctx) { // Create LoDTensor. std::vector vector({1, 2, 3, 4}); framework::LoDTensor lod_tensor; - framework::TensorFromVector(vector, &lod_tensor); + framework::TensorFromVector(vector, ctx, &lod_tensor); framework::LoD lod({{0, 2, 4}}); lod_tensor.Resize({4, 1}); lod_tensor.set_lod(lod); @@ -94,7 +94,26 @@ void test_tensor_copy(const platform::DeviceContext& ctx) { } #endif std::vector result; - TensorToVector(lod_tensor_n, &result); + TensorToVector(lod_tensor_n, ctx, &result); + ASSERT_EQ(result, vector); + ASSERT_EQ(lod_tensor_n.lod(), lod_tensor.lod()); +} + +void test_tensor_share(const platform::DeviceContext& ctx) { + std::vector vector({1, 2, 3, 4}); + framework::LoDTensor lod_tensor; + framework::TensorFromVector(vector, ctx, &lod_tensor); + framework::LoD lod({{0, 2, 4}}); + lod_tensor.Resize({4, 1}); + lod_tensor.set_lod(lod); + // Create lite::Tensor and share. + paddle::lite::Tensor lite_tensor; + TensorDataShare(&lite_tensor, &lod_tensor); + // Copy to LoDTensor. + framework::LoDTensor lod_tensor_n; + TensorCopyAsync(&lod_tensor_n, lite_tensor, ctx); + std::vector result; + TensorToVector(lod_tensor_n, ctx, &result); ASSERT_EQ(result, vector); ASSERT_EQ(lod_tensor_n.lod(), lod_tensor.lod()); } @@ -110,6 +129,17 @@ TEST(LiteEngineOp, TensorCopyAsync) { #endif } +TEST(LiteEngineOp, TensorShare) { + auto* ctx_cpu = + platform::DeviceContextPool::Instance().Get(platform::CPUPlace()); + test_tensor_share(*ctx_cpu); +#ifdef PADDLE_WITH_CUDA + auto* ctx_gpu = + platform::DeviceContextPool::Instance().Get(platform::CUDAPlace(0)); + test_tensor_share(*ctx_gpu); +#endif +} + } // namespace utils } // namespace lite } // namespace inference diff --git a/paddle/fluid/operators/lite/lite_engine_op.h b/paddle/fluid/operators/lite/lite_engine_op.h index 3b48615338f729a56db133a2072ceea5e8e94b22..a920bf7c3f505b839f8f1fd252c9f8505393f3a9 100644 --- a/paddle/fluid/operators/lite/lite_engine_op.h +++ b/paddle/fluid/operators/lite/lite_engine_op.h @@ -42,6 +42,7 @@ class LiteEngineOp : public framework::OperatorBase { paddle::lite::Predictor *engine_; framework::proto::VarType::Type precision_; bool use_gpu_; + bool zero_copy_; public: LiteEngineOp(const std::string &type, @@ -60,6 +61,7 @@ class LiteEngineOp : public framework::OperatorBase { precision_ = framework::proto::VarType_Type_FP32; } use_gpu_ = Attr("use_gpu"); + zero_copy_ = Attr("zero_copy"); } protected: @@ -73,13 +75,13 @@ class LiteEngineOp : public framework::OperatorBase { const platform::DeviceContext *ctx = platform::DeviceContextPool::Instance().Get(dev_place); for (size_t i = 0; i < in_names_.size(); i++) { - const framework::LoDTensor &src_t = + framework::LoDTensor src_t = inference::analysis::GetFromScope(scope, in_names_[i]); paddle::lite::Tensor *dst_t = engine_->GetInput(i); - VLOG(3) << "[Copy] fluid -> lite (" << in_names_[i] << " -> " + VLOG(3) << "== fluid -> lite (" << in_names_[i] << " -> " << engine_->GetInputNames()[i] << ")"; - inference::lite::utils::TensorCopyAsync(dst_t, src_t, *ctx); + inference::lite::utils::TensorCopy(dst_t, &src_t, *ctx, zero_copy_); } #ifdef PADDLE_WITH_CUDA if (platform::is_gpu_place(dev_place)) { @@ -91,13 +93,13 @@ class LiteEngineOp : public framework::OperatorBase { engine_->Run(); VLOG(3) << "lite engine run done"; for (size_t i = 0; i < out_names_.size(); i++) { - const paddle::lite::Tensor &src_t = *(engine_->GetOutput(i)); + paddle::lite::Tensor src_t = *(engine_->GetOutput(i)); framework::LoDTensor *dst_t = &inference::analysis::GetFromScope( scope, out_names_[i]); - VLOG(3) << "[Copy] lite -> fluid (" << out_names_[i] << " -> " + VLOG(3) << "== lite -> fluid (" << out_names_[i] << " -> " << engine_->GetOutputNames()[i] << ")"; - inference::lite::utils::TensorCopyAsync(dst_t, src_t, *ctx); + inference::lite::utils::TensorCopy(dst_t, &src_t, *ctx, zero_copy_); } #ifdef PADDLE_WITH_CUDA if (platform::is_gpu_place(dev_place)) { diff --git a/paddle/fluid/operators/lite/lite_engine_op_test.cc b/paddle/fluid/operators/lite/lite_engine_op_test.cc index 3812911e915bc8ad03fd6f1c4ecaeda69b33971b..fb5c0dcb3514de815b97944d0fdbf3bd7853b628 100644 --- a/paddle/fluid/operators/lite/lite_engine_op_test.cc +++ b/paddle/fluid/operators/lite/lite_engine_op_test.cc @@ -100,6 +100,7 @@ TEST(LiteEngineOp, engine_op) { engine_op_desc.SetAttr("engine_key", engine_key); engine_op_desc.SetAttr("enable_int8", false); engine_op_desc.SetAttr("use_gpu", true); + engine_op_desc.SetAttr("zero_copy", true); engine_op_desc.SetBlockAttr("sub_block", &block_desc); inference::Singleton::Global().Create( engine_key, config); diff --git a/paddle/fluid/pybind/inference_api.cc b/paddle/fluid/pybind/inference_api.cc index 5a0b18a34f768f3fb4392abf1d796feb951990c3..ed6e18699d4cbd9fccda4583fbeb51e5e68264ef 100644 --- a/paddle/fluid/pybind/inference_api.cc +++ b/paddle/fluid/pybind/inference_api.cc @@ -433,6 +433,7 @@ void BindAnalysisConfig(py::module *m) { py::arg("disable_trt_plugin_fp16") = false) .def("tensorrt_engine_enabled", &AnalysisConfig::tensorrt_engine_enabled) .def("enable_lite_engine", &AnalysisConfig::EnableLiteEngine, + py::arg("zero_copy") = false, py::arg("precision_mode") = AnalysisConfig::Precision::kFloat32, py::arg("passes_filter") = std::vector(), py::arg("ops_filter") = std::vector())