提交 5fb1ec2a 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2043 code review

Merge pull request !2043 from liubuyu/master
......@@ -14,11 +14,9 @@
* limitations under the License.
*/
#include "common/trans.h"
#include <algorithm>
#include <functional>
#include <numeric>
#include <utility>
#include "./securec.h"
#include "common/utils.h"
#include "session/anf_runtime_algorithm.h"
#include "kernel/kernel.h"
......@@ -29,34 +27,7 @@
namespace mindspore {
namespace trans {
namespace {
std::vector<size_t> PaddingShapeTo4dByDefault(const std::vector<size_t> &shape) {
std::vector<size_t> shape_4d(4, 1);
switch (shape.size()) {
case 0:
return shape_4d;
case 1:
shape_4d[1] = shape[0];
break;
case 2:
shape_4d[1] = shape[0];
shape_4d[2] = shape[1];
break;
case 3:
shape_4d[1] = shape[0];
shape_4d[2] = shape[1];
shape_4d[3] = shape[2];
break;
case 4:
std::copy(shape.begin(), shape.end(), shape_4d.begin());
break;
default:
MS_LOG(EXCEPTION) << "Unexpect shape size = " << shape.size();
}
return shape_4d;
}
} // namespace
const size_t kNchwDims = 4;
enum kAxis : int { kN = 0, kC, kH, kW, kNchwDims, kNdhwc };
const std::map<TypeId, size_t> type_map = {{kNumberTypeBool, 1}, {kNumberTypeInt, 4}, {kNumberTypeInt8, 1},
{kNumberTypeInt16, 2}, {kNumberTypeInt32, 4}, {kNumberTypeInt64, 8},
{kNumberTypeUInt, 4}, {kNumberTypeUInt8, 1}, {kNumberTypeUInt16, 2},
......@@ -84,7 +55,10 @@ inline void SetData(size_t size, bool pad_zero, size_t src_idx, size_t dst_idx,
template <typename T>
T DivCeil(T n1, T n2) {
return (n2 != 0) ? (n1 - 1) / n2 + 1 : 0;
if (n2 != 0) {
return (n1 - 1) / n2 + 1;
}
return 0;
}
enum DataTypeTransMode {
......@@ -226,8 +200,7 @@ size_t CubeSizeByType(const TypeId data_type) {
}
size_t ShapeSize(const std::vector<size_t> &shape) {
size_t product = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<size_t>());
return product;
return std::accumulate(shape.begin(), shape.end(), IntToSize(1), std::multiplies<size_t>());
}
size_t TypeIdSize(const TypeId data_type) {
......@@ -239,57 +212,9 @@ size_t TypeIdSize(const TypeId data_type) {
return unsupported_type_error;
}
bool IsNeedPadding(const std::string &format, const size_t shape_size) {
if (shape_size == 0) {
return false;
}
if (format == kOpFormat_DEFAULT || format == kOpFormat_FRAC_NZ) {
return false;
} else if (shape_size < 4) {
return true;
}
return false;
}
std::vector<int> GetRuntimePaddingShape(const AnfNodePtr &node, size_t index) {
std::vector<int> shape;
std::vector<size_t> host_shape;
if (node->isa<ValueNode>()) {
auto value_node = node->cast<ValueNodePtr>();
auto node_value = value_node->value();
auto tensor = node_value->cast<tensor::TensorPtr>();
if (tensor == nullptr) {
MS_LOG(EXCEPTION) << " the node[ " << node->DebugString() << "]'s cannot convert ";
}
auto shape_temp = tensor->shape();
(void)std::transform(shape_temp.begin(), shape_temp.end(), std::back_inserter(host_shape), IntToSize);
if (host_shape.empty()) {
host_shape.push_back(1);
}
} else {
host_shape = AnfAlgo::GetOutputInferShape(node, index);
}
if (trans::IsNeedPadding(AnfAlgo::GetOutputFormat(node, 0), host_shape.size())) {
host_shape = trans::PaddingShapeTo4d(host_shape, AnfAlgo::GetOutputReshapeType(node, 0));
}
std::transform(host_shape.begin(), host_shape.end(), std::back_inserter(shape), SizeToInt);
return shape;
}
std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std::vector<kernel::Axis> &padding_axis) {
if (padding_axis.empty() || shape.size() != padding_axis.size()) {
return PaddingShapeTo4dByDefault(shape);
}
std::vector<size_t> shape_4d(4, 1);
for (size_t index = 0; index < padding_axis.size(); index++) {
shape_4d[padding_axis[index]] = shape[index];
}
return shape_4d;
}
namespace {
bool CheckDims(const std::vector<size_t> &shape) {
if (shape.size() != 4) {
if (shape.size() != kNchwDims) {
MS_LOG(ERROR) << "Host shape dims shoud be 4";
return false;
}
......@@ -308,10 +233,10 @@ std::vector<size_t> NhwcDeviceShape(const std::vector<size_t> &shape) {
MS_LOG(EXCEPTION) << "Ccheck dims failed.";
}
std::vector<size_t> device_shape;
device_shape.push_back(shape[0]);
device_shape.push_back(shape[2]);
device_shape.push_back(shape[3]);
device_shape.push_back(shape[1]);
device_shape.push_back(shape[kN]);
device_shape.push_back(shape[kH]);
device_shape.push_back(shape[kW]);
device_shape.push_back(shape[kC]);
return device_shape;
}
......@@ -320,10 +245,10 @@ std::vector<size_t> HwchDeviceShape(const std::vector<size_t> &shape) {
MS_LOG(EXCEPTION) << "Check dims failed.";
}
std::vector<size_t> device_shape;
device_shape.push_back(shape[2]);
device_shape.push_back(shape[3]);
device_shape.push_back(shape[1]);
device_shape.push_back(shape[0]);
device_shape.push_back(shape[kH]);
device_shape.push_back(shape[kW]);
device_shape.push_back(shape[kC]);
device_shape.push_back(shape[kN]);
return device_shape;
}
......@@ -332,9 +257,9 @@ std::vector<size_t> FracZDeviceShape(const std::vector<size_t> &shape) {
MS_LOG(EXCEPTION) << "Check dims failed.";
}
std::vector<size_t> device_shape;
size_t cout16 = ((shape[0] + kCubeSize - 1) / kCubeSize) * kCubeSize;
size_t cin16 = ((shape[1] + kCubeSize - 1) / kCubeSize) * kCubeSize;
device_shape.push_back(shape[2] * shape[3] * cin16 / kCubeSize);
const size_t cout16 = ((shape[kN] + kCubeSize - 1) / kCubeSize) * kCubeSize;
const size_t cin16 = ((shape[kC] + kCubeSize - 1) / kCubeSize) * kCubeSize;
device_shape.push_back(shape[kH] * shape[kW] * cin16 / kCubeSize);
device_shape.push_back(cout16 / kCubeSize);
device_shape.push_back(kCubeSize);
device_shape.push_back(kCubeSize);
......@@ -346,12 +271,12 @@ std::vector<size_t> Nc1hwc0DeviceShape(const std::vector<size_t> &shape) {
MS_LOG(EXCEPTION) << "Check dims failed.";
}
std::vector<size_t> device_shape;
size_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize;
size_t C0 = kCubeSize;
device_shape.push_back(shape[0]);
const size_t C1 = (shape[kC] + kCubeSize - 1) / kCubeSize;
const size_t C0 = kCubeSize;
device_shape.push_back(shape[kN]);
device_shape.push_back(C1);
device_shape.push_back(shape[2]);
device_shape.push_back(shape[3]);
device_shape.push_back(shape[kH]);
device_shape.push_back(shape[kW]);
device_shape.push_back(C0);
return device_shape;
}
......@@ -361,10 +286,10 @@ std::vector<size_t> C1hwncoc0DeviceShape(const std::vector<size_t> &shape) {
MS_LOG(EXCEPTION) << "Check dims failed.";
}
std::vector<size_t> device_shape;
device_shape.push_back((shape[1] - 1) / kCubeSize + 1);
device_shape.push_back(shape[2]);
device_shape.push_back(shape[3]);
device_shape.push_back(shape[0]);
device_shape.push_back((shape[kC] - 1) / kCubeSize + 1);
device_shape.push_back(shape[kH]);
device_shape.push_back(shape[kW]);
device_shape.push_back(shape[kN]);
device_shape.push_back(kCubeSize);
device_shape.push_back(kCubeSize);
return device_shape;
......@@ -375,9 +300,9 @@ std::vector<size_t> FracZc04DeviceShape(const std::vector<size_t> &shape) {
MS_LOG(EXCEPTION) << "Check dims failed.";
}
std::vector<size_t> device_shape;
size_t c0 = 4;
auto first_dim = DivCeil(c0 * shape.at(2) * shape.at(3), kCubeSize);
auto no = DivCeil(shape.at(0), kCubeSize);
const size_t c0 = 4;
auto first_dim = DivCeil(c0 * shape[kH] * shape[kW], kCubeSize);
auto no = DivCeil(shape.at(kN), kCubeSize);
device_shape.push_back(first_dim);
device_shape.push_back(no);
device_shape.push_back(kCubeSize);
......@@ -390,24 +315,101 @@ std::vector<size_t> Nc1hwc04DeviceShape(const std::vector<size_t> &shape) {
MS_LOG(EXCEPTION) << "Check dims failed.";
}
std::vector<size_t> device_shape;
size_t C1 = 1;
size_t C0 = 4;
device_shape.push_back(shape[0]);
const size_t C1 = 1;
const size_t C0 = 4;
device_shape.push_back(shape[kN]);
device_shape.push_back(C1);
device_shape.push_back(shape[2]);
device_shape.push_back(shape[3]);
device_shape.push_back(shape[kH]);
device_shape.push_back(shape[kW]);
device_shape.push_back(C0);
return device_shape;
}
std::vector<size_t> NdhwcDeviceShape(const std::vector<size_t> &shape) {
if (shape.size() < 5) {
if (shape.size() < kNdhwc) {
MS_LOG(EXCEPTION) << "Shape dims must be 5 when format is ndhwc.";
}
return shape;
}
std::vector<size_t> PaddingShapeTo4dByDefault(const std::vector<size_t> &shape) {
std::vector<size_t> shape_4d(kNchwDims, 1);
switch (shape.size()) {
case 0:
return shape_4d;
case 1:
shape_4d[kC] = shape[kN];
break;
case 2:
shape_4d[kC] = shape[kN];
shape_4d[kH] = shape[kC];
break;
case 3:
shape_4d[kC] = shape[kN];
shape_4d[kH] = shape[kC];
shape_4d[kW] = shape[kH];
break;
case 4:
std::copy(shape.begin(), shape.end(), shape_4d.begin());
break;
default:
MS_LOG(EXCEPTION) << "Unexpect shape size = " << shape.size();
}
return shape_4d;
}
} // namespace
bool IsNeedPadding(const std::string &format, const size_t shape_size) {
if (shape_size == 0) {
return false;
}
if (format == kOpFormat_DEFAULT || format == kOpFormat_FRAC_NZ) {
return false;
} else if (shape_size < kNchwDims) {
return true;
}
return false;
}
std::vector<int> GetRuntimePaddingShape(const AnfNodePtr &node, size_t index) {
MS_EXCEPTION_IF_NULL(node);
std::vector<int> shape;
std::vector<size_t> host_shape;
if (node->isa<ValueNode>()) {
auto value_node = node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
auto node_value = value_node->value();
MS_EXCEPTION_IF_NULL(node_value);
auto tensor = node_value->cast<tensor::TensorPtr>();
if (tensor == nullptr) {
MS_LOG(EXCEPTION) << " The node[ " << node->DebugString() << "]'s cannot convert ";
}
auto shape_temp = tensor->shape();
(void)std::transform(shape_temp.begin(), shape_temp.end(), std::back_inserter(host_shape), IntToSize);
if (host_shape.empty()) {
host_shape.push_back(1);
}
} else {
host_shape = AnfAlgo::GetOutputInferShape(node, index);
}
if (trans::IsNeedPadding(AnfAlgo::GetOutputFormat(node, 0), host_shape.size())) {
host_shape = trans::PaddingShapeTo4d(host_shape, AnfAlgo::GetOutputReshapeType(node, 0));
}
std::transform(host_shape.begin(), host_shape.end(), std::back_inserter(shape), SizeToInt);
return shape;
}
std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std::vector<kernel::Axis> &padding_axis) {
if (padding_axis.empty() || shape.size() != padding_axis.size()) {
return PaddingShapeTo4dByDefault(shape);
}
std::vector<size_t> shape_4d(kNchwDims, 1);
for (size_t index = 0; index < padding_axis.size(); index++) {
shape_4d[padding_axis[index]] = shape[index];
}
return shape_4d;
}
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format) {
using DeviceShapeTransfer = std::function<std::vector<size_t>(const std::vector<size_t> &)>;
const std::map<std::string, DeviceShapeTransfer> device_shape_map{{kOpFormat_NCHW, NchwDeviceShape},
......@@ -439,7 +441,7 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s
device_shape.push_back(kCubeSize);
return device_shape;
}
if (shape.size() != 4) {
if (shape.size() != kNchwDims) {
MS_LOG(WARNING) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly";
temp_shape = PaddingShapeTo4dByDefault(shape);
}
......@@ -455,6 +457,8 @@ bool CheckArgs(const FormatArgs &args, size_t *size, size_t *total_size) {
MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
return false;
}
MS_EXCEPTION_IF_NULL(size);
MS_EXCEPTION_IF_NULL(total_size);
*size = TypeIdSize(args.src_data_type);
if (*size < 1) {
MS_LOG(ERROR) << "Illegal dtype.";
......@@ -540,10 +544,10 @@ bool NchwTo4D(const FormatArgs &args, void *result) {
MS_LOG(ERROR) << "Check args failed.";
return false;
}
size_t n = args.host_shape[0];
size_t c = args.host_shape[1];
size_t h = args.host_shape[2];
size_t w = args.host_shape[3];
auto n = args.host_shape[kN];
auto c = args.host_shape[kC];
auto h = args.host_shape[kH];
auto w = args.host_shape[kW];
for (size_t ni = 0; ni < n; ni++) {
for (size_t ci = 0; ci < c; ci++) {
for (size_t hi = 0; hi < h; hi++) {
......@@ -572,10 +576,10 @@ bool ToNchw(const FormatArgs &args, void *result) {
MS_LOG(ERROR) << "Check args failed.";
return false;
}
size_t n = args.host_shape[0];
size_t c = args.host_shape[1];
size_t h = args.host_shape[2];
size_t w = args.host_shape[3];
auto n = args.host_shape[kN];
auto c = args.host_shape[kC];
auto h = args.host_shape[kH];
auto w = args.host_shape[kW];
for (size_t ni = 0; ni < n; ni++) {
for (size_t ci = 0; ci < c; ci++) {
for (size_t hi = 0; hi < h; hi++) {
......@@ -602,32 +606,32 @@ bool NchwToFracZ(const FormatArgs &args, void *result) {
MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
return false;
}
size_t size = TypeIdSize(args.src_data_type);
auto size = TypeIdSize(args.src_data_type);
if (size < 1) {
MS_LOG(ERROR) << "Illegal dtype.";
return false;
}
auto n = args.host_shape[0];
auto c = args.host_shape[1];
auto h = args.host_shape[2];
auto w = args.host_shape[3];
auto n = args.host_shape[kN];
auto c = args.host_shape[kC];
auto h = args.host_shape[kH];
auto w = args.host_shape[kW];
size_t c0 = CubeSizeByType(args.src_data_type);
auto c0 = CubeSizeByType(args.src_data_type);
if (c0 < 1) {
MS_LOG(ERROR) << "Illegal dtype.";
return false;
}
size_t c1 = DivCeil(c, c0);
size_t hw = h * w;
size_t chw = c * hw;
size_t hwc0 = hw * c0;
size_t nchw = n * chw;
size_t hf_cnt = DivCeil(n, kCubeSize);
size_t vf_cnt = c1 * hw;
size_t fractal_ele_cnt = c0 * kCubeSize;
size_t total_ele_cnt = hf_cnt * vf_cnt * fractal_ele_cnt;
size_t dst_size = total_ele_cnt * size;
auto c1 = DivCeil(c, c0);
auto hw = h * w;
auto chw = c * hw;
auto hwc0 = hw * c0;
auto nchw = n * chw;
auto hf_cnt = DivCeil(n, kCubeSize);
auto vf_cnt = c1 * hw;
auto fractal_ele_cnt = c0 * kCubeSize;
auto total_ele_cnt = hf_cnt * vf_cnt * fractal_ele_cnt;
auto dst_size = total_ele_cnt * size;
if (dst_size != args.device_size) {
MS_LOG(ERROR) << "Illegal total data size."
<< "dst size is :" << dst_size << "device size is :" << args.device_size;
......@@ -647,7 +651,7 @@ bool NchwToFracZ(const FormatArgs &args, void *result) {
auto src_ni = hfi * kCubeSize + col;
auto src_idx = src_row_offset + chw * col;
auto dst_idx = gfi * fractal_ele_cnt + col * c0 + row;
auto pad_zero = (src_ni >= n || src_idx >= nchw || src_ci >= c) ? true : false;
auto pad_zero = src_ni >= n || src_idx >= nchw || src_ci >= c;
SetData(size, pad_zero, src_idx, dst_idx, args, result);
}
}
......@@ -663,12 +667,12 @@ bool FracZToNchw(const FormatArgs &args, void *result) {
MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
return false;
}
size_t size = TypeIdSize(args.src_data_type);
auto size = TypeIdSize(args.src_data_type);
if (size < 1) {
MS_LOG(ERROR) << "Illegal dtype.";
return false;
}
size_t total_size = ShapeSize(args.device_shape) * size;
auto total_size = ShapeSize(args.device_shape) * size;
if (total_size != args.device_size) {
MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
return false;
......@@ -677,18 +681,16 @@ bool FracZToNchw(const FormatArgs &args, void *result) {
auto n0 = args.device_shape.at(1);
auto ni = args.device_shape.at(2);
auto c0 = args.device_shape.at(3);
auto n = args.host_shape[0];
auto c = args.host_shape[1];
auto h = args.host_shape[2];
auto w = args.host_shape[3];
size_t nc = ni * n0;
size_t ncc0 = nc * c0;
size_t wncc0 = w * ncc0;
size_t hwncc0 = h * wncc0;
size_t hw = h * w;
size_t chw = c * hw;
auto n = args.host_shape[kN];
auto c = args.host_shape[kC];
auto h = args.host_shape[kH];
auto w = args.host_shape[kW];
auto nc = ni * n0;
auto ncc0 = nc * c0;
auto wncc0 = w * ncc0;
auto hwncc0 = h * wncc0;
auto hw = h * w;
auto chw = c * hw;
for (size_t n_idx = 0; n_idx < n; n_idx++) {
size_t n_head_addr = n_idx * chw;
......@@ -720,20 +722,18 @@ bool NchwToFracZc04(const FormatArgs &args, void *result) {
MS_LOG(ERROR) << "Check args failed.";
return false;
}
size_t cube = kCubeSize;
size_t n = args.host_shape[0];
size_t c = args.host_shape[1];
size_t h = args.host_shape[2];
size_t w = args.host_shape[3];
size_t c0 = 4;
size_t c1 = DivCeil(c, c0);
size_t hwc0 = h * w * c0;
size_t hwc = h * w * c;
size_t nhwc = n * h * w * c;
size_t n_cnt = DivCeil(n, cube);
size_t v_cnt = DivCeil(h * w * c0 * c1, cube);
auto cube = kCubeSize;
auto n = args.host_shape[kN];
auto c = args.host_shape[kC];
auto h = args.host_shape[kH];
auto w = args.host_shape[kW];
const size_t c0 = 4;
auto c1 = DivCeil(c, c0);
auto hwc0 = h * w * c0;
auto hwc = h * w * c;
auto nhwc = n * h * w * c;
auto n_cnt = DivCeil(n, cube);
auto v_cnt = DivCeil(h * w * c0 * c1, cube);
size_t dst_idx = 0;
for (size_t vi = 0; vi < v_cnt; vi++) {
......@@ -929,7 +929,7 @@ bool NchwToNc1hwc0(const FormatArgs &args, void *result) {
MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
return false;
}
size_t size = TypeIdSize(args.src_data_type);
auto size = TypeIdSize(args.src_data_type);
if (size < 1) {
MS_LOG(ERROR) << "Illegal dtype.";
return false;
......@@ -940,20 +940,23 @@ bool NchwToNc1hwc0(const FormatArgs &args, void *result) {
return false;
}
auto n = args.host_shape[0];
auto c = args.host_shape[1];
auto h = args.host_shape[2];
auto w = args.host_shape[3];
size_t c0 = CubeSizeByType(args.src_data_type);
auto n = args.host_shape[kN];
auto c = args.host_shape[kC];
auto h = args.host_shape[kH];
auto w = args.host_shape[kW];
auto c0 = CubeSizeByType(args.src_data_type);
if (c0 < 1) {
MS_LOG(ERROR) << "Illegal dtype.";
return false;
}
size_t c1 = DivCeil(c, c0);
size_t hw = h * w;
size_t chw = c * hw;
size_t c1hwc0 = c1 * hw * c0;
size_t wc0 = w * c0;
if (args.device_format == kOpFormat_NC1HWC0_C04) {
c0 = 4;
}
auto c1 = DivCeil(c, c0);
auto hw = h * w;
auto chw = c * hw;
auto c1hwc0 = c1 * hw * c0;
auto wc0 = w * c0;
for (size_t n_idx = 0; n_idx < n; n_idx++) {
size_t n_head_addr = n_idx * c1hwc0;
......@@ -967,7 +970,7 @@ bool NchwToNc1hwc0(const FormatArgs &args, void *result) {
size_t dst_idx = c0_idx + w_head_addr;
size_t c_idx = c0_idx + c1_idx * c0;
size_t src_idx = n_idx * chw + c_idx * hw + h_idx * w + w_idx;
auto pad_zero = (c_idx < c) ? false : true;
auto pad_zero = c_idx >= c;
SetData(size, pad_zero, src_idx, dst_idx, args, result);
}
}
......@@ -984,29 +987,29 @@ bool Nc1hwc0ToNchw(const FormatArgs &args, void *result) {
MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
return false;
}
size_t size = TypeIdSize(args.src_data_type);
auto size = TypeIdSize(args.src_data_type);
if (size < 1) {
MS_LOG(ERROR) << "Illegal dtype.";
return false;
}
size_t total_size = ShapeSize(args.device_shape) * size;
auto total_size = ShapeSize(args.device_shape) * size;
if (total_size != args.device_size) {
MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
return false;
}
auto n = args.host_shape[0];
auto c = args.host_shape[1];
auto h = args.host_shape[2];
auto w = args.host_shape[3];
auto n = args.host_shape[kN];
auto c = args.host_shape[kC];
auto h = args.host_shape[kH];
auto w = args.host_shape[kW];
auto c1 = args.device_shape[1];
auto c0 = args.device_shape[4];
size_t hw = h * w;
size_t chw = c * hw;
size_t wc0 = w * c0;
size_t hwc0 = h * wc0;
size_t c1hwc0 = c1 * hwc0;
auto hw = h * w;
auto chw = c * hw;
auto wc0 = w * c0;
auto hwc0 = h * wc0;
auto c1hwc0 = c1 * hwc0;
for (size_t n_idx = 0; n_idx < n; n_idx++) {
size_t n_head_addr = n_idx * chw;
......@@ -1037,13 +1040,15 @@ bool NchwToC1hwncoc0(const FormatArgs &args, void *result) {
MS_LOG(ERROR) << "Check args failed.";
return false;
}
auto n = args.host_shape[0];
auto c = args.host_shape[1];
auto h = args.host_shape[2];
auto w = args.host_shape[3];
auto n = args.host_shape[kN];
auto c = args.host_shape[kC];
auto h = args.host_shape[kH];
auto w = args.host_shape[kW];
const int co_idx = 4;
const int c0_idx = 5;
auto c1 = args.device_shape[0];
auto co = args.device_shape[4];
auto c0 = args.device_shape[5];
auto co = args.device_shape[co_idx];
auto c0 = args.device_shape[c0_idx];
for (size_t c1_i = 0; c1_i < c1; c1_i++) {
for (size_t h_i = 0; h_i < h; h_i++) {
......@@ -1055,7 +1060,7 @@ bool NchwToC1hwncoc0(const FormatArgs &args, void *result) {
co_i * c0 + c0_i;
size_t c_i = c0_i + c1_i * c0;
size_t src_idx = n_i * c * h * w + c_i * h * w + h_i * w + w_i;
auto pad_zero = (c_i < c && c0_i == co_i) ? false : true;
auto pad_zero = !(c_i < c && c0_i == co_i);
SetData(size, pad_zero, src_idx, dst_idx, args, result);
}
}
......@@ -1076,12 +1081,14 @@ bool C1hwncoc0ToNchw(const FormatArgs &args, void *result) {
MS_LOG(ERROR) << "Check args failed.";
return false;
}
auto n = args.host_shape[0];
auto c = args.host_shape[1];
auto h = args.host_shape[2];
auto w = args.host_shape[3];
auto co = args.device_shape[4];
auto c0 = args.device_shape[5];
auto n = args.host_shape[kN];
auto c = args.host_shape[kC];
auto h = args.host_shape[kH];
auto w = args.host_shape[kW];
const int co_idx = 4;
const int c0_idx = 5;
auto co = args.device_shape[co_idx];
auto c0 = args.device_shape[c0_idx];
for (size_t n_i = 0; n_i < n; n_i++) {
for (size_t c_i = 0; c_i < c; c_i++) {
for (size_t h_i = 0; h_i < h; h_i++) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册