未验证 提交 878b6972 编写于 作者: W Wilber 提交者: GitHub

make lite subgraph support multiple tensor precision. (#30055)

上级 477b0c46
...@@ -34,7 +34,7 @@ if (NOT LITE_SOURCE_DIR OR NOT LITE_BINARY_DIR) ...@@ -34,7 +34,7 @@ if (NOT LITE_SOURCE_DIR OR NOT LITE_BINARY_DIR)
set(LITE_INSTALL_DIR ${THIRD_PARTY_PATH}/install/lite) set(LITE_INSTALL_DIR ${THIRD_PARTY_PATH}/install/lite)
if(NOT LITE_GIT_TAG) if(NOT LITE_GIT_TAG)
set(LITE_GIT_TAG 345545e2ce2f3895a332be88d5c3d495d9b206d3) set(LITE_GIT_TAG d3a3a6931b6d22d504d21ba32b3ae972770e9204)
endif() endif()
if(NOT CUDA_ARCH_NAME) if(NOT CUDA_ARCH_NAME)
......
...@@ -272,6 +272,8 @@ void LiteSubgraphPass::SetUpEngine( ...@@ -272,6 +272,8 @@ void LiteSubgraphPass::SetUpEngine(
paddle::lite_api::Place({target_type, PRECISION(kInt64)}), paddle::lite_api::Place({target_type, PRECISION(kInt64)}),
paddle::lite_api::Place({target_type, PRECISION(kFloat)}), paddle::lite_api::Place({target_type, PRECISION(kFloat)}),
paddle::lite_api::Place({TARGET(kHost), PRECISION(kFloat)}), paddle::lite_api::Place({TARGET(kHost), PRECISION(kFloat)}),
paddle::lite_api::Place({TARGET(kX86), precision_type}),
paddle::lite_api::Place({TARGET(kX86), PRECISION(kFloat)}),
}; };
config.cpu_math_library_num_threads = cpu_math_library_num_threads; config.cpu_math_library_num_threads = cpu_math_library_num_threads;
config.xpu_l3_workspace_size = xpu_l3_workspace_size; config.xpu_l3_workspace_size = xpu_l3_workspace_size;
......
...@@ -195,10 +195,8 @@ void InitDstTensor(paddle::lite_api::Tensor* dst, ...@@ -195,10 +195,8 @@ void InitDstTensor(paddle::lite_api::Tensor* dst,
void InitDstTensor(framework::LoDTensor* dst, void InitDstTensor(framework::LoDTensor* dst,
const paddle::lite_api::Tensor& src) { const paddle::lite_api::Tensor& src) {
constexpr framework::proto::VarType::Type dtype =
framework::proto::VarType_Type_FP32;
dst->mutable_data(inference::lite::utils::GetNativePlace(src.target()), dst->mutable_data(inference::lite::utils::GetNativePlace(src.target()),
dtype); GetNativePrecisionType(src.precision()));
SetLoD(dst->mutable_lod(), src.lod()); SetLoD(dst->mutable_lod(), src.lod());
} }
...@@ -254,17 +252,17 @@ void TensorDataShare(paddle::lite_api::Tensor* dst, framework::LoDTensor* src) { ...@@ -254,17 +252,17 @@ void TensorDataShare(paddle::lite_api::Tensor* dst, framework::LoDTensor* src) {
template <> template <>
void TensorDataShare(framework::LoDTensor* dst, paddle::lite_api::Tensor* src) { void TensorDataShare(framework::LoDTensor* dst, paddle::lite_api::Tensor* src) {
constexpr framework::proto::VarType::Type dtype =
framework::proto::VarType_Type_FP32;
void* src_raw_data = void* src_raw_data =
GetLiteTensorDataPtr(src, GetLitePrecisionType(dtype), src->target()); GetLiteTensorDataPtr(src, src->precision(), src->target());
size_t memory_size = GetLiteTensorNumel(*src) * sizeof(float); size_t memory_size =
GetLiteTensorNumel(*src) *
framework::SizeOfType(GetNativePrecisionType(src->precision()));
std::shared_ptr<memory::allocation::Allocation> holder( std::shared_ptr<memory::allocation::Allocation> holder(
new memory::allocation::Allocation(src_raw_data, memory_size, new memory::allocation::Allocation(src_raw_data, memory_size,
GetNativePlace(src->target()))); GetNativePlace(src->target())));
dst->Resize(paddle::framework::make_ddim(src->shape())); dst->Resize(paddle::framework::make_ddim(src->shape()));
SetLoD(dst->mutable_lod(), src->lod()); SetLoD(dst->mutable_lod(), src->lod());
dst->ResetHolderWithType(holder, dtype); dst->ResetHolderWithType(holder, GetNativePrecisionType(src->precision()));
} }
} // namespace utils } // namespace utils
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册