提交 bb44e908 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2851 Fix some mistakes of InplaceADD vm ops and apply in new character of dynamic format

Merge pull request !2851 from liuwenhao/pr_1949
......@@ -27,11 +27,11 @@ accumulate_n_v2_op_info = TBERegOp("AccumulateNV2") \
.input(0, "x", False, "dynamic", "all") \
.output(0, "y", False, "required", "all") \
.op_pattern("broadcast") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.F16_None, DataType.F16_None) \
.dtype_format(DataType.F32_None, DataType.F32_None) \
.dtype_format(DataType.I32_None, DataType.I32_None) \
.dtype_format(DataType.I8_None, DataType.I8_None) \
.dtype_format(DataType.U8_None, DataType.U8_None) \
.get_op_info()
......
......@@ -28,10 +28,8 @@ approximate_equal_op_info = TBERegOp("ApproximateEqual") \
.input(0, "x1", False, "required", "all") \
.input(1, "x2", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.BOOL_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.BOOL_5HD) \
.dtype_format(DataType.F16_None, DataType.F16_None, DataType.BOOL_None) \
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.BOOL_None) \
.get_op_info()
......
......@@ -28,10 +28,8 @@ binary_cross_entropy_op_info = TBERegOp("BinaryCrossEntropy") \
.input(1, "y", False, "required", "all") \
.input(2, "weight", False, "optional", "all") \
.output(0, "output", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.op_pattern("dynamicFormat") \
.dtype_format(DataType.None_None, DataType.None_None, DataType.None_None, DataType.None_None) \
.get_op_info()
......
......@@ -29,8 +29,8 @@ lin_space_op_info = TBERegOp("LinSpace") \
.input(2, "stop", False, "required", "all") \
.input(3, "num", False, "required", "all") \
.output(0, "output", False, "required", "all") \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.I32_Default,
DataType.F32_Default,) \
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None, DataType.I32_None,
DataType.F32_None,) \
.get_op_info()
......
......@@ -26,16 +26,12 @@ mod_op_info = TBERegOp("Mod") \
.input(0, "x1", False, "required", "all") \
.input(1, "x2", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.I8_5HD) \
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U8_5HD, DataType.U8_5HD, DataType.U8_5HD) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.op_pattern("broadcast") \
.dtype_format(DataType.I8_None, DataType.I8_None, DataType.I8_None) \
.dtype_format(DataType.U8_None, DataType.U8_None, DataType.U8_None) \
.dtype_format(DataType.I32_None, DataType.I32_None, DataType.I32_None) \
.dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \
.get_op_info()
......
......@@ -27,10 +27,11 @@ reduce_mean_d_op_info = TBERegOp("ReduceMeanD") \
.attr("keep_dims", "optional", "bool", "all") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.op_pattern("reduce") \
.dtype_format(DataType.I8_None, DataType.I8_None) \
.dtype_format(DataType.U8_None, DataType.U8_None) \
.dtype_format(DataType.F16_None, DataType.F16_None) \
.dtype_format(DataType.F32_None, DataType.F32_None) \
.get_op_info()
......
......@@ -26,8 +26,8 @@ softsign_op_info = TBERegOp("Softsign") \
.op_pattern("formatAgnostic") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F16_None, DataType.F16_None) \
.dtype_format(DataType.F32_None, DataType.F32_None) \
.get_op_info()
......
......@@ -29,28 +29,7 @@ split_v_op_info = TBERegOp("SplitV") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "dynamic", "all") \
.op_pattern("dynamicFormat") \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
.dtype_format(DataType.BOOL_NHWC, DataType.BOOL_NHWC) \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I8_NHWC, DataType.I8_NHWC) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U8_NHWC, DataType.U8_NHWC) \
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I16_NHWC, DataType.I16_NHWC) \
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.U16_NHWC, DataType.U16_NHWC) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_NHWC, DataType.I32_NHWC) \
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.U32_NHWC, DataType.U32_NHWC) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.I64_NHWC, DataType.I64_NHWC) \
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.U64_NHWC, DataType.U64_NHWC) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_NHWC, DataType.F16_NHWC) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_NHWC, DataType.F32_NHWC) \
.dtype_format(DataType.None_None, DataType.None_None) \
.get_op_info()
......
......@@ -3055,7 +3055,7 @@ class InplaceUpdate(PrimitiveWithInfer):
raise ValueError(f'The value of indices must be in [0, {x_shape[0]}), but got {i}.')
x_rank = len(x_shape)
for idx in range(x_rank)[1:]:
validator.check("x dim %d" % idx, x_shape[idx], 'v dim %d' % idx, v_shape[idx], Rel.EQ, self.name)
validator.check('v dim %d' % idx, v_shape[idx], "x dim %d" % idx, x_shape[idx], Rel.EQ, self.name)
return x_shape
......
......@@ -947,7 +947,7 @@ class InplaceAdd(PrimitiveWithInfer):
raise ValueError(f'The value of indices must be in [0, {x_shape[0]}), but got {i}.')
x_rank = len(x_shape)
for idx in range(x_rank)[1:]:
validator.check("x dim %d" % idx, x_shape[idx], 'v dim %d' % idx, v_shape[idx], Rel.EQ, self.name)
validator.check('v dim %d' % idx, v_shape[idx], "x dim %d" % idx, x_shape[idx], Rel.EQ, self.name)
return x_shape
......@@ -1005,7 +1005,7 @@ class InplaceSub(PrimitiveWithInfer):
raise ValueError(f'The value of indices must be in [0, {x_shape[0]}), but got {i}.')
x_rank = len(x_shape)
for idx in range(x_rank)[1:]:
validator.check("x dim %d" % idx, x_shape[idx], 'v dim %d' % idx, v_shape[idx], Rel.EQ, self.name)
validator.check('v dim %d' % idx, v_shape[idx], "x dim %d" % idx, x_shape[idx], Rel.EQ, self.name)
return x_shape
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册