未验证 提交 878b6972 编写于 作者: W Wilber 提交者: GitHub

make lite subgraph support multiple tensor precision. (#30055)

上级 477b0c46
......@@ -34,7 +34,7 @@ if (NOT LITE_SOURCE_DIR OR NOT LITE_BINARY_DIR)
set(LITE_INSTALL_DIR ${THIRD_PARTY_PATH}/install/lite)
if(NOT LITE_GIT_TAG)
set(LITE_GIT_TAG 345545e2ce2f3895a332be88d5c3d495d9b206d3)
set(LITE_GIT_TAG d3a3a6931b6d22d504d21ba32b3ae972770e9204)
endif()
if(NOT CUDA_ARCH_NAME)
......
......@@ -272,6 +272,8 @@ void LiteSubgraphPass::SetUpEngine(
paddle::lite_api::Place({target_type, PRECISION(kInt64)}),
paddle::lite_api::Place({target_type, PRECISION(kFloat)}),
paddle::lite_api::Place({TARGET(kHost), PRECISION(kFloat)}),
paddle::lite_api::Place({TARGET(kX86), precision_type}),
paddle::lite_api::Place({TARGET(kX86), PRECISION(kFloat)}),
};
config.cpu_math_library_num_threads = cpu_math_library_num_threads;
config.xpu_l3_workspace_size = xpu_l3_workspace_size;
......
......@@ -195,10 +195,8 @@ void InitDstTensor(paddle::lite_api::Tensor* dst,
void InitDstTensor(framework::LoDTensor* dst,
const paddle::lite_api::Tensor& src) {
constexpr framework::proto::VarType::Type dtype =
framework::proto::VarType_Type_FP32;
dst->mutable_data(inference::lite::utils::GetNativePlace(src.target()),
dtype);
GetNativePrecisionType(src.precision()));
SetLoD(dst->mutable_lod(), src.lod());
}
......@@ -254,17 +252,17 @@ void TensorDataShare(paddle::lite_api::Tensor* dst, framework::LoDTensor* src) {
template <>
void TensorDataShare(framework::LoDTensor* dst, paddle::lite_api::Tensor* src) {
constexpr framework::proto::VarType::Type dtype =
framework::proto::VarType_Type_FP32;
void* src_raw_data =
GetLiteTensorDataPtr(src, GetLitePrecisionType(dtype), src->target());
size_t memory_size = GetLiteTensorNumel(*src) * sizeof(float);
GetLiteTensorDataPtr(src, src->precision(), src->target());
size_t memory_size =
GetLiteTensorNumel(*src) *
framework::SizeOfType(GetNativePrecisionType(src->precision()));
std::shared_ptr<memory::allocation::Allocation> holder(
new memory::allocation::Allocation(src_raw_data, memory_size,
GetNativePlace(src->target())));
dst->Resize(paddle::framework::make_ddim(src->shape()));
SetLoD(dst->mutable_lod(), src->lod());
dst->ResetHolderWithType(holder, dtype);
dst->ResetHolderWithType(holder, GetNativePrecisionType(src->precision()));
}
} // namespace utils
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册