From efe40e91b7cc76a71b20dec876a3df7fb29b61da Mon Sep 17 00:00:00 2001 From: liuwenhao4 Date: Fri, 19 Jun 2020 15:33:33 +0800 Subject: [PATCH] Fix some mistakes of TransData vm ops --- mindspore/ops/_op_impl/tbe/trans_data.py | 123 ++++++++++++++++++----- 1 file changed, 96 insertions(+), 27 deletions(-) diff --git a/mindspore/ops/_op_impl/tbe/trans_data.py b/mindspore/ops/_op_impl/tbe/trans_data.py index da5ae6e1b..c0cce302c 100644 --- a/mindspore/ops/_op_impl/tbe/trans_data.py +++ b/mindspore/ops/_op_impl/tbe/trans_data.py @@ -23,43 +23,112 @@ trans_data_op_info = TBERegOp("TransData") \ .compute_cost(10) \ .kernel_name("trans_data") \ .partial_flag(True) \ - .attr("src_format", "required", "str", "DefaultFormat,NC1HWC0,FracZ,FRACTAL_NZ,HWCN,C1HWNCoC0")\ - .attr("dst_format", "required", "str", "DefaultFormat,NC1HWC0,FracZ,FRACTAL_NZ,HWCN,C1HWNCoC0")\ + .attr("src_format", "required", "str", "DefaultFormat, NC1HWC0, FracZ, FRACTAL_NZ, HWCN, C1HWNCoC0, NDHWC, NHWC") \ + .attr("dst_format", "required", "str", "DefaultFormat, NC1HWC0, FracZ, FRACTAL_NZ, HWCN, C1HWNCoC0, NDHWC, NHWC") \ .input(0, "src", False, "required", "all") \ .output(0, "dst", False, "required", "all") \ - .dtype_format(DataType.U16_Default, DataType.U16_5HD) \ - .dtype_format(DataType.U16_Default, DataType.U16_FracZ) \ - .dtype_format(DataType.U16_Default, DataType.U16_FracNZ) \ - .dtype_format(DataType.U16_FracZ, DataType.U16_Default) \ - .dtype_format(DataType.U16_FracZ, DataType.U16_HWCN) \ - .dtype_format(DataType.U16_FracNZ, DataType.U16_Default) \ - .dtype_format(DataType.U16_5HD, DataType.U16_Default) \ - .dtype_format(DataType.U16_HWCN, DataType.U16_FracZ) \ - .dtype_format(DataType.U16_HWCN, DataType.U16_C1HWNCoC0) \ - .dtype_format(DataType.U16_C1HWNCoC0, DataType.U16_HWCN) \ - .dtype_format(DataType.BOOL_Default, DataType.BOOL_5HD) \ - .dtype_format(DataType.F16_Default, DataType.F16_5HD) \ + .dtype_format(DataType.F32_NHWC, DataType.F32_5HD) \ + .dtype_format(DataType.F32_Default, DataType.F32_5HD) \ + .dtype_format(DataType.F32_5HD, DataType.F32_NHWC) \ + .dtype_format(DataType.F32_5HD, DataType.F32_Default) \ + .dtype_format(DataType.F32_FracZ, DataType.F32_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_FracZ) \ + .dtype_format(DataType.F32_HWCN, DataType.F32_FracZ) \ + .dtype_format(DataType.F32_FracZ, DataType.F32_HWCN) \ + .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_HWCN) \ + .dtype_format(DataType.F32_HWCN, DataType.F32_C1HWNCoC0) \ .dtype_format(DataType.F16_Default, DataType.F16_FracZ) \ - .dtype_format(DataType.F16_Default, DataType.F16_FracNZ) \ - .dtype_format(DataType.F16_FracZ, DataType.F16_Default) \ - .dtype_format(DataType.F16_FracZ, DataType.F16_HWCN) \ - .dtype_format(DataType.F16_FracNZ, DataType.F16_Default) \ + .dtype_format(DataType.F16_NHWC, DataType.F16_FracZ) \ + .dtype_format(DataType.F16_HWCN, DataType.F16_FracZ) \ + .dtype_format(DataType.F16_Default, DataType.F16_5HD) \ + .dtype_format(DataType.F16_NHWC, DataType.F16_5HD) \ + .dtype_format(DataType.F16_HWCN, DataType.F16_5HD) \ + .dtype_format(DataType.F16_5HD, DataType.F16_NHWC) \ .dtype_format(DataType.F16_5HD, DataType.F16_Default) \ + .dtype_format(DataType.F16_FracZ, DataType.F16_Default) \ + .dtype_format(DataType.F16_Default, DataType.F16_FracZ) \ .dtype_format(DataType.F16_HWCN, DataType.F16_FracZ) \ - .dtype_format(DataType.F16_HWCN, DataType.F16_C1HWNCoC0) \ + .dtype_format(DataType.F16_FracZ, DataType.F16_HWCN) \ .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_HWCN) \ - .dtype_format(DataType.F32_Default, DataType.F32_5HD) \ - .dtype_format(DataType.F32_Default, DataType.F32_FracZ) \ + .dtype_format(DataType.F16_HWCN, DataType.F16_5HD) \ + .dtype_format(DataType.F16_Default, DataType.F16_FracNZ) \ .dtype_format(DataType.F32_Default, DataType.F32_FracNZ) \ - .dtype_format(DataType.F32_FracZ, DataType.F32_Default) \ - .dtype_format(DataType.F32_FracZ, DataType.F32_HWCN) \ + .dtype_format(DataType.F16_FracNZ, DataType.F16_Default) \ .dtype_format(DataType.F32_FracNZ, DataType.F32_Default) \ - .dtype_format(DataType.F32_5HD, DataType.F32_Default) \ - .dtype_format(DataType.F32_HWCN, DataType.F32_FracZ) \ - .dtype_format(DataType.F32_HWCN, DataType.F32_C1HWNCoC0) \ - .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_HWCN) \ + .dtype_format(DataType.BOOL_NHWC, DataType.BOOL_5HD) \ + .dtype_format(DataType.BOOL_Default, DataType.BOOL_5HD) \ + .dtype_format(DataType.BOOL_5HD, DataType.BOOL_NHWC) \ + .dtype_format(DataType.BOOL_5HD, DataType.BOOL_Default) \ + .dtype_format(DataType.F16_Default, DataType.F16_NHWC) \ + .dtype_format(DataType.F16_Default, DataType.F16_HWCN) \ + .dtype_format(DataType.F16_NHWC, DataType.F16_Default) \ + .dtype_format(DataType.F16_NHWC, DataType.F16_HWCN) \ + .dtype_format(DataType.F16_HWCN, DataType.F16_Default) \ + .dtype_format(DataType.F16_HWCN, DataType.F16_NHWC) \ + .dtype_format(DataType.F32_Default, DataType.F32_NHWC) \ .dtype_format(DataType.F32_Default, DataType.F32_HWCN) \ + .dtype_format(DataType.F32_NHWC, DataType.F32_Default) \ + .dtype_format(DataType.F32_NHWC, DataType.F32_HWCN) \ .dtype_format(DataType.F32_HWCN, DataType.F32_Default) \ + .dtype_format(DataType.F32_HWCN, DataType.F32_NHWC) \ + .dtype_format(DataType.I8_Default, DataType.I8_FracNZ) \ + .dtype_format(DataType.I8_Default, DataType.I8_FracZ) \ + .dtype_format(DataType.I8_Default, DataType.I8_NHWC) \ + .dtype_format(DataType.I8_Default, DataType.I8_HWCN) \ + .dtype_format(DataType.I8_NHWC, DataType.I8_Default) \ + .dtype_format(DataType.I8_NHWC, DataType.I8_HWCN) \ + .dtype_format(DataType.I8_HWCN, DataType.I8_Default) \ + .dtype_format(DataType.I8_HWCN, DataType.I8_NHWC) \ + .dtype_format(DataType.I16_Default, DataType.I16_NHWC) \ + .dtype_format(DataType.I16_Default, DataType.I16_HWCN) \ + .dtype_format(DataType.I16_NHWC, DataType.I16_Default) \ + .dtype_format(DataType.I16_NHWC, DataType.I16_HWCN) \ + .dtype_format(DataType.I16_HWCN, DataType.I16_Default) \ + .dtype_format(DataType.I16_HWCN, DataType.I16_NHWC) \ + .dtype_format(DataType.I32_Default, DataType.I32_NHWC) \ + .dtype_format(DataType.I32_Default, DataType.I32_HWCN) \ + .dtype_format(DataType.I32_NHWC, DataType.I32_Default) \ + .dtype_format(DataType.I32_NHWC, DataType.I32_HWCN) \ + .dtype_format(DataType.I32_HWCN, DataType.I32_Default) \ + .dtype_format(DataType.I32_HWCN, DataType.I32_NHWC) \ + .dtype_format(DataType.I64_Default, DataType.I64_NHWC) \ + .dtype_format(DataType.I64_Default, DataType.I64_HWCN) \ + .dtype_format(DataType.I64_NHWC, DataType.I64_Default) \ + .dtype_format(DataType.I64_NHWC, DataType.I64_HWCN) \ + .dtype_format(DataType.I64_HWCN, DataType.I64_Default) \ + .dtype_format(DataType.I64_HWCN, DataType.I64_NHWC) \ + .dtype_format(DataType.U8_Default, DataType.U8_NHWC) \ + .dtype_format(DataType.U8_Default, DataType.U8_HWCN) \ + .dtype_format(DataType.U8_NHWC, DataType.U8_Default) \ + .dtype_format(DataType.U8_NHWC, DataType.U8_HWCN) \ + .dtype_format(DataType.U8_HWCN, DataType.U8_Default) \ + .dtype_format(DataType.U8_HWCN, DataType.U8_NHWC) \ + .dtype_format(DataType.U16_Default, DataType.U16_NHWC) \ + .dtype_format(DataType.U16_Default, DataType.U16_HWCN) \ + .dtype_format(DataType.U16_NHWC, DataType.U16_Default) \ + .dtype_format(DataType.U16_NHWC, DataType.U16_HWCN) \ + .dtype_format(DataType.U16_HWCN, DataType.U16_Default) \ + .dtype_format(DataType.U16_HWCN, DataType.U16_NHWC) \ + .dtype_format(DataType.U32_Default, DataType.U32_NHWC) \ + .dtype_format(DataType.U32_Default, DataType.U32_HWCN) \ + .dtype_format(DataType.U32_NHWC, DataType.U32_Default) \ + .dtype_format(DataType.U32_NHWC, DataType.U32_HWCN) \ + .dtype_format(DataType.U32_HWCN, DataType.U32_Default) \ + .dtype_format(DataType.U32_HWCN, DataType.U32_NHWC) \ + .dtype_format(DataType.U64_Default, DataType.U64_NHWC) \ + .dtype_format(DataType.U64_Default, DataType.U64_HWCN) \ + .dtype_format(DataType.U64_NHWC, DataType.U64_Default) \ + .dtype_format(DataType.U64_NHWC, DataType.U64_HWCN) \ + .dtype_format(DataType.U64_HWCN, DataType.U64_Default) \ + .dtype_format(DataType.U64_HWCN, DataType.U64_NHWC) \ + .dtype_format(DataType.I32_FracNZ, DataType.I32_Default) \ + .dtype_format(DataType.F16_NDHWC, DataType.F16_5HD) \ + .dtype_format(DataType.F16_5HD, DataType.F16_NDHWC) \ + .dtype_format(DataType.I8_HWCN, DataType.I8_C1HWNCoC0) \ + .dtype_format(DataType.F16_HWCN, DataType.F16_FracZ) \ + .dtype_format(DataType.F16_FracZ, DataType.F16_HWCN) \ + .dtype_format(DataType.F16_HWCN, DataType.F16_FracNZ) \ + .dtype_format(DataType.F32_HWCN, DataType.F16_FracNZ) \ .get_op_info() -- GitLab