未验证 提交 f9a21106 编写于 作者: C Chun-Wei Chen 提交者: GitHub

Make shape inference handle MapProto (#3772)

* make shape inference handle MapProto
Signed-off-by: NChun-Wei Chen <jacky82226@gmail.com>

* add more part which needs map
Signed-off-by: NChun-Wei Chen <jacky82226@gmail.com>

* use todatatypestring again
Signed-off-by: NChun-Wei Chen <jacky82226@gmail.com>

* refactor mergeInShapeInfo
Signed-off-by: NChun-Wei Chen <jacky82226@gmail.com>

* add AddExistingSymbolicDims propagateElemTypeWithValidation
Signed-off-by: NChun-Wei Chen <jacky82226@gmail.com>

* remove duplicate data_type_utils.h
Signed-off-by: NChun-Wei Chen <jacky82226@gmail.com>

* move to .cc
Signed-off-by: NChun-Wei Chen <jacky82226@gmail.com>
Co-authored-by: NAshwini Khade <askhade@microsoft.com>
上级 780b0361
......@@ -6,6 +6,173 @@
namespace ONNX_NAMESPACE {
// Note: for all methods below for propagating type or shape, callers are
// responsible to handle optional inputs/outputs and ensure that the specified
// index value is less than NumInputs/NumOutputs.
// Supports mixed tensor and sparse tensor
void propagateElemTypeFromTensorInputToOutput(InferenceContext& ctx, size_t inputIndex, size_t outputIndex) {
auto input_type = ctx.getInputType(inputIndex);
if (nullptr == input_type) {
fail_type_inference("Input type was null");
}
const auto input_value_case = input_type->value_case();
if (input_value_case != TypeProto::kTensorType && input_value_case != TypeProto::kSparseTensorType) {
fail_type_inference(
"Input ", inputIndex, " expected to have tensor or sparse tensor type. Got: ", input_value_case);
}
const auto input_elem_type = getTensorElementType(*input_type);
if (input_elem_type == TensorProto::UNDEFINED) {
fail_type_inference("Element type of input ", inputIndex, " unknown");
}
auto output_type = ctx.getOutputType(outputIndex);
const auto output_value_case = output_type->value_case();
if (output_value_case == TypeProto::kTensorType || output_value_case == TypeProto::kSparseTensorType) {
setTensorElementType(input_elem_type, output_value_case, *output_type);
} else if (output_value_case == TypeProto::VALUE_NOT_SET) {
// Assume output will have the same type
setTensorElementType(input_elem_type, input_value_case, *output_type);
} else {
// This is not expected to happen
fail_type_inference(
"Output ", outputIndex, " expected to have tensor or sparse tensor type. Got: ", output_value_case);
}
}
void propagateElemTypeFromSequenceInputToOutput(InferenceContext& ctx, size_t inputIndex, size_t outputIndex) {
auto input_type = ctx.getInputType(inputIndex);
if (nullptr == input_type || input_type->value_case() != TypeProto::kSequenceType) {
fail_type_inference("Input ", inputIndex, " expected to have sequence type");
}
auto input_seq_type = input_type->sequence_type();
if (!input_seq_type.has_elem_type()) {
fail_type_inference("Element type of sequence input ", inputIndex, " unknown");
}
auto output_type = ctx.getOutputType(outputIndex);
output_type->mutable_sequence_type()->mutable_elem_type()->CopyFrom(input_seq_type.elem_type());
}
void propagateElemTypeFromOptionalInputToOutput(InferenceContext& ctx, size_t inputIndex, size_t outputIndex) {
auto input_type = ctx.getInputType(inputIndex);
if (nullptr == input_type || input_type->value_case() != TypeProto::kOptionalType) {
fail_type_inference("Input ", inputIndex, " expected to have optional type");
}
auto input_opt_type = input_type->optional_type();
if (!input_opt_type.has_elem_type()) {
fail_type_inference("Element type of optional input ", inputIndex, " unknown");
}
auto output_type = ctx.getOutputType(outputIndex);
output_type->mutable_optional_type()->mutable_elem_type()->CopyFrom(input_opt_type.elem_type());
}
void propagateElemTypeFromMapInputToOutput(InferenceContext& ctx, size_t inputIndex, size_t outputIndex) {
auto input_type = ctx.getInputType(inputIndex);
if (nullptr == input_type || input_type->value_case() != TypeProto::kMapType) {
fail_type_inference("Input ", inputIndex, " expected to have map type");
}
auto input_map_type = input_type->map_type();
if (!input_map_type.has_key_type()) {
fail_type_inference("Key type of map input ", inputIndex, " unknown");
}
if (!input_map_type.has_value_type()) {
fail_type_inference("Value type of map input ", inputIndex, " unknown");
}
auto output_type = ctx.getOutputType(outputIndex);
output_type->mutable_map_type()->set_key_type(input_map_type.key_type());
output_type->mutable_map_type()->mutable_value_type()->CopyFrom(input_map_type.value_type());
}
void propagateElemTypeFromInputToOutput(InferenceContext& ctx, size_t inputIndex, size_t outputIndex) {
auto input_type = ctx.getInputType(inputIndex);
if (nullptr == input_type) {
fail_type_inference("Input ", inputIndex, " expected to have type but instead is null");
}
const auto input_value_case = input_type->value_case();
if (input_value_case == TypeProto::kTensorType || input_value_case == TypeProto::kSparseTensorType) {
propagateElemTypeFromTensorInputToOutput(ctx, inputIndex, outputIndex);
} else if (input_value_case == TypeProto::kSequenceType) {
propagateElemTypeFromSequenceInputToOutput(ctx, inputIndex, outputIndex);
} else if (input_value_case == TypeProto::kOptionalType) {
propagateElemTypeFromOptionalInputToOutput(ctx, inputIndex, outputIndex);
} else if (input_value_case == TypeProto::kMapType) {
propagateElemTypeFromMapInputToOutput(ctx, inputIndex, outputIndex);
}
}
/*
Merge shape information from a source shape into a target shape.
* merges each TensorShapeProto_Dimension separately.
* prefer values over params.
* If both have values, values must match.
* prefer target param over source param if mismatched.
* Fail if there are mismatches in number of dimensions or dimension values.
*/
void mergeInShapeInfo(const TensorShapeProto& source, TensorShapeProto& target) {
auto num_source_dims = source.dim_size();
auto num_target_dims = target.dim_size();
if (num_source_dims != num_target_dims) {
fail_shape_inference(
"Mismatch between number of source and target dimensions. Source=",
num_source_dims,
" Target=",
num_target_dims);
}
auto& source_dims = source.dim();
auto* target_dims = target.mutable_dim();
for (int i = 0, end = source_dims.size(); i < end; ++i) {
auto& source_dim = source_dims.Get(i);
auto& target_dim = *target_dims->Mutable(i);
mergeInDimensionInfo(source_dim, target_dim, i);
}
}
void mergeInShapeInfo(const TensorShapeProto& source_shape, TypeProto_Tensor& target_type) {
if (target_type.has_shape()) {
// merge with existing info.
mergeInShapeInfo(source_shape, *target_type.mutable_shape());
} else {
// copy to target
(*target_type.mutable_shape()) = source_shape;
}
}
void mergeInShapeInfo(const TensorShapeProto& source_shape, TypeProto_SparseTensor& target_type) {
if (target_type.has_shape()) {
// merge with existing info.
mergeInShapeInfo(source_shape, *target_type.mutable_shape());
} else {
// copy to target
(*target_type.mutable_shape()) = source_shape;
}
}
/*
Merge the shape information from two TypeProto_Tensor instances.
Values are merged into target from source.
If target has no shape information, copy from source.
If source has no shape information, ignore source.
If both have shape information:
- merge each TensorShapeProto_Dimension separately.
- Prefer values over params. If both have values, values must match.
- Prefer target param over source param if mismatched.
Fail if there are mismatches in number of dimensions or dimension values.
*/
void mergeInShapeInfo(const TypeProto_Tensor& source, TypeProto_Tensor& target) {
if (source.has_shape())
mergeInShapeInfo(source.shape(), target);
}
void mergeInShapeInfo(const TypeProto_SparseTensor& source, TypeProto_SparseTensor& target) {
if (source.has_shape())
mergeInShapeInfo(source.shape(), target);
}
/// <summary>
/// Utility function for UnionShapeInfoForTensor.
/// Both shapes must be of the same rank
......@@ -108,6 +275,29 @@ void UnionTypeInfo(const TypeProto& source_type, TypeProto& target_type) {
fail_type_inference("target optional type missing element type.");
}
UnionTypeInfo(source_type.optional_type().elem_type(), *target_type.mutable_optional_type()->mutable_elem_type());
} else if (target_case == TypeProto::ValueCase::kMapType) {
if (!source_type.map_type().has_key_type()) {
fail_type_inference("source map type missing key type.");
}
if (!target_type.map_type().has_key_type()) {
fail_type_inference("target map type missing key type.");
}
auto source_key_type = source_type.map_type().key_type();
auto target_key_type = target_type.map_type().key_type();
if (source_key_type != target_key_type) {
fail_type_inference(
"Mismatched map tensor key type:", " source=",
Utils::DataTypeUtils::ToDataTypeString(source_key_type),
" target=", Utils::DataTypeUtils::ToDataTypeString(target_key_type));
}
if (!source_type.map_type().has_value_type()) {
fail_type_inference("source map type missing value type.");
}
if (!target_type.map_type().has_value_type()) {
fail_type_inference("target map type missing value type.");
}
UnionTypeInfo(source_type.map_type().value_type(), *target_type.mutable_map_type()->mutable_value_type());
}
}
......@@ -167,7 +357,7 @@ void propagateSequenceElemTypeWithValidation(const TypeProto* input_type, TypePr
propagateElemTypeWithValidation(
&input_seq_type.elem_type(), output_type->mutable_sequence_type()->mutable_elem_type());
} else {
fail_type_inference("Element type of input was unknown");
fail_type_inference("Element type of sequence input was unknown");
}
}
......@@ -186,8 +376,30 @@ void propagateOptionalElemTypeWithValidation(const TypeProto* input_type, TypePr
propagateElemTypeWithValidation(
&input_opt_type.elem_type(), output_type->mutable_optional_type()->mutable_elem_type());
} else {
fail_type_inference("Element type of input was unknown");
fail_type_inference("Element type of optional input was unknown");
}
}
void propagateMapElemTypeWithValidation(const TypeProto* input_type, TypeProto* output_type) {
if (nullptr == input_type) {
fail_type_inference("Input type was null");
}
if (input_type->value_case() != TypeProto::kMapType) {
fail_type_inference("Input was expected to have map type. Got ", input_type->value_case());
}
auto input_map_type = input_type->map_type();
if (!input_map_type.has_key_type()) {
fail_type_inference("Key type of map input was unknown");
}
if (!input_map_type.has_value_type()) {
fail_type_inference("Value type of map input was unknown");
}
output_type->mutable_map_type()->set_key_type(input_map_type.key_type());
propagateElemTypeWithValidation(
&input_map_type.value_type(), output_type->mutable_map_type()->mutable_value_type());
}
// propagate the element type from an input type to an output type.
......@@ -204,8 +416,10 @@ void propagateElemTypeWithValidation(const TypeProto* input_type, TypeProto* out
propagateSequenceElemTypeWithValidation(input_type, output_type);
} else if (input_value_case == TypeProto::kOptionalType) {
propagateOptionalElemTypeWithValidation(input_type, output_type);
} else if (input_value_case == TypeProto::kMapType) {
propagateMapElemTypeWithValidation(input_type, output_type);
} else {
fail_type_inference("Input was expected to have either tensor, sequence, or optional type. Got ", input_value_case);
fail_type_inference("Input was expected to have either tensor, sequence, optional or map type. Got ", input_value_case);
}
}
......
......@@ -222,82 +222,7 @@ inline void setTensorElementType(int32_t elem_type, TypeProto::ValueCase value_c
void propagateElemTypeWithValidation(const TypeProto* input_type, TypeProto* output_type);
// Note: for all methods below for propagating type or shape, callers are
// responsible to handle optional inputs/outputs and ensure that the specified
// index value is less than NumInputs/NumOutputs.
// Supports mixed tensor and sparse tensor
inline void propagateElemTypeFromTensorInputToOutput(InferenceContext& ctx, size_t inputIndex, size_t outputIndex) {
auto input_type = ctx.getInputType(inputIndex);
if (nullptr == input_type) {
fail_type_inference("Input type was null");
}
const auto input_value_case = input_type->value_case();
if (input_value_case != TypeProto::kTensorType && input_value_case != TypeProto::kSparseTensorType) {
fail_type_inference(
"Input ", inputIndex, " expected to have tensor or sparse tensor type. Got: ", input_value_case);
}
const auto input_elem_type = getTensorElementType(*input_type);
if (input_elem_type == TensorProto::UNDEFINED) {
fail_type_inference("Element type of input ", inputIndex, " unknown");
}
auto output_type = ctx.getOutputType(outputIndex);
const auto output_value_case = output_type->value_case();
if (output_value_case == TypeProto::kTensorType || output_value_case == TypeProto::kSparseTensorType) {
setTensorElementType(input_elem_type, output_value_case, *output_type);
} else if (output_value_case == TypeProto::VALUE_NOT_SET) {
// Assume output will have the same type
setTensorElementType(input_elem_type, input_value_case, *output_type);
} else {
// This is not expected to happen
fail_type_inference(
"Output ", outputIndex, " expected to have tensor or sparse tensor type. Got: ", output_value_case);
}
}
inline void propagateElemTypeFromSequenceInputToOutput(InferenceContext& ctx, size_t inputIndex, size_t outputIndex) {
auto input_type = ctx.getInputType(inputIndex);
if (nullptr == input_type || input_type->value_case() != TypeProto::kSequenceType) {
fail_type_inference("Input ", inputIndex, " expected to have sequence type");
}
auto input_seq_type = input_type->sequence_type();
if (!input_seq_type.has_elem_type()) {
fail_type_inference("Element type of sequence input ", inputIndex, " unknown");
}
auto output_type = ctx.getOutputType(outputIndex);
output_type->mutable_sequence_type()->mutable_elem_type()->CopyFrom(input_seq_type.elem_type());
}
inline void propagateElemTypeFromOptionalInputToOutput(InferenceContext& ctx, size_t inputIndex, size_t outputIndex) {
auto input_type = ctx.getInputType(inputIndex);
if (nullptr == input_type || input_type->value_case() != TypeProto::kOptionalType) {
fail_type_inference("Input ", inputIndex, " expected to have optional type");
}
auto input_opt_type = input_type->optional_type();
if (!input_opt_type.has_elem_type()) {
fail_type_inference("Element type of optional input ", inputIndex, " unknown");
}
auto output_type = ctx.getOutputType(outputIndex);
output_type->mutable_optional_type()->mutable_elem_type()->CopyFrom(input_opt_type.elem_type());
}
inline void propagateElemTypeFromInputToOutput(InferenceContext& ctx, size_t inputIndex, size_t outputIndex) {
auto input_type = ctx.getInputType(inputIndex);
if (nullptr == input_type) {
fail_type_inference("Input ", inputIndex, " expected to have type but instead is null");
}
const auto input_value_case = input_type->value_case();
if (input_value_case == TypeProto::kTensorType || input_value_case == TypeProto::kSparseTensorType) {
propagateElemTypeFromTensorInputToOutput(ctx, inputIndex, outputIndex);
} else if (input_value_case == TypeProto::kSequenceType) {
propagateElemTypeFromSequenceInputToOutput(ctx, inputIndex, outputIndex);
} else if (input_value_case == TypeProto::kOptionalType) {
propagateElemTypeFromOptionalInputToOutput(ctx, inputIndex, outputIndex);
}
}
void propagateElemTypeFromInputToOutput(InferenceContext& ctx, size_t inputIndex, size_t outputIndex);
inline void propagateElemTypeFromDtypeToOutput(
InferenceContext& ctx,
......@@ -438,6 +363,8 @@ inline void propagateShape(const TypeProto* from_type, TypeProto* to_type) {
propagateShape(&from_type->sequence_type().elem_type(), to_type->mutable_sequence_type()->mutable_elem_type());
} else if (TypeProto::kOptionalType == from_type_case) {
propagateShape(&from_type->optional_type().elem_type(), to_type->mutable_optional_type()->mutable_elem_type());
} else if (TypeProto::kMapType == from_type_case) {
propagateShape(&from_type->map_type().value_type(), to_type->mutable_map_type()->mutable_value_type());
} else {
fail_shape_inference("Unsupported Source/Target type=", from_type_case);
}
......@@ -703,54 +630,9 @@ inline void mergeInDimensionInfo(
}
}
/*
Merge shape information from a source shape into a target shape.
* merges each TensorShapeProto_Dimension separately.
* prefer values over params.
* If both have values, values must match.
* prefer target param over source param if mismatched.
* Fail if there are mismatches in number of dimensions or dimension values.
*/
inline void mergeInShapeInfo(const TensorShapeProto& source, TensorShapeProto& target) {
auto num_source_dims = source.dim_size();
auto num_target_dims = target.dim_size();
if (num_source_dims != num_target_dims) {
fail_shape_inference(
"Mismatch between number of source and target dimensions. Source=",
num_source_dims,
" Target=",
num_target_dims);
}
auto& source_dims = source.dim();
auto* target_dims = target.mutable_dim();
for (int i = 0, end = source_dims.size(); i < end; ++i) {
auto& source_dim = source_dims.Get(i);
auto& target_dim = *target_dims->Mutable(i);
mergeInDimensionInfo(source_dim, target_dim, i);
}
}
inline void mergeInShapeInfo(const TensorShapeProto& source_shape, TypeProto_Tensor& target_type) {
if (target_type.has_shape()) {
// merge with existing info.
mergeInShapeInfo(source_shape, *target_type.mutable_shape());
} else {
// copy to target
(*target_type.mutable_shape()) = source_shape;
}
}
void mergeInShapeInfo(const TensorShapeProto& source_shape, TypeProto_Tensor& target_type);
inline void mergeInShapeInfo(const TensorShapeProto& source_shape, TypeProto_SparseTensor& target_type) {
if (target_type.has_shape()) {
// merge with existing info.
mergeInShapeInfo(source_shape, *target_type.mutable_shape());
} else {
// copy to target
(*target_type.mutable_shape()) = source_shape;
}
}
void mergeInShapeInfo(const TensorShapeProto& source_shape, TypeProto_SparseTensor& target_type);
/*
Merge the shape information from two TypeProto_Tensor instances.
......@@ -763,15 +645,9 @@ If both have shape information:
- Prefer target param over source param if mismatched.
Fail if there are mismatches in number of dimensions or dimension values.
*/
inline void mergeInShapeInfo(const TypeProto_Tensor& source, TypeProto_Tensor& target) {
if (source.has_shape())
mergeInShapeInfo(source.shape(), target);
}
void mergeInShapeInfo(const TypeProto_Tensor& source, TypeProto_Tensor& target);
inline void mergeInShapeInfo(const TypeProto_SparseTensor& source, TypeProto_SparseTensor& target) {
if (source.has_shape())
mergeInShapeInfo(source.shape(), target);
}
void mergeInShapeInfo(const TypeProto_SparseTensor& source, TypeProto_SparseTensor& target);
// Return a copy of a type, with a specified dimension removed from its shape.
inline TypeProto RemoveIthDimensionFromShape(const TypeProto& proto, int removed_dim) {
......
......@@ -6,6 +6,7 @@
#include <fstream>
#include <list>
#include "onnx/checker.h"
#include "onnx/defs/data_type_utils.h"
#include "onnx/string_utils.h"
namespace ONNX_NAMESPACE {
......@@ -115,6 +116,15 @@ void checkShapesAndTypes(const TypeProto& inferredType, const TypeProto& existin
checkShapesAndTypes(inferredType.sequence_type().elem_type(), existingType.sequence_type().elem_type());
} else if (inferredTypeCase == TypeProto::kOptionalType && existingTypeCase == TypeProto::kOptionalType) {
checkShapesAndTypes(inferredType.optional_type().elem_type(), existingType.optional_type().elem_type());
} else if (inferredTypeCase == TypeProto::TypeProto::kMapType && existingTypeCase == TypeProto::TypeProto::kMapType) {
if (inferredType.map_type().key_type() != existingType.map_type().key_type()) {
fail_type_inference(
"key type mismatch from MapProto. existing=",
Utils::DataTypeUtils::ToDataTypeString(existingType.map_type().key_type()),
" inferred=",
Utils::DataTypeUtils::ToDataTypeString(inferredType.map_type().key_type()));
}
checkShapesAndTypes(inferredType.map_type().value_type(), existingType.map_type().value_type());
} else {
fail_type_inference("type case unsupported. existing=", existingTypeCase, " inferred=", inferredTypeCase);
}
......@@ -145,6 +155,8 @@ void materializeSymbolicShape(TypeProto* inferredType, SymbolTable& symbolTable)
materializeSymbolicShape(inferredType->mutable_sequence_type()->mutable_elem_type(), symbolTable);
} else if (inferred_val_case == TypeProto::kOptionalType) {
materializeSymbolicShape(inferredType->mutable_optional_type()->mutable_elem_type(), symbolTable);
} else if (inferred_val_case == TypeProto::TypeProto::kMapType) {
materializeSymbolicShape(inferredType->mutable_map_type()->mutable_value_type(), symbolTable);
} else {
fail_shape_inference("type case unsupported for symbolic shape inference. inferred=", inferred_val_case);
}
......@@ -220,6 +232,9 @@ void mergeShapesAndTypes(const TypeProto& inferredType, TypeProto* existingType)
} else if (inferred_val_case == TypeProto::kOptionalType) {
mergeShapesAndTypes(
inferredType.optional_type().elem_type(), existingType->mutable_optional_type()->mutable_elem_type());
} else if (inferred_val_case == TypeProto::kMapType) {
mergeShapesAndTypes(
inferredType.map_type().value_type(), existingType->mutable_map_type()->mutable_value_type());
}
}
......
......@@ -60,6 +60,12 @@ class SymbolTableImpl : public SymbolTable {
case TypeProto::kSequenceType:
AddExistingSymbolicDims(typeProto.sequence_type().elem_type());
break;
case TypeProto::kOptionalType:
AddExistingSymbolicDims(typeProto.optional_type().elem_type());
break;
case TypeProto::kMapType:
AddExistingSymbolicDims(typeProto.map_type().value_type());
break;
default:
break;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册