提交 9c5f6b91 编写于 作者: W wandongdong

add onnx ops for deepfm

上级 a5161a96
......@@ -249,6 +249,13 @@ OPERATOR_ONNX_CONVERT_DEFINE(
.Attr("padding", "auto_pad", onnx::AttributeProto_AttributeType_STRING, SetPoolingPadMode)
.Attr("strides", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>))
OPERATOR_ONNX_CONVERT_DEFINE(GatherV2, Gather, OpNameInfo())
OPERATOR_ONNX_CONVERT_DEFINE(make_tuple, SequenceConstruct, OpNameInfo())
OPERATOR_ONNX_CONVERT_DEFINE(Concat, Concat, OpNameInfo())
OPERATOR_ONNX_CONVERT_DEFINE(RealDiv, Div, OpNameInfo())
OPERATOR_ONNX_CONVERT_DEFINE(ReduceSum, ReduceSum, OpNameInfo())
OPERATOR_ONNX_CONVERT_DEFINE(Sub, Sub, OpNameInfo())
#define OP_CONVERT_FUNCTION_NAME(name) GetOpOnnxConvertInfo_##name
void RegisterOpConverters(const std::function<void(OpNameInfo &&)> &fn) {
......@@ -269,6 +276,12 @@ void RegisterOpConverters(const std::function<void(OpNameInfo &&)> &fn) {
fn(OP_CONVERT_FUNCTION_NAME(Squeeze)());
fn(OP_CONVERT_FUNCTION_NAME(BatchNorm)());
fn(OP_CONVERT_FUNCTION_NAME(MatMul)());
fn(OP_CONVERT_FUNCTION_NAME(make_tuple)());
fn(OP_CONVERT_FUNCTION_NAME(Concat)());
fn(OP_CONVERT_FUNCTION_NAME(RealDiv)());
fn(OP_CONVERT_FUNCTION_NAME(BiasAdd)());
fn(OP_CONVERT_FUNCTION_NAME(Sub)());
}
class OpConvertRegistry {
......@@ -325,8 +338,8 @@ class OnnxExporter {
void ExportPrimReshape(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportPrimReduceMean(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportPrimReduce(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportPrimCast(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
onnx::GraphProto *graph_proto);
void ExportPrimPReLU(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
......@@ -335,6 +348,12 @@ class OnnxExporter {
onnx::GraphProto *graph_proto);
void ExportPrimDepthwiseConv2d(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportPrimTile(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
onnx::GraphProto *graph_proto);
void ExportPrimSquare(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportPrimGatherV2(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportMergeConv(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
onnx::GraphProto *graph_proto);
......@@ -628,16 +647,19 @@ void OnnxExporter::ExportPrimReshape(const FuncGraphPtr & /*func_graph*/, const
node_proto->add_input(name_shape);
}
void OnnxExporter::ExportPrimReduceMean(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr,
onnx::GraphProto *const graph_proto) {
void OnnxExporter::ExportPrimReduce(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto);
auto input_axis = node->input(2);
auto node_idx = AllocateNodeIndex();
(*node_map_ptr)[node] = node_idx;
onnx::NodeProto *node_proto = graph_proto->add_node();
node_proto->set_op_type(prim::kPrimReduceMean->name());
auto name = prim::kPrimReduceMean->name();
if (node->IsApply(prim::kPrimReduceSum)) {
name = prim::kPrimReduceSum->name();
}
node_proto->set_op_type(name);
node_proto->add_output(std::to_string(node_idx));
node_proto->add_input(input_data);
......@@ -646,13 +668,18 @@ void OnnxExporter::ExportPrimReduceMean(const FuncGraphPtr & /*func_graph*/, con
attr_proto->set_name("axes");
attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS);
auto axis_value = dyn_cast<ValueNode>(input_axis)->value();
auto tuple_ptr = dyn_cast<ValueTuple>(axis_value);
MS_EXCEPTION_IF_NULL(tuple_ptr);
for (size_t i = 0; i < tuple_ptr->size(); ++i) {
attr_proto->add_ints(GetValue<int>((*tuple_ptr)[i]));
auto int_ptr = dyn_cast<Int32Imm>(axis_value);
if (int_ptr == nullptr) {
auto tuple_ptr = dyn_cast<ValueTuple>(axis_value);
MS_EXCEPTION_IF_NULL(tuple_ptr);
for (size_t i = 0; i < tuple_ptr->size(); ++i) {
attr_proto->add_ints(GetValue<int>((*tuple_ptr)[i]));
}
} else {
attr_proto->add_ints(int_ptr->value());
}
} else {
MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to attributes for ReduceMean.";
MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to attributes for " << name;
}
}
......@@ -826,6 +853,83 @@ void OnnxExporter::ExportPrimDepthwiseConv2d(const FuncGraphPtr & /*func_graph*/
SetAttrTupleValueToProto<2>(prim->GetAttr("stride"), onnx::AttributeProto_AttributeType_INTS, onnx_attr_proto, prim);
}
void OnnxExporter::ExportPrimTile(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
auto name_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto);
auto multiples = node->input(2);
std::string name_multiples;
if (multiples->isa<ValueNode>()) {
auto const_node_idx = AllocateNodeIndex();
(*node_map_ptr)[multiples] = const_node_idx;
onnx::NodeProto *node_proto = graph_proto->add_node();
name_multiples = std::to_string(const_node_idx);
node_proto->add_output(name_multiples);
node_proto->set_op_type("Constant");
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
attr_proto->set_name("repeat");
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
ConvertTupleToTensor(dyn_cast<ValueNode>(multiples)->value(), attr_proto->mutable_t());
} else {
name_multiples = GetNodeInputName(multiples, node_map_ptr, graph_proto);
MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to tensor for Tile.";
}
auto node_idx = AllocateNodeIndex();
(*node_map_ptr)[node] = node_idx;
onnx::NodeProto *node_proto = graph_proto->add_node();
node_proto->set_op_type("Tile");
node_proto->add_output(std::to_string(node_idx));
node_proto->add_input(name_x);
node_proto->add_input(name_multiples);
}
void OnnxExporter::ExportPrimSquare(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
auto name_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto);
std::string name_exponent;
auto const_node_idx = AllocateNodeIndex();
onnx::NodeProto *node_proto_exp = graph_proto->add_node();
name_exponent = std::to_string(const_node_idx);
node_proto_exp->add_output(name_exponent);
node_proto_exp->set_op_type("Constant");
onnx::AttributeProto *attr_proto = node_proto_exp->add_attribute();
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
onnx::TensorProto *tensor_proto = attr_proto->mutable_t();
tensor_proto->set_name("exponent");
tensor_proto->add_dims(static_cast<::google::protobuf::int64>(1));
tensor_proto->set_data_type(onnx::TensorProto_DataType_INT64);
tensor_proto->add_int64_data(2);
auto node_idx = AllocateNodeIndex();
(*node_map_ptr)[node] = node_idx;
onnx::NodeProto *node_proto = graph_proto->add_node();
node_proto->set_op_type("Pow");
node_proto->add_output(std::to_string(node_idx));
node_proto->add_input(name_x);
node_proto->add_input(name_exponent);
}
void OnnxExporter::ExportPrimGatherV2(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
auto name_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto);
auto name_indices = GetNodeInputName(node->input(2), node_map_ptr, graph_proto);
auto axis = node->input(3)->cast<ValueNodePtr>()->value();
auto node_idx = AllocateNodeIndex();
(*node_map_ptr)[node] = node_idx;
onnx::NodeProto *node_proto = graph_proto->add_node();
node_proto->set_op_type("Gather");
node_proto->add_output(std::to_string(node_idx));
node_proto->add_input(name_x);
node_proto->add_input(name_indices);
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
attr_proto->set_i(static_cast<::google::protobuf::int64>(dyn_cast<Int32Imm>(axis)->value()));
}
void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
// Type of the 2nd input of 'Reshape' of MindSpore is tuple, but ONNX's is tensor, need to do some convert
......@@ -833,8 +937,8 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n
return ExportPrimReshape(func_graph, node, node_map_ptr, graph_proto);
}
if (node->IsApply(prim::kPrimReduceMean)) {
return ExportPrimReduceMean(func_graph, node, node_map_ptr, graph_proto);
if (node->IsApply(prim::kPrimReduceMean) || node->IsApply(prim::kPrimReduceSum)) {
return ExportPrimReduce(func_graph, node, node_map_ptr, graph_proto);
}
// MindSpore Cast(x, T) --> ONNX Cast[to=T](x)
......@@ -857,6 +961,21 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n
return ExportPrimDepthwiseConv2d(func_graph, node, node_map_ptr, graph_proto);
}
// MindSpore Tile(x) --> ONNX Tile(x, repeat)
if (node->IsApply(prim::kPrimTile)) {
return ExportPrimTile(func_graph, node, node_map_ptr, graph_proto);
}
// MindSpore Square(x) --> ONNX Pow(x, 2)
if (node->IsApply(prim::kPrimSquare)) {
return ExportPrimSquare(func_graph, node, node_map_ptr, graph_proto);
}
// MindSpore GatherV2(x, indices, axis) --> ONNX Pow(x, indices)
if (node->IsApply(prim::kPrimGatherV2)) {
return ExportPrimGatherV2(func_graph, node, node_map_ptr, graph_proto);
}
auto inputs = node->inputs();
if (inputs.size() < 1) {
MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
......@@ -1054,7 +1173,30 @@ void OnnxExporter::SetNodeAttribute(const ValuePtr &value, onnx::NodeProto *cons
node_proto->set_op_type("Constant");
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
attr_proto->set_name("value");
MS_LOG(EXCEPTION) << "Need to set value " << value->ToString() << " attribute for Constant node";
if (value->isa<Int32Imm>()) {
attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
auto casted_value = dyn_cast<Int32Imm>(value);
if (casted_value == nullptr) {
MS_LOG(EXCEPTION) << "Cast value " << value->ToString() << " to type T failed.";
}
auto attr_value = casted_value->value();
attr_proto->set_i(static_cast<::google::protobuf::int64>(attr_value));
attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
} else if (value->isa<tensor::Tensor>()) {
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
onnx::TensorProto *tensor_proto = attr_proto->mutable_t();
auto data = dyn_cast<tensor::Tensor>(value);
tensor_proto->set_raw_data(data->data().request(true).ptr, static_cast<size_t>(data->data().nbytes()));
auto dtype = data->data_type();
auto shape = data->shape_c();
tensor_proto->set_data_type(GetOnnxDataType(dtype));
for (const auto &dim : shape) {
tensor_proto->add_dims(dim);
}
} else {
MS_LOG(EXCEPTION) << "Need to set value " << value->ToString() << " attribute for Constant node";
}
}
std::string GetOnnxProtoString(const FuncGraphPtr &func_graph) {
......
......@@ -142,6 +142,20 @@ class DepthwiseConv2dAndReLU6(nn.Cell):
x = self.relu6(x)
return x
class DeepFMOpNet(nn.Cell):
"""Net definition with Gatherv2 and Tile and Square."""
def __init__(self):
super(DeepFMOpNet, self).__init__()
self.gather = P.GatherV2()
self.square = P.Square()
self.tile = P.Tile()
def construct(self, x, y):
x = self.tile(x, (1000, 1))
x = self.square(x)
x = self.gather(x, y, 0)
return x
# generate mindspore Tensor by shape and numpy datatype
def gen_tensor(shape, dtype=np.float32):
......@@ -153,6 +167,7 @@ net_cfgs = [
('lenet', LeNet5(), gen_tensor([1, 1, 32, 32])),
('maxpoolwithargmax', DefinedNet(), gen_tensor([1, 3, 224, 224])),
('depthwiseconv_relu6', DepthwiseConv2dAndReLU6(3, kernel_size=3), gen_tensor([1, 3, 32, 32])),
('deepfm_ops', DeepFMOpNet(), (gen_tensor([1, 1]), gen_tensor([1000, 1], dtype=np.int32)))
]
......@@ -164,7 +179,10 @@ def get_id(cfg):
@pytest.mark.parametrize('name, net, inp', net_cfgs, ids=get_id(net_cfgs))
def test_onnx_export(name, net, inp):
onnx_file = name + ".onnx"
export(net, inp, file_name=onnx_file, file_format='ONNX')
if isinstance(inp, (tuple, list)):
export(net, *inp, file_name=onnx_file, file_format='ONNX')
else:
export(net, inp, file_name=onnx_file, file_format='ONNX')
# check existence of exported onnx file and delete it
assert os.path.exists(onnx_file)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
新手
引导
客服 返回
顶部