未验证 提交 d08b3e95 编写于 作者: A Ashwini Khade 提交者: GitHub

crash fix (#3513)

* crash fix
Signed-off-by: NAshwini Khade <askhade@microsoft.com>

* plus updates
Signed-off-by: NAshwini Khade <askhade@microsoft.com>
上级 04971f7d
......@@ -2859,6 +2859,9 @@ bool BuildContextDependentFunctionBody(
}
auto input_type = ctx.getInputType(0)->tensor_type().elem_type();
bool float_input = input_type == TensorProto_DataType_FLOAT;
auto reduction_attr_proto = ctx.getAttribute("reduction");
std::string reduction_attr =
reduction_attr_proto != nullptr && reduction_attr_proto->has_s() ? reduction_attr_proto->s() : "mean";
std::vector<FunctionBodyHelper::NodeDef> body;
body.push_back(
{{"const_zero"},
......@@ -2898,7 +2901,7 @@ bool BuildContextDependentFunctionBody(
{"loss_NCdd", "const_zero", "const_one", "const_one"}});
if (!ctx.hasInput(2)) {
if (ctx.getAttribute("reduction")->s() == "none") {
if (reduction_attr == "none") {
body.push_back(
{{"loss"},
"Squeeze",
......@@ -2908,7 +2911,7 @@ bool BuildContextDependentFunctionBody(
{{"loss_Ndd"},
"Squeeze",
{"loss_N1dd", "axes"}});
if (ctx.getAttribute("reduction")->s() == "mean") {
if (reduction_attr == "mean") {
body.push_back(
{{"loss"},
"ReduceMean",
......@@ -2928,12 +2931,12 @@ bool BuildContextDependentFunctionBody(
{{"loss_unweighted"},
"Squeeze",
{"loss_N1dd", "axes"}});
if (ctx.getAttribute("reduction")->s() == "none") {
if (reduction_attr == "none") {
body.push_back({{"loss"}, "Mul", {"loss_unweighted", "weight_gather"}});
} else {
body.push_back(
{{"loss_Ndd"}, "Mul", {"loss_unweighted", "weight_gather"}});
if (ctx.getAttribute("reduction")->s() == "mean") {
if (reduction_attr == "mean") {
body.push_back(
{{"loss_sum"},
"ReduceSum",
......@@ -3052,12 +3055,12 @@ bool BuildContextDependentFunctionBody(
{{"loss_unweighted"},
"Squeeze",
{"loss_N1dd", "axes"}});
if (ctx.getAttribute("reduction")->s() == "none") {
if (reduction_attr == "none") {
body.push_back({{"loss"}, "Mul", {"loss_unweighted", "weight_gather"}});
} else {
body.push_back(
{{"loss_Ndd"}, "Mul", {"loss_unweighted", "weight_gather"}});
if (ctx.getAttribute("reduction")->s() == "mean") {
if (reduction_attr == "mean") {
body.push_back(
{{"loss_sum"},
"ReduceSum",
......@@ -3201,7 +3204,7 @@ ONNX_OPERATOR_SET_SCHEMA(
TensorShapeProto* output_shape =
ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape();
if (ctx.getAttribute("reduction")->s() == "none") {
if (getAttribute(ctx, "reduction", "mean") == "none") {
// output tensor is of shape (N, d1, d2, ..., dk) if
// reduction attribute is "none".
for (int i = 0; i < input_rank - 1; i++) {
......
......@@ -1253,6 +1253,9 @@ bool BuildContextDependentFunctionBody_opset12(
}
auto input_type = ctx.getInputType(0)->tensor_type().elem_type();
bool float_input = input_type == TensorProto_DataType_FLOAT;
auto reduction_attr_proto = ctx.getAttribute("reduction");
std::string reduction_attr =
reduction_attr_proto != nullptr && reduction_attr_proto->has_s() ? reduction_attr_proto->s() : "mean";
std::vector<FunctionBodyHelper::NodeDef> body;
body.push_back(
{{"const_zero"},
......@@ -1287,7 +1290,7 @@ bool BuildContextDependentFunctionBody_opset12(
{"loss_NCdd", "const_zero", "const_one", "const_one"}});
if (!ctx.hasInput(2)) {
if (ctx.getAttribute("reduction")->s() == "none") {
if (reduction_attr == "none") {
body.push_back(
{{"loss"},
"Squeeze",
......@@ -1299,7 +1302,7 @@ bool BuildContextDependentFunctionBody_opset12(
"Squeeze",
{"loss_N1dd"},
{MakeAttribute("axes", std::vector<int64_t>({1}))}});
if (ctx.getAttribute("reduction")->s() == "mean") {
if (reduction_attr == "mean") {
body.push_back(
{{"loss"},
"ReduceMean",
......@@ -1320,12 +1323,12 @@ bool BuildContextDependentFunctionBody_opset12(
"Squeeze",
{"loss_N1dd"},
{MakeAttribute("axes", std::vector<int64_t>({1}))}});
if (ctx.getAttribute("reduction")->s() == "none") {
if (reduction_attr == "none") {
body.push_back({{"loss"}, "Mul", {"loss_unweighted", "weight_gather"}});
} else {
body.push_back(
{{"loss_Ndd"}, "Mul", {"loss_unweighted", "weight_gather"}});
if (ctx.getAttribute("reduction")->s() == "mean") {
if (reduction_attr == "mean") {
body.push_back(
{{"loss_sum"},
"ReduceSum",
......@@ -1447,12 +1450,12 @@ bool BuildContextDependentFunctionBody_opset12(
"Squeeze",
{"loss_N1dd"},
{MakeAttribute("axes", std::vector<int64_t>({1}))}});
if (ctx.getAttribute("reduction")->s() == "none") {
if (reduction_attr == "none") {
body.push_back({{"loss"}, "Mul", {"loss_unweighted", "weight_gather"}});
} else {
body.push_back(
{{"loss_Ndd"}, "Mul", {"loss_unweighted", "weight_gather"}});
if (ctx.getAttribute("reduction")->s() == "mean") {
if (reduction_attr == "mean") {
body.push_back(
{{"loss_sum"},
"ReduceSum",
......@@ -1532,64 +1535,62 @@ ONNX_OPERATOR_SET_SCHEMA(
"Constrain target to integer types")
.SetContextDependentFunctionBodyBuilder(
BuildContextDependentFunctionBody_opset12)
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
// Type inference
propagateElemTypeFromInputToOutput(ctx, 0, 0);
// Shape inference
if (hasNInputShapes(ctx, 2)) {
const TensorShapeProto& input_shape =
ctx.getInputType(0)->tensor_type().shape();
const TensorShapeProto& target_shape =
ctx.getInputType(1)->tensor_type().shape();
const int input_rank = static_cast<int>(input_shape.dim_size());
const int target_rank = static_cast<int>(target_shape.dim_size());
if (input_rank < 2) {
fail_shape_inference("Input rank must be >= 2.");
}
if (target_rank != input_rank - 1) {
fail_shape_inference(
"Target rank must be 1 less than the input rank.");
}
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
// Type inference
propagateElemTypeFromInputToOutput(ctx, 0, 0);
// Shape inference
if (hasNInputShapes(ctx, 2)) {
const TensorShapeProto& input_shape =
ctx.getInputType(0)->tensor_type().shape();
const TensorShapeProto& target_shape =
ctx.getInputType(1)->tensor_type().shape();
const int input_rank = static_cast<int>(input_shape.dim_size());
const int target_rank = static_cast<int>(target_shape.dim_size());
if (input_rank < 2) {
fail_shape_inference("Input rank must be >= 2.");
}
if (target_rank != input_rank - 1) {
fail_shape_inference(
"Target rank must be 1 less than the input rank.");
}
// match input dimensions (N, C, d1, ..., dk) with target
// dimensions of (C, d1, ..., dk)
for (int dim = 0; dim < target_rank; dim++) {
const auto input_dim =
dim == 0 ? input_shape.dim(dim) : input_shape.dim(dim + 1);
const auto target_dim = target_shape.dim(dim);
if (input_dim.has_dim_value() && target_dim.has_dim_value() &&
input_dim.dim_value() != target_dim.dim_value())
fail_shape_inference(
"Input and target dimension value mismatch.");
}
// match input dimensions (N, C, d1, ..., dk) with target
// dimensions of (C, d1, ..., dk)
for (int dim = 0; dim < target_rank; dim++) {
const auto input_dim =
dim == 0 ? input_shape.dim(dim) : input_shape.dim(dim + 1);
const auto target_dim = target_shape.dim(dim);
if (input_dim.has_dim_value() && target_dim.has_dim_value() &&
input_dim.dim_value() != target_dim.dim_value())
fail_shape_inference(
"Input and target dimension value mismatch.");
}
if (ctx.getNumInputs() == 3 && hasInputShape(ctx, 2)) {
const TensorShapeProto& weight_shape =
ctx.getInputType(2)->tensor_type().shape();
if (weight_shape.dim_size() != 1) {
fail_shape_inference("Weight rank must be 1.");
}
}
if (ctx.getNumInputs() == 3 && hasInputShape(ctx, 2)) {
const TensorShapeProto& weight_shape =
ctx.getInputType(2)->tensor_type().shape();
if (weight_shape.dim_size() != 1) {
fail_shape_inference("Weight rank must be 1.");
}
}
TensorShapeProto* output_shape =
ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape();
if (ctx.getAttribute("reduction")->s() == "none") {
// output tensor is of shape (N, d1, d2, ..., dk) if
// reduction attribute is "none".
for (int i = 0; i < input_rank - 1; i++) {
auto* dim = output_shape->add_dim();
if (i == 0)
*dim = input_shape.dim(i);
else
*dim = input_shape.dim(i + 1);
}
}
// otherwise output is a scalar.
}
TensorShapeProto* output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape();
if (getAttribute(ctx, "reduction", "mean") == "none") {
// output tensor is of shape (N, d1, d2, ..., dk) if
// reduction attribute is "none".
for (int i = 0; i < input_rank - 1; i++) {
auto* dim = output_shape->add_dim();
if (i == 0)
*dim = input_shape.dim(i);
else
*dim = input_shape.dim(i + 1);
}
}
// otherwise output is a scalar.
}
}));
const char* reduction_doc_sce_opset12 =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册