diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 0dc2a667dc3a754c947f58114ede30ed885dedcf..d2b6f9a39790379f878301de3b4b4ce8c7c041ed 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -678,6 +678,13 @@ class FusedBatchNormEx(PrimitiveWithInfer): >>> op = P.FusedBatchNormEx() >>> output = op(input_x, scale, bias, mean, variance) """ + __mindspore_signature__ = ( + ('input_x', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T2), + ('scale', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), + ('bias', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), + ('mean', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), + ('variance', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), + ) @prim_attr_register def __init__(self, mode=0, epsilon=1e-5, momentum=0.1):