提交 c741936c 编写于 作者: 李滨

Merge branch 'mnmt' into 'master'

Improve transpose performance

See merge request !840
......@@ -122,16 +122,16 @@ struct TransposeFunctor : OpKernel {
if (input->dim_size() == 2) {
MACE_CHECK(dims_[0] == 1 && dims_[1] == 0, "no need transform");
index_t stride_i = input_shape[0];
index_t stride_j = input_shape[1];
index_t tile_size = input_shape[0] > 512 || input_shape[1] > 512
? 64 : 32;
index_t height = input_shape[0];
index_t width = input_shape[1];
index_t stride_i = height;
index_t stride_j = width;
index_t tile_size = height > 512 || width > 512 ? 64 : 32;
#pragma omp parallel for collapse(2)
for (index_t i = 0; i < input_shape[0]; i += tile_size) {
for (index_t j = 0; j < input_shape[1]; j += tile_size) {
index_t end_i = std::min(i + tile_size, input_shape[0]);
index_t end_j = std::min(j + tile_size, input_shape[1]);
for (index_t i = 0; i < height; i += tile_size) {
for (index_t j = 0; j < width; j += tile_size) {
index_t end_i = std::min(i + tile_size, height);
index_t end_j = std::min(j + tile_size, width);
for (index_t tile_i = i; tile_i < end_i; ++tile_i) {
for (index_t tile_j = j; tile_j < end_j; ++tile_j) {
output_data[tile_j * stride_i + tile_i] =
......@@ -144,6 +144,7 @@ struct TransposeFunctor : OpKernel {
std::vector<int> transpose_order_from_NHWC_to_NCHW{0, 3, 1, 2};
std::vector<int> transpose_order_from_NCHW_to_NHWC{0, 2, 3, 1};
index_t batch_size = input->dim(1) * input->dim(2) * input->dim(3);
if (dims_ == transpose_order_from_NHWC_to_NCHW && input->dim(3) == 3) {
for (index_t b = 0; b < input->dim(0); ++b) {
TransposeNHWCToNCHWC3(input_data + b * batch_size,
......@@ -159,6 +160,30 @@ struct TransposeFunctor : OpKernel {
input->dim(2),
input->dim(3));
}
} else if (dims_ == std::vector<int>{0, 2, 1, 3}) {
index_t height = input_shape[1];
index_t width = input_shape[2];
index_t channel = input_shape[3];
index_t channel_raw_size = channel * sizeof(T);
index_t stride_i = height;
index_t stride_j = width;
index_t tile_size = std::max(static_cast<index_t>(1),
static_cast<index_t>(std::sqrt(
8 * 1024 / channel)));
#pragma omp parallel for collapse(2)
for (index_t i = 0; i < height; i += tile_size) {
for (index_t j = 0; j < width; j += tile_size) {
index_t end_i = std::min(i + tile_size, height);
index_t end_j = std::min(j + tile_size, width);
for (index_t tile_i = i; tile_i < end_i; ++tile_i) {
for (index_t tile_j = j; tile_j < end_j; ++tile_j) {
memcpy(output_data + (tile_j * stride_i + tile_i) * channel,
input_data + (tile_i * stride_j + tile_j) * channel,
channel_raw_size);
}
}
}
}
} else {
std::vector<index_t>
in_stride{input_shape[1] * input_shape[2] * input_shape[3],
......
......@@ -88,6 +88,7 @@ MACE_BM_TRANSPOSE4D(1, 512, 512, 3, 0, 3, 1, 2);
MACE_BM_TRANSPOSE4D(1, 2, 512, 512, 0, 2, 3, 1);
MACE_BM_TRANSPOSE4D(1, 64, 64, 512, 0, 3, 1, 2);
MACE_BM_TRANSPOSE4D(1, 512, 64, 64, 0, 2, 3, 1);
MACE_BM_TRANSPOSE4D(1, 4, 20, 64, 0, 2, 1, 3);
MACE_BM_TRANSPOSE2D(128, 128);
MACE_BM_TRANSPOSE2D(512, 512);
MACE_BM_TRANSPOSE2D(1024, 1024);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册