提交 9304d370 编写于 作者: 张知劲

Merge branch 'fix_bf16_bug' into 'master'

fix: fix input tensor dtype error

See merge request deep-computing/mace!1279
......@@ -453,6 +453,8 @@ class MaceEngine::Impl {
MaceStatus TransposeOutput(const Tensor *output_tensor,
std::pair<const std::string, MaceTensor> *output);
Tensor *CreateInputTensor(const std::string &input_name, DataType input_dt);
private:
std::unique_ptr<port::ReadOnlyMemoryRegion> model_data_;
std::unique_ptr<OpRegistry> op_registry_;
......@@ -554,6 +556,20 @@ MaceEngine::Impl::Impl(const MaceEngineConfig &config)
MACE_CHECK_NOTNULL(device_);
}
Tensor *MaceEngine::Impl::CreateInputTensor(const std::string &input_name,
DataType input_dt) {
Tensor *input_tensor = nullptr;
if (input_dt == DT_FLOAT &&
(net_data_type_ == DT_BFLOAT16 || net_data_type_ == DT_FLOAT16)) {
input_tensor =
ws_->CreateTensor(input_name, device_->allocator(), net_data_type_);
} else {
input_tensor =
ws_->CreateTensor(input_name, device_->allocator(), input_dt);
}
return input_tensor;
}
MaceStatus MaceEngine::Impl::Init(
const NetDef *net_def,
const std::vector<std::string> &input_nodes,
......@@ -584,8 +600,7 @@ MaceStatus MaceEngine::Impl::Init(
<< MakeString(MapKeys(input_info_map_));
}
DataType input_dt = input_info_map_[input_name].data_type();
Tensor *input_tensor =
ws_->CreateTensor(input_name, device_->allocator(), input_dt);
Tensor *input_tensor = CreateInputTensor(input_name, input_dt);
// Resize to possible largest shape to avoid resize during running.
std::vector<index_t> shape(input_info_map_[input_name].dims_size());
for (int i = 0; i < input_info_map_[input_name].dims_size(); ++i) {
......@@ -765,18 +780,19 @@ MaceStatus MaceEngine::Impl::TransposeInput(
input.second.shape(),
dst_dims,
input_data);
#ifdef MACE_ENABLE_BFLOAT16
} else if (net_data_type_ == DT_BFLOAT16) {
auto *input_data = input_tensor->mutable_data<BFloat16>();
return ops::Transpose(thread_pool_.get(),
input.second.data<float>().get(),
input.second.shape(),
dst_dims,
input_data);
#endif // MACE_ENABLE_BFLOAT16
} else {
LOG(FATAL) << "Invalid net data type: " << net_data_type_;
}
#ifdef MACE_ENABLE_BFLOAT16 // todo(lichao): add float16 macro
} else if (input_dt == DataType::DT_FLOAT16 ||
input_dt == DataType::DT_BFLOAT16) {
auto *input_data = input_tensor->mutable_data<BFloat16>();
return ops::Transpose(thread_pool_.get(),
input.second.data<float>().get(),
input.second.shape(),
dst_dims,
input_data);
#endif // MACE_ENABLE_BFLOAT16
} else if (input_dt == DataType::DT_INT32) {
auto input_data = input_tensor->mutable_data<int>();
return ops::Transpose(thread_pool_.get(),
......@@ -800,17 +816,18 @@ MaceStatus MaceEngine::Impl::TransposeInput(
auto input_data = input_tensor->mutable_data<float>();
memcpy(input_data, input.second.data().get(),
input_tensor->size() * sizeof(float));
#ifdef MACE_ENABLE_BFLOAT16
} else if (net_data_type_ == DataType::DT_BFLOAT16) {
auto input_data = input_tensor->mutable_data<BFloat16>();
const float *data = input.second.data().get();
for (index_t i = 0; i < input_tensor->size(); ++i) {
input_data[i] = data[i];
}
#endif // MACE_ENABLE_BFLOAT16
} else {
LOG(FATAL) << "Invalid net data type: " << net_data_type_;
}
#ifdef MACE_ENABLE_BFLOAT16 // todo(lichao): add float16 macro
} else if (input_dt == DataType::DT_FLOAT16 ||
input_dt == DataType::DT_BFLOAT16) {
auto input_data = input_tensor->mutable_data<BFloat16>();
const float *data = input.second.data().get();
for (index_t i = 0; i < input_tensor->size(); ++i) {
input_data[i] = data[i];
}
#endif // MACE_ENABLE_BFLOAT16
} else if (input_dt == DataType::DT_INT32) {
auto input_data = input_tensor->mutable_data<int>();
memcpy(input_data, input.second.data().get(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册