提交 95d50602 编写于 作者: L liuqi

Add data format explanation document and fix bugs in NetDefAdapter.

上级 50cf1737
Data Format
===========
As we all know, input/output tensors in CNN model have data format like
`NHWC`(tensorflow) or `NCHW`(caffe), but there is no data format for non-CNN model.
However, in MACE, CNN model run on `CPU` with `float` type using `NCHW` data format,
while the others using `NHWC` data format.
To support all models, so there are some concepts in `MACE` you should know.
Source Data Format
-----------------------
Source Data Format(`src_df` for short) stands for the original data format where
the model come from. For example, if you use caffe, the `src_df` is `NCHW`.
We need this data format because some operators(Reshape etc.) are
related to the data format.
Operators Partition
--------------------
Generally, operators could be divided into 2 categories
based on whether the operator needs inputs with fixed data format(`NHWC` or `NCHW`),
one is the operators whose inputs have fixed data format(like `convolution`),
the other is the operators whose inputs should be the same with source framework.
Since the data format the operators need in MACE may be inconsistent with the original framework,
we need to add `Transpose` operator to transpose the input tensors if necessary.
However, for some operators like `concat`,
we could transpose their arguments to eliminate `Transpose` op for acceleration.
Based on these conditions, We partition the ops into 3 categories.
1. Ops with fixed inputs' data format(`FixedDataFormatOps`): `Convolution`, `Depthwise Convolution`, etc.
2. Ops could eliminate `Transpose` by transposing their arguments(`TransposableDataFormatOps`): `Concat`, `Element-wise`, etc.
3. Ops keeping consistent with source platform(`SourceDataFormatOps`): `Reshape`, `ExpandDims`, etc.
By default, the operators not in either `FixedDataFormatOps` or `TransposableDataFormatOps`
are listed in `SourceDataFormatOps`.
For detailed information, you could refer to [code](https://github.com/XiaoMi/mace/blob/master/mace/python/tools/converter_tool/base_converter.py).
Data Format in Operator
------------------------
Based on the operator partition strategy, every operator in `MACE` has
data format argument which stands for the wanted inputs' data format,
the values could be one of the [`NHWC`, `NCHW`, `AUTO`].
1. `NHWC` or `NCHW` represent `src_df`.
2. `AUTO` represents the operator's inputs must have fixed data format,
and the real data format will be determined at runtime.
the data format of operators in `FixedDataFormatOps` must be `AUTO`,
while the data format of operators in `TransposableDataFormatOps`
is determined based on their inputs' ops data format.
MACE will transpose the input tensors based on the data format information automatically at runtime.
Data Format of Model's Inputs/Outputs
-------------------------------------
1. If the model's inputs/outputs have data format, MACE supports the data format
`NHWC` and `NCHW`.
2. If the model's inputs/outputs do not have data format, just set `NONE` for
model's inputs and outputs at `model deployment file` and `MaceTensor`.
Dynamic LSTM
==================
============
The DynamicLSTM in MACE is implemented for Kaldi's time delay RNN models.
......@@ -8,17 +8,16 @@ The following pictures explain how to fuse components into a DynamicLSTMCell.
Before fusing:
<div align="left">
<img src="imgs/FuseLSTM.png" width = "320" height = "960" alt="how to fuse lstm" />
</div>
.. image:: imgs/FuseLSTM.png
:scale: 100 %
:align: center
After fusing:
<div align="left">
<img src="imgs/DynamicLSTM.png" width = "358" height = "391" alt="DynamicLSTM" />
</div>
.. image:: imgs/DynamicLSTM.png
:scale: 100 %
:align: center
For more details about LSTMNonlinear in Kaldi,
please refer to [LstmNonlinearComponent](http://kaldi-asr.org/doc/nnet-simple-component_8h_source.html#l02164)
\ No newline at end of file
......@@ -41,6 +41,8 @@ The main documentation is organized into the following sections:
development/how_to_run_tests
development/how_to_debug
development/memory_layout
development/data_format
development/dynamic_lstm
.. toctree::
:maxdepth: 1
......
......@@ -113,8 +113,8 @@ MaceStatus NetDefAdapter::AdaptNetDef(
// quantize model flag
bool is_quantized_model = IsQuantizedModel(*net_def);
// Const tensors(filter) -> shape
std::unordered_map<std::string, std::vector<index_t>> tensor_shape_map;
// tensor -> shape
TensorShapeMap tensor_shape_map;
// Output tensors -> information
TensorInfoMap output_map;
// output tensor : related information
......@@ -135,13 +135,13 @@ MaceStatus NetDefAdapter::AdaptNetDef(
<< target_device->device_type();
}
DataFormat expected_data_format = GetDefaultDataFormat(
target_device->device_type(), is_quantized_model);
int input_size = target_net_def->input_info_size();
for (int i = 0; i < input_size; ++i) {
auto input_info = target_net_def->mutable_input_info(i);
auto input_data_format = static_cast<DataFormat>(
input_info->data_format());
DataFormat expected_data_format = GetDefaultDataFormat(
target_device->device_type(), is_quantized_model);
std::vector<index_t> input_shape(input_info->dims().begin(),
input_info->dims().end());
if (input_data_format != DataFormat::NONE
......@@ -192,12 +192,14 @@ MaceStatus NetDefAdapter::AdaptNetDef(
&op_def,
is_quantized_model,
&output_map,
&tensor_shape_map,
&transformed_set,
&op_output_data_format,
target_net_def));
MACE_RETURN_IF_ERROR(this->AdaptMemoryType(&context,
&op_def,
&output_map,
&tensor_shape_map,
&transformed_set,
&op_output_mem_type,
target_net_def));
......@@ -205,6 +207,7 @@ MaceStatus NetDefAdapter::AdaptNetDef(
MACE_RETURN_IF_ERROR(this->AdaptMemoryType(&context,
&op_def,
&output_map,
&tensor_shape_map,
&transformed_set,
&op_output_mem_type,
target_net_def));
......@@ -212,6 +215,7 @@ MaceStatus NetDefAdapter::AdaptNetDef(
&op_def,
is_quantized_model,
&output_map,
&tensor_shape_map,
&transformed_set,
&op_output_data_format,
target_net_def));
......@@ -227,18 +231,20 @@ MaceStatus NetDefAdapter::AdaptNetDef(
ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
op_def, "T", static_cast<int>(DataType::DT_FLOAT)));
}
auto output_shape = op_def.output_shape().empty() ?
std::vector<index_t>() :
std::vector<index_t>(
op_def.output_shape(out_idx).dims().begin(),
op_def.output_shape(out_idx).dims().end());
output_map.emplace(
op_def.output(out_idx),
InternalOutputInfo(
op_output_mem_type,
dt,
op_output_data_format,
op_def.output_shape().empty() ?
std::vector<index_t>() :
std::vector<index_t>(
op_def.output_shape(out_idx).dims().begin(),
op_def.output_shape(out_idx).dims().end()),
output_shape,
target_net_def->op_size()));
tensor_shape_map.emplace(op_def.output(out_idx), output_shape);
}
// Add op to target net
target_net_def->add_op()->CopyFrom(op_def);
......@@ -357,6 +363,7 @@ MaceStatus NetDefAdapter::AdaptDataFormat(
OperatorDef *op_def,
bool is_quantized_model,
TensorInfoMap *output_map,
TensorShapeMap *tensor_shape_map,
std::unordered_set<std::string> *transformed_set,
DataFormat *op_output_df,
NetDef *target_net_def) {
......@@ -465,6 +472,8 @@ MaceStatus NetDefAdapter::AdaptDataFormat(
dst_df,
output_shape,
target_net_def->op_size() - 1));
// update tensor shape map
tensor_shape_map->emplace(transformed_name, output_shape);
// record transformed tensors
transformed_set->insert(transformed_name);
}
......@@ -479,6 +488,7 @@ MaceStatus NetDefAdapter::AdaptMemoryType(
OpConditionContext *context,
OperatorDef *op_def,
NetDefAdapter::TensorInfoMap *output_map,
TensorShapeMap *tensor_shape_map,
std::unordered_set<std::string> *transformed_set,
MemoryType *op_output_mem_types,
NetDef *target_net_def) {
......@@ -545,6 +555,8 @@ MaceStatus NetDefAdapter::AdaptMemoryType(
input_info.data_format,
input_info.shape,
target_net_def->op_size() - 1));
// update tensor shape map
tensor_shape_map->emplace(transformed_name, input_info.shape);
// record transformed tensors
transformed_set->insert(transformed_name);
}
......@@ -555,6 +567,7 @@ MaceStatus NetDefAdapter::AdaptMemoryType(
}
#else
MACE_UNUSED(output_map);
MACE_UNUSED(tensor_shape_map);
MACE_UNUSED(transformed_set);
MACE_UNUSED(target_net_def);
#endif // MACE_ENABLE_OPENCL
......
......@@ -32,10 +32,21 @@ class OpRegistryBase;
class Workspace;
class Device;
/// Conventions:
/// 1. DataFormat::AUTO stands for formatted (NHWC or NCHW)
/// 2. if Op with DataFormat::AUTO, the arguments of this op
/// is formatted to NHWC
///////////////////////////////////////////////////////////////////////////////
/// Conventions
///
/// 1. For the Ops ran with data format(like Conv2D),
/// The inputs and outputs are DataFormat::NCHW if ran on CPU
/// with float data type.
/// while the inputs and outputs are DataFormat::NHWC for
/// other situation(ran on GPU, quantization, DSP)
///
/// 2. Op with DataFormat::AUTO stands for inputs must have
/// fixed format (NHWC or NCHW), determined at runtime.
///
/// 3. if Op with DataFormat::AUTO, the arguments of this op
/// is formatted to NHWC.
///////////////////////////////////////////////////////////////////////////////
class NetDefAdapter {
public:
NetDefAdapter(const OpRegistryBase *op_registry,
......@@ -77,6 +88,7 @@ class NetDefAdapter {
};
typedef std::unordered_map<std::string, InternalOutputInfo> TensorInfoMap;
typedef std::unordered_map<std::string, std::vector<index_t>> TensorShapeMap;
private:
MaceStatus AdaptDevice(OpConditionContext *context,
......@@ -92,6 +104,7 @@ class NetDefAdapter {
OperatorDef *op,
bool is_quantized_model,
TensorInfoMap *output_map,
TensorShapeMap *tensor_shape_map,
std::unordered_set<std::string> *transformed_set,
DataFormat *op_output_df,
NetDef *target_net_def);
......@@ -100,6 +113,7 @@ class NetDefAdapter {
OpConditionContext *context,
OperatorDef *op_def,
TensorInfoMap *output_map,
TensorShapeMap *tensor_shape_map,
std::unordered_set<std::string> *transformed_set,
MemoryType *op_output_mem_types,
NetDef *target_net_def);
......
......@@ -167,7 +167,7 @@ MaceSupportedOps = [
MaceOp = Enum('MaceOp', [(op, op) for op in MaceSupportedOps], type=str)
MaceHasDataFormatOps = [MaceOp.BatchNorm,
MaceFixedDataFormatOps = [MaceOp.BatchNorm,
MaceOp.BatchToSpaceND,
MaceOp.Conv2D,
MaceOp.Deconv2D,
......@@ -182,7 +182,7 @@ MaceHasDataFormatOps = [MaceOp.BatchNorm,
MaceOp.SpaceToBatchND,
MaceOp.SpaceToDepth]
MaceMayHasDataFormatOps = [MaceOp.Activation,
MaceTransposableDataFormatOps = [MaceOp.Activation,
MaceOp.AddN,
MaceOp.BiasAdd,
MaceOp.ChannelShuffle,
......
......@@ -27,8 +27,8 @@ from mace.python.tools.converter_tool.base_converter import EltwiseType
from mace.python.tools.converter_tool.base_converter import FrameworkType
from mace.python.tools.converter_tool.base_converter import MaceKeyword
from mace.python.tools.converter_tool.base_converter import MaceOp
from mace.python.tools.converter_tool.base_converter import MaceHasDataFormatOps # noqa
from mace.python.tools.converter_tool.base_converter import MaceMayHasDataFormatOps # noqa
from mace.python.tools.converter_tool.base_converter import MaceFixedDataFormatOps # noqa
from mace.python.tools.converter_tool.base_converter import MaceTransposableDataFormatOps # noqa
from mace.python.tools.converter_tool.base_converter import PaddingMode
from mace.python.tools.converter_tool.base_converter import ReduceType
from mace.python.tools.converter_tool.base_converter import TransformerRule
......@@ -1348,9 +1348,9 @@ class Transformer(base_converter.ConverterInterface):
if not df_arg:
df_arg = op.arg.add()
df_arg.name = MaceKeyword.mace_data_format_str
if op.type in MaceHasDataFormatOps:
if op.type in MaceFixedDataFormatOps:
df_arg.i = DataFormat.AUTO.value
elif op.type in MaceMayHasDataFormatOps:
elif op.type in MaceTransposableDataFormatOps:
input_df = DataFormat.AUTO.value
for input_tensor in op.input:
if input_tensor in self._consts:
......
......@@ -96,10 +96,9 @@ void CreateOperator{{i}}(mace::OperatorDef *op) {
{% if net.op[i].output_shape|length > 0 %}
op->mutable_output_shape()->Reserve({{ net.op[i].output_shape|length }});
mace::OutputShape * output_shape = nullptr;
{% for shape in net.op[i].output_shape %}
{% if shape.dims|length > 0 %}
output_shape = op->add_output_shape();
mace::OutputShape * output_shape = op->add_output_shape();
output_shape->mutable_dims()->Reserve({{ shape.dims|length }});
{% for dim in shape.dims %}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册