提交 8b32fc9d 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1835 support GPU quantization aware Training

Merge pull request !1835 from yangjie159/quant_aware_train_develop
......@@ -36,7 +36,6 @@ using GraphDefT = mindspore::predict::GraphDefT;
using TensorDefT = mindspore::predict::TensorDefT;
using SubGraphDefT = mindspore::predict::SubGraphDefT;
using SubGraphPtr = std::unique_ptr<mindspore::predict::SubGraphDefT>;
using NodeDef = mindspore::predict::NodeDefT;
using MsDataType = mindspore::predict::DataType;
using MsFormat = mindspore::predict::Format;
using MsKernelKey = void *;
......
......@@ -108,8 +108,7 @@ bool Kernel2Ms::SetGraphOutputIdx(const KernelGraphPtr &kernel_graph_ptr, const
}
bool Kernel2Ms::SetOpOutputIdx(const CNodePtr &c_node_ptr, const TensorPtr &output_tensor,
const TensorCachePtr &tensor_cache, int ref_count, size_t order_index,
NodeDef *ms_node) {
const TensorCachePtr &tensor_cache, int ref_count, size_t order_index, OpDefT *ms_node) {
MS_EXCEPTION_IF_NULL(c_node_ptr);
MS_EXCEPTION_IF_NULL(output_tensor);
MS_EXCEPTION_IF_NULL(ms_node);
......@@ -123,7 +122,7 @@ bool Kernel2Ms::SetOpOutputIdx(const CNodePtr &c_node_ptr, const TensorPtr &outp
std::vector<int> tensor_shape;
(void)std::transform(host_shape.begin(), host_shape.end(), std::back_inserter(tensor_shape), SizeToInt);
int outputIndex = tensor_cache->addExTensor(tensor_key, output_tensor, ref_count, tensor_shape, KERNEL);
ms_node->opDef->outputIndex.push_back(outputIndex);
ms_node->outputIndex.push_back(outputIndex);
return true;
}
......@@ -164,7 +163,7 @@ void Kernel2Ms::GetRealInpoutsPtr(const AnfNodePtr &node, std::vector<AnfNodePtr
}
}
bool Kernel2Ms::SetOpInputIdx(const CNodePtr &c_node_ptr, const TensorCachePtr &tensor_cache, NodeDef *ms_node) {
bool Kernel2Ms::SetOpInputIdx(const CNodePtr &c_node_ptr, const TensorCachePtr &tensor_cache, OpDefT *ms_node) {
MS_EXCEPTION_IF_NULL(c_node_ptr);
MS_EXCEPTION_IF_NULL(tensor_cache);
MS_EXCEPTION_IF_NULL(ms_node);
......@@ -184,7 +183,7 @@ bool Kernel2Ms::SetOpInputIdx(const CNodePtr &c_node_ptr, const TensorCachePtr &
}
ExTensorPtr ex_tensor_ptr = ex_tensor_list[real_output_idx[j]];
ex_tensor_list.clear();
ms_node->opDef->inputIndex.push_back(ex_tensor_ptr->index_);
ms_node->inputIndex.push_back(ex_tensor_ptr->index_);
}
}
return true;
......@@ -397,19 +396,18 @@ bool Kernel2Ms::SetGraphOpTensors(const KernelGraphPtr &kernel_graph_ptr, const
return false;
}
auto kernel_key = node_indexs_[kernel.get()];
std::unique_ptr<NodeDef> ms_node(new NodeDef);
std::unique_ptr<OpDefT> ms_node(new OpDefT);
ms_node->name = kernel->fullname_with_scope();
ms_node->fmkType = mindspore::predict::FmkType_CAFFE;
std::unique_ptr<OpDefT> ms_op(new OpDefT());
auto c_name = AnfAlgo::GetCNodeName(kernel);
auto fun = predict::convert::OpAttrFactory::GetInstance()->GetPackFun(c_name);
if (fun == nullptr) {
MS_LOG(ERROR) << "get node [" << kernel->fullname_with_scope() << "] attr failed.";
return false;
} else if (!fun(kernel, ms_op.get())) {
} else if (!fun(kernel, ms_node.get())) {
MS_LOG(ERROR) << "set node [" << kernel->fullname_with_scope() << "] attr failed.";
return false;
}
ms_node->opDef = std::move(ms_op);
auto output_size = AnfAlgo::GetOutputTensorNum(kernel);
int nodeRefCount = SizeToInt(output_size);
for (size_t j = 0; j < output_size; ++j) {
......@@ -466,7 +464,7 @@ bool Kernel2Ms::KernelGraph2MsGraph(const KernelGraphPtr &kernel_graph_ptr) {
if (!SetOpInputIdx(kernels[i], tensor_cache_ptr_, ms_node)) {
return false;
}
std::unique_ptr<NodeDef> ms_node_tmp(ms_node);
std::unique_ptr<OpDefT> ms_node_tmp(ms_node);
sub_ms_graph->nodes.emplace_back(std::move(ms_node_tmp));
}
if (!SetAllTensors(tensor_cache_ptr_, sub_ms_graph.get())) {
......
......@@ -64,10 +64,10 @@ class Kernel2Ms {
bool SetAllTensors(const TensorCachePtr &tensor_cache, SubGraphDefT *sub_graph_def_t);
bool SetOpInputIdx(const CNodePtr &c_node_ptr, const TensorCachePtr &tensor_cache, NodeDef *ms_node);
bool SetOpInputIdx(const CNodePtr &c_node_ptr, const TensorCachePtr &tensor_cache, OpDefT *ms_node);
bool SetOpOutputIdx(const CNodePtr &c_node_ptr, const TensorPtr &output_tensor, const TensorCachePtr &tensor_cache,
int ref_count, size_t order_index, NodeDef *ms_node);
int ref_count, size_t order_index, OpDefT *ms_node);
bool SetGraphOutputIdx(const KernelGraphPtr &kernel_graph_ptr, const TensorCachePtr &tensor_cache,
SubGraphDefT *sub_graph_def_t, AllOutputTensors *all_output_tensors);
......@@ -102,7 +102,7 @@ class Kernel2Ms {
bool SetMemResue() const;
SubGraphPtr sub_ms_graph_;
AllOutputTensors all_output_tensors_;
std::vector<NodeDef *> tmp_op_nodes_;
std::vector<OpDefT *> tmp_op_nodes_;
std::unordered_map<MsKernelKey, int> node_indexs_;
std::unordered_map<int, MsKernelKey> index_nodes_;
int graph_index_ = 0;
......
......@@ -33,6 +33,14 @@ bool CastPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op);
bool MeanPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op);
bool SoftmaxPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op);
bool ScalePacker(const CNodePtr &c_node_ptr, OpDefT *ms_op);
bool AddFoldPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op);
bool ArgMaxPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op);
bool BatchNormFoldPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op);
bool FakeQuantWithMinMaxPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op);
bool FakeQuantWithMinMaxPerChannelPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op);
bool MulPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op);
bool MulFoldPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op);
bool SqueezePacker(const CNodePtr &c_node_ptr, OpDefT *ms_op);
OpAttrFactory::OpAttrFactory() {
pack_funs_ = {{"Conv2D", Conv2dPacker},
......@@ -60,23 +68,31 @@ OpAttrFactory::OpAttrFactory() {
{"TensorAdd", AddPacker},
{"SoftMax", SoftmaxPacker},
{"SimpleMean", MeanPacker},
{"Scale", ScalePacker}};
{"ReduceMean", MeanPacker},
{"AddFold", AddFoldPacker},
{"ArgMax", ArgMaxPacker},
{"BatchNorm", BatchNormFoldPacker},
{"FakeQuantWithMinMax", FakeQuantWithMinMaxPacker},
{"FakeQuantWithMinMaxPerChannel", FakeQuantWithMinMaxPerChannelPacker},
{"Mul", MulPacker},
{"MulFold", MulFoldPacker},
{"Squeeze", SqueezePacker}};
}
OpAttrPackFun OpAttrFactory::GetPackFun(const std::string &opType) {
if (pack_funs_.find(opType) == pack_funs_.end()) {
MS_LOG(ERROR) << "Op Attr pack fun [\" << opType << \"] not found.";
MS_LOG(WARNING) << "Op Attr pack fun [" << opType << "] not found.";
return nullptr;
}
return pack_funs_[opType];
}
mindspore::predict::DataFormatType GetAttrFormat(const std::string &format) {
mindspore::predict::Format GetAttrFormat(const std::string &format) {
if (format == kOpFormat_NCHW) {
return predict::DataFormatType::DataFormatType_NCHW;
return predict::Format::Format_NCHW;
} else if (format == kOpFormat_NHWC) {
return predict::DataFormatType::DataFormatType_NHWC;
return predict::Format::Format_NHWC;
} else {
return predict::DataFormatType::DataFormatType_UNKNOW;
return predict::Format::Format_NUM_OF_FORMAT;
}
}
......
......@@ -48,7 +48,7 @@ class OpAttrFactory {
std::unordered_map<std::string, OpAttrPackFun> pack_funs_;
};
mindspore::predict::DataFormatType GetAttrFormat(const std::string &format);
mindspore::predict::Format GetAttrFormat(const std::string &format);
mindspore::predict::PadMode GetAttrPadMode(const std::string &pad_mode);
} // namespace convert
......
......@@ -25,7 +25,6 @@ bool AddPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op) {
}
std::unique_ptr<AddT> attr(new AddT());
MS_EXCEPTION_IF_NULL(attr);
attr->format = predict::DataFormatType::DataFormatType_NCHW;
ms_op->name = c_node_ptr->fullname_with_scope();
ms_op->attr.type = OpT_Add;
ms_op->attr.value = attr.release();
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "predict/converter/lite_model/op_attr_packer.h"
namespace mindspore {
namespace predict {
namespace convert {
bool AddFoldPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op) {
if (c_node_ptr == nullptr || ms_op == nullptr) {
return false;
}
std::unique_ptr<AddFoldT> attr(new AddFoldT());
MS_EXCEPTION_IF_NULL(attr);
ms_op->attr.type = OpT_AddFold;
ms_op->attr.value = attr.release();
return true;
}
} // namespace convert
} // namespace predict
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "predict/converter/lite_model/op_attr_packer.h"
namespace mindspore {
namespace predict {
namespace convert {
bool ArgMaxPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op) {
if (c_node_ptr == nullptr || ms_op == nullptr) {
return false;
}
std::unique_ptr<ArgMaxT> attr(new ArgMaxT());
MS_EXCEPTION_IF_NULL(attr);
ms_op->attr.type = OpT_ArgMax;
ms_op->attr.value = attr.release();
return true;
}
} // namespace convert
} // namespace predict
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "predict/converter/lite_model/op_attr_packer.h"
namespace mindspore {
namespace predict {
namespace convert {
bool BatchNormFoldPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op) {
if (c_node_ptr == nullptr || ms_op == nullptr) {
return false;
}
std::unique_ptr<BatchNormFoldT> attr(new BatchNormFoldT());
MS_EXCEPTION_IF_NULL(attr);
ms_op->attr.type = OpT_BatchNormFold;
ms_op->attr.value = attr.release();
return true;
}
} // namespace convert
} // namespace predict
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "predict/converter/lite_model/op_attr_packer.h"
namespace mindspore {
namespace predict {
namespace convert {
bool FakeQuantWithMinMaxPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op) {
if (c_node_ptr == nullptr || ms_op == nullptr) {
return false;
}
std::unique_ptr<FakeQuantWithMinMaxT> attr(new FakeQuantWithMinMaxT());
MS_EXCEPTION_IF_NULL(attr);
ms_op->attr.type = OpT_FakeQuantWithMinMax;
ms_op->attr.value = attr.release();
return true;
}
} // namespace convert
} // namespace predict
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "predict/converter/lite_model/op_attr_packer.h"
namespace mindspore {
namespace predict {
namespace convert {
bool FakeQuantWithMinMaxPerChannelPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op) {
if (c_node_ptr == nullptr || ms_op == nullptr) {
return false;
}
std::unique_ptr<FakeQuantWithMinMaxPerChannelT> attr(new FakeQuantWithMinMaxPerChannelT());
MS_EXCEPTION_IF_NULL(attr);
ms_op->attr.type = OpT_FakeQuantWithMinMaxPerChannel;
ms_op->attr.value = attr.release();
return true;
}
} // namespace convert
} // namespace predict
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "predict/converter/lite_model/op_attr_packer.h"
namespace mindspore {
namespace predict {
namespace convert {
bool MulPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op) {
if (c_node_ptr == nullptr || ms_op == nullptr) {
return false;
}
std::unique_ptr<MulT> attr(new MulT());
MS_EXCEPTION_IF_NULL(attr);
ms_op->attr.type = OpT_Mul;
ms_op->attr.value = attr.release();
return true;
}
} // namespace convert
} // namespace predict
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "predict/converter/lite_model/op_attr_packer.h"
namespace mindspore {
namespace predict {
namespace convert {
bool MulFoldPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op) {
if (c_node_ptr == nullptr || ms_op == nullptr) {
return false;
}
std::unique_ptr<MulFoldT> attr(new MulFoldT());
MS_EXCEPTION_IF_NULL(attr);
ms_op->name = c_node_ptr->fullname_with_scope();
ms_op->attr.type = OpT_MulFold;
ms_op->attr.value = attr.release();
return true;
}
} // namespace convert
} // namespace predict
} // namespace mindspore
......@@ -36,7 +36,6 @@ bool PoolingPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op) {
attr->poolingMode = mindspore::predict::PoolMode::PoolMode_MEAN_POOLING;
} else if (c_name == "GlobalPool") {
ms_op->name = c_node_ptr->fullname_with_scope();
attr->poolingMode = mindspore::predict::PoolMode::PoolMode_GLOBAL_POOING;
} else {
MS_LOG(ERROR) << "unknowed pooling type.";
return false;
......@@ -53,7 +52,6 @@ bool PoolingPacker(const CNodePtr &c_node_ptr, OpDefT *ms_op) {
attr->padDown = 0;
attr->padLeft = 0;
attr->padRight = 0;
attr->caffeMode = false;
ms_op->attr.type = OpT_Pooling;
ms_op->attr.value = attr.release();
return true;
......
......@@ -25,7 +25,7 @@ bool ReshapePacker(const CNodePtr &c_node_ptr, OpDefT *ms_op) {
}
std::unique_ptr<ReshapeT> attr(new ReshapeT());
MS_EXCEPTION_IF_NULL(attr);
attr->format = predict::DataFormatType::DataFormatType_NCHW;
attr->format = predict::Format::Format_NCHW;
ms_op->name = c_node_ptr->fullname_with_scope();
ms_op->attr.type = OpT_Reshape;
ms_op->attr.value = attr.release();
......
......@@ -25,7 +25,7 @@ bool ScalePacker(const CNodePtr &c_node_ptr, OpDefT *ms_op) {
}
std::unique_ptr<ScaleT> attr(new ScaleT());
MS_EXCEPTION_IF_NULL(attr);
attr->format = predict::DataFormatType::DataFormatType_NCHW;
attr->format = predict::Format::Format_NCHW;
ms_op->name = c_node_ptr->fullname_with_scope();
ms_op->attr.type = OpT_Scale;
ms_op->attr.value = attr.release();
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "predict/converter/lite_model/op_attr_packer.h"
namespace mindspore {
namespace predict {
namespace convert {
bool SqueezePacker(const CNodePtr &c_node_ptr, OpDefT *ms_op) {
if (c_node_ptr == nullptr || ms_op == nullptr) {
return false;
}
std::unique_ptr<SqueezeT> attr(new SqueezeT());
MS_EXCEPTION_IF_NULL(attr);
std::vector<int> kernel_axis_value = AnfAlgo::GetNodeAttr<std::vector<int>>(c_node_ptr, "axis");
attr->axis = kernel_axis_value;
ms_op->attr.type = OpT_Squeeze;
ms_op->attr.value = attr.release();
return true;
}
} // namespace convert
} // namespace predict
} // namespace mindspore
......@@ -22,7 +22,7 @@
namespace mindspore {
namespace predictmodel {
void StepConvertGraph(const KernelGraphPtrNew &kernel_graph_ptr) {
void StepConvertGraph(const KernelGraphPtr &kernel_graph_ptr) {
MS_LOG(INFO) << "start convert_graph step";
// get kernel_graph. this graph can be origin or device, depends on which steps to persistence
MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
......@@ -59,15 +59,5 @@ void StepConvertWeight(const std::vector<tensor::TensorPtr> &inputs) {
}
}
}
executor::TargetMode GetDeviceTarget(const std::string &device_target) {
if (device_target == "GPU") {
return executor::kGPUTarget;
} else if (device_target == "Ascend") {
return executor::kCPUTarget;
} else {
return executor::kUnknowTarget;
}
}
} // namespace predictmodel
} // namespace mindspore
......@@ -19,16 +19,14 @@
#include <memory>
#include <vector>
#include <string>
#include "session/session_basic.h"
#include "predict/converter/kernel2ms.h"
namespace mindspore {
namespace predictmodel {
using KernelGraphPtrNew = std::shared_ptr<mindspore::session::KernelGraph>;
void StepConvertGraph(const KernelGraphPtrNew &kernel_graph_ptr);
using KernelGraphPtr = std::shared_ptr<mindspore::session::KernelGraph>;
void StepConvertGraph(const KernelGraphPtr &kernel_graph_ptr);
void StepConvertWeight(const std::vector<tensor::TensorPtr> &inputs);
executor::TargetMode GetDeviceTarget(const std::string &device_target);
} // namespace predictmodel
} // namespace mindspore
#endif // MINDSPORE_MINDSPORE_CCSRC_PREDICT_H_
......@@ -13,42 +13,26 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
include "op.fbs";
namespace mindspore.predict;
enum DataType : int {
DT_FLOAT = 0,
DT_FLOAT16 = 1,
DT_INT8 = 2,
DT_INT32 = 3,
DT_UINT8 = 4,
DT_UINT32 = 8,
DT_UNDEFINED = 16
}
enum Format : int {
NCHW = 0,
NHWC,
NC4HW4 = 100,
NUM_OF_FORMAT
}
enum MSConst: int {
enum MSCONST: int {
WEIGHT_REFCOUNT = 999
}
table QuantizationDef {
// Quantized value q, corresponding float value r:
// r = scale * (q - zero_point), where scale = (rmax - rmin) / (qmax - qmin)
min: [float];
max: [float];
scale: [float];
zero_point: [long];
table QuantParam {
scale: double;
zeroPoint: int;
min: double = 0;
max: double = 0;
narrowRange: bool = true;
numBits: int = 8;
}
// Tensor shape of the specifies dimension.
dimension: int;
table QuantParamArray {
param: [QuantParam]; //pre-channel
}
table TensorDef {
......@@ -60,7 +44,6 @@ table TensorDef {
refCount: int;
offset: int;
data: [ubyte];
quantization: QuantizationDef;
}
union OpT {
......@@ -70,7 +53,6 @@ union OpT {
Conv2D,
FusedBatchNorm,
CaffeBatchNorm,
Squeeze,
BiasAdd,
Pooling,
DepthwiseConv2D,
......@@ -85,57 +67,134 @@ union OpT {
Eltwise,
NetOutput,
Add,
Sub,
MatMul,
StridedSlice,
Power,
Slice,
Stack,
Mul,
RealDiv,
Pad,
Maximum,
Minimum,
CaffePReLU,
LeakyReLU,
ArgMax,
ArgMin,
Exp,
CaffeCrop,
Range,
Rsqrt,
ExpandDims,
Tile,
Cast
// Split
Cast,
Shape,
Nchw2Nhwc,
Nhwc2Nchw,
QuantDTypeCast,
Split,
Permute,
FakeQuantWithMinMaxVars,
Equal,
Less,
Greater,
Min,
Floor,
Abs,
Neg,
Cos,
Sin,
Sqrt,
Square,
Constant,
Log,
Tan,
Atan,
Asin,
Clip,
Transpose,
Squeeze,
Unsqueeze,
Upsample,
Dropout,
Broadcast,
Lrn,
Prelu,
ZerosLike,
TopK,
SpaceToDepth,
SpaceToBatch,
SparseToDense,
ReverseSequence,
Rank,
Gather,
GatherNd,
Fill,
Elu,
DepthToSpace,
BatchToSpace,
AddN,
Ceil,
EmbeddingLookup,
EmbeddingLookupSparse,
FloorDiv,
FloorMod,
L2Norm,
LocalResponseNormalization,
MatrixDiag,
Reduce,
Reverse,
Round,
Select,
Scatter,
Unique,
Unstack,
LogicalAnd,
LogicalOr,
LogicalXor,
LogicalNot,
OnnxInt8Quantize,
OnnxInt8Dequantize,
FakeQuantWithMinMax,
FakeQuantWithMinMaxPerChannel,
BatchNormFold,
MulFold,
AddFold,
SquaredDifference
}
enum QuantType: int {
QUANT_NONE,
QUANT_INT8
AwareTrainning,
WeightQuant,
PostTraining
}
enum FmkType: int {
TF,
CAFFE,
ONNX,
MS,
TFLITE
}
table OpDef {
name: string;
fmkType: FmkType;
attr: OpT;
inputIndex: [uint];
outputIndex: [uint];
isLastConv: bool;
quantType: QuantType = QUANT_NONE;
quantParam: [QuantParamArray];
}
enum FmkType: int {
TF,
CAFFE
}
table NodeDef {
fmkType: FmkType;
opDef: OpDef;
}
table SubGraphDef {
name: string;
inputIndex: [uint];
outputIndex: [uint];
mempoolSize: uint;
nodes: [NodeDef];
nodes: [OpDef];
allTensors: [TensorDef]; // weight + input + output
}
......
......@@ -22,12 +22,30 @@ enum ResizeMethod: byte {
NEAREST_NEIGHBOR = 1
}
enum DataFormatType : byte { // todo combine with mslite.h::Format
UNKNOW = -1,
enum DataType : int {
DT_FLOAT = 0,
DT_FLOAT16 = 1,
DT_INT8 = 2,
DT_INT32 = 3,
DT_UINT8 = 4,
DT_INT16 = 5,
DT_UINT32 = 8,
DT_INT64 = 9,
DT_UINT16 = 10,
DT_UNDEFINED = 16
}
enum Format : int {
NCHW = 0,
NHWC = 1,
HWC = 2, // for input image or resize
CHW = 3, // for input image or resize
NHWC,
HWKC,
HWCK,
KCHW,
CKHW,
KHWC,
CHWK,
NC4HW4 = 100,
NUM_OF_FORMAT
}
enum ActivationType : byte {
......@@ -42,26 +60,47 @@ enum ActivationType : byte {
SOFTSIGN = 8,
SOFTPLUS = 9,
TANH = 10,
UNKNOW = 11
SELU = 11,
HSWISH = 12,
HSIGMOID = 13,
THRESHOLDRELU = 14,
LINEAR = 15,
UNKNOW = 16
}
enum ReduceType : byte {
REDUCE_MAX = 0,
REDUCE_MEAN = 1,
REDUCE_ALL = 2,
REDUCE_ANY = 3,
REDUCE_LOG_SUM_EXP = 4,
REDUCE_PROD = 5,
REDUCE_SUM = 6,
UNKNOW = 7
}
enum PoolMode : byte {
MAX_POOLING = 0,
MEAN_POOLING = 1,
GLOBAL_POOING = 2
}
enum EltwiseMode : byte {
PROD = 0,
SUM = 1,
MAXIMUM = 2
MAXIMUM = 2,
UNKNOW = 3
}
enum PadMode : byte {
NOTSET=0,
SAME=1,
VALID=2,
CAFFE_CEIL_NEW=4
NOTSET = 0,
SAME = 1,
VALID = 2,
CAFFE = 4
}
enum RoundMode : byte {
FLOOR = 0,
CEIL = 1
}
enum PaddingMode : byte {
......@@ -77,7 +116,9 @@ table Pad {
}
table Maximum {
format: DataFormatType = 0;
}
table Minimum {
}
table Concat {
......@@ -94,7 +135,7 @@ table Activation {
}
table Conv2D {
format: DataFormatType = 0;
format: Format = 0;
group: int;
channelIn: int;
channelOut: int;
......@@ -114,15 +155,29 @@ table Conv2D {
}
table FusedBatchNorm {
epsilon: float; // eg. epsilon=0.001
epsilon: float = 0.00001; // eg. epsilon=0.001
momentum: float = 0.9;
spatial: int = 1;
}
table CaffeBatchNorm {
epsilon: float; // eg. epsilon=0.001
}
table Squeeze {
axis: [int];
table Shape {
}
table Nchw2Nhwc {
}
table Nhwc2Nchw {
}
table FakeQuantWithMinMaxVars {
narrowRange: bool;
numBits: int;
}
table BiasAdd {
......@@ -130,8 +185,9 @@ table BiasAdd {
}
table Pooling {
format: DataFormatType = 0;
format: Format = 0;
poolingMode: PoolMode;
global: bool = false;
windowW: int;
windowH: int;
strideW: int;
......@@ -141,12 +197,11 @@ table Pooling {
padDown: int;
padLeft: int;
padRight: int;
// todo replace with padValueMode in convolution pooling and so on
caffeMode: bool = false;
roundMode: RoundMode;
}
table DepthwiseConv2D {
format: DataFormatType = 0;
format: Format = 0;
channelIn: int;
channelMultiplier: int;
kernelW: int;
......@@ -165,7 +220,7 @@ table DepthwiseConv2D {
}
table DeDepthwiseConv2D {
format: DataFormatType = 0;
format: Format = 0;
channelIn: int;
channelMultiplier: int;
kernelW: int;
......@@ -185,7 +240,7 @@ table DeDepthwiseConv2D {
table Resize {
format: DataFormatType = 0;
format: Format = 0;
method: ResizeMethod;
newHeight: long;
newWidth: long;
......@@ -194,7 +249,7 @@ table Resize {
}
table DetectionPostProcess {
format: DataFormatType = 0;
format: Format = 0;
inputSize: int;
hScale: float;
wScale: float;
......@@ -210,8 +265,8 @@ table DetectionPostProcess {
}
table FullConnection {
format: DataFormatType = 0;
hasBias: bool;
axis: int;
}
// Mean(input_tensor, axis, keep_dims)
......@@ -221,7 +276,7 @@ table Mean {
}
table DeConv2D {
format: DataFormatType = 0;
format: Format = 0;
group: int;
channelIn: int;
channelOut: int;
......@@ -241,34 +296,88 @@ table DeConv2D {
}
table Scale {
format: DataFormatType = 0;
format: Format = 0;
}
table Eltwise {
format: DataFormatType = 0;
mode: EltwiseMode;
// todo repeat coeff (default 1)
}
table Add {
format: DataFormatType = 0;
}
table Sub {
}
table Mul {
}
table RealDiv {
}
table Rsqrt {
}
table Equal {
}
table Less {
}
table Greater {
}
table Min {
}
table Slice {
format: DataFormatType = 0;
format: Format = 0;
begin: [int];
end: [int];
stride: [int];
size: [int];
}
table Mul {
table Floor {
}
table Abs {
}
table Neg {
}
table Exp {
}
table Cos {
}
table Sin {
}
table Sqrt {
}
table Square {
}
table Ceil {
}
table Log {
}
table Tan {
}
table Atan {
}
table Asin {
}
table Reshape {
format: DataFormatType = 0;
format: Format = 0;
shape: [long];
}
table Power {
......@@ -280,13 +389,20 @@ table Power {
table ArgMax {
axis: int;
outMaxValue: bool;
topK: int;
topK: int = 1;
keepDims: bool;
axisType: int;
}
table ArgMin {
axis: int;
outMaxValue: bool;
topK: int = 1;
keepDims: bool;
axisType: int;
}
table NetOutput {
format: DataFormatType = 0;
}
table MatMul {
......@@ -298,6 +414,10 @@ table CaffePReLU {
channelShared : bool = false;
}
table LeakyReLU {
negativeSlope: float;
}
table StridedSlice {
beginMask: int;
endMask: int;
......@@ -317,6 +437,7 @@ table Stack {
}
table Range {
dType: DataType;
start: int;
limit: int;
delta: int;
......@@ -335,13 +456,244 @@ table Cast {
dstT: int;
}
//table Split {
// numberSplit: int;
// sizeSplits: [int];
// splitDim: int;
//}
table QuantDTypeCast {
srcT: DataType;
dstT: DataType;
}
table Split {
numberSplit: int;
sizeSplits: [int];
splitDim: int;
}
table CaffeCrop {
axis : long;
offsets : [long];
}
table Permute {
order: [long];
}
table Clip {
max: float;
min: float;
}
table Constant {
}
table Elu {
alpha: float = 1.0;
}
table Broadcast {
}
table Lrn {
alpha: float = 0.0001;
beta: float = 0.75;
bias: float = 1.0;
size: int;
}
enum ReduceMode : byte {
ReduceMean = 0,
ReduceMax = 1,
ReduceMin = 2,
ReduceProd = 3,
ReduceSum = 4,
ReduceSumSquare = 5
}
table Reduce {
axes: [int];
keepDims: int;
mode: ReduceMode;
}
table Prelu {
slope: [float];
}
table Transpose {
perm: [int];
conjugate: bool = false;
}
table Squeeze {
axis: [int];
}
table Unsqueeze {
axis: [int];
}
table Upsample {
mode: string;
scales: [float];
}
table Dropout {
ratio : float = 0.5;
}
table LocalResponseNormalization {
depth_radius: int;
bias: float;
alpha: float;
beta: float;
}
table ZerosLike {
}
table TopK {
k : int;
sorted : bool = true;
}
table SpaceToDepth {
blockSize : int;
format: Format = 0;
}
table SpaceToBatch {
blockShape : [int];
paddings : [int];
}
table SparseToDense {
validateIndices: bool;
}
table ReverseSequence {
seqAxis: int;
batchAxis: int;
}
table Rank {
}
table Gather {
axis: int;
batchDims: int;
}
table GatherNd {
batchDims: int;
}
table Fill {
dims: [int];
}
table DepthToSpace {
blockSize: int;
format: Format = 0;
}
table BatchToSpace {
blockShape: [int];
crops: [int];
}
table AddN {
N: int;
}
table EmbeddingLookup {
ids: [int];
maxNorm: float;
}
table EmbeddingLookupSparse {
spIds: [int];
spWeights: [float];
//combiner: Combiner=0;
maxNortm: float;
}
table FloorDiv {
}
table FloorMod {
}
table L2Norm {
axis: [int];
epsilon: float;
}
table LogicalAnd {
}
table LogicalOr {
}
table LogicalXor {
}
table LogicalNot {
}
table MatrixDiag {
k: int;
numRows: int;
numCols: int;
paddingValue: float;
}
table Select {
}
table TfReduce {
type: ReduceType = 7;
}
table Reverse {
axis: [int];
}
table Round {
}
table Scatter {
}
table Unique {
}
table Unstack {
num: int;
axis: int;
}
table OnnxInt8Quantize {
}
table OnnxInt8Dequantize {
}
table FakeQuantWithMinMax {
}
table FakeQuantWithMinMaxPerChannel {
}
table BatchNormFold {
}
table MulFold {
}
table AddFold {
}
table SquaredDifference {
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册