diff --git a/mace/libmace/mace.cc b/mace/libmace/mace.cc index d7b79d544bbe3796f403add1442c427f425a326f..a03e877499785913671524ab31753535c146acca 100644 --- a/mace/libmace/mace.cc +++ b/mace/libmace/mace.cc @@ -453,6 +453,8 @@ class MaceEngine::Impl { MaceStatus TransposeOutput(const Tensor *output_tensor, std::pair *output); + Tensor *CreateInputTensor(const std::string &input_name, DataType input_dt); + private: std::unique_ptr model_data_; std::unique_ptr op_registry_; @@ -554,6 +556,20 @@ MaceEngine::Impl::Impl(const MaceEngineConfig &config) MACE_CHECK_NOTNULL(device_); } +Tensor *MaceEngine::Impl::CreateInputTensor(const std::string &input_name, + DataType input_dt) { + Tensor *input_tensor = nullptr; + if (input_dt == DT_FLOAT && + (net_data_type_ == DT_BFLOAT16 || net_data_type_ == DT_FLOAT16)) { + input_tensor = + ws_->CreateTensor(input_name, device_->allocator(), net_data_type_); + } else { + input_tensor = + ws_->CreateTensor(input_name, device_->allocator(), input_dt); + } + return input_tensor; +} + MaceStatus MaceEngine::Impl::Init( const NetDef *net_def, const std::vector &input_nodes, @@ -584,8 +600,7 @@ MaceStatus MaceEngine::Impl::Init( << MakeString(MapKeys(input_info_map_)); } DataType input_dt = input_info_map_[input_name].data_type(); - Tensor *input_tensor = - ws_->CreateTensor(input_name, device_->allocator(), input_dt); + Tensor *input_tensor = CreateInputTensor(input_name, input_dt); // Resize to possible largest shape to avoid resize during running. std::vector shape(input_info_map_[input_name].dims_size()); for (int i = 0; i < input_info_map_[input_name].dims_size(); ++i) { @@ -765,18 +780,19 @@ MaceStatus MaceEngine::Impl::TransposeInput( input.second.shape(), dst_dims, input_data); -#ifdef MACE_ENABLE_BFLOAT16 - } else if (net_data_type_ == DT_BFLOAT16) { - auto *input_data = input_tensor->mutable_data(); - return ops::Transpose(thread_pool_.get(), - input.second.data().get(), - input.second.shape(), - dst_dims, - input_data); -#endif // MACE_ENABLE_BFLOAT16 } else { LOG(FATAL) << "Invalid net data type: " << net_data_type_; } +#ifdef MACE_ENABLE_BFLOAT16 // todo(lichao): add float16 macro + } else if (input_dt == DataType::DT_FLOAT16 || + input_dt == DataType::DT_BFLOAT16) { + auto *input_data = input_tensor->mutable_data(); + return ops::Transpose(thread_pool_.get(), + input.second.data().get(), + input.second.shape(), + dst_dims, + input_data); +#endif // MACE_ENABLE_BFLOAT16 } else if (input_dt == DataType::DT_INT32) { auto input_data = input_tensor->mutable_data(); return ops::Transpose(thread_pool_.get(), @@ -800,17 +816,18 @@ MaceStatus MaceEngine::Impl::TransposeInput( auto input_data = input_tensor->mutable_data(); memcpy(input_data, input.second.data().get(), input_tensor->size() * sizeof(float)); -#ifdef MACE_ENABLE_BFLOAT16 - } else if (net_data_type_ == DataType::DT_BFLOAT16) { - auto input_data = input_tensor->mutable_data(); - const float *data = input.second.data().get(); - for (index_t i = 0; i < input_tensor->size(); ++i) { - input_data[i] = data[i]; - } -#endif // MACE_ENABLE_BFLOAT16 } else { LOG(FATAL) << "Invalid net data type: " << net_data_type_; } +#ifdef MACE_ENABLE_BFLOAT16 // todo(lichao): add float16 macro + } else if (input_dt == DataType::DT_FLOAT16 || + input_dt == DataType::DT_BFLOAT16) { + auto input_data = input_tensor->mutable_data(); + const float *data = input.second.data().get(); + for (index_t i = 0; i < input_tensor->size(); ++i) { + input_data[i] = data[i]; + } +#endif // MACE_ENABLE_BFLOAT16 } else if (input_dt == DataType::DT_INT32) { auto input_data = input_tensor->mutable_data(); memcpy(input_data, input.second.data().get(),