提交 828d0b12 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1209 add format trans function

Merge pull request !1209 from liubuyu/master
...@@ -63,26 +63,24 @@ const std::map<TypeId, size_t> type_map = {{kNumberTypeBool, 1}, {kNumberType ...@@ -63,26 +63,24 @@ const std::map<TypeId, size_t> type_map = {{kNumberTypeBool, 1}, {kNumberType
{kNumberTypeUInt32, 4}, {kNumberTypeUInt64, 8}, {kNumberTypeFloat, 4}, {kNumberTypeUInt32, 4}, {kNumberTypeUInt64, 8}, {kNumberTypeFloat, 4},
{kNumberTypeFloat16, 2}, {kNumberTypeFloat32, 4}, {kNumberTypeFloat64, 8}}; {kNumberTypeFloat16, 2}, {kNumberTypeFloat32, 4}, {kNumberTypeFloat64, 8}};
#define SetDataBysize(size, pad_zero) \ inline void SetData(size_t size, bool pad_zero, size_t src_idx, size_t dst_idx, const FormatArgs &args, void *result) {
do { \ switch (size) {
switch (size) { \ case 1:
case 1: \ static_cast<uint8_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint8_t *>(args.data)[src_idx];
static_cast<uint8_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint8_t *>(args.data)[src_idx]; \ break;
break; \ case 2:
case 2: \ static_cast<uint16_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint16_t *>(args.data)[src_idx];
static_cast<uint16_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint16_t *>(args.data)[src_idx]; \ break;
break; \ case 4:
case 4: \ static_cast<uint32_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint32_t *>(args.data)[src_idx];
static_cast<uint32_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint32_t *>(args.data)[src_idx]; \ break;
break; \ case 8:
case 8: \ static_cast<uint64_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint64_t *>(args.data)[src_idx];
static_cast<uint64_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint64_t *>(args.data)[src_idx]; \ break;
break; \ default:
default: \ MS_LOG(EXCEPTION) << "Trans data not support size " << size;
MS_LOG(ERROR) << "Trans data not support size " << size; \ }
return false; \ }
} \
} while (0)
template <typename T> template <typename T>
T DivCeil(T n1, T n2) { T DivCeil(T n1, T n2) {
...@@ -401,6 +399,13 @@ std::vector<size_t> Nc1hwc04DeviceShape(const std::vector<size_t> &shape) { ...@@ -401,6 +399,13 @@ std::vector<size_t> Nc1hwc04DeviceShape(const std::vector<size_t> &shape) {
device_shape.push_back(C0); device_shape.push_back(C0);
return device_shape; return device_shape;
} }
std::vector<size_t> NdhwcDeviceShape(const std::vector<size_t> &shape) {
if (shape.size() < 5) {
MS_LOG(EXCEPTION) << "Shape dims must be 5 when format is ndhwc.";
}
return shape;
}
} // namespace } // namespace
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format) { std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format) {
...@@ -412,7 +417,8 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s ...@@ -412,7 +417,8 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s
{kOpFormat_NC1HWC0, Nc1hwc0DeviceShape}, {kOpFormat_NC1HWC0, Nc1hwc0DeviceShape},
{kOpFormat_C1HWNCoC0, C1hwncoc0DeviceShape}, {kOpFormat_C1HWNCoC0, C1hwncoc0DeviceShape},
{kOpFormat_FRACTAL_Z_C04, FracZc04DeviceShape}, {kOpFormat_FRACTAL_Z_C04, FracZc04DeviceShape},
{kOpFormat_NC1HWC0_C04, Nc1hwc04DeviceShape}}; {kOpFormat_NC1HWC0_C04, Nc1hwc04DeviceShape},
{kOpFormat_NDHWC, NdhwcDeviceShape}};
if (format == kOpFormat_ND || format == kOpFormat_DEFAULT) { if (format == kOpFormat_ND || format == kOpFormat_DEFAULT) {
return shape; return shape;
...@@ -482,43 +488,109 @@ bool TransDataType(const TypeIdArgs &args, void *result) { ...@@ -482,43 +488,109 @@ bool TransDataType(const TypeIdArgs &args, void *result) {
} }
bool TransFormat(const FormatArgs &args, void *result) { bool TransFormat(const FormatArgs &args, void *result) {
using FormatTransfer = std::function<bool(const FormatArgs &, void *)>;
const std::map<std::string, FormatTransfer> format_trans_map{
{kOpFormat_FRAC_Z, NchwToFracZ}, {kOpFormat_FRAC_NZ, NchwToFracNz},
{kOpFormat_NC1HWC0, NchwToNc1hwc0}, {kOpFormat_C1HWNCoC0, NchwToC1hwncoc0},
{kOpFormat_FRACTAL_Z_C04, NchwToFracZc04}, {kOpFormat_NC1HWC0_C04, NchwToNc1hwc04}};
MS_LOG(DEBUG) << "Start trans format."; MS_LOG(DEBUG) << "Start trans format.";
if (TypeIdSize(args.src_data_type) < 1) { if (TypeIdSize(args.src_data_type) < 1) {
MS_LOG(ERROR) << "Invalid datatype.."; MS_LOG(ERROR) << "Invalid datatype..";
return false; return false;
} }
if (args.device_format == kOpFormat_FRAC_Z) { if (args.device_format == kOpFormat_HWCN || args.device_format == kOpFormat_NHWC) {
return NchwToFracZ(args, result); return NchwTo4D(args, result);
} else if (args.device_format == kOpFormat_FRAC_NZ) {
return NchwToFracNz(args, result);
} else if (args.device_format == kOpFormat_NC1HWC0) {
return NchwToNc1hwc0(args, result);
} else if (args.device_format == kOpFormat_C1HWNCoC0) {
return NchwToC1hwncoc0(args, result);
} else if (args.device_format == kOpFormat_FRACTAL_Z_C04) {
return NchwToFracZc04(args, result);
} else if (args.device_format == kOpFormat_NC1HWC0_C04) {
return NchwToNc1hwc04(args, result);
} }
return true; auto iter = format_trans_map.find(args.device_format);
if (iter == format_trans_map.end()) {
MS_LOG(EXCEPTION) << "Unexpected format[" << args.device_format << "]";
}
return iter->second(args, result);
} }
bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result) { bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result) {
using FormatTransfer = std::function<bool(const FormatArgs &, void *)>;
const std::map<std::string, FormatTransfer> format_trans_map{{kOpFormat_FRAC_Z, FracZToNchw},
{kOpFormat_FRAC_NZ, FracNzToNchw},
{kOpFormat_NC1HWC0, Nc1hwc0ToNchw},
{kOpFormat_C1HWNCoC0, C1hwncoc0ToNchw},
{kOpFormat_NC1HWC0_C04, Nc1hwc04ToNchw}};
MS_LOG(DEBUG) << "Start trans format."; MS_LOG(DEBUG) << "Start trans format.";
if (TypeIdSize(args.src_data_type) < 1) { if (TypeIdSize(args.src_data_type) < 1) {
MS_LOG(ERROR) << "Invalid datatype.."; MS_LOG(ERROR) << "Invalid datatype..";
return false; return false;
} }
if (args.device_format == kOpFormat_FRAC_Z) { if (args.device_format == kOpFormat_HWCN || args.device_format == kOpFormat_NHWC) {
return FracZToNchw(args, result); return ToNchw(args, result);
} else if (args.device_format == kOpFormat_FRAC_NZ) { }
return FracNzToNchw(args, result); auto iter = format_trans_map.find(args.device_format);
} else if (args.device_format == kOpFormat_NC1HWC0) { if (iter == format_trans_map.end()) {
return Nc1hwc0ToNchw(args, result); MS_LOG(EXCEPTION) << "Unexpected format[" << args.device_format << "]";
} else if (args.device_format == kOpFormat_C1HWNCoC0) { }
return C1hwncoc0ToNchw(args, result); return iter->second(args, result);
} else if (args.device_format == kOpFormat_NC1HWC0_C04) { }
return Nc1hwc04ToNchw(args, result);
bool NchwTo4D(const FormatArgs &args, void *result) {
// trans nchw to 4d
MS_LOG(DEBUG) << "Trans format from nchw to 4d.";
MS_EXCEPTION_IF_NULL(result);
size_t size = 0;
size_t total_size = 0;
if (!CheckArgs(args, &size, &total_size)) {
MS_LOG(ERROR) << "Check args failed.";
return false;
}
size_t n = args.host_shape[0];
size_t c = args.host_shape[1];
size_t h = args.host_shape[2];
size_t w = args.host_shape[3];
for (size_t ni = 0; ni < n; ni++) {
for (size_t ci = 0; ci < c; ci++) {
for (size_t hi = 0; hi < h; hi++) {
for (size_t wi = 0; wi < w; wi++) {
auto src_idx = ni * c * h * w + ci * h * w + hi * w + wi;
auto dst_idx = 0;
if (args.device_format == kOpFormat_NHWC) {
dst_idx = ni * h * w * c + hi * w * c + wi * c + ci;
} else if (args.device_format == kOpFormat_HWCN) {
dst_idx = hi * w * c * n + wi * c * n + ci * n + ni;
}
SetData(size, false, src_idx, dst_idx, args, result);
}
}
}
}
return true;
}
bool ToNchw(const FormatArgs &args, void *result) {
MS_LOG(DEBUG) << "Trans format to nchw from 4d.";
MS_EXCEPTION_IF_NULL(result);
size_t size = 0;
size_t total_size = 0;
if (!CheckArgs(args, &size, &total_size)) {
MS_LOG(ERROR) << "Check args failed.";
return false;
}
size_t n = args.host_shape[0];
size_t c = args.host_shape[1];
size_t h = args.host_shape[2];
size_t w = args.host_shape[3];
for (size_t ni = 0; ni < n; ni++) {
for (size_t ci = 0; ci < c; ci++) {
for (size_t hi = 0; hi < h; hi++) {
for (size_t wi = 0; wi < w; wi++) {
auto dst_idx = ni * c * h * w + ci * h * w + hi * w + wi;
auto src_idx = 0;
if (args.device_format == kOpFormat_NHWC) {
src_idx = ni * h * w * c + hi * w * c + wi * c + ci;
} else if (args.device_format == kOpFormat_HWCN) {
src_idx = hi * w * c * n + wi * c * n + ci * n + ni;
}
SetData(size, false, src_idx, dst_idx, args, result);
}
}
}
} }
return true; return true;
} }
...@@ -575,8 +647,8 @@ bool NchwToFracZ(const FormatArgs &args, void *result) { ...@@ -575,8 +647,8 @@ bool NchwToFracZ(const FormatArgs &args, void *result) {
auto src_ni = hfi * kCubeSize + col; auto src_ni = hfi * kCubeSize + col;
auto src_idx = src_row_offset + chw * col; auto src_idx = src_row_offset + chw * col;
auto dst_idx = gfi * fractal_ele_cnt + col * c0 + row; auto dst_idx = gfi * fractal_ele_cnt + col * c0 + row;
auto pad_zero = (src_ni >= n || src_idx >= nchw || src_ci >= c) ? 1 : 0; auto pad_zero = (src_ni >= n || src_idx >= nchw || src_ci >= c) ? true : false;
SetDataBysize(size, pad_zero); SetData(size, pad_zero, src_idx, dst_idx, args, result);
} }
} }
} }
...@@ -630,7 +702,7 @@ bool FracZToNchw(const FormatArgs &args, void *result) { ...@@ -630,7 +702,7 @@ bool FracZToNchw(const FormatArgs &args, void *result) {
size_t c0_idx = c_idx % c0; size_t c0_idx = c_idx % c0;
size_t nc_idx = n_idx; size_t nc_idx = n_idx;
size_t src_idx = c1_idx * hwncc0 + h_idx * wncc0 + w_idx * ncc0 + nc_idx * c0 + c0_idx; size_t src_idx = c1_idx * hwncc0 + h_idx * wncc0 + w_idx * ncc0 + nc_idx * c0 + c0_idx;
SetDataBysize(size, 0); SetData(size, false, src_idx, dst_idx, args, result);
} }
} }
} }
...@@ -679,7 +751,7 @@ bool NchwToFracZc04(const FormatArgs &args, void *result) { ...@@ -679,7 +751,7 @@ bool NchwToFracZc04(const FormatArgs &args, void *result) {
auto c_idx = desc_c1 * c0 + desc_c0; auto c_idx = desc_c1 * c0 + desc_c0;
auto src_idx = desc_g * nhwc + desc_n * hwc + c_idx * h * w + desc_h * w + desc_w; auto src_idx = desc_g * nhwc + desc_n * hwc + c_idx * h * w + desc_h * w + desc_w;
auto pad_zero = desc_g >= 1 || desc_n >= n || c_idx >= c; auto pad_zero = desc_g >= 1 || desc_n >= n || c_idx >= c;
SetDataBysize(size, pad_zero); SetData(size, pad_zero, src_idx, dst_idx, args, result);
dst_idx++; dst_idx++;
} }
} }
...@@ -773,7 +845,7 @@ bool NchwToFracNz(const FormatArgs &args, void *result) { ...@@ -773,7 +845,7 @@ bool NchwToFracNz(const FormatArgs &args, void *result) {
for (size_t i = 0; i < w0; ++i) { for (size_t i = 0; i < w0; ++i) {
size_t src_idx = src_h_head + w1_idx * w0 + i; size_t src_idx = src_h_head + w1_idx * w0 + i;
size_t dst_idx = h1h0_head + w1_idx * h1h0w0 + i; size_t dst_idx = h1h0_head + w1_idx * h1h0w0 + i;
SetDataBysize(size, 0); SetData(size, false, src_idx, dst_idx, args, result);
} }
} }
auto w1_head = num_w1 * w0; auto w1_head = num_w1 * w0;
...@@ -781,7 +853,7 @@ bool NchwToFracNz(const FormatArgs &args, void *result) { ...@@ -781,7 +853,7 @@ bool NchwToFracNz(const FormatArgs &args, void *result) {
auto src_w_idx = w1_head + w0_idx; auto src_w_idx = w1_head + w0_idx;
size_t dst_idx = h1h0_head + num_w1 * h1h0w0 + w0_idx; size_t dst_idx = h1h0_head + num_w1 * h1h0w0 + w0_idx;
size_t src_idx = src_h_head + src_w_idx; size_t src_idx = src_h_head + src_w_idx;
SetDataBysize(size, 0); SetData(size, false, src_idx, dst_idx, args, result);
} }
} }
} }
...@@ -835,7 +907,7 @@ bool FracNzToNchw(const FormatArgs &args, void *result) { ...@@ -835,7 +907,7 @@ bool FracNzToNchw(const FormatArgs &args, void *result) {
for (size_t i = 0; i < w0; ++i) { for (size_t i = 0; i < w0; ++i) {
size_t src_idx = h1h0_head + w1_idx * h1h0w0 + i; size_t src_idx = h1h0_head + w1_idx * h1h0w0 + i;
size_t dst_idx = src_h_head + w1_idx * w0 + i; size_t dst_idx = src_h_head + w1_idx * w0 + i;
SetDataBysize(size, 0); SetData(size, false, src_idx, dst_idx, args, result);
} }
} }
auto w1_head = num_w1 * w0; auto w1_head = num_w1 * w0;
...@@ -843,7 +915,7 @@ bool FracNzToNchw(const FormatArgs &args, void *result) { ...@@ -843,7 +915,7 @@ bool FracNzToNchw(const FormatArgs &args, void *result) {
auto src_w_idx = w1_head + w0_idx; auto src_w_idx = w1_head + w0_idx;
size_t src_idx = h1h0_head + num_w1 * h1h0w0 + w0_idx; size_t src_idx = h1h0_head + num_w1 * h1h0w0 + w0_idx;
size_t dst_idx = src_h_head + src_w_idx; size_t dst_idx = src_h_head + src_w_idx;
SetDataBysize(size, 0); SetData(size, false, src_idx, dst_idx, args, result);
} }
} }
} }
...@@ -895,8 +967,8 @@ bool NchwToNc1hwc0(const FormatArgs &args, void *result) { ...@@ -895,8 +967,8 @@ bool NchwToNc1hwc0(const FormatArgs &args, void *result) {
size_t dst_idx = c0_idx + w_head_addr; size_t dst_idx = c0_idx + w_head_addr;
size_t c_idx = c0_idx + c1_idx * c0; size_t c_idx = c0_idx + c1_idx * c0;
size_t src_idx = n_idx * chw + c_idx * hw + h_idx * w + w_idx; size_t src_idx = n_idx * chw + c_idx * hw + h_idx * w + w_idx;
auto pad_zero = (c_idx < c) ? 0 : 1; auto pad_zero = (c_idx < c) ? false : true;
SetDataBysize(size, pad_zero); SetData(size, pad_zero, src_idx, dst_idx, args, result);
} }
} }
} }
...@@ -947,7 +1019,7 @@ bool Nc1hwc0ToNchw(const FormatArgs &args, void *result) { ...@@ -947,7 +1019,7 @@ bool Nc1hwc0ToNchw(const FormatArgs &args, void *result) {
size_t c1_idx = c_idx / c0; size_t c1_idx = c_idx / c0;
size_t c0_idx = c_idx % c0; size_t c0_idx = c_idx % c0;
size_t src_idx = n_idx * c1hwc0 + c1_idx * hwc0 + h_idx * wc0 + w_idx * c0 + c0_idx; size_t src_idx = n_idx * c1hwc0 + c1_idx * hwc0 + h_idx * wc0 + w_idx * c0 + c0_idx;
SetDataBysize(size, 0); SetData(size, false, src_idx, dst_idx, args, result);
} }
} }
} }
...@@ -983,8 +1055,8 @@ bool NchwToC1hwncoc0(const FormatArgs &args, void *result) { ...@@ -983,8 +1055,8 @@ bool NchwToC1hwncoc0(const FormatArgs &args, void *result) {
co_i * c0 + c0_i; co_i * c0 + c0_i;
size_t c_i = c0_i + c1_i * c0; size_t c_i = c0_i + c1_i * c0;
size_t src_idx = n_i * c * h * w + c_i * h * w + h_i * w + w_i; size_t src_idx = n_i * c * h * w + c_i * h * w + h_i * w + w_i;
auto pad_zero = (c_i < c && c0_i == co_i) ? 0 : 1; auto pad_zero = (c_i < c && c0_i == co_i) ? false : true;
SetDataBysize(size, pad_zero); SetData(size, pad_zero, src_idx, dst_idx, args, result);
} }
} }
} }
...@@ -1020,7 +1092,7 @@ bool C1hwncoc0ToNchw(const FormatArgs &args, void *result) { ...@@ -1020,7 +1092,7 @@ bool C1hwncoc0ToNchw(const FormatArgs &args, void *result) {
size_t co_i = c0_i; size_t co_i = c0_i;
size_t src_idx = size_t src_idx =
c1_i * h * w * n * co * c0 + h_i * w * n * co * c0 + w_i * n * co * c0 + n_i * co * c0 + co_i * c0 + c0_i; c1_i * h * w * n * co * c0 + h_i * w * n * co * c0 + w_i * n * co * c0 + n_i * co * c0 + co_i * c0 + c0_i;
SetDataBysize(size, 0); SetData(size, false, src_idx, dst_idx, args, result);
} }
} }
} }
......
...@@ -61,6 +61,7 @@ bool TransFormat(const FormatArgs &args, void *result); ...@@ -61,6 +61,7 @@ bool TransFormat(const FormatArgs &args, void *result);
bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result); bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result);
// host to device // host to device
bool NchwTo4D(const FormatArgs &args, void *result);
bool NchwToFracZ(const FormatArgs &args, void *result); bool NchwToFracZ(const FormatArgs &args, void *result);
bool NchwToFracNz(const FormatArgs &args, void *result); bool NchwToFracNz(const FormatArgs &args, void *result);
bool NchwToNc1hwc0(const FormatArgs &args, void *result); bool NchwToNc1hwc0(const FormatArgs &args, void *result);
...@@ -68,6 +69,7 @@ bool NchwToFracZc04(const FormatArgs &args, void *result); ...@@ -68,6 +69,7 @@ bool NchwToFracZc04(const FormatArgs &args, void *result);
bool NchwToNc1hwc04(const FormatArgs &args, void *result); bool NchwToNc1hwc04(const FormatArgs &args, void *result);
bool NchwToC1hwncoc0(const FormatArgs &args, void *result); bool NchwToC1hwncoc0(const FormatArgs &args, void *result);
// device to host // device to host
bool ToNchw(const FormatArgs &args, void *result);
bool FracZToNchw(const FormatArgs &args, void *result); bool FracZToNchw(const FormatArgs &args, void *result);
bool FracNzToNchw(const FormatArgs &args, void *result); bool FracNzToNchw(const FormatArgs &args, void *result);
bool Nc1hwc0ToNchw(const FormatArgs &args, void *result); bool Nc1hwc0ToNchw(const FormatArgs &args, void *result);
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "device/ascend/ascend_device_address.h" #include "device/ascend/ascend_device_address.h"
#include <memory> #include <memory>
#include <vector> #include <vector>
#include <set>
#include <algorithm> #include <algorithm>
#include "runtime/mem.h" #include "runtime/mem.h"
#include "device/kernel_runtime_manager.h" #include "device/kernel_runtime_manager.h"
...@@ -34,6 +35,10 @@ namespace device { ...@@ -34,6 +35,10 @@ namespace device {
namespace ascend { namespace ascend {
const int FLOAT_LEN = sizeof(float); const int FLOAT_LEN = sizeof(float);
const int FLOAT16_LEN = 2; // sizeof(float16); const int FLOAT16_LEN = 2; // sizeof(float16);
const std::set<std::string> kOpNeedTransFormat = {kOpFormat_NHWC, kOpFormat_HWCN, kOpFormat_NC1HWC0,
kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, kOpFormat_FRAC_NZ,
kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04};
void SyncMemory(void *dst, const void *src, uint64_t size, rtMemcpyKind_t kind) { void SyncMemory(void *dst, const void *src, uint64_t size, rtMemcpyKind_t kind) {
auto ret_rt_memcpy = rtMemcpy(dst, size, src, size, kind); auto ret_rt_memcpy = rtMemcpy(dst, size, src, size, kind);
if (ret_rt_memcpy != RT_ERROR_NONE) { if (ret_rt_memcpy != RT_ERROR_NONE) {
...@@ -97,7 +102,7 @@ bool AscendDeviceAddress::SyncDeviceToHost(const std::vector<int> &shape, size_t ...@@ -97,7 +102,7 @@ bool AscendDeviceAddress::SyncDeviceToHost(const std::vector<int> &shape, size_t
if (host_shape.empty()) { if (host_shape.empty()) {
host_shape.emplace_back(1); host_shape.emplace_back(1);
} }
if (format_ == kOpFormat_NCHW || format_ == kOpFormat_DEFAULT) { if (format_ == kOpFormat_NCHW || format_ == kOpFormat_DEFAULT || format_ == kOpFormat_NDHWC) {
if (type_id_ == type) { if (type_id_ == type) {
SyncMemory(host_ptr, ptr_, size, RT_MEMCPY_DEVICE_TO_HOST); SyncMemory(host_ptr, ptr_, size, RT_MEMCPY_DEVICE_TO_HOST);
sync_ok = true; sync_ok = true;
...@@ -115,9 +120,11 @@ bool AscendDeviceAddress::SyncDeviceToHost(const std::vector<int> &shape, size_t ...@@ -115,9 +120,11 @@ bool AscendDeviceAddress::SyncDeviceToHost(const std::vector<int> &shape, size_t
} }
} }
} else { } else {
auto iter = kNeedTransFormatSet.find(format_); auto iter = kOpNeedTransFormat.find(format_);
if (iter != kNeedTransFormatSet.end()) { if (iter != kOpNeedTransFormat.end()) {
sync_ok = SyncDeviceToHostAndConvertFormat(shape, size, type, host_ptr); sync_ok = SyncDeviceToHostAndConvertFormat(shape, size, type, host_ptr);
} else {
MS_LOG(INFO) << "Can not find format transfer for :" << format_;
} }
} }
if (!sync_ok) { if (!sync_ok) {
...@@ -141,7 +148,7 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const std::vector<int ...@@ -141,7 +148,7 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const std::vector<int
if (host_shape.empty()) { if (host_shape.empty()) {
host_shape.emplace_back(1); host_shape.emplace_back(1);
} }
if (format_ == kOpFormat_FRAC_NZ) { if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NDHWC) {
device_shape = trans::TransShapeToDevice(host_shape, format_); device_shape = trans::TransShapeToDevice(host_shape, format_);
} else { } else {
host_shape = trans::PaddingShapeTo4d(host_shape); host_shape = trans::PaddingShapeTo4d(host_shape);
...@@ -185,7 +192,7 @@ bool AscendDeviceAddress::SyncHostToDevice(const std::vector<int> &shape, size_t ...@@ -185,7 +192,7 @@ bool AscendDeviceAddress::SyncHostToDevice(const std::vector<int> &shape, size_t
if (host_shape.empty()) { if (host_shape.empty()) {
host_shape.emplace_back(1); host_shape.emplace_back(1);
} }
if (format_ == kOpFormat_NCHW || format_ == kOpFormat_DEFAULT) { if (format_ == kOpFormat_NCHW || format_ == kOpFormat_DEFAULT || format_ == kOpFormat_NDHWC) {
if (type_id_ == type) { if (type_id_ == type) {
SyncMemory(ptr_, host_ptr, size_, RT_MEMCPY_HOST_TO_DEVICE); SyncMemory(ptr_, host_ptr, size_, RT_MEMCPY_HOST_TO_DEVICE);
sync_ok = true; sync_ok = true;
...@@ -203,9 +210,11 @@ bool AscendDeviceAddress::SyncHostToDevice(const std::vector<int> &shape, size_t ...@@ -203,9 +210,11 @@ bool AscendDeviceAddress::SyncHostToDevice(const std::vector<int> &shape, size_t
SyncMemory(ptr_, host_tmp.data(), size_, RT_MEMCPY_HOST_TO_DEVICE); SyncMemory(ptr_, host_tmp.data(), size_, RT_MEMCPY_HOST_TO_DEVICE);
} }
} else { } else {
auto iter = kNeedTransFormatSet.find(format_); auto iter = kOpNeedTransFormat.find(format_);
if (iter != kNeedTransFormatSet.end()) { if (iter != kOpNeedTransFormat.end()) {
sync_ok = ConvertFormatAndSyncHostToDevice(shape, size, type, host_ptr); sync_ok = ConvertFormatAndSyncHostToDevice(shape, size, type, host_ptr);
} else {
MS_LOG(INFO) << "Can not find format transfer for :" << format_;
} }
} }
if (!sync_ok) { if (!sync_ok) {
...@@ -227,7 +236,7 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const std::vector<int ...@@ -227,7 +236,7 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const std::vector<int
host_shape.emplace_back(1); host_shape.emplace_back(1);
} }
std::vector<size_t> device_shape; std::vector<size_t> device_shape;
if (format_ == kOpFormat_FRAC_NZ) { if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NDHWC) {
device_shape = trans::TransShapeToDevice(host_shape, format_); device_shape = trans::TransShapeToDevice(host_shape, format_);
} else { } else {
host_shape = trans::PaddingShapeTo4d(host_shape); host_shape = trans::PaddingShapeTo4d(host_shape);
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include "common/common_test.h"
#include "common/trans.h"
#include "utils/utils.h"
using namespace std;
namespace mindspore {
namespace trans {
class FormatTransTest : public UT::Common {
public:
FormatTransTest() = default;
void SetUp() override {}
void TearDown() override {}
};
TEST_F(FormatTransTest, nchw_to_hwcn) {
uint16_t data[2*2*2*2] = {12581,14220,14937,14302,
15004,14951,14694,14564,
14069,14554,10507,14787,
13016,15263,14872,10838};
uint16_t res[2*2*2*2] = {12581,14069,15004,13016,
14220,14554,14951,15263,
14937,10507,14694,14872,
14302,14787,14564,10838};
size_t device_size = 32;
auto trans_tmp = std::vector<uint8_t>(device_size);
FormatArgs format_args{data, device_size, kOpFormat_NCHW, kOpFormat_HWCN,
{2, 2, 2, 2}, {2, 2, 2, 2}, kNumberTypeFloat16};
EXPECT_EQ(trans::TransFormat(format_args, trans_tmp.data()), true);
for (size_t i = 0; i < sizeof(res) / sizeof(res[0]); i++) {
EXPECT_EQ((reinterpret_cast<uint16_t *>(trans_tmp.data()))[i], res[i]);
}
}
TEST_F(FormatTransTest, hwcn_to_nchw) {
uint16_t data[2*2*2*2] = {12581,14069,15004,13016,
14220,14554,14951,15263,
14937,10507,14694,14872,
14302,14787,14564,10838};
uint16_t res[2*2*2*2] = {12581,14220,14937,14302,
15004,14951,14694,14564,
14069,14554,10507,14787,
13016,15263,14872,10838};
size_t device_size = 32;
auto trans_tmp = std::vector<uint8_t>(device_size);
FormatArgs format_args{data, device_size, kOpFormat_NCHW, kOpFormat_HWCN,
{2, 2, 2, 2}, {2, 2, 2, 2}, kNumberTypeFloat16};
EXPECT_EQ(trans::TransFormatFromDeviceToHost(format_args, trans_tmp.data()), true);
for (size_t i = 0; i < sizeof(res) / sizeof(res[0]); i++) {
EXPECT_EQ((reinterpret_cast<uint16_t *>(trans_tmp.data()))[i], res[i]);
}
}
TEST_F(FormatTransTest, nchw_to_nhwc) {
uint16_t data[2*2*2*2] = {11750,13778,15007,15321,
15163,13446,15063,14467,
15056,13284,15219,14797,
12684,14288,14855,14799};
uint16_t res[2*2*2*2] = {11750,15163,13778,13446,
15007,15063,15321,14467,
15056,12684,13284,14288,
15219,14855,14797,14799};
size_t device_size = 32;
auto trans_tmp = std::vector<uint8_t>(device_size);
FormatArgs format_args{data, device_size, kOpFormat_NCHW, kOpFormat_NHWC,
{2, 2, 2, 2}, {2, 2, 2, 2}, kNumberTypeFloat16};
EXPECT_EQ(trans::TransFormat(format_args, trans_tmp.data()), true);
for (size_t i = 0; i < sizeof(res) / sizeof(res[0]); i++) {
EXPECT_EQ((reinterpret_cast<uint16_t *>(trans_tmp.data()))[i], res[i]);
}
}
TEST_F(FormatTransTest, nhwc_to_nchw) {
uint16_t data[2*2*2*2] = {11750,15163,13778,13446,
15007,15063,15321,14467,
15056,12684,13284,14288,
15219,14855,14797,14799};
uint16_t res[2*2*2*2] = {11750,13778,15007,15321,
15163,13446,15063,14467,
15056,13284,15219,14797,
12684,14288,14855,14799};
size_t device_size = 32;
auto trans_tmp = std::vector<uint8_t>(device_size);
FormatArgs format_args{data, device_size, kOpFormat_NCHW, kOpFormat_NHWC,
{2, 2, 2, 2}, {2, 2, 2, 2}, kNumberTypeFloat16};
EXPECT_EQ(trans::TransFormatFromDeviceToHost(format_args, trans_tmp.data()), true);
for (size_t i = 0; i < sizeof(res) / sizeof(res[0]); i++) {
EXPECT_EQ((reinterpret_cast<uint16_t *>(trans_tmp.data()))[i], res[i]);
}
}
} // namespace trans
} // namespace mindspore
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册