提交 95d50602 编写于 作者: L liuqi

Add data format explanation document and fix bugs in NetDefAdapter.

上级 50cf1737
Data Format
===========
As we all know, input/output tensors in CNN model have data format like
`NHWC`(tensorflow) or `NCHW`(caffe), but there is no data format for non-CNN model.
However, in MACE, CNN model run on `CPU` with `float` type using `NCHW` data format,
while the others using `NHWC` data format.
To support all models, so there are some concepts in `MACE` you should know.
Source Data Format
-----------------------
Source Data Format(`src_df` for short) stands for the original data format where
the model come from. For example, if you use caffe, the `src_df` is `NCHW`.
We need this data format because some operators(Reshape etc.) are
related to the data format.
Operators Partition
--------------------
Generally, operators could be divided into 2 categories
based on whether the operator needs inputs with fixed data format(`NHWC` or `NCHW`),
one is the operators whose inputs have fixed data format(like `convolution`),
the other is the operators whose inputs should be the same with source framework.
Since the data format the operators need in MACE may be inconsistent with the original framework,
we need to add `Transpose` operator to transpose the input tensors if necessary.
However, for some operators like `concat`,
we could transpose their arguments to eliminate `Transpose` op for acceleration.
Based on these conditions, We partition the ops into 3 categories.
1. Ops with fixed inputs' data format(`FixedDataFormatOps`): `Convolution`, `Depthwise Convolution`, etc.
2. Ops could eliminate `Transpose` by transposing their arguments(`TransposableDataFormatOps`): `Concat`, `Element-wise`, etc.
3. Ops keeping consistent with source platform(`SourceDataFormatOps`): `Reshape`, `ExpandDims`, etc.
By default, the operators not in either `FixedDataFormatOps` or `TransposableDataFormatOps`
are listed in `SourceDataFormatOps`.
For detailed information, you could refer to [code](https://github.com/XiaoMi/mace/blob/master/mace/python/tools/converter_tool/base_converter.py).
Data Format in Operator
------------------------
Based on the operator partition strategy, every operator in `MACE` has
data format argument which stands for the wanted inputs' data format,
the values could be one of the [`NHWC`, `NCHW`, `AUTO`].
1. `NHWC` or `NCHW` represent `src_df`.
2. `AUTO` represents the operator's inputs must have fixed data format,
and the real data format will be determined at runtime.
the data format of operators in `FixedDataFormatOps` must be `AUTO`,
while the data format of operators in `TransposableDataFormatOps`
is determined based on their inputs' ops data format.
MACE will transpose the input tensors based on the data format information automatically at runtime.
Data Format of Model's Inputs/Outputs
-------------------------------------
1. If the model's inputs/outputs have data format, MACE supports the data format
`NHWC` and `NCHW`.
2. If the model's inputs/outputs do not have data format, just set `NONE` for
model's inputs and outputs at `model deployment file` and `MaceTensor`.
Dynamic LSTM Dynamic LSTM
================== ============
The DynamicLSTM in MACE is implemented for Kaldi's time delay RNN models. The DynamicLSTM in MACE is implemented for Kaldi's time delay RNN models.
...@@ -8,17 +8,16 @@ The following pictures explain how to fuse components into a DynamicLSTMCell. ...@@ -8,17 +8,16 @@ The following pictures explain how to fuse components into a DynamicLSTMCell.
Before fusing: Before fusing:
<div align="left"> .. image:: imgs/FuseLSTM.png
<img src="imgs/FuseLSTM.png" width = "320" height = "960" alt="how to fuse lstm" /> :scale: 100 %
</div> :align: center
After fusing: After fusing:
<div align="left"> .. image:: imgs/DynamicLSTM.png
<img src="imgs/DynamicLSTM.png" width = "358" height = "391" alt="DynamicLSTM" /> :scale: 100 %
</div> :align: center
For more details about LSTMNonlinear in Kaldi, For more details about LSTMNonlinear in Kaldi,
please refer to [LstmNonlinearComponent](http://kaldi-asr.org/doc/nnet-simple-component_8h_source.html#l02164) please refer to [LstmNonlinearComponent](http://kaldi-asr.org/doc/nnet-simple-component_8h_source.html#l02164)
\ No newline at end of file
...@@ -41,6 +41,8 @@ The main documentation is organized into the following sections: ...@@ -41,6 +41,8 @@ The main documentation is organized into the following sections:
development/how_to_run_tests development/how_to_run_tests
development/how_to_debug development/how_to_debug
development/memory_layout development/memory_layout
development/data_format
development/dynamic_lstm
.. toctree:: .. toctree::
:maxdepth: 1 :maxdepth: 1
......
...@@ -113,8 +113,8 @@ MaceStatus NetDefAdapter::AdaptNetDef( ...@@ -113,8 +113,8 @@ MaceStatus NetDefAdapter::AdaptNetDef(
// quantize model flag // quantize model flag
bool is_quantized_model = IsQuantizedModel(*net_def); bool is_quantized_model = IsQuantizedModel(*net_def);
// Const tensors(filter) -> shape // tensor -> shape
std::unordered_map<std::string, std::vector<index_t>> tensor_shape_map; TensorShapeMap tensor_shape_map;
// Output tensors -> information // Output tensors -> information
TensorInfoMap output_map; TensorInfoMap output_map;
// output tensor : related information // output tensor : related information
...@@ -135,13 +135,13 @@ MaceStatus NetDefAdapter::AdaptNetDef( ...@@ -135,13 +135,13 @@ MaceStatus NetDefAdapter::AdaptNetDef(
<< target_device->device_type(); << target_device->device_type();
} }
DataFormat expected_data_format = GetDefaultDataFormat(
target_device->device_type(), is_quantized_model);
int input_size = target_net_def->input_info_size(); int input_size = target_net_def->input_info_size();
for (int i = 0; i < input_size; ++i) { for (int i = 0; i < input_size; ++i) {
auto input_info = target_net_def->mutable_input_info(i); auto input_info = target_net_def->mutable_input_info(i);
auto input_data_format = static_cast<DataFormat>( auto input_data_format = static_cast<DataFormat>(
input_info->data_format()); input_info->data_format());
DataFormat expected_data_format = GetDefaultDataFormat(
target_device->device_type(), is_quantized_model);
std::vector<index_t> input_shape(input_info->dims().begin(), std::vector<index_t> input_shape(input_info->dims().begin(),
input_info->dims().end()); input_info->dims().end());
if (input_data_format != DataFormat::NONE if (input_data_format != DataFormat::NONE
...@@ -192,12 +192,14 @@ MaceStatus NetDefAdapter::AdaptNetDef( ...@@ -192,12 +192,14 @@ MaceStatus NetDefAdapter::AdaptNetDef(
&op_def, &op_def,
is_quantized_model, is_quantized_model,
&output_map, &output_map,
&tensor_shape_map,
&transformed_set, &transformed_set,
&op_output_data_format, &op_output_data_format,
target_net_def)); target_net_def));
MACE_RETURN_IF_ERROR(this->AdaptMemoryType(&context, MACE_RETURN_IF_ERROR(this->AdaptMemoryType(&context,
&op_def, &op_def,
&output_map, &output_map,
&tensor_shape_map,
&transformed_set, &transformed_set,
&op_output_mem_type, &op_output_mem_type,
target_net_def)); target_net_def));
...@@ -205,6 +207,7 @@ MaceStatus NetDefAdapter::AdaptNetDef( ...@@ -205,6 +207,7 @@ MaceStatus NetDefAdapter::AdaptNetDef(
MACE_RETURN_IF_ERROR(this->AdaptMemoryType(&context, MACE_RETURN_IF_ERROR(this->AdaptMemoryType(&context,
&op_def, &op_def,
&output_map, &output_map,
&tensor_shape_map,
&transformed_set, &transformed_set,
&op_output_mem_type, &op_output_mem_type,
target_net_def)); target_net_def));
...@@ -212,6 +215,7 @@ MaceStatus NetDefAdapter::AdaptNetDef( ...@@ -212,6 +215,7 @@ MaceStatus NetDefAdapter::AdaptNetDef(
&op_def, &op_def,
is_quantized_model, is_quantized_model,
&output_map, &output_map,
&tensor_shape_map,
&transformed_set, &transformed_set,
&op_output_data_format, &op_output_data_format,
target_net_def)); target_net_def));
...@@ -227,18 +231,20 @@ MaceStatus NetDefAdapter::AdaptNetDef( ...@@ -227,18 +231,20 @@ MaceStatus NetDefAdapter::AdaptNetDef(
ProtoArgHelper::GetOptionalArg<OperatorDef, int>( ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
op_def, "T", static_cast<int>(DataType::DT_FLOAT))); op_def, "T", static_cast<int>(DataType::DT_FLOAT)));
} }
auto output_shape = op_def.output_shape().empty() ?
std::vector<index_t>() :
std::vector<index_t>(
op_def.output_shape(out_idx).dims().begin(),
op_def.output_shape(out_idx).dims().end());
output_map.emplace( output_map.emplace(
op_def.output(out_idx), op_def.output(out_idx),
InternalOutputInfo( InternalOutputInfo(
op_output_mem_type, op_output_mem_type,
dt, dt,
op_output_data_format, op_output_data_format,
op_def.output_shape().empty() ? output_shape,
std::vector<index_t>() :
std::vector<index_t>(
op_def.output_shape(out_idx).dims().begin(),
op_def.output_shape(out_idx).dims().end()),
target_net_def->op_size())); target_net_def->op_size()));
tensor_shape_map.emplace(op_def.output(out_idx), output_shape);
} }
// Add op to target net // Add op to target net
target_net_def->add_op()->CopyFrom(op_def); target_net_def->add_op()->CopyFrom(op_def);
...@@ -357,6 +363,7 @@ MaceStatus NetDefAdapter::AdaptDataFormat( ...@@ -357,6 +363,7 @@ MaceStatus NetDefAdapter::AdaptDataFormat(
OperatorDef *op_def, OperatorDef *op_def,
bool is_quantized_model, bool is_quantized_model,
TensorInfoMap *output_map, TensorInfoMap *output_map,
TensorShapeMap *tensor_shape_map,
std::unordered_set<std::string> *transformed_set, std::unordered_set<std::string> *transformed_set,
DataFormat *op_output_df, DataFormat *op_output_df,
NetDef *target_net_def) { NetDef *target_net_def) {
...@@ -465,6 +472,8 @@ MaceStatus NetDefAdapter::AdaptDataFormat( ...@@ -465,6 +472,8 @@ MaceStatus NetDefAdapter::AdaptDataFormat(
dst_df, dst_df,
output_shape, output_shape,
target_net_def->op_size() - 1)); target_net_def->op_size() - 1));
// update tensor shape map
tensor_shape_map->emplace(transformed_name, output_shape);
// record transformed tensors // record transformed tensors
transformed_set->insert(transformed_name); transformed_set->insert(transformed_name);
} }
...@@ -479,6 +488,7 @@ MaceStatus NetDefAdapter::AdaptMemoryType( ...@@ -479,6 +488,7 @@ MaceStatus NetDefAdapter::AdaptMemoryType(
OpConditionContext *context, OpConditionContext *context,
OperatorDef *op_def, OperatorDef *op_def,
NetDefAdapter::TensorInfoMap *output_map, NetDefAdapter::TensorInfoMap *output_map,
TensorShapeMap *tensor_shape_map,
std::unordered_set<std::string> *transformed_set, std::unordered_set<std::string> *transformed_set,
MemoryType *op_output_mem_types, MemoryType *op_output_mem_types,
NetDef *target_net_def) { NetDef *target_net_def) {
...@@ -545,6 +555,8 @@ MaceStatus NetDefAdapter::AdaptMemoryType( ...@@ -545,6 +555,8 @@ MaceStatus NetDefAdapter::AdaptMemoryType(
input_info.data_format, input_info.data_format,
input_info.shape, input_info.shape,
target_net_def->op_size() - 1)); target_net_def->op_size() - 1));
// update tensor shape map
tensor_shape_map->emplace(transformed_name, input_info.shape);
// record transformed tensors // record transformed tensors
transformed_set->insert(transformed_name); transformed_set->insert(transformed_name);
} }
...@@ -555,6 +567,7 @@ MaceStatus NetDefAdapter::AdaptMemoryType( ...@@ -555,6 +567,7 @@ MaceStatus NetDefAdapter::AdaptMemoryType(
} }
#else #else
MACE_UNUSED(output_map); MACE_UNUSED(output_map);
MACE_UNUSED(tensor_shape_map);
MACE_UNUSED(transformed_set); MACE_UNUSED(transformed_set);
MACE_UNUSED(target_net_def); MACE_UNUSED(target_net_def);
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
......
...@@ -32,10 +32,21 @@ class OpRegistryBase; ...@@ -32,10 +32,21 @@ class OpRegistryBase;
class Workspace; class Workspace;
class Device; class Device;
/// Conventions: ///////////////////////////////////////////////////////////////////////////////
/// 1. DataFormat::AUTO stands for formatted (NHWC or NCHW) /// Conventions
/// 2. if Op with DataFormat::AUTO, the arguments of this op ///
/// is formatted to NHWC /// 1. For the Ops ran with data format(like Conv2D),
/// The inputs and outputs are DataFormat::NCHW if ran on CPU
/// with float data type.
/// while the inputs and outputs are DataFormat::NHWC for
/// other situation(ran on GPU, quantization, DSP)
///
/// 2. Op with DataFormat::AUTO stands for inputs must have
/// fixed format (NHWC or NCHW), determined at runtime.
///
/// 3. if Op with DataFormat::AUTO, the arguments of this op
/// is formatted to NHWC.
///////////////////////////////////////////////////////////////////////////////
class NetDefAdapter { class NetDefAdapter {
public: public:
NetDefAdapter(const OpRegistryBase *op_registry, NetDefAdapter(const OpRegistryBase *op_registry,
...@@ -77,6 +88,7 @@ class NetDefAdapter { ...@@ -77,6 +88,7 @@ class NetDefAdapter {
}; };
typedef std::unordered_map<std::string, InternalOutputInfo> TensorInfoMap; typedef std::unordered_map<std::string, InternalOutputInfo> TensorInfoMap;
typedef std::unordered_map<std::string, std::vector<index_t>> TensorShapeMap;
private: private:
MaceStatus AdaptDevice(OpConditionContext *context, MaceStatus AdaptDevice(OpConditionContext *context,
...@@ -92,6 +104,7 @@ class NetDefAdapter { ...@@ -92,6 +104,7 @@ class NetDefAdapter {
OperatorDef *op, OperatorDef *op,
bool is_quantized_model, bool is_quantized_model,
TensorInfoMap *output_map, TensorInfoMap *output_map,
TensorShapeMap *tensor_shape_map,
std::unordered_set<std::string> *transformed_set, std::unordered_set<std::string> *transformed_set,
DataFormat *op_output_df, DataFormat *op_output_df,
NetDef *target_net_def); NetDef *target_net_def);
...@@ -100,6 +113,7 @@ class NetDefAdapter { ...@@ -100,6 +113,7 @@ class NetDefAdapter {
OpConditionContext *context, OpConditionContext *context,
OperatorDef *op_def, OperatorDef *op_def,
TensorInfoMap *output_map, TensorInfoMap *output_map,
TensorShapeMap *tensor_shape_map,
std::unordered_set<std::string> *transformed_set, std::unordered_set<std::string> *transformed_set,
MemoryType *op_output_mem_types, MemoryType *op_output_mem_types,
NetDef *target_net_def); NetDef *target_net_def);
......
...@@ -167,33 +167,33 @@ MaceSupportedOps = [ ...@@ -167,33 +167,33 @@ MaceSupportedOps = [
MaceOp = Enum('MaceOp', [(op, op) for op in MaceSupportedOps], type=str) MaceOp = Enum('MaceOp', [(op, op) for op in MaceSupportedOps], type=str)
MaceHasDataFormatOps = [MaceOp.BatchNorm, MaceFixedDataFormatOps = [MaceOp.BatchNorm,
MaceOp.BatchToSpaceND, MaceOp.BatchToSpaceND,
MaceOp.Conv2D, MaceOp.Conv2D,
MaceOp.Deconv2D, MaceOp.Deconv2D,
MaceOp.DepthToSpace, MaceOp.DepthToSpace,
MaceOp.DepthwiseConv2d, MaceOp.DepthwiseConv2d,
MaceOp.DepthwiseDeconv2d, MaceOp.DepthwiseDeconv2d,
MaceOp.FullyConnected, MaceOp.FullyConnected,
MaceOp.Pooling, MaceOp.Pooling,
MaceOp.ResizeBicubic, MaceOp.ResizeBicubic,
MaceOp.ResizeBilinear, MaceOp.ResizeBilinear,
MaceOp.ResizeNearestNeighbor, MaceOp.ResizeNearestNeighbor,
MaceOp.SpaceToBatchND, MaceOp.SpaceToBatchND,
MaceOp.SpaceToDepth] MaceOp.SpaceToDepth]
MaceMayHasDataFormatOps = [MaceOp.Activation, MaceTransposableDataFormatOps = [MaceOp.Activation,
MaceOp.AddN, MaceOp.AddN,
MaceOp.BiasAdd, MaceOp.BiasAdd,
MaceOp.ChannelShuffle, MaceOp.ChannelShuffle,
MaceOp.Concat, MaceOp.Concat,
MaceOp.Crop, MaceOp.Crop,
MaceOp.Eltwise, MaceOp.Eltwise,
MaceOp.Pad, MaceOp.Pad,
MaceOp.Reduce, MaceOp.Reduce,
MaceOp.Softmax, MaceOp.Softmax,
MaceOp.Split, MaceOp.Split,
MaceOp.SqrDiffMean] MaceOp.SqrDiffMean]
class MaceKeyword(object): class MaceKeyword(object):
......
...@@ -27,8 +27,8 @@ from mace.python.tools.converter_tool.base_converter import EltwiseType ...@@ -27,8 +27,8 @@ from mace.python.tools.converter_tool.base_converter import EltwiseType
from mace.python.tools.converter_tool.base_converter import FrameworkType from mace.python.tools.converter_tool.base_converter import FrameworkType
from mace.python.tools.converter_tool.base_converter import MaceKeyword from mace.python.tools.converter_tool.base_converter import MaceKeyword
from mace.python.tools.converter_tool.base_converter import MaceOp from mace.python.tools.converter_tool.base_converter import MaceOp
from mace.python.tools.converter_tool.base_converter import MaceHasDataFormatOps # noqa from mace.python.tools.converter_tool.base_converter import MaceFixedDataFormatOps # noqa
from mace.python.tools.converter_tool.base_converter import MaceMayHasDataFormatOps # noqa from mace.python.tools.converter_tool.base_converter import MaceTransposableDataFormatOps # noqa
from mace.python.tools.converter_tool.base_converter import PaddingMode from mace.python.tools.converter_tool.base_converter import PaddingMode
from mace.python.tools.converter_tool.base_converter import ReduceType from mace.python.tools.converter_tool.base_converter import ReduceType
from mace.python.tools.converter_tool.base_converter import TransformerRule from mace.python.tools.converter_tool.base_converter import TransformerRule
...@@ -1348,9 +1348,9 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1348,9 +1348,9 @@ class Transformer(base_converter.ConverterInterface):
if not df_arg: if not df_arg:
df_arg = op.arg.add() df_arg = op.arg.add()
df_arg.name = MaceKeyword.mace_data_format_str df_arg.name = MaceKeyword.mace_data_format_str
if op.type in MaceHasDataFormatOps: if op.type in MaceFixedDataFormatOps:
df_arg.i = DataFormat.AUTO.value df_arg.i = DataFormat.AUTO.value
elif op.type in MaceMayHasDataFormatOps: elif op.type in MaceTransposableDataFormatOps:
input_df = DataFormat.AUTO.value input_df = DataFormat.AUTO.value
for input_tensor in op.input: for input_tensor in op.input:
if input_tensor in self._consts: if input_tensor in self._consts:
......
...@@ -96,10 +96,9 @@ void CreateOperator{{i}}(mace::OperatorDef *op) { ...@@ -96,10 +96,9 @@ void CreateOperator{{i}}(mace::OperatorDef *op) {
{% if net.op[i].output_shape|length > 0 %} {% if net.op[i].output_shape|length > 0 %}
op->mutable_output_shape()->Reserve({{ net.op[i].output_shape|length }}); op->mutable_output_shape()->Reserve({{ net.op[i].output_shape|length }});
mace::OutputShape * output_shape = nullptr;
{% for shape in net.op[i].output_shape %} {% for shape in net.op[i].output_shape %}
{% if shape.dims|length > 0 %} {% if shape.dims|length > 0 %}
output_shape = op->add_output_shape(); mace::OutputShape * output_shape = op->add_output_shape();
output_shape->mutable_dims()->Reserve({{ shape.dims|length }}); output_shape->mutable_dims()->Reserve({{ shape.dims|length }});
{% for dim in shape.dims %} {% for dim in shape.dims %}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册