提交 f002c764 编写于 作者: 叶剑武

Merge branch 'cpu-fp16' into 'master'

feature: cpu support fp16_fp32 type.

See merge request !893
......@@ -67,10 +67,10 @@ cc_library(
"//mace/codegen:generated_version",
"//mace/proto:mace_cc",
"//mace/utils",
"@half//:half",
] + if_opencl_enabled([
":opencl_headers",
"//mace/codegen:generated_opencl",
"@half//:half",
]) + if_quantize_enabled([
"@gemmlowp",
]) + if_hexagon_enabled([
......
......@@ -34,9 +34,7 @@ bool DataTypeCanUseMemcpy(DataType dt) {
std::string DataTypeToString(const DataType dt) {
static std::map<DataType, std::string> dtype_string_map = {
{DT_FLOAT, "DT_FLOAT"},
#ifdef MACE_ENABLE_OPENCL
{DT_HALF, "DT_HALF"},
#endif
{DT_UINT8, "DT_UINT8"},
{DT_INT32, "DT_UINT32"}};
MACE_CHECK(dt != DT_INVALID, "Not support Invalid data type");
......@@ -47,10 +45,8 @@ size_t GetEnumTypeSize(const DataType dt) {
switch (dt) {
case DT_FLOAT:
return sizeof(float);
#ifdef MACE_ENABLE_OPENCL
case DT_HALF:
return sizeof(half);
#endif
case DT_UINT8:
return sizeof(uint8_t);
case DT_INT32:
......
......@@ -19,17 +19,13 @@
#include <string>
#include "mace/proto/mace.pb.h"
#ifdef MACE_ENABLE_OPENCL
#include "include/half.hpp"
#endif
namespace mace {
typedef int64_t index_t;
#ifdef MACE_ENABLE_OPENCL
using half = half_float::half;
#endif
bool DataTypeCanUseMemcpy(DataType dt);
......@@ -54,9 +50,7 @@ struct EnumToDataType;
typedef DATA_TYPE Type; \
};
#ifdef MACE_ENABLE_OPENCL
MACE_MAPPING_DATA_TYPE_AND_ENUM(half, DT_HALF);
#endif
MACE_MAPPING_DATA_TYPE_AND_ENUM(float, DT_FLOAT);
MACE_MAPPING_DATA_TYPE_AND_ENUM(uint8_t, DT_UINT8);
MACE_MAPPING_DATA_TYPE_AND_ENUM(int32_t, DT_INT32);
......
......@@ -28,14 +28,24 @@
namespace mace {
namespace {
bool HasQuantizeOp(const NetDef &net_def) {
for (auto &op : net_def.op()) {
if (op.type() == "Quantize") {
bool HasQuantizedTensor(const NetDef &net_def) {
for (auto &tensor : net_def.tensors()) {
if (tensor.quantized()) {
return true;
}
}
return false;
}
bool HasHalfTensor(const NetDef &net_def) {
for (auto &tensor : net_def.tensors()) {
if (tensor.data_type() == DataType::DT_HALF) {
return true;
}
}
return false;
}
} // namespace
Workspace::Workspace() = default;
......@@ -93,10 +103,16 @@ MaceStatus Workspace::LoadModelTensor(const NetDef &net_def,
const DeviceType device_type = device->device_type();
if (model_data_size > 0) {
bool is_quantize_model = IsQuantizedModel(net_def);
diffused_buffer_ = (device_type == DeviceType::CPU &&
(HasHalfTensor(net_def) ||
(!is_quantize_model && HasQuantizedTensor(net_def))));
#ifdef MACE_ENABLE_OPENCL
if (device_type == DeviceType::GPU &&
diffused_buffer_ = diffused_buffer_ || (device_type == DeviceType::GPU &&
device->opencl_runtime()->GetDeviceMaxMemAllocSize() <=
static_cast<uint64_t>(model_data_size)) {
static_cast<uint64_t>(model_data_size));
#endif
if (diffused_buffer_) {
for (auto &const_tensor : net_def.tensors()) {
MACE_LATENCY_LOGGER(2, "Load tensor ", const_tensor.name());
VLOG(3) << "Tensor name: " << const_tensor.name()
......@@ -108,32 +124,63 @@ MaceStatus Workspace::LoadModelTensor(const NetDef &net_def,
dims.push_back(d);
}
DataType dst_data_type = const_tensor.data_type();
if (device_type == DeviceType::CPU &&
const_tensor.data_type() == DataType::DT_HALF) {
dst_data_type = DataType::DT_FLOAT;
}
std::unique_ptr<Tensor> tensor(
new Tensor(device->allocator(),
const_tensor.data_type(), true));
new Tensor(device->allocator(), dst_data_type, true));
tensor->Resize(dims);
MACE_CHECK(tensor->size() == const_tensor.data_size(),
"Tensor's data_size not equal with the shape");
MACE_CHECK(const_tensor.offset() + tensor->raw_size() <=
MACE_CHECK(static_cast<index_t>(const_tensor.offset() +
tensor->size() * GetEnumTypeSize(const_tensor.data_type())) <=
model_data_size,
"buffer offset + length (",
const_tensor.offset(),
" + ",
tensor->raw_size(),
tensor->size() * GetEnumTypeSize(const_tensor.data_type()),
") should <= ",
model_data_size);
tensor->CopyBytes(model_data + const_tensor.offset(),
const_tensor.data_size() *
GetEnumTypeSize(const_tensor.data_type()));
if (device_type == DeviceType::CPU) {
if (const_tensor.data_type() == DataType::DT_HALF) {
// uncompress the weights of fp16
auto org_data = reinterpret_cast<const half *>(
model_data + const_tensor.offset());
float *dst_data = tensor->mutable_data<float>();
for (int i = 0; i < const_tensor.data_size(); ++i) {
dst_data[i] = half_float::half_cast<float>(org_data[i]);
}
} else if (!is_quantize_model && const_tensor.quantized()) {
// uncompress the weights of uint8
std::unique_ptr<Tensor> dequantized_tensor(new Tensor(true));
dequantized_tensor->Resize(dims);
auto quantized_data = reinterpret_cast<const uint8_t *>(
model_data + const_tensor.offset());
auto dequantized_data = tensor->mutable_data<float>();
Dequantize(quantized_data,
tensor->size(),
const_tensor.scale(),
const_tensor.zero_point(),
dequantized_data);
} else {
tensor->CopyBytes(model_data + const_tensor.offset(),
const_tensor.data_size() *
GetEnumTypeSize(const_tensor.data_type()));
}
} else {
tensor->CopyBytes(model_data + const_tensor.offset(),
const_tensor.data_size() *
GetEnumTypeSize(const_tensor.data_type()));
}
tensor_map_[const_tensor.name()] = std::move(tensor);
}
fused_buffer_ = false;
} else {
#else
{
#endif
if (device_type == DeviceType::CPU) {
tensor_buffer_ = std::unique_ptr<Buffer>(
new Buffer(device->allocator(),
......@@ -148,7 +195,6 @@ MaceStatus Workspace::LoadModelTensor(const NetDef &net_def,
0, model_data_size);
tensor_buffer_->UnMap();
}
bool has_quantize_op = HasQuantizeOp(net_def);
for (auto &const_tensor : net_def.tensors()) {
MACE_LATENCY_LOGGER(2, "Load tensor ", const_tensor.name());
VLOG(3) << "Tensor name: " << const_tensor.name()
......@@ -173,25 +219,8 @@ MaceStatus Workspace::LoadModelTensor(const NetDef &net_def,
tensor->SetScale(const_tensor.scale());
tensor->SetZeroPoint(const_tensor.zero_point());
// Only weights are quantized
if (const_tensor.quantized() && !has_quantize_op) {
std::unique_ptr<Tensor> dequantized_tensor(new Tensor(true));
dequantized_tensor->Resize(dims);
Tensor::MappingGuard quantize_guard(tensor.get());
Tensor::MappingGuard dequantize_guard(dequantized_tensor.get());
auto quantized_data = tensor->data<uint8_t>();
auto dequantized_data = dequantized_tensor->mutable_data<float>();
Dequantize(quantized_data,
tensor->size(),
tensor->scale(),
tensor->zero_point(),
dequantized_data);
tensor_map_[const_tensor.name()] = std::move(dequantized_tensor);
} else {
tensor_map_[const_tensor.name()] = std::move(tensor);
}
tensor_map_[const_tensor.name()] = std::move(tensor);
}
fused_buffer_ = true;
}
}
return MaceStatus::MACE_SUCCESS;
......@@ -305,7 +334,7 @@ void Workspace::RemoveAndReloadBuffer(const NetDef &net_def,
auto iter = tensor_map_.find(const_tensor.name());
if (iter->second->unused()) {
tensor_map_.erase(iter);
} else if (fused_buffer_) {
} else if (!diffused_buffer_) {
tensor_map_.erase(iter);
std::vector<index_t> dims;
for (const index_t d : const_tensor.dims()) {
......
......@@ -45,6 +45,10 @@ class Workspace {
return tensor_map_.find(name) != tensor_map_.end();
}
inline bool diffused_buffer() const {
return diffused_buffer_;
}
const Tensor *GetTensor(const std::string &name) const;
Tensor *GetTensor(const std::string &name);
......@@ -67,18 +71,14 @@ class Workspace {
void RemoveTensor(const std::string &name);
private:
MaceStatus CreateOutputTensorBuffer(const NetDef &net_def,
Device *device);
TensorMap tensor_map_;
std::unique_ptr<BufferBase> tensor_buffer_;
PreallocatedPooledAllocator preallocated_allocator_;
bool fused_buffer_;
bool diffused_buffer_;
MACE_DISABLE_COPY_AND_ASSIGN(Workspace);
};
......
......@@ -502,8 +502,10 @@ MaceStatus MaceEngine::Impl::Init(
MACE_RETURN_IF_ERROR(Init(net_def, input_nodes, output_nodes, model_data_));
if (device_type_ == DeviceType::GPU || device_type_ == DeviceType::HEXAGON) {
if (device_type_ == DeviceType::GPU || device_type_ == DeviceType::HEXAGON ||
(device_type_ == DeviceType::CPU && ws_->diffused_buffer())) {
UnloadModelData(model_data_, model_data_size_);
model_data_ = nullptr;
}
return MaceStatus::MACE_SUCCESS;
}
......
......@@ -45,13 +45,12 @@ data_format_map = {
def parse_data_type(data_type, device_type):
if device_type == cvt.DeviceType.GPU.value:
if device_type == cvt.DeviceType.CPU.value or\
device_type == cvt.DeviceType.GPU.value:
if data_type == 'fp32_fp32':
return mace_pb2.DT_FLOAT
else:
return mace_pb2.DT_HALF
elif device_type == cvt.DeviceType.CPU.value:
return mace_pb2.DT_FLOAT
elif device_type == cvt.DeviceType.HEXAGON.value:
return mace_pb2.DT_UINT8
else:
......
......@@ -106,7 +106,6 @@ class Transformer(base_converter.ConverterInterface):
self._consumers = {}
self._producer = {}
self._target_data_format = DataFormat.NHWC
self._output_op_names = set()
self._quantize_activation_info = {}
self._quantized_tensor = set()
......@@ -1276,10 +1275,7 @@ class Transformer(base_converter.ConverterInterface):
print("update op with float data type")
net = self._model
# TODO(liuqi): unify the data_type when CPU support half storage
data_type = self._option.data_type
if self._option.device == DeviceType.CPU.value:
data_type = mace_pb2.DT_HALF
for op in net.op:
data_type_arg = ConverterUtil.get_arg(
op, MaceKeyword.mace_op_data_type_str)
......@@ -1288,8 +1284,7 @@ class Transformer(base_converter.ConverterInterface):
data_type_arg.name = MaceKeyword.mace_op_data_type_str
data_type_arg.i = data_type
elif data_type_arg.i != data_type \
and data_type_arg.i == mace_pb2.DT_FLOAT \
and op.name not in self._output_op_names:
and data_type_arg.i == mace_pb2.DT_FLOAT:
data_type_arg.i = data_type
return False
......
......@@ -26,14 +26,6 @@ from jinja2 import Environment, FileSystemLoader
GENERATED_NAME = set()
GPUDataTypeStrs = [
"fp16_fp32",
"fp32_fp32",
]
GPUDataType = \
Enum('GPUDataType', [(ele, ele) for ele in GPUDataTypeStrs], type=str)
class ModelFormat(object):
file = "file"
......@@ -129,13 +121,12 @@ class TensorInfo:
tensor.data_type)
def update_tensor_infos(net_def, data_type, device):
def update_tensor_infos(net_def, data_type):
offset = 0
counter = 0
tensor_infos = []
for tensor in net_def.tensors:
if device == cvt.DeviceType.GPU.value and\
tensor.data_type == mace_pb2.DT_FLOAT:
if tensor.data_type == mace_pb2.DT_FLOAT:
tensor.data_type = data_type
# Add offset and data_size
......@@ -283,7 +274,7 @@ def save_model(option, net_def, model_checksum, weight_checksum, template_dir,
output_dir = output_dir + '/'
# update tensor type
update_tensor_infos(net_def, option.data_type, option.device)
update_tensor_infos(net_def, option.data_type)
if model_graph_format == ModelFormat.file or not embed_model_data:
save_model_data(net_def, model_tag, output_dir)
......
......@@ -138,21 +138,13 @@ InputDataType = Enum('InputDataType',
[(ele, ele) for ele in InputDataTypeStrs],
type=str)
CPUDataTypeStrs = [
"fp32",
]
CPUDataType = Enum('CPUDataType', [(ele, ele) for ele in CPUDataTypeStrs],
type=str)
GPUDataTypeStrs = [
FPDataTypeStrs = [
"fp16_fp32",
"fp32_fp32",
]
GPUDataType = Enum('GPUDataType', [(ele, ele) for ele in GPUDataTypeStrs],
type=str)
FPDataType = Enum('GPUDataType', [(ele, ele) for ele in FPDataTypeStrs],
type=str)
DSPDataTypeStrs = [
"uint8",
......@@ -438,28 +430,7 @@ def format_model_config(flags):
"host only support cpu runtime now.")
data_type = model_config.get(YAMLKeyword.data_type, "")
if runtime == RuntimeType.cpu_gpu and data_type not in GPUDataTypeStrs:
model_config[YAMLKeyword.data_type] = \
GPUDataType.fp16_fp32.value
elif runtime == RuntimeType.cpu:
if len(data_type) > 0:
mace_check(data_type in CPUDataTypeStrs,
ModuleName.YAML_CONFIG,
"'data_type' must be in " + str(CPUDataTypeStrs)
+ " for cpu runtime")
else:
model_config[YAMLKeyword.data_type] = \
CPUDataType.fp32.value
elif runtime == RuntimeType.gpu:
if len(data_type) > 0:
mace_check(data_type in GPUDataTypeStrs,
ModuleName.YAML_CONFIG,
"'data_type' must be in " + str(GPUDataTypeStrs)
+ " for gpu runtime")
else:
model_config[YAMLKeyword.data_type] =\
GPUDataType.fp16_fp32.value
elif runtime == RuntimeType.dsp:
if runtime == RuntimeType.dsp:
if len(data_type) > 0:
mace_check(data_type in DSPDataTypeStrs,
ModuleName.YAML_CONFIG,
......@@ -468,6 +439,19 @@ def format_model_config(flags):
else:
model_config[YAMLKeyword.data_type] = \
DSPDataType.uint8.value
else:
if len(data_type) > 0:
mace_check(data_type in FPDataTypeStrs,
ModuleName.YAML_CONFIG,
"'data_type' must be in " + str(FPDataTypeStrs)
+ " for cpu runtime")
else:
if runtime == RuntimeType.cpu:
model_config[YAMLKeyword.data_type] = \
FPDataType.fp32_fp32.value
else:
model_config[YAMLKeyword.data_type] = \
FPDataType.fp16_fp32.value
subgraphs = model_config.get(YAMLKeyword.subgraphs, "")
mace_check(len(subgraphs) > 0, ModuleName.YAML_CONFIG,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册