From f70429d671a7e80c324cbbce3c874e1fe6db5fc2 Mon Sep 17 00:00:00 2001 From: liubuyu Date: Mon, 18 May 2020 09:00:14 +0800 Subject: [PATCH] add format transfer --- mindspore/ccsrc/common/trans.cc | 188 ++++++++++++------ mindspore/ccsrc/common/trans.h | 2 + .../device/ascend/ascend_device_address.cc | 25 ++- tests/ut/cpp/common/trans_test.cc | 113 +++++++++++ 4 files changed, 262 insertions(+), 66 deletions(-) create mode 100644 tests/ut/cpp/common/trans_test.cc diff --git a/mindspore/ccsrc/common/trans.cc b/mindspore/ccsrc/common/trans.cc index e53b8bdf2..55e476103 100644 --- a/mindspore/ccsrc/common/trans.cc +++ b/mindspore/ccsrc/common/trans.cc @@ -63,26 +63,24 @@ const std::map type_map = {{kNumberTypeBool, 1}, {kNumberType {kNumberTypeUInt32, 4}, {kNumberTypeUInt64, 8}, {kNumberTypeFloat, 4}, {kNumberTypeFloat16, 2}, {kNumberTypeFloat32, 4}, {kNumberTypeFloat64, 8}}; -#define SetDataBysize(size, pad_zero) \ - do { \ - switch (size) { \ - case 1: \ - static_cast(result)[dst_idx] = pad_zero ? 0 : static_cast(args.data)[src_idx]; \ - break; \ - case 2: \ - static_cast(result)[dst_idx] = pad_zero ? 0 : static_cast(args.data)[src_idx]; \ - break; \ - case 4: \ - static_cast(result)[dst_idx] = pad_zero ? 0 : static_cast(args.data)[src_idx]; \ - break; \ - case 8: \ - static_cast(result)[dst_idx] = pad_zero ? 0 : static_cast(args.data)[src_idx]; \ - break; \ - default: \ - MS_LOG(ERROR) << "Trans data not support size " << size; \ - return false; \ - } \ - } while (0) +inline void SetData(size_t size, bool pad_zero, size_t src_idx, size_t dst_idx, const FormatArgs &args, void *result) { + switch (size) { + case 1: + static_cast(result)[dst_idx] = pad_zero ? 0 : static_cast(args.data)[src_idx]; + break; + case 2: + static_cast(result)[dst_idx] = pad_zero ? 0 : static_cast(args.data)[src_idx]; + break; + case 4: + static_cast(result)[dst_idx] = pad_zero ? 0 : static_cast(args.data)[src_idx]; + break; + case 8: + static_cast(result)[dst_idx] = pad_zero ? 0 : static_cast(args.data)[src_idx]; + break; + default: + MS_LOG(EXCEPTION) << "Trans data not support size " << size; + } +} template T DivCeil(T n1, T n2) { @@ -401,6 +399,13 @@ std::vector Nc1hwc04DeviceShape(const std::vector &shape) { device_shape.push_back(C0); return device_shape; } + +std::vector NdhwcDeviceShape(const std::vector &shape) { + if (shape.size() < 5) { + MS_LOG(EXCEPTION) << "Shape dims must be 5 when format is ndhwc."; + } + return shape; +} } // namespace std::vector TransShapeToDevice(const std::vector &shape, const std::string &format) { @@ -412,7 +417,8 @@ std::vector TransShapeToDevice(const std::vector &shape, const s {kOpFormat_NC1HWC0, Nc1hwc0DeviceShape}, {kOpFormat_C1HWNCoC0, C1hwncoc0DeviceShape}, {kOpFormat_FRACTAL_Z_C04, FracZc04DeviceShape}, - {kOpFormat_NC1HWC0_C04, Nc1hwc04DeviceShape}}; + {kOpFormat_NC1HWC0_C04, Nc1hwc04DeviceShape}, + {kOpFormat_NDHWC, NdhwcDeviceShape}}; if (format == kOpFormat_ND || format == kOpFormat_DEFAULT) { return shape; @@ -482,43 +488,109 @@ bool TransDataType(const TypeIdArgs &args, void *result) { } bool TransFormat(const FormatArgs &args, void *result) { + using FormatTransfer = std::function; + const std::map format_trans_map{ + {kOpFormat_FRAC_Z, NchwToFracZ}, {kOpFormat_FRAC_NZ, NchwToFracNz}, + {kOpFormat_NC1HWC0, NchwToNc1hwc0}, {kOpFormat_C1HWNCoC0, NchwToC1hwncoc0}, + {kOpFormat_FRACTAL_Z_C04, NchwToFracZc04}, {kOpFormat_NC1HWC0_C04, NchwToNc1hwc04}}; MS_LOG(DEBUG) << "Start trans format."; if (TypeIdSize(args.src_data_type) < 1) { MS_LOG(ERROR) << "Invalid datatype.."; return false; } - if (args.device_format == kOpFormat_FRAC_Z) { - return NchwToFracZ(args, result); - } else if (args.device_format == kOpFormat_FRAC_NZ) { - return NchwToFracNz(args, result); - } else if (args.device_format == kOpFormat_NC1HWC0) { - return NchwToNc1hwc0(args, result); - } else if (args.device_format == kOpFormat_C1HWNCoC0) { - return NchwToC1hwncoc0(args, result); - } else if (args.device_format == kOpFormat_FRACTAL_Z_C04) { - return NchwToFracZc04(args, result); - } else if (args.device_format == kOpFormat_NC1HWC0_C04) { - return NchwToNc1hwc04(args, result); + if (args.device_format == kOpFormat_HWCN || args.device_format == kOpFormat_NHWC) { + return NchwTo4D(args, result); } - return true; + auto iter = format_trans_map.find(args.device_format); + if (iter == format_trans_map.end()) { + MS_LOG(EXCEPTION) << "Unexpected format[" << args.device_format << "]"; + } + return iter->second(args, result); } bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result) { + using FormatTransfer = std::function; + const std::map format_trans_map{{kOpFormat_FRAC_Z, FracZToNchw}, + {kOpFormat_FRAC_NZ, FracNzToNchw}, + {kOpFormat_NC1HWC0, Nc1hwc0ToNchw}, + {kOpFormat_C1HWNCoC0, C1hwncoc0ToNchw}, + {kOpFormat_NC1HWC0_C04, Nc1hwc04ToNchw}}; MS_LOG(DEBUG) << "Start trans format."; if (TypeIdSize(args.src_data_type) < 1) { MS_LOG(ERROR) << "Invalid datatype.."; return false; } - if (args.device_format == kOpFormat_FRAC_Z) { - return FracZToNchw(args, result); - } else if (args.device_format == kOpFormat_FRAC_NZ) { - return FracNzToNchw(args, result); - } else if (args.device_format == kOpFormat_NC1HWC0) { - return Nc1hwc0ToNchw(args, result); - } else if (args.device_format == kOpFormat_C1HWNCoC0) { - return C1hwncoc0ToNchw(args, result); - } else if (args.device_format == kOpFormat_NC1HWC0_C04) { - return Nc1hwc04ToNchw(args, result); + if (args.device_format == kOpFormat_HWCN || args.device_format == kOpFormat_NHWC) { + return ToNchw(args, result); + } + auto iter = format_trans_map.find(args.device_format); + if (iter == format_trans_map.end()) { + MS_LOG(EXCEPTION) << "Unexpected format[" << args.device_format << "]"; + } + return iter->second(args, result); +} + +bool NchwTo4D(const FormatArgs &args, void *result) { + // trans nchw to 4d + MS_LOG(DEBUG) << "Trans format from nchw to 4d."; + MS_EXCEPTION_IF_NULL(result); + size_t size = 0; + size_t total_size = 0; + if (!CheckArgs(args, &size, &total_size)) { + MS_LOG(ERROR) << "Check args failed."; + return false; + } + size_t n = args.host_shape[0]; + size_t c = args.host_shape[1]; + size_t h = args.host_shape[2]; + size_t w = args.host_shape[3]; + for (size_t ni = 0; ni < n; ni++) { + for (size_t ci = 0; ci < c; ci++) { + for (size_t hi = 0; hi < h; hi++) { + for (size_t wi = 0; wi < w; wi++) { + auto src_idx = ni * c * h * w + ci * h * w + hi * w + wi; + auto dst_idx = 0; + if (args.device_format == kOpFormat_NHWC) { + dst_idx = ni * h * w * c + hi * w * c + wi * c + ci; + } else if (args.device_format == kOpFormat_HWCN) { + dst_idx = hi * w * c * n + wi * c * n + ci * n + ni; + } + SetData(size, false, src_idx, dst_idx, args, result); + } + } + } + } + return true; +} + +bool ToNchw(const FormatArgs &args, void *result) { + MS_LOG(DEBUG) << "Trans format to nchw from 4d."; + MS_EXCEPTION_IF_NULL(result); + size_t size = 0; + size_t total_size = 0; + if (!CheckArgs(args, &size, &total_size)) { + MS_LOG(ERROR) << "Check args failed."; + return false; + } + size_t n = args.host_shape[0]; + size_t c = args.host_shape[1]; + size_t h = args.host_shape[2]; + size_t w = args.host_shape[3]; + for (size_t ni = 0; ni < n; ni++) { + for (size_t ci = 0; ci < c; ci++) { + for (size_t hi = 0; hi < h; hi++) { + for (size_t wi = 0; wi < w; wi++) { + auto dst_idx = ni * c * h * w + ci * h * w + hi * w + wi; + auto src_idx = 0; + if (args.device_format == kOpFormat_NHWC) { + src_idx = ni * h * w * c + hi * w * c + wi * c + ci; + } else if (args.device_format == kOpFormat_HWCN) { + src_idx = hi * w * c * n + wi * c * n + ci * n + ni; + } + SetData(size, false, src_idx, dst_idx, args, result); + } + } + } } return true; } @@ -575,8 +647,8 @@ bool NchwToFracZ(const FormatArgs &args, void *result) { auto src_ni = hfi * kCubeSize + col; auto src_idx = src_row_offset + chw * col; auto dst_idx = gfi * fractal_ele_cnt + col * c0 + row; - auto pad_zero = (src_ni >= n || src_idx >= nchw || src_ci >= c) ? 1 : 0; - SetDataBysize(size, pad_zero); + auto pad_zero = (src_ni >= n || src_idx >= nchw || src_ci >= c) ? true : false; + SetData(size, pad_zero, src_idx, dst_idx, args, result); } } } @@ -630,7 +702,7 @@ bool FracZToNchw(const FormatArgs &args, void *result) { size_t c0_idx = c_idx % c0; size_t nc_idx = n_idx; size_t src_idx = c1_idx * hwncc0 + h_idx * wncc0 + w_idx * ncc0 + nc_idx * c0 + c0_idx; - SetDataBysize(size, 0); + SetData(size, false, src_idx, dst_idx, args, result); } } } @@ -679,7 +751,7 @@ bool NchwToFracZc04(const FormatArgs &args, void *result) { auto c_idx = desc_c1 * c0 + desc_c0; auto src_idx = desc_g * nhwc + desc_n * hwc + c_idx * h * w + desc_h * w + desc_w; auto pad_zero = desc_g >= 1 || desc_n >= n || c_idx >= c; - SetDataBysize(size, pad_zero); + SetData(size, pad_zero, src_idx, dst_idx, args, result); dst_idx++; } } @@ -773,7 +845,7 @@ bool NchwToFracNz(const FormatArgs &args, void *result) { for (size_t i = 0; i < w0; ++i) { size_t src_idx = src_h_head + w1_idx * w0 + i; size_t dst_idx = h1h0_head + w1_idx * h1h0w0 + i; - SetDataBysize(size, 0); + SetData(size, false, src_idx, dst_idx, args, result); } } auto w1_head = num_w1 * w0; @@ -781,7 +853,7 @@ bool NchwToFracNz(const FormatArgs &args, void *result) { auto src_w_idx = w1_head + w0_idx; size_t dst_idx = h1h0_head + num_w1 * h1h0w0 + w0_idx; size_t src_idx = src_h_head + src_w_idx; - SetDataBysize(size, 0); + SetData(size, false, src_idx, dst_idx, args, result); } } } @@ -835,7 +907,7 @@ bool FracNzToNchw(const FormatArgs &args, void *result) { for (size_t i = 0; i < w0; ++i) { size_t src_idx = h1h0_head + w1_idx * h1h0w0 + i; size_t dst_idx = src_h_head + w1_idx * w0 + i; - SetDataBysize(size, 0); + SetData(size, false, src_idx, dst_idx, args, result); } } auto w1_head = num_w1 * w0; @@ -843,7 +915,7 @@ bool FracNzToNchw(const FormatArgs &args, void *result) { auto src_w_idx = w1_head + w0_idx; size_t src_idx = h1h0_head + num_w1 * h1h0w0 + w0_idx; size_t dst_idx = src_h_head + src_w_idx; - SetDataBysize(size, 0); + SetData(size, false, src_idx, dst_idx, args, result); } } } @@ -895,8 +967,8 @@ bool NchwToNc1hwc0(const FormatArgs &args, void *result) { size_t dst_idx = c0_idx + w_head_addr; size_t c_idx = c0_idx + c1_idx * c0; size_t src_idx = n_idx * chw + c_idx * hw + h_idx * w + w_idx; - auto pad_zero = (c_idx < c) ? 0 : 1; - SetDataBysize(size, pad_zero); + auto pad_zero = (c_idx < c) ? false : true; + SetData(size, pad_zero, src_idx, dst_idx, args, result); } } } @@ -947,7 +1019,7 @@ bool Nc1hwc0ToNchw(const FormatArgs &args, void *result) { size_t c1_idx = c_idx / c0; size_t c0_idx = c_idx % c0; size_t src_idx = n_idx * c1hwc0 + c1_idx * hwc0 + h_idx * wc0 + w_idx * c0 + c0_idx; - SetDataBysize(size, 0); + SetData(size, false, src_idx, dst_idx, args, result); } } } @@ -983,8 +1055,8 @@ bool NchwToC1hwncoc0(const FormatArgs &args, void *result) { co_i * c0 + c0_i; size_t c_i = c0_i + c1_i * c0; size_t src_idx = n_i * c * h * w + c_i * h * w + h_i * w + w_i; - auto pad_zero = (c_i < c && c0_i == co_i) ? 0 : 1; - SetDataBysize(size, pad_zero); + auto pad_zero = (c_i < c && c0_i == co_i) ? false : true; + SetData(size, pad_zero, src_idx, dst_idx, args, result); } } } @@ -1020,7 +1092,7 @@ bool C1hwncoc0ToNchw(const FormatArgs &args, void *result) { size_t co_i = c0_i; size_t src_idx = c1_i * h * w * n * co * c0 + h_i * w * n * co * c0 + w_i * n * co * c0 + n_i * co * c0 + co_i * c0 + c0_i; - SetDataBysize(size, 0); + SetData(size, false, src_idx, dst_idx, args, result); } } } diff --git a/mindspore/ccsrc/common/trans.h b/mindspore/ccsrc/common/trans.h index e15b95e6d..a8fc7c8a0 100644 --- a/mindspore/ccsrc/common/trans.h +++ b/mindspore/ccsrc/common/trans.h @@ -61,6 +61,7 @@ bool TransFormat(const FormatArgs &args, void *result); bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result); // host to device +bool NchwTo4D(const FormatArgs &args, void *result); bool NchwToFracZ(const FormatArgs &args, void *result); bool NchwToFracNz(const FormatArgs &args, void *result); bool NchwToNc1hwc0(const FormatArgs &args, void *result); @@ -68,6 +69,7 @@ bool NchwToFracZc04(const FormatArgs &args, void *result); bool NchwToNc1hwc04(const FormatArgs &args, void *result); bool NchwToC1hwncoc0(const FormatArgs &args, void *result); // device to host +bool ToNchw(const FormatArgs &args, void *result); bool FracZToNchw(const FormatArgs &args, void *result); bool FracNzToNchw(const FormatArgs &args, void *result); bool Nc1hwc0ToNchw(const FormatArgs &args, void *result); diff --git a/mindspore/ccsrc/device/ascend/ascend_device_address.cc b/mindspore/ccsrc/device/ascend/ascend_device_address.cc index 1f452ce9e..40a3eec71 100644 --- a/mindspore/ccsrc/device/ascend/ascend_device_address.cc +++ b/mindspore/ccsrc/device/ascend/ascend_device_address.cc @@ -16,6 +16,7 @@ #include "device/ascend/ascend_device_address.h" #include #include +#include #include #include "runtime/mem.h" #include "device/kernel_runtime_manager.h" @@ -34,6 +35,10 @@ namespace device { namespace ascend { const int FLOAT_LEN = sizeof(float); const int FLOAT16_LEN = 2; // sizeof(float16); +const std::set kOpNeedTransFormat = {kOpFormat_NHWC, kOpFormat_HWCN, kOpFormat_NC1HWC0, + kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, kOpFormat_FRAC_NZ, + kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04}; + void SyncMemory(void *dst, const void *src, uint64_t size, rtMemcpyKind_t kind) { auto ret_rt_memcpy = rtMemcpy(dst, size, src, size, kind); if (ret_rt_memcpy != RT_ERROR_NONE) { @@ -97,7 +102,7 @@ bool AscendDeviceAddress::SyncDeviceToHost(const std::vector &shape, size_t if (host_shape.empty()) { host_shape.emplace_back(1); } - if (format_ == kOpFormat_NCHW || format_ == kOpFormat_DEFAULT) { + if (format_ == kOpFormat_NCHW || format_ == kOpFormat_DEFAULT || format_ == kOpFormat_NDHWC) { if (type_id_ == type) { SyncMemory(host_ptr, ptr_, size, RT_MEMCPY_DEVICE_TO_HOST); sync_ok = true; @@ -115,9 +120,11 @@ bool AscendDeviceAddress::SyncDeviceToHost(const std::vector &shape, size_t } } } else { - auto iter = kNeedTransFormatSet.find(format_); - if (iter != kNeedTransFormatSet.end()) { + auto iter = kOpNeedTransFormat.find(format_); + if (iter != kOpNeedTransFormat.end()) { sync_ok = SyncDeviceToHostAndConvertFormat(shape, size, type, host_ptr); + } else { + MS_LOG(INFO) << "Can not find format transfer for :" << format_; } } if (!sync_ok) { @@ -141,7 +148,7 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const std::vector &shape, size_t if (host_shape.empty()) { host_shape.emplace_back(1); } - if (format_ == kOpFormat_NCHW || format_ == kOpFormat_DEFAULT) { + if (format_ == kOpFormat_NCHW || format_ == kOpFormat_DEFAULT || format_ == kOpFormat_NDHWC) { if (type_id_ == type) { SyncMemory(ptr_, host_ptr, size_, RT_MEMCPY_HOST_TO_DEVICE); sync_ok = true; @@ -203,9 +210,11 @@ bool AscendDeviceAddress::SyncHostToDevice(const std::vector &shape, size_t SyncMemory(ptr_, host_tmp.data(), size_, RT_MEMCPY_HOST_TO_DEVICE); } } else { - auto iter = kNeedTransFormatSet.find(format_); - if (iter != kNeedTransFormatSet.end()) { + auto iter = kOpNeedTransFormat.find(format_); + if (iter != kOpNeedTransFormat.end()) { sync_ok = ConvertFormatAndSyncHostToDevice(shape, size, type, host_ptr); + } else { + MS_LOG(INFO) << "Can not find format transfer for :" << format_; } } if (!sync_ok) { @@ -227,7 +236,7 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const std::vector device_shape; - if (format_ == kOpFormat_FRAC_NZ) { + if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NDHWC) { device_shape = trans::TransShapeToDevice(host_shape, format_); } else { host_shape = trans::PaddingShapeTo4d(host_shape); diff --git a/tests/ut/cpp/common/trans_test.cc b/tests/ut/cpp/common/trans_test.cc new file mode 100644 index 000000000..559933fd6 --- /dev/null +++ b/tests/ut/cpp/common/trans_test.cc @@ -0,0 +1,113 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "common/common_test.h" +#include "common/trans.h" +#include "utils/utils.h" + +using namespace std; +namespace mindspore { +namespace trans { +class FormatTransTest : public UT::Common { + public: + FormatTransTest() = default; + void SetUp() override {} + void TearDown() override {} +}; + +TEST_F(FormatTransTest, nchw_to_hwcn) { + uint16_t data[2*2*2*2] = {12581,14220,14937,14302, + 15004,14951,14694,14564, + 14069,14554,10507,14787, + 13016,15263,14872,10838}; + uint16_t res[2*2*2*2] = {12581,14069,15004,13016, + 14220,14554,14951,15263, + 14937,10507,14694,14872, + 14302,14787,14564,10838}; + size_t device_size = 32; + auto trans_tmp = std::vector(device_size); + FormatArgs format_args{data, device_size, kOpFormat_NCHW, kOpFormat_HWCN, + {2, 2, 2, 2}, {2, 2, 2, 2}, kNumberTypeFloat16}; + EXPECT_EQ(trans::TransFormat(format_args, trans_tmp.data()), true); + for (size_t i = 0; i < sizeof(res) / sizeof(res[0]); i++) { + EXPECT_EQ((reinterpret_cast(trans_tmp.data()))[i], res[i]); + } +} + +TEST_F(FormatTransTest, hwcn_to_nchw) { + uint16_t data[2*2*2*2] = {12581,14069,15004,13016, + 14220,14554,14951,15263, + 14937,10507,14694,14872, + 14302,14787,14564,10838}; + + uint16_t res[2*2*2*2] = {12581,14220,14937,14302, + 15004,14951,14694,14564, + 14069,14554,10507,14787, + 13016,15263,14872,10838}; + + size_t device_size = 32; + auto trans_tmp = std::vector(device_size); + FormatArgs format_args{data, device_size, kOpFormat_NCHW, kOpFormat_HWCN, + {2, 2, 2, 2}, {2, 2, 2, 2}, kNumberTypeFloat16}; + EXPECT_EQ(trans::TransFormatFromDeviceToHost(format_args, trans_tmp.data()), true); + for (size_t i = 0; i < sizeof(res) / sizeof(res[0]); i++) { + EXPECT_EQ((reinterpret_cast(trans_tmp.data()))[i], res[i]); + } +} + +TEST_F(FormatTransTest, nchw_to_nhwc) { + uint16_t data[2*2*2*2] = {11750,13778,15007,15321, + 15163,13446,15063,14467, + 15056,13284,15219,14797, + 12684,14288,14855,14799}; + uint16_t res[2*2*2*2] = {11750,15163,13778,13446, + 15007,15063,15321,14467, + 15056,12684,13284,14288, + 15219,14855,14797,14799}; + size_t device_size = 32; + auto trans_tmp = std::vector(device_size); + FormatArgs format_args{data, device_size, kOpFormat_NCHW, kOpFormat_NHWC, + {2, 2, 2, 2}, {2, 2, 2, 2}, kNumberTypeFloat16}; + EXPECT_EQ(trans::TransFormat(format_args, trans_tmp.data()), true); + for (size_t i = 0; i < sizeof(res) / sizeof(res[0]); i++) { + EXPECT_EQ((reinterpret_cast(trans_tmp.data()))[i], res[i]); + } +} +TEST_F(FormatTransTest, nhwc_to_nchw) { + uint16_t data[2*2*2*2] = {11750,15163,13778,13446, + 15007,15063,15321,14467, + 15056,12684,13284,14288, + 15219,14855,14797,14799}; + uint16_t res[2*2*2*2] = {11750,13778,15007,15321, + 15163,13446,15063,14467, + 15056,13284,15219,14797, + 12684,14288,14855,14799}; + + size_t device_size = 32; + auto trans_tmp = std::vector(device_size); + FormatArgs format_args{data, device_size, kOpFormat_NCHW, kOpFormat_NHWC, + {2, 2, 2, 2}, {2, 2, 2, 2}, kNumberTypeFloat16}; + EXPECT_EQ(trans::TransFormatFromDeviceToHost(format_args, trans_tmp.data()), true); + for (size_t i = 0; i < sizeof(res) / sizeof(res[0]); i++) { + EXPECT_EQ((reinterpret_cast(trans_tmp.data()))[i], res[i]); + } +} +} // namespace trans +} // namespace mindspore + + + -- GitLab