提交 41c969ab 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!414 add 6d format transfer

Merge pull request !414 from liubuyu/dev_lby
......@@ -231,7 +231,98 @@ std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std
return shape_4d;
}
namespace {
bool CheckDims(const std::vector<size_t> &shape) {
if (shape.size() != 4) {
MS_LOG(ERROR) << "Host shape dims shoud be 4";
return false;
}
return true;
}
std::vector<size_t> NchwDeviceShape(const std::vector<size_t> &shape) {
if (!CheckDims(shape)) {
MS_LOG(EXCEPTION) << "Check dims failed.";
}
return shape;
}
std::vector<size_t> NhwcDeviceShape(const std::vector<size_t> &shape) {
if (!CheckDims(shape)) {
MS_LOG(EXCEPTION) << "Ccheck dims failed.";
}
std::vector<size_t> device_shape;
device_shape.push_back(shape[0]);
device_shape.push_back(shape[2]);
device_shape.push_back(shape[3]);
device_shape.push_back(shape[1]);
return device_shape;
}
std::vector<size_t> HwchDeviceShape(const std::vector<size_t> &shape) {
if (!CheckDims(shape)) {
MS_LOG(EXCEPTION) << "Check dims failed.";
}
std::vector<size_t> device_shape;
device_shape.push_back(shape[2]);
device_shape.push_back(shape[3]);
device_shape.push_back(shape[1]);
device_shape.push_back(shape[0]);
return device_shape;
}
std::vector<size_t> FracZDeviceShape(const std::vector<size_t> &shape) {
if (!CheckDims(shape)) {
MS_LOG(EXCEPTION) << "Check dims failed.";
}
std::vector<size_t> device_shape;
size_t cout16 = ((shape[0] + kCubeSize - 1) / kCubeSize) * kCubeSize;
size_t cin16 = ((shape[1] + kCubeSize - 1) / kCubeSize) * kCubeSize;
device_shape.push_back(shape[2] * shape[3] * cin16 / kCubeSize);
device_shape.push_back(cout16 / kCubeSize);
device_shape.push_back(kCubeSize);
device_shape.push_back(kCubeSize);
return device_shape;
}
std::vector<size_t> Nc1hwc0DeviceShape(const std::vector<size_t> &shape) {
if (!CheckDims(shape)) {
MS_LOG(EXCEPTION) << "Check dims failed.";
}
std::vector<size_t> device_shape;
size_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize;
size_t C0 = kCubeSize;
device_shape.push_back(shape[0]);
device_shape.push_back(C1);
device_shape.push_back(shape[2]);
device_shape.push_back(shape[3]);
device_shape.push_back(C0);
return device_shape;
}
std::vector<size_t> C1hwncoc0DeviceShape(const std::vector<size_t> &shape) {
if (!CheckDims(shape)) {
MS_LOG(EXCEPTION) << "Check dims failed.";
}
std::vector<size_t> device_shape;
device_shape.push_back((shape[1] - 1) / kCubeSize + 1);
device_shape.push_back(shape[2]);
device_shape.push_back(shape[3]);
device_shape.push_back(shape[0]);
device_shape.push_back(kCubeSize);
device_shape.push_back(kCubeSize);
return device_shape;
}
} // namespace
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format) {
using DeviceShapeTransfer = std::function<std::vector<size_t>(const std::vector<size_t> &)>;
const std::map<std::string, DeviceShapeTransfer> device_shape_map{
{kOpFormat_NCHW, NchwDeviceShape}, {kOpFormat_NHWC, NhwcDeviceShape},
{kOpFormat_HWCN, HwchDeviceShape}, {kOpFormat_FRAC_Z, FracZDeviceShape},
{kOpFormat_NC1HWC0, Nc1hwc0DeviceShape}, {kOpFormat_C1HWNCoC0, C1hwncoc0DeviceShape},
};
if (format == kOpFormat_ND || format == kOpFormat_DEFAULT) {
return shape;
}
......@@ -255,37 +346,31 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s
MS_LOG(WARNING) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly";
temp_shape = PaddingShapeTo4dByDefault(shape);
}
if (format == kOpFormat_NC1HWC0) {
size_t C1 = (temp_shape[1] + kCubeSize - 1) / kCubeSize;
size_t C0 = kCubeSize;
device_shape.push_back(temp_shape[0]);
device_shape.push_back(C1);
device_shape.push_back(temp_shape[2]);
device_shape.push_back(temp_shape[3]);
device_shape.push_back(C0);
return device_shape;
} else if (format == kOpFormat_FRAC_Z) {
size_t cout16 = ((temp_shape[0] + kCubeSize - 1) / kCubeSize) * kCubeSize;
size_t cin16 = ((temp_shape[1] + kCubeSize - 1) / kCubeSize) * kCubeSize;
device_shape.push_back(temp_shape[2] * temp_shape[3] * cin16 / kCubeSize);
device_shape.push_back(cout16 / kCubeSize);
device_shape.push_back(kCubeSize);
device_shape.push_back(kCubeSize);
return device_shape;
} else if (format == kOpFormat_NHWC) {
device_shape.push_back(temp_shape[0]);
device_shape.push_back(temp_shape[2]);
device_shape.push_back(temp_shape[3]);
device_shape.push_back(temp_shape[1]);
return device_shape;
} else if (format == kOpFormat_HWCN) {
return {temp_shape[2], temp_shape[3], temp_shape[1], temp_shape[0]};
} else if (format == kOpFormat_NCHW) {
return temp_shape;
auto iter = device_shape_map.find(format);
if (iter != device_shape_map.end()) {
return iter->second(temp_shape);
}
MS_LOG(EXCEPTION) << "Unexpected format[" << format << "]";
}
bool CheckArgs(const FormatArgs &args, size_t *size, size_t *total_size) {
if (args.host_shape.size() != kNchwDims) {
MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
return false;
}
*size = TypeIdSize(args.src_data_type);
if (*size < 1) {
MS_LOG(ERROR) << "Illegal dtype.";
return false;
}
*total_size = ShapeSize(args.device_shape) * (*size);
if (*total_size != args.device_size) {
MS_LOG(ERROR) << "Illegal total data size, total_size:" << *total_size << ", device_size:" << args.device_size;
return false;
}
return true;
}
bool TransDataType(const TypeIdArgs &args, void *result) {
MS_LOG(DEBUG) << "Begin trans datatype from " << TypeIdLabel(args.host_data_type) << " to "
<< TypeIdLabel(args.device_data_type);
......@@ -320,13 +405,14 @@ bool TransFormat(const FormatArgs &args, void *result) {
MS_LOG(ERROR) << "Invalid datatype..";
return false;
}
if ((args.host_format == kOpFormat_NCHW || args.host_format == kOpFormat_ND) &&
args.device_format == kOpFormat_FRAC_Z) {
if (args.device_format == kOpFormat_FRAC_Z) {
return NchwToFracZ(args, result);
} else if (args.device_format == kOpFormat_FRAC_NZ) {
return NchwToFracNz(args, result);
} else if (args.device_format == kOpFormat_NC1HWC0) {
return NchwToNc1hwc0(args, result);
} else if (args.device_format == kOpFormat_C1HWNCoC0) {
return NchwToC1hwncoc0(args, result);
}
return true;
}
......@@ -337,13 +423,14 @@ bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result) {
MS_LOG(ERROR) << "Invalid datatype..";
return false;
}
if ((args.host_format == kOpFormat_NCHW || args.host_format == kOpFormat_ND) &&
args.device_format == kOpFormat_FRAC_Z) {
if (args.device_format == kOpFormat_FRAC_Z) {
return FracZToNchw(args, result);
} else if (args.device_format == kOpFormat_FRAC_NZ) {
return FracNzToNchw(args, result);
} else if (args.device_format == kOpFormat_NC1HWC0) {
return Nc1hwc0ToNchw(args, result);
} else if (args.device_format == kOpFormat_C1HWNCoC0) {
return C1hwncoc0ToNchw(args, result);
}
return true;
}
......@@ -801,5 +888,99 @@ bool Nc1hwc0ToNchw(const FormatArgs &args, void *result) {
}
return true;
}
bool NchwToC1hwncoc0(const FormatArgs &args, void *result) {
// trans nchw to c1hwncoc0
MS_LOG(DEBUG) << "Trans format from nchw to c1hwncoc0.";
MS_EXCEPTION_IF_NULL(result);
size_t size = 0;
size_t total_size = 0;
if (!CheckArgs(args, &size, &total_size)) {
MS_LOG(ERROR) << "Check args failed.";
return false;
}
auto n = args.host_shape[0];
auto c = args.host_shape[1];
auto h = args.host_shape[2];
auto w = args.host_shape[3];
auto c1 = args.device_shape[0];
auto co = args.device_shape[4];
auto c0 = args.device_shape[5];
for (size_t c1_i = 0; c1_i < c1; c1_i++) {
for (size_t h_i = 0; h_i < h; h_i++) {
for (size_t w_i = 0; w_i < w; w_i++) {
for (size_t n_i = 0; n_i < n; n_i++) {
for (size_t co_i = 0; co_i < co; co_i++) {
for (size_t c0_i = 0; c0_i < c0; c0_i++) {
size_t dst_offset = (c1_i * h * w * n * co * c0 + h_i * w * n * co * c0 + w_i * n * co * c0 +
n_i * co * c0 + co_i * c0 + c0_i) *
size;
size_t protected_size = total_size - dst_offset < static_cast<size_t>(SECUREC_MEM_MAX_LEN)
? total_size - dst_offset
: static_cast<size_t>(SECUREC_MEM_MAX_LEN);
size_t c_i = c0_i + c1_i * c0;
size_t src_offset = (n_i * c * h * w + c_i * h * w + h_i * w + w_i) * size;
error_t ret;
if (c_i < c && c0_i == co_i) {
ret = memcpy_s(static_cast<uint8_t *>(result) + dst_offset, protected_size,
static_cast<uint8_t const *>(args.data) + src_offset, size);
} else {
ret = memset_s(static_cast<uint8_t *>(result) + dst_offset, protected_size, 0, size);
}
if (ret != EOK) {
MS_LOG(ERROR) << "Failed to operate the dst memory, error-code:" << ret;
return false;
}
}
}
}
}
}
}
return true;
}
bool C1hwncoc0ToNchw(const FormatArgs &args, void *result) {
// trans c1hwncoc0 to nchw
MS_LOG(DEBUG) << "Trans format from c1hwncoc0 to nchw";
MS_EXCEPTION_IF_NULL(result);
size_t size = 0;
size_t total_size = 0;
if (!CheckArgs(args, &size, &total_size)) {
MS_LOG(ERROR) << "Check args failed.";
return false;
}
auto n = args.host_shape[0];
auto c = args.host_shape[1];
auto h = args.host_shape[2];
auto w = args.host_shape[3];
auto co = args.device_shape[4];
auto c0 = args.device_shape[5];
for (size_t n_i = 0; n_i < n; n_i++) {
for (size_t c_i = 0; c_i < c; c_i++) {
for (size_t h_i = 0; h_i < h; h_i++) {
for (size_t w_i = 0; w_i < w; w_i++) {
size_t dst_offset = (n_i * c * h * w + c_i * h * w + h_i * w + w_i) * size;
size_t c1_i = c_i / kCubeSize;
size_t c0_i = c_i % kCubeSize;
size_t co_i = c0_i;
size_t src_offset = (c1_i * h * w * n * co * c0 + h_i * w * n * co * c0 + w_i * n * co * c0 + n_i * co * c0 +
co_i * c0 + c0_i) *
size;
size_t protected_size = total_size - dst_offset < static_cast<size_t>(SECUREC_MEM_MAX_LEN)
? total_size - dst_offset
: static_cast<size_t>(SECUREC_MEM_MAX_LEN);
auto ret = memcpy_s(static_cast<uint8_t *>(result) + dst_offset, protected_size,
static_cast<uint8_t const *>(args.data) + src_offset, size);
if (ret != EOK) {
MS_LOG(ERROR) << "Failed to operate the dst memory, error-code:" << ret;
return false;
}
}
}
}
}
return true;
}
} // namespace trans
} // namespace mindspore
......@@ -63,10 +63,12 @@ bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result);
bool NchwToFracZ(const FormatArgs &args, void *result);
bool NchwToFracNz(const FormatArgs &args, void *result);
bool NchwToNc1hwc0(const FormatArgs &args, void *result);
bool NchwToC1hwncoc0(const FormatArgs &args, void *result);
// device to host
bool FracZToNchw(const FormatArgs &args, void *result);
bool FracNzToNchw(const FormatArgs &args, void *result);
bool Nc1hwc0ToNchw(const FormatArgs &args, void *result);
bool C1hwncoc0ToNchw(const FormatArgs &args, void *result);
} // namespace trans
} // namespace mindspore
......
......@@ -114,8 +114,11 @@ bool AscendDeviceAddress::SyncDeviceToHost(const std::vector<int> &shape, size_t
return false;
}
}
} else if (format_ == kOpFormat_NC1HWC0 || format_ == kOpFormat_FRAC_Z || format_ == kOpFormat_FRAC_NZ) {
sync_ok = SyncDeviceToHostAndConvertFormat(shape, size, type, host_ptr);
} else {
auto iter = kNeedTransFormatSet.find(format_);
if (iter != kNeedTransFormatSet.end()) {
sync_ok = ConvertFormatAndSyncHostToDevice(shape, size, type, host_ptr);
}
}
if (!sync_ok) {
MS_LOG(ERROR) << "Not support to trans, dev_format:" << format_ << ", dev_type:" << TypeIdLabel(type_id_)
......@@ -199,9 +202,12 @@ bool AscendDeviceAddress::SyncHostToDevice(const std::vector<int> &shape, size_t
}
SyncMemory(ptr_, host_tmp.data(), size_, RT_MEMCPY_HOST_TO_DEVICE);
}
} else if (format_ == kOpFormat_NC1HWC0 || format_ == kOpFormat_FRAC_Z || format_ == kOpFormat_FRAC_NZ) {
} else {
auto iter = kNeedTransFormatSet.find(format_);
if (iter != kNeedTransFormatSet.end()) {
sync_ok = ConvertFormatAndSyncHostToDevice(shape, size, type, host_ptr);
}
}
if (!sync_ok) {
MS_LOG(ERROR) << "Not support to trans, dev_format:" << format_ << ", dev_type:" << TypeIdLabel(type_id_)
<< ", host_type:" << TypeIdLabel(type);
......
......@@ -187,7 +187,9 @@ constexpr auto kOpFormat_FRAC_NZ = "FRACTAL_NZ";
constexpr auto kOpFormat_C1HWNCoC0 = "C1HWNCoC0";
constexpr auto kOpFormat_NC1HWC0_C04 = "NC1HWC0_C04";
const std::set<std::string> k1DSupportFormat = {kOpFormat_DEFAULT, kOpFormat_NCHW, kOpFormat_NHWC,
kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0};
kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0,
kOpFormat_C1HWNCoC0};
const std::set<std::string> k2DSupportFormat = {kOpFormat_DEFAULT, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_FRAC_Z,
kOpFormat_NC1KHKWHWC0};
const std::set<std::string> k3DSupportFormat = {kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册