提交 e5f91c98 编写于 作者: 吴承辉

Merge branch 'transpose' into 'master'

Improve transpose perf

See merge request !507
......@@ -15,6 +15,10 @@
#ifndef MACE_KERNELS_TRANSPOSE_H_
#define MACE_KERNELS_TRANSPOSE_H_
#if defined(MACE_ENABLE_NEON)
#include <arm_neon.h>
#endif
#include <vector>
#include "mace/core/future.h"
......@@ -25,6 +29,64 @@
namespace mace {
namespace kernels {
static void TransposeNHWCToNCHWC3(const float *input,
float *output,
const index_t height,
const index_t width) {
index_t image_size = height * width;
#pragma omp parallel for
for (index_t h = 0; h < height; ++h) {
index_t in_offset = h * width * 3;
index_t out_offset = h * width;
index_t w;
for (w = 0; w + 3 < width; w += 4) {
float32x4x3_t vi = vld3q_f32(input + in_offset);
vst1q_f32(output + out_offset, vi.val[0]);
vst1q_f32(output + out_offset + image_size, vi.val[1]);
vst1q_f32(output + out_offset + image_size * 2, vi.val[2]);
in_offset += 12;
out_offset += 4;
}
for (; w < width; ++w) {
for (index_t c = 0; c < 3; ++c) {
output[h * width + image_size * c + w] =
input[h * width * 3 + w * 3 + c];
}
}
}
}
static void TransposeNCHWToNHWCC2(const float *input,
float *output,
const index_t height,
const index_t width) {
index_t image_size = height * width;
#pragma omp parallel for
for (index_t h = 0; h < height; ++h) {
index_t in_offset = h * width;
index_t out_offset = h * width * 2;
index_t w;
for (w = 0; w + 3 < width; w += 4) {
float32x4_t vi0 = vld1q_f32(input + in_offset);
float32x4_t vi1 = vld1q_f32(input + in_offset + image_size);
vst2q_f32(output + out_offset, {vi0, vi1});
in_offset += 4;
out_offset += 8;
}
for (; w < width; ++w) {
for (index_t c = 0; c < 2; ++c) {
output[h * width * 2 + w * 2 + c] =
input[h * width + image_size * c + w];
}
}
}
}
template<DeviceType D, typename T>
struct TransposeFunctor {
explicit TransposeFunctor(const std::vector<int> &dims) : dims_(dims) {}
......@@ -48,28 +110,48 @@ struct TransposeFunctor {
}
}
} else if (input->dim_size() == 4) {
std::vector<index_t>
in_stride{input_shape[1] * input_shape[2] * input_shape[3],
input_shape[2] * input_shape[3], input_shape[3], 1};
std::vector<index_t>
out_stride{output_shape[1] * output_shape[2] * output_shape[3],
output_shape[2] * output_shape[3], output_shape[3], 1};
std::vector<index_t> idim(4, 0);
std::vector<index_t> odim(4, 0);
for (odim[0] = 0; odim[0] < output_shape[0]; ++odim[0]) {
for (odim[1] = 0; odim[1] < output_shape[1]; ++odim[1]) {
for (odim[2] = 0; odim[2] < output_shape[2]; ++odim[2]) {
for (odim[3] = 0; odim[3] < output_shape[3]; ++odim[3]) {
idim[dims_[0]] = odim[0];
idim[dims_[1]] = odim[1];
idim[dims_[2]] = odim[2];
idim[dims_[3]] = odim[3];
output_data[odim[0] * out_stride[0] + odim[1] * out_stride[1]
+ odim[2] * out_stride[2] + odim[3]] =
input_data[idim[0] * in_stride[0] + idim[1] * in_stride[1]
+ idim[2] * in_stride[2] + idim[3]];
std::vector<int> transpose_order_from_NHWC_to_NCHW{0, 3, 1, 2};
std::vector<int> transpose_order_from_NCHW_to_NHWC{0, 2, 3, 1};
index_t batch_size = input->dim(1) * input->dim(2) * input->dim(3);
if (dims_ == transpose_order_from_NHWC_to_NCHW && input->dim(3) == 3) {
for (index_t b = 0; b < input->dim(0); ++b) {
TransposeNHWCToNCHWC3(input_data + b * batch_size,
output_data + b * batch_size,
input->dim(1),
input->dim(2));
}
} else if (dims_ == transpose_order_from_NCHW_to_NHWC
&& input->dim(1) == 2) {
for (index_t b = 0; b < input->dim(0); ++b) {
TransposeNCHWToNHWCC2(input_data + b * batch_size,
output_data + b * batch_size,
input->dim(2),
input->dim(3));
}
} else {
std::vector<index_t>
in_stride{input_shape[1] * input_shape[2] * input_shape[3],
input_shape[2] * input_shape[3], input_shape[3], 1};
std::vector<index_t>
out_stride{output_shape[1] * output_shape[2] * output_shape[3],
output_shape[2] * output_shape[3], output_shape[3], 1};
std::vector<index_t> idim(4, 0);
std::vector<index_t> odim(4, 0);
for (odim[0] = 0; odim[0] < output_shape[0]; ++odim[0]) {
for (odim[1] = 0; odim[1] < output_shape[1]; ++odim[1]) {
for (odim[2] = 0; odim[2] < output_shape[2]; ++odim[2]) {
for (odim[3] = 0; odim[3] < output_shape[3]; ++odim[3]) {
idim[dims_[0]] = odim[0];
idim[dims_[1]] = odim[1];
idim[dims_[2]] = odim[2];
idim[dims_[3]] = odim[3];
output_data[odim[0] * out_stride[0] + odim[1] * out_stride[1]
+ odim[2] * out_stride[2] + odim[3]] =
input_data[idim[0] * in_stride[0] + idim[1] * in_stride[1]
+ idim[2] * in_stride[2] + idim[3]];
}
}
}
}
......
......@@ -83,6 +83,9 @@ void TransposeBenchmark(int iters,
#define BM_TRANSPOSE4D(N, C, H, W, D0, D1, D2, D3) \
BM_TRANSPOSE4D_MACRO(N, C, H, W, D0, D1, D2, D3, float, CPU);
BM_TRANSPOSE4D(1, 512, 512, 3, 0, 3, 1, 2);
BM_TRANSPOSE4D(1, 2, 512, 512, 0, 2, 3, 1);
BM_TRANSPOSE4D(1, 64, 64, 512, 0, 3, 1, 2);
BM_TRANSPOSE4D(1, 512, 64, 64, 0, 2, 3, 1);
BM_TRANSPOSE2D(128, 128);
......
......@@ -37,16 +37,51 @@ void TransposeNCHWTest(const std::vector<index_t> &input_shape) {
// Run on cpu
net.RunOp();
net.FillNHWCInputToNCHWInput<DeviceType::CPU, float>("InputNCHW", "Input");
net.TransformDataFormat<DeviceType::CPU, float>("Input",
DataFormat::NHWC,
"InputNCHW",
DataFormat::NCHW);
ExpectTensorNear<float>(*net.GetOutput("InputNCHW"),
*net.GetOutput("Output"));
}
void TransposeNHWCTest(const std::vector<index_t> &input_shape) {
// Construct graph
OpsTestNet net;
// Add input data
net.AddRandomInput<CPU, float>("Input", input_shape);
OpDefBuilder("Transpose", "TransposeNHWCTest")
.Input("Input")
.Output("Output")
.AddIntsArg("dims", {0, 2, 3, 1})
.Finalize(net.NewOperatorDef());
// Run on cpu
net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("Input",
DataFormat::NCHW,
"InputNHWC",
DataFormat::NHWC);
ExpectTensorNear<float>(*net.GetOutput("InputNHWC"),
*net.GetOutput("Output"));
}
} // namespace
TEST_F(TransposeOpTest, NCHW) {
TEST_F(TransposeOpTest, NHWC_to_NCHW) {
TransposeNCHWTest({3, 64, 64, 128});
TransposeNCHWTest({1, 64, 48, 128});
TransposeNCHWTest({1, 512, 512, 3});
TransposeNCHWTest({2, 512, 512, 3});
}
TEST_F(TransposeOpTest, NCHW_to_NHWC) {
TransposeNHWCTest({1, 2, 512, 512});
TransposeNHWCTest({1, 3, 512, 512});
TransposeNHWCTest({2, 2, 512, 512});
}
TEST_F(TransposeOpTest, Rank2) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册