diff --git a/onnx/defs/math/defs.cc b/onnx/defs/math/defs.cc index 2a5458a02089b4bd3e32d23f17b761f85cf15096..1efc98b9dd51595b72b1025c179410e5b822539e 100644 --- a/onnx/defs/math/defs.cc +++ b/onnx/defs/math/defs.cc @@ -2859,6 +2859,9 @@ bool BuildContextDependentFunctionBody( } auto input_type = ctx.getInputType(0)->tensor_type().elem_type(); bool float_input = input_type == TensorProto_DataType_FLOAT; + auto reduction_attr_proto = ctx.getAttribute("reduction"); + std::string reduction_attr = + reduction_attr_proto != nullptr && reduction_attr_proto->has_s() ? reduction_attr_proto->s() : "mean"; std::vector body; body.push_back( {{"const_zero"}, @@ -2898,7 +2901,7 @@ bool BuildContextDependentFunctionBody( {"loss_NCdd", "const_zero", "const_one", "const_one"}}); if (!ctx.hasInput(2)) { - if (ctx.getAttribute("reduction")->s() == "none") { + if (reduction_attr == "none") { body.push_back( {{"loss"}, "Squeeze", @@ -2908,7 +2911,7 @@ bool BuildContextDependentFunctionBody( {{"loss_Ndd"}, "Squeeze", {"loss_N1dd", "axes"}}); - if (ctx.getAttribute("reduction")->s() == "mean") { + if (reduction_attr == "mean") { body.push_back( {{"loss"}, "ReduceMean", @@ -2928,12 +2931,12 @@ bool BuildContextDependentFunctionBody( {{"loss_unweighted"}, "Squeeze", {"loss_N1dd", "axes"}}); - if (ctx.getAttribute("reduction")->s() == "none") { + if (reduction_attr == "none") { body.push_back({{"loss"}, "Mul", {"loss_unweighted", "weight_gather"}}); } else { body.push_back( {{"loss_Ndd"}, "Mul", {"loss_unweighted", "weight_gather"}}); - if (ctx.getAttribute("reduction")->s() == "mean") { + if (reduction_attr == "mean") { body.push_back( {{"loss_sum"}, "ReduceSum", @@ -3052,12 +3055,12 @@ bool BuildContextDependentFunctionBody( {{"loss_unweighted"}, "Squeeze", {"loss_N1dd", "axes"}}); - if (ctx.getAttribute("reduction")->s() == "none") { + if (reduction_attr == "none") { body.push_back({{"loss"}, "Mul", {"loss_unweighted", "weight_gather"}}); } else { body.push_back( {{"loss_Ndd"}, "Mul", {"loss_unweighted", "weight_gather"}}); - if (ctx.getAttribute("reduction")->s() == "mean") { + if (reduction_attr == "mean") { body.push_back( {{"loss_sum"}, "ReduceSum", @@ -3201,7 +3204,7 @@ ONNX_OPERATOR_SET_SCHEMA( TensorShapeProto* output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape(); - if (ctx.getAttribute("reduction")->s() == "none") { + if (getAttribute(ctx, "reduction", "mean") == "none") { // output tensor is of shape (N, d1, d2, ..., dk) if // reduction attribute is "none". for (int i = 0; i < input_rank - 1; i++) { diff --git a/onnx/defs/math/old.cc b/onnx/defs/math/old.cc index 3c8d4ddab8af5087cc369d6a9d8a6aa26e60b23b..90ade0a66f9ba114dece4f9e101effe5432cee65 100644 --- a/onnx/defs/math/old.cc +++ b/onnx/defs/math/old.cc @@ -1253,6 +1253,9 @@ bool BuildContextDependentFunctionBody_opset12( } auto input_type = ctx.getInputType(0)->tensor_type().elem_type(); bool float_input = input_type == TensorProto_DataType_FLOAT; + auto reduction_attr_proto = ctx.getAttribute("reduction"); + std::string reduction_attr = + reduction_attr_proto != nullptr && reduction_attr_proto->has_s() ? reduction_attr_proto->s() : "mean"; std::vector body; body.push_back( {{"const_zero"}, @@ -1287,7 +1290,7 @@ bool BuildContextDependentFunctionBody_opset12( {"loss_NCdd", "const_zero", "const_one", "const_one"}}); if (!ctx.hasInput(2)) { - if (ctx.getAttribute("reduction")->s() == "none") { + if (reduction_attr == "none") { body.push_back( {{"loss"}, "Squeeze", @@ -1299,7 +1302,7 @@ bool BuildContextDependentFunctionBody_opset12( "Squeeze", {"loss_N1dd"}, {MakeAttribute("axes", std::vector({1}))}}); - if (ctx.getAttribute("reduction")->s() == "mean") { + if (reduction_attr == "mean") { body.push_back( {{"loss"}, "ReduceMean", @@ -1320,12 +1323,12 @@ bool BuildContextDependentFunctionBody_opset12( "Squeeze", {"loss_N1dd"}, {MakeAttribute("axes", std::vector({1}))}}); - if (ctx.getAttribute("reduction")->s() == "none") { + if (reduction_attr == "none") { body.push_back({{"loss"}, "Mul", {"loss_unweighted", "weight_gather"}}); } else { body.push_back( {{"loss_Ndd"}, "Mul", {"loss_unweighted", "weight_gather"}}); - if (ctx.getAttribute("reduction")->s() == "mean") { + if (reduction_attr == "mean") { body.push_back( {{"loss_sum"}, "ReduceSum", @@ -1447,12 +1450,12 @@ bool BuildContextDependentFunctionBody_opset12( "Squeeze", {"loss_N1dd"}, {MakeAttribute("axes", std::vector({1}))}}); - if (ctx.getAttribute("reduction")->s() == "none") { + if (reduction_attr == "none") { body.push_back({{"loss"}, "Mul", {"loss_unweighted", "weight_gather"}}); } else { body.push_back( {{"loss_Ndd"}, "Mul", {"loss_unweighted", "weight_gather"}}); - if (ctx.getAttribute("reduction")->s() == "mean") { + if (reduction_attr == "mean") { body.push_back( {{"loss_sum"}, "ReduceSum", @@ -1532,64 +1535,62 @@ ONNX_OPERATOR_SET_SCHEMA( "Constrain target to integer types") .SetContextDependentFunctionBodyBuilder( BuildContextDependentFunctionBody_opset12) - .TypeAndShapeInferenceFunction([](InferenceContext& ctx) { - // Type inference - propagateElemTypeFromInputToOutput(ctx, 0, 0); - - // Shape inference - if (hasNInputShapes(ctx, 2)) { - const TensorShapeProto& input_shape = - ctx.getInputType(0)->tensor_type().shape(); - const TensorShapeProto& target_shape = - ctx.getInputType(1)->tensor_type().shape(); - - const int input_rank = static_cast(input_shape.dim_size()); - const int target_rank = static_cast(target_shape.dim_size()); - - if (input_rank < 2) { - fail_shape_inference("Input rank must be >= 2."); - } - if (target_rank != input_rank - 1) { - fail_shape_inference( - "Target rank must be 1 less than the input rank."); - } + .TypeAndShapeInferenceFunction([](InferenceContext& ctx) { + // Type inference + propagateElemTypeFromInputToOutput(ctx, 0, 0); + + // Shape inference + if (hasNInputShapes(ctx, 2)) { + const TensorShapeProto& input_shape = + ctx.getInputType(0)->tensor_type().shape(); + const TensorShapeProto& target_shape = + ctx.getInputType(1)->tensor_type().shape(); + + const int input_rank = static_cast(input_shape.dim_size()); + const int target_rank = static_cast(target_shape.dim_size()); + + if (input_rank < 2) { + fail_shape_inference("Input rank must be >= 2."); + } + if (target_rank != input_rank - 1) { + fail_shape_inference( + "Target rank must be 1 less than the input rank."); + } - // match input dimensions (N, C, d1, ..., dk) with target - // dimensions of (C, d1, ..., dk) - for (int dim = 0; dim < target_rank; dim++) { - const auto input_dim = - dim == 0 ? input_shape.dim(dim) : input_shape.dim(dim + 1); - const auto target_dim = target_shape.dim(dim); - if (input_dim.has_dim_value() && target_dim.has_dim_value() && - input_dim.dim_value() != target_dim.dim_value()) - fail_shape_inference( - "Input and target dimension value mismatch."); - } + // match input dimensions (N, C, d1, ..., dk) with target + // dimensions of (C, d1, ..., dk) + for (int dim = 0; dim < target_rank; dim++) { + const auto input_dim = + dim == 0 ? input_shape.dim(dim) : input_shape.dim(dim + 1); + const auto target_dim = target_shape.dim(dim); + if (input_dim.has_dim_value() && target_dim.has_dim_value() && + input_dim.dim_value() != target_dim.dim_value()) + fail_shape_inference( + "Input and target dimension value mismatch."); + } - if (ctx.getNumInputs() == 3 && hasInputShape(ctx, 2)) { - const TensorShapeProto& weight_shape = - ctx.getInputType(2)->tensor_type().shape(); - if (weight_shape.dim_size() != 1) { - fail_shape_inference("Weight rank must be 1."); - } - } + if (ctx.getNumInputs() == 3 && hasInputShape(ctx, 2)) { + const TensorShapeProto& weight_shape = + ctx.getInputType(2)->tensor_type().shape(); + if (weight_shape.dim_size() != 1) { + fail_shape_inference("Weight rank must be 1."); + } + } - TensorShapeProto* output_shape = - ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape(); - - if (ctx.getAttribute("reduction")->s() == "none") { - // output tensor is of shape (N, d1, d2, ..., dk) if - // reduction attribute is "none". - for (int i = 0; i < input_rank - 1; i++) { - auto* dim = output_shape->add_dim(); - if (i == 0) - *dim = input_shape.dim(i); - else - *dim = input_shape.dim(i + 1); - } - } - // otherwise output is a scalar. - } + TensorShapeProto* output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape(); + if (getAttribute(ctx, "reduction", "mean") == "none") { + // output tensor is of shape (N, d1, d2, ..., dk) if + // reduction attribute is "none". + for (int i = 0; i < input_rank - 1; i++) { + auto* dim = output_shape->add_dim(); + if (i == 0) + *dim = input_shape.dim(i); + else + *dim = input_shape.dim(i + 1); + } + } + // otherwise output is a scalar. + } })); const char* reduction_doc_sce_opset12 =