From 4f717825034febffb8be3704f709577e69a951c8 Mon Sep 17 00:00:00 2001 From: luxuhui Date: Fri, 5 Jun 2020 20:01:26 +0800 Subject: [PATCH] fix: fix input tensor dtype error N/A Signed-off-by: Luxuhui --- mace/libmace/mace.cc | 55 +++++++++++++++++++++++++++++--------------- 1 file changed, 36 insertions(+), 19 deletions(-) diff --git a/mace/libmace/mace.cc b/mace/libmace/mace.cc index d7b79d54..a03e8774 100644 --- a/mace/libmace/mace.cc +++ b/mace/libmace/mace.cc @@ -453,6 +453,8 @@ class MaceEngine::Impl { MaceStatus TransposeOutput(const Tensor *output_tensor, std::pair *output); + Tensor *CreateInputTensor(const std::string &input_name, DataType input_dt); + private: std::unique_ptr model_data_; std::unique_ptr op_registry_; @@ -554,6 +556,20 @@ MaceEngine::Impl::Impl(const MaceEngineConfig &config) MACE_CHECK_NOTNULL(device_); } +Tensor *MaceEngine::Impl::CreateInputTensor(const std::string &input_name, + DataType input_dt) { + Tensor *input_tensor = nullptr; + if (input_dt == DT_FLOAT && + (net_data_type_ == DT_BFLOAT16 || net_data_type_ == DT_FLOAT16)) { + input_tensor = + ws_->CreateTensor(input_name, device_->allocator(), net_data_type_); + } else { + input_tensor = + ws_->CreateTensor(input_name, device_->allocator(), input_dt); + } + return input_tensor; +} + MaceStatus MaceEngine::Impl::Init( const NetDef *net_def, const std::vector &input_nodes, @@ -584,8 +600,7 @@ MaceStatus MaceEngine::Impl::Init( << MakeString(MapKeys(input_info_map_)); } DataType input_dt = input_info_map_[input_name].data_type(); - Tensor *input_tensor = - ws_->CreateTensor(input_name, device_->allocator(), input_dt); + Tensor *input_tensor = CreateInputTensor(input_name, input_dt); // Resize to possible largest shape to avoid resize during running. std::vector shape(input_info_map_[input_name].dims_size()); for (int i = 0; i < input_info_map_[input_name].dims_size(); ++i) { @@ -765,18 +780,19 @@ MaceStatus MaceEngine::Impl::TransposeInput( input.second.shape(), dst_dims, input_data); -#ifdef MACE_ENABLE_BFLOAT16 - } else if (net_data_type_ == DT_BFLOAT16) { - auto *input_data = input_tensor->mutable_data(); - return ops::Transpose(thread_pool_.get(), - input.second.data().get(), - input.second.shape(), - dst_dims, - input_data); -#endif // MACE_ENABLE_BFLOAT16 } else { LOG(FATAL) << "Invalid net data type: " << net_data_type_; } +#ifdef MACE_ENABLE_BFLOAT16 // todo(lichao): add float16 macro + } else if (input_dt == DataType::DT_FLOAT16 || + input_dt == DataType::DT_BFLOAT16) { + auto *input_data = input_tensor->mutable_data(); + return ops::Transpose(thread_pool_.get(), + input.second.data().get(), + input.second.shape(), + dst_dims, + input_data); +#endif // MACE_ENABLE_BFLOAT16 } else if (input_dt == DataType::DT_INT32) { auto input_data = input_tensor->mutable_data(); return ops::Transpose(thread_pool_.get(), @@ -800,17 +816,18 @@ MaceStatus MaceEngine::Impl::TransposeInput( auto input_data = input_tensor->mutable_data(); memcpy(input_data, input.second.data().get(), input_tensor->size() * sizeof(float)); -#ifdef MACE_ENABLE_BFLOAT16 - } else if (net_data_type_ == DataType::DT_BFLOAT16) { - auto input_data = input_tensor->mutable_data(); - const float *data = input.second.data().get(); - for (index_t i = 0; i < input_tensor->size(); ++i) { - input_data[i] = data[i]; - } -#endif // MACE_ENABLE_BFLOAT16 } else { LOG(FATAL) << "Invalid net data type: " << net_data_type_; } +#ifdef MACE_ENABLE_BFLOAT16 // todo(lichao): add float16 macro + } else if (input_dt == DataType::DT_FLOAT16 || + input_dt == DataType::DT_BFLOAT16) { + auto input_data = input_tensor->mutable_data(); + const float *data = input.second.data().get(); + for (index_t i = 0; i < input_tensor->size(); ++i) { + input_data[i] = data[i]; + } +#endif // MACE_ENABLE_BFLOAT16 } else if (input_dt == DataType::DT_INT32) { auto input_data = input_tensor->mutable_data(); memcpy(input_data, input.second.data().get(), -- GitLab