diff --git a/docs/development/data_format.md b/docs/development/data_format.md
new file mode 100644
index 0000000000000000000000000000000000000000..83ea17c57526f0244d1ccb3fbb6ebee7acafb6d2
--- /dev/null
+++ b/docs/development/data_format.md
@@ -0,0 +1,64 @@
+Data Format
+===========
+
+As we all know, input/output tensors in CNN model have data format like
+`NHWC`(tensorflow) or `NCHW`(caffe), but there is no data format for non-CNN model.
+
+However, in MACE, CNN model run on `CPU` with `float` type using `NCHW` data format,
+while the others using `NHWC` data format.
+
+To support all models, so there are some concepts in `MACE` you should know.
+
+Source Data Format
+-----------------------
+Source Data Format(`src_df` for short) stands for the original data format where
+the model come from. For example, if you use caffe, the `src_df` is `NCHW`.
+We need this data format because some operators(Reshape etc.) are
+ related to the data format.
+
+
+Operators Partition
+--------------------
+Generally, operators could be divided into 2 categories
+based on whether the operator needs inputs with fixed data format(`NHWC` or `NCHW`),
+one is the operators whose inputs have fixed data format(like `convolution`),
+the other is the operators whose inputs should be the same with source framework.
+
+Since the data format the operators need in MACE may be inconsistent with the original framework,
+ we need to add `Transpose` operator to transpose the input tensors if necessary.
+
+However, for some operators like `concat`,
+we could transpose their arguments to eliminate `Transpose` op for acceleration.
+
+Based on these conditions, We partition the ops into 3 categories.
+1. Ops with fixed inputs' data format(`FixedDataFormatOps`): `Convolution`, `Depthwise Convolution`, etc.
+2. Ops could eliminate `Transpose` by transposing their arguments(`TransposableDataFormatOps`): `Concat`, `Element-wise`, etc.
+3. Ops keeping consistent with source platform(`SourceDataFormatOps`): `Reshape`, `ExpandDims`, etc.
+
+By default, the operators not in either `FixedDataFormatOps` or `TransposableDataFormatOps`
+are listed in `SourceDataFormatOps`.
+
+For detailed information, you could refer to [code](https://github.com/XiaoMi/mace/blob/master/mace/python/tools/converter_tool/base_converter.py).
+
+
+Data Format in Operator
+------------------------
+Based on the operator partition strategy, every operator in `MACE` has
+data format argument which stands for the wanted inputs' data format,
+the values could be one of the [`NHWC`, `NCHW`, `AUTO`].
+1. `NHWC` or `NCHW` represent `src_df`.
+2. `AUTO` represents the operator's inputs must have fixed data format,
+and the real data format will be determined at runtime.
+the data format of operators in `FixedDataFormatOps` must be `AUTO`,
+while the data format of operators in `TransposableDataFormatOps`
+is determined based on their inputs' ops data format.
+
+MACE will transpose the input tensors based on the data format information automatically at runtime.
+
+
+Data Format of Model's Inputs/Outputs
+-------------------------------------
+1. If the model's inputs/outputs have data format, MACE supports the data format
+`NHWC` and `NCHW`.
+2. If the model's inputs/outputs do not have data format, just set `NONE` for
+ model's inputs and outputs at `model deployment file` and `MaceTensor`.
diff --git a/docs/development/dynamic_lstm.md b/docs/development/dynamic_lstm.rst
similarity index 52%
rename from docs/development/dynamic_lstm.md
rename to docs/development/dynamic_lstm.rst
index f7d24a629c02263b42f19f3fb3004e3f4c5c2193..b01bdbc79d3bba56a8d0a821726e5abeccffd8ca 100644
--- a/docs/development/dynamic_lstm.md
+++ b/docs/development/dynamic_lstm.rst
@@ -1,5 +1,5 @@
Dynamic LSTM
-==================
+============
The DynamicLSTM in MACE is implemented for Kaldi's time delay RNN models.
@@ -8,17 +8,16 @@ The following pictures explain how to fuse components into a DynamicLSTMCell.
Before fusing:
-
-
![how to fuse lstm](imgs/FuseLSTM.png)
-
+.. image:: imgs/FuseLSTM.png
+ :scale: 100 %
+ :align: center
After fusing:
-
-
![DynamicLSTM](imgs/DynamicLSTM.png)
-
-
+.. image:: imgs/DynamicLSTM.png
+ :scale: 100 %
+ :align: center
For more details about LSTMNonlinear in Kaldi,
please refer to [LstmNonlinearComponent](http://kaldi-asr.org/doc/nnet-simple-component_8h_source.html#l02164)
\ No newline at end of file
diff --git a/docs/index.rst b/docs/index.rst
index c73aa1d349e5be57f55b6294a8a03fe6c0169496..0e31e083a94909fa8e36c1b357ed846def3770f7 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -41,6 +41,8 @@ The main documentation is organized into the following sections:
development/how_to_run_tests
development/how_to_debug
development/memory_layout
+ development/data_format
+ development/dynamic_lstm
.. toctree::
:maxdepth: 1
diff --git a/mace/core/net_def_adapter.cc b/mace/core/net_def_adapter.cc
index 7c7bb86517a96f011955cfd3b98a4f3b0050f9cb..bd1e360513f2f67d9c85362085415b44b1a71bd6 100644
--- a/mace/core/net_def_adapter.cc
+++ b/mace/core/net_def_adapter.cc
@@ -113,8 +113,8 @@ MaceStatus NetDefAdapter::AdaptNetDef(
// quantize model flag
bool is_quantized_model = IsQuantizedModel(*net_def);
- // Const tensors(filter) -> shape
- std::unordered_map> tensor_shape_map;
+ // tensor -> shape
+ TensorShapeMap tensor_shape_map;
// Output tensors -> information
TensorInfoMap output_map;
// output tensor : related information
@@ -135,13 +135,13 @@ MaceStatus NetDefAdapter::AdaptNetDef(
<< target_device->device_type();
}
+ DataFormat expected_data_format = GetDefaultDataFormat(
+ target_device->device_type(), is_quantized_model);
int input_size = target_net_def->input_info_size();
for (int i = 0; i < input_size; ++i) {
auto input_info = target_net_def->mutable_input_info(i);
auto input_data_format = static_cast(
input_info->data_format());
- DataFormat expected_data_format = GetDefaultDataFormat(
- target_device->device_type(), is_quantized_model);
std::vector input_shape(input_info->dims().begin(),
input_info->dims().end());
if (input_data_format != DataFormat::NONE
@@ -192,12 +192,14 @@ MaceStatus NetDefAdapter::AdaptNetDef(
&op_def,
is_quantized_model,
&output_map,
+ &tensor_shape_map,
&transformed_set,
&op_output_data_format,
target_net_def));
MACE_RETURN_IF_ERROR(this->AdaptMemoryType(&context,
&op_def,
&output_map,
+ &tensor_shape_map,
&transformed_set,
&op_output_mem_type,
target_net_def));
@@ -205,6 +207,7 @@ MaceStatus NetDefAdapter::AdaptNetDef(
MACE_RETURN_IF_ERROR(this->AdaptMemoryType(&context,
&op_def,
&output_map,
+ &tensor_shape_map,
&transformed_set,
&op_output_mem_type,
target_net_def));
@@ -212,6 +215,7 @@ MaceStatus NetDefAdapter::AdaptNetDef(
&op_def,
is_quantized_model,
&output_map,
+ &tensor_shape_map,
&transformed_set,
&op_output_data_format,
target_net_def));
@@ -227,18 +231,20 @@ MaceStatus NetDefAdapter::AdaptNetDef(
ProtoArgHelper::GetOptionalArg(
op_def, "T", static_cast(DataType::DT_FLOAT)));
}
+ auto output_shape = op_def.output_shape().empty() ?
+ std::vector() :
+ std::vector(
+ op_def.output_shape(out_idx).dims().begin(),
+ op_def.output_shape(out_idx).dims().end());
output_map.emplace(
op_def.output(out_idx),
InternalOutputInfo(
op_output_mem_type,
dt,
op_output_data_format,
- op_def.output_shape().empty() ?
- std::vector() :
- std::vector(
- op_def.output_shape(out_idx).dims().begin(),
- op_def.output_shape(out_idx).dims().end()),
+ output_shape,
target_net_def->op_size()));
+ tensor_shape_map.emplace(op_def.output(out_idx), output_shape);
}
// Add op to target net
target_net_def->add_op()->CopyFrom(op_def);
@@ -357,6 +363,7 @@ MaceStatus NetDefAdapter::AdaptDataFormat(
OperatorDef *op_def,
bool is_quantized_model,
TensorInfoMap *output_map,
+ TensorShapeMap *tensor_shape_map,
std::unordered_set *transformed_set,
DataFormat *op_output_df,
NetDef *target_net_def) {
@@ -465,6 +472,8 @@ MaceStatus NetDefAdapter::AdaptDataFormat(
dst_df,
output_shape,
target_net_def->op_size() - 1));
+ // update tensor shape map
+ tensor_shape_map->emplace(transformed_name, output_shape);
// record transformed tensors
transformed_set->insert(transformed_name);
}
@@ -479,6 +488,7 @@ MaceStatus NetDefAdapter::AdaptMemoryType(
OpConditionContext *context,
OperatorDef *op_def,
NetDefAdapter::TensorInfoMap *output_map,
+ TensorShapeMap *tensor_shape_map,
std::unordered_set *transformed_set,
MemoryType *op_output_mem_types,
NetDef *target_net_def) {
@@ -545,6 +555,8 @@ MaceStatus NetDefAdapter::AdaptMemoryType(
input_info.data_format,
input_info.shape,
target_net_def->op_size() - 1));
+ // update tensor shape map
+ tensor_shape_map->emplace(transformed_name, input_info.shape);
// record transformed tensors
transformed_set->insert(transformed_name);
}
@@ -555,6 +567,7 @@ MaceStatus NetDefAdapter::AdaptMemoryType(
}
#else
MACE_UNUSED(output_map);
+ MACE_UNUSED(tensor_shape_map);
MACE_UNUSED(transformed_set);
MACE_UNUSED(target_net_def);
#endif // MACE_ENABLE_OPENCL
diff --git a/mace/core/net_def_adapter.h b/mace/core/net_def_adapter.h
index d821ed810c32d2ef7d5644430948ad010c63e646..d924d84c93d492daa824d2547755036922249d28 100644
--- a/mace/core/net_def_adapter.h
+++ b/mace/core/net_def_adapter.h
@@ -32,10 +32,21 @@ class OpRegistryBase;
class Workspace;
class Device;
-/// Conventions:
-/// 1. DataFormat::AUTO stands for formatted (NHWC or NCHW)
-/// 2. if Op with DataFormat::AUTO, the arguments of this op
-/// is formatted to NHWC
+///////////////////////////////////////////////////////////////////////////////
+/// Conventions
+///
+/// 1. For the Ops ran with data format(like Conv2D),
+/// The inputs and outputs are DataFormat::NCHW if ran on CPU
+/// with float data type.
+/// while the inputs and outputs are DataFormat::NHWC for
+/// other situation(ran on GPU, quantization, DSP)
+///
+/// 2. Op with DataFormat::AUTO stands for inputs must have
+/// fixed format (NHWC or NCHW), determined at runtime.
+///
+/// 3. if Op with DataFormat::AUTO, the arguments of this op
+/// is formatted to NHWC.
+///////////////////////////////////////////////////////////////////////////////
class NetDefAdapter {
public:
NetDefAdapter(const OpRegistryBase *op_registry,
@@ -77,6 +88,7 @@ class NetDefAdapter {
};
typedef std::unordered_map TensorInfoMap;
+ typedef std::unordered_map> TensorShapeMap;
private:
MaceStatus AdaptDevice(OpConditionContext *context,
@@ -92,6 +104,7 @@ class NetDefAdapter {
OperatorDef *op,
bool is_quantized_model,
TensorInfoMap *output_map,
+ TensorShapeMap *tensor_shape_map,
std::unordered_set *transformed_set,
DataFormat *op_output_df,
NetDef *target_net_def);
@@ -100,6 +113,7 @@ class NetDefAdapter {
OpConditionContext *context,
OperatorDef *op_def,
TensorInfoMap *output_map,
+ TensorShapeMap *tensor_shape_map,
std::unordered_set *transformed_set,
MemoryType *op_output_mem_types,
NetDef *target_net_def);
diff --git a/mace/python/tools/converter_tool/base_converter.py b/mace/python/tools/converter_tool/base_converter.py
index 61e65bae9152ed3337306addd84e6e29c2d9bc57..92c31e4163202f1de3dd5fd2c8bd4257cf4babbb 100644
--- a/mace/python/tools/converter_tool/base_converter.py
+++ b/mace/python/tools/converter_tool/base_converter.py
@@ -167,33 +167,33 @@ MaceSupportedOps = [
MaceOp = Enum('MaceOp', [(op, op) for op in MaceSupportedOps], type=str)
-MaceHasDataFormatOps = [MaceOp.BatchNorm,
- MaceOp.BatchToSpaceND,
- MaceOp.Conv2D,
- MaceOp.Deconv2D,
- MaceOp.DepthToSpace,
- MaceOp.DepthwiseConv2d,
- MaceOp.DepthwiseDeconv2d,
- MaceOp.FullyConnected,
- MaceOp.Pooling,
- MaceOp.ResizeBicubic,
- MaceOp.ResizeBilinear,
- MaceOp.ResizeNearestNeighbor,
- MaceOp.SpaceToBatchND,
- MaceOp.SpaceToDepth]
-
-MaceMayHasDataFormatOps = [MaceOp.Activation,
- MaceOp.AddN,
- MaceOp.BiasAdd,
- MaceOp.ChannelShuffle,
- MaceOp.Concat,
- MaceOp.Crop,
- MaceOp.Eltwise,
- MaceOp.Pad,
- MaceOp.Reduce,
- MaceOp.Softmax,
- MaceOp.Split,
- MaceOp.SqrDiffMean]
+MaceFixedDataFormatOps = [MaceOp.BatchNorm,
+ MaceOp.BatchToSpaceND,
+ MaceOp.Conv2D,
+ MaceOp.Deconv2D,
+ MaceOp.DepthToSpace,
+ MaceOp.DepthwiseConv2d,
+ MaceOp.DepthwiseDeconv2d,
+ MaceOp.FullyConnected,
+ MaceOp.Pooling,
+ MaceOp.ResizeBicubic,
+ MaceOp.ResizeBilinear,
+ MaceOp.ResizeNearestNeighbor,
+ MaceOp.SpaceToBatchND,
+ MaceOp.SpaceToDepth]
+
+MaceTransposableDataFormatOps = [MaceOp.Activation,
+ MaceOp.AddN,
+ MaceOp.BiasAdd,
+ MaceOp.ChannelShuffle,
+ MaceOp.Concat,
+ MaceOp.Crop,
+ MaceOp.Eltwise,
+ MaceOp.Pad,
+ MaceOp.Reduce,
+ MaceOp.Softmax,
+ MaceOp.Split,
+ MaceOp.SqrDiffMean]
class MaceKeyword(object):
diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py
index 51806961d045e40a9cc9de184238b41b5d953308..380c0051d764ff23200d1ff5a72cb3a516205ecf 100644
--- a/mace/python/tools/converter_tool/transformer.py
+++ b/mace/python/tools/converter_tool/transformer.py
@@ -27,8 +27,8 @@ from mace.python.tools.converter_tool.base_converter import EltwiseType
from mace.python.tools.converter_tool.base_converter import FrameworkType
from mace.python.tools.converter_tool.base_converter import MaceKeyword
from mace.python.tools.converter_tool.base_converter import MaceOp
-from mace.python.tools.converter_tool.base_converter import MaceHasDataFormatOps # noqa
-from mace.python.tools.converter_tool.base_converter import MaceMayHasDataFormatOps # noqa
+from mace.python.tools.converter_tool.base_converter import MaceFixedDataFormatOps # noqa
+from mace.python.tools.converter_tool.base_converter import MaceTransposableDataFormatOps # noqa
from mace.python.tools.converter_tool.base_converter import PaddingMode
from mace.python.tools.converter_tool.base_converter import ReduceType
from mace.python.tools.converter_tool.base_converter import TransformerRule
@@ -1348,9 +1348,9 @@ class Transformer(base_converter.ConverterInterface):
if not df_arg:
df_arg = op.arg.add()
df_arg.name = MaceKeyword.mace_data_format_str
- if op.type in MaceHasDataFormatOps:
+ if op.type in MaceFixedDataFormatOps:
df_arg.i = DataFormat.AUTO.value
- elif op.type in MaceMayHasDataFormatOps:
+ elif op.type in MaceTransposableDataFormatOps:
input_df = DataFormat.AUTO.value
for input_tensor in op.input:
if input_tensor in self._consts:
diff --git a/mace/python/tools/operator.jinja2 b/mace/python/tools/operator.jinja2
index b184b54a3d98f034147866d04a6b48c1af0703f9..e60057ed75be1da5edb7c5cc46fdc7c00f243c8c 100644
--- a/mace/python/tools/operator.jinja2
+++ b/mace/python/tools/operator.jinja2
@@ -96,10 +96,9 @@ void CreateOperator{{i}}(mace::OperatorDef *op) {
{% if net.op[i].output_shape|length > 0 %}
op->mutable_output_shape()->Reserve({{ net.op[i].output_shape|length }});
- mace::OutputShape * output_shape = nullptr;
{% for shape in net.op[i].output_shape %}
{% if shape.dims|length > 0 %}
- output_shape = op->add_output_shape();
+ mace::OutputShape * output_shape = op->add_output_shape();
output_shape->mutable_dims()->Reserve({{ shape.dims|length }});
{% for dim in shape.dims %}