提交 43bd2a35 编写于 作者: 李寅

Improve transpose perf

上级 139a62b9
......@@ -15,6 +15,10 @@
#ifndef MACE_KERNELS_TRANSPOSE_H_
#define MACE_KERNELS_TRANSPOSE_H_
#if defined(MACE_ENABLE_NEON)
#include <arm_neon.h>
#endif
#include <vector>
#include "mace/core/future.h"
......@@ -25,6 +29,64 @@
namespace mace {
namespace kernels {
static void TransposeNHWCToNCHWC3(const float *input,
float *output,
const index_t height,
const index_t width) {
index_t image_size = height * width;
#pragma omp parallel for
for (index_t h = 0; h < height; ++h) {
index_t in_offset = h * width * 3;
index_t out_offset = h * width;
index_t w;
for (w = 0; w + 3 < width; w += 4) {
float32x4x3_t vi = vld3q_f32(input + in_offset);
vst1q_f32(output + out_offset, vi.val[0]);
vst1q_f32(output + out_offset + image_size, vi.val[1]);
vst1q_f32(output + out_offset + image_size * 2, vi.val[2]);
in_offset += 12;
out_offset += 4;
}
for (; w < width; ++w) {
for (index_t c = 0; c < 3; ++c) {
output[h * width + image_size * c + w] =
input[h * width * 3 + w * 3 + c];
}
}
}
}
static void TransposeNCHWToNHWCC2(const float *input,
float *output,
const index_t height,
const index_t width) {
index_t image_size = height * width;
#pragma omp parallel for
for (index_t h = 0; h < height; ++h) {
index_t in_offset = h * width;
index_t out_offset = h * width * 2;
index_t w;
for (w = 0; w + 3 < width; w += 4) {
float32x4_t vi0 = vld1q_f32(input + in_offset);
float32x4_t vi1 = vld1q_f32(input + in_offset + image_size);
vst2q_f32(output + out_offset, {vi0, vi1});
in_offset += 4;
out_offset += 8;
}
for (; w < width; ++w) {
for (index_t c = 0; c < 2; ++c) {
output[h * width * 2 + w * 2 + c] =
input[h * width + image_size * c + w];
}
}
}
}
template<DeviceType D, typename T>
struct TransposeFunctor {
explicit TransposeFunctor(const std::vector<int> &dims) : dims_(dims) {}
......@@ -48,6 +110,25 @@ struct TransposeFunctor {
}
}
} else if (input->dim_size() == 4) {
std::vector<int> transpose_order_from_NHWC_to_NCHW{0, 3, 1, 2};
std::vector<int> transpose_order_from_NCHW_to_NHWC{0, 2, 3, 1};
index_t batch_size = input->dim(1) * input->dim(2) * input->dim(3);
if (dims_ == transpose_order_from_NHWC_to_NCHW && input->dim(3) == 3) {
for (index_t b = 0; b < input->dim(0); ++b) {
TransposeNHWCToNCHWC3(input_data + b * batch_size,
output_data + b * batch_size,
input->dim(1),
input->dim(2));
}
} else if (dims_ == transpose_order_from_NCHW_to_NHWC
&& input->dim(1) == 2) {
for (index_t b = 0; b < input->dim(0); ++b) {
TransposeNCHWToNHWCC2(input_data + b * batch_size,
output_data + b * batch_size,
input->dim(2),
input->dim(3));
}
} else {
std::vector<index_t>
in_stride{input_shape[1] * input_shape[2] * input_shape[3],
input_shape[2] * input_shape[3], input_shape[3], 1};
......@@ -74,6 +155,7 @@ struct TransposeFunctor {
}
}
}
}
} else {
MACE_NOT_IMPLEMENTED;
}
......
......@@ -83,6 +83,9 @@ void TransposeBenchmark(int iters,
#define BM_TRANSPOSE4D(N, C, H, W, D0, D1, D2, D3) \
BM_TRANSPOSE4D_MACRO(N, C, H, W, D0, D1, D2, D3, float, CPU);
BM_TRANSPOSE4D(1, 512, 512, 3, 0, 3, 1, 2);
BM_TRANSPOSE4D(1, 2, 512, 512, 0, 2, 3, 1);
BM_TRANSPOSE4D(1, 64, 64, 512, 0, 3, 1, 2);
BM_TRANSPOSE4D(1, 512, 64, 64, 0, 2, 3, 1);
BM_TRANSPOSE2D(128, 128);
......
......@@ -37,16 +37,51 @@ void TransposeNCHWTest(const std::vector<index_t> &input_shape) {
// Run on cpu
net.RunOp();
net.FillNHWCInputToNCHWInput<DeviceType::CPU, float>("InputNCHW", "Input");
net.TransformDataFormat<DeviceType::CPU, float>("Input",
DataFormat::NHWC,
"InputNCHW",
DataFormat::NCHW);
ExpectTensorNear<float>(*net.GetOutput("InputNCHW"),
*net.GetOutput("Output"));
}
void TransposeNHWCTest(const std::vector<index_t> &input_shape) {
// Construct graph
OpsTestNet net;
// Add input data
net.AddRandomInput<CPU, float>("Input", input_shape);
OpDefBuilder("Transpose", "TransposeNHWCTest")
.Input("Input")
.Output("Output")
.AddIntsArg("dims", {0, 2, 3, 1})
.Finalize(net.NewOperatorDef());
// Run on cpu
net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("Input",
DataFormat::NCHW,
"InputNHWC",
DataFormat::NHWC);
ExpectTensorNear<float>(*net.GetOutput("InputNHWC"),
*net.GetOutput("Output"));
}
} // namespace
TEST_F(TransposeOpTest, NCHW) {
TEST_F(TransposeOpTest, NHWC_to_NCHW) {
TransposeNCHWTest({3, 64, 64, 128});
TransposeNCHWTest({1, 64, 48, 128});
TransposeNCHWTest({1, 512, 512, 3});
TransposeNCHWTest({2, 512, 512, 3});
}
TEST_F(TransposeOpTest, NCHW_to_NHWC) {
TransposeNHWCTest({1, 2, 512, 512});
TransposeNHWCTest({1, 3, 512, 512});
TransposeNHWCTest({2, 2, 512, 512});
}
TEST_F(TransposeOpTest, Rank2) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册