提交 10fa6290 编写于 作者: 李超

Merge branch 'add-3d-transpose' into 'master'

add support 3d-input transpose

See merge request !1133
......@@ -162,8 +162,9 @@ MaceStatus Transpose(utils::ThreadPool *thread_pool,
const std::vector<int> &dst_dims,
T *output) {
MACE_CHECK((input_shape.size() == 2 && dst_dims.size() == 2) ||
(input_shape.size() == 4 && dst_dims.size() == 4),
"Only support 2D or 4D transpose");
(input_shape.size() == 3 && dst_dims.size() == 3) ||
(input_shape.size() == 4 && dst_dims.size() == 4),
"Only support 2D, 3D or 4D transpose");
std::vector<index_t> output_shape;
for (size_t i = 0; i < dst_dims.size(); ++i) {
......@@ -219,7 +220,7 @@ MaceStatus Transpose(utils::ThreadPool *thread_pool,
index_t height = input_shape[1];
index_t width = input_shape[2];
index_t channel = input_shape[3];
index_t channel_raw_size = channel * sizeof(T);
size_t channel_raw_size = channel * sizeof(T);
index_t stride_i = height;
index_t stride_j = width;
index_t tile_size = std::max(static_cast<index_t>(1),
......@@ -266,6 +267,89 @@ MaceStatus Transpose(utils::ThreadPool *thread_pool,
}
}
}
} else if (input_shape.size() == 3) {
const index_t batch = input_shape[0];
const index_t height = input_shape[1];
const index_t width = input_shape[2];
if (dst_dims == std::vector<int>{0, 2, 1}) {
index_t stride_h = height;
index_t stride_w = width;
index_t stride_hw = height * width;
index_t tile_size = height > 512 || width > 512 ? 64 : 32;
thread_pool->Compute3D([=](index_t start, index_t end, index_t step,
index_t start0, index_t end0, index_t step0,
index_t start1, index_t end1, index_t step1) {
for (index_t b = start; b < end; b += step) {
for (index_t i = start0; i < end0; i += step0) {
for (index_t j = start1; j < end1; j += step1) {
index_t end_i = std::min(i + tile_size, height);
index_t end_j = std::min(j + tile_size, width);
for (index_t tile_i = i; tile_i < end_i; ++tile_i) {
for (index_t tile_j = j; tile_j < end_j; ++tile_j) {
output[b * stride_hw + tile_j * stride_h + tile_i] =
input[b * stride_hw + tile_i * stride_w + tile_j];
}
}
}
}
}
}, 0, batch, 1, 0, height, tile_size, 0, width, tile_size);
} else if (dst_dims == std::vector<int>{1, 0, 2}) {
size_t width_raw_size = width * sizeof(T);
thread_pool->Compute2D([=](index_t start0, index_t end0, index_t step0,
index_t start1, index_t end1, index_t step1) {
for (int i = start0; i < end0; i += step0) {
for (int j = start1; j < end1; j += step1) {
memcpy(output + (j * batch + i) * width,
input + (i * height + j) * width,
width_raw_size);
}
}
}, 0, batch, 1, 0, height, 1);
} else if (dst_dims == std::vector<int>{1, 2, 0}) {
thread_pool->Compute3D([=](index_t start, index_t end, index_t step,
index_t start0, index_t end0, index_t step0,
index_t start1, index_t end1, index_t step1) {
for (int i = start; i < end; i += step) {
for (int j = start0; j < end0; j += step0) {
for (int k = start1; k < end1; k += step1) {
output[(j * width + k) * batch + i] =
input[(i * height + j) * width + k];
}
}
}
}, 0, batch, 1, 0, height, 1, 0, width, 1);
} else if (dst_dims == std::vector<int>{2, 0, 1}) {
thread_pool->Compute3D([=](index_t start, index_t end, index_t step,
index_t start0, index_t end0, index_t step0,
index_t start1, index_t end1, index_t step1) {
for (int i = start; i < end; i += step) {
for (int j = start0; j < end0; j += step0) {
for (int k = start1; k < end1; k += step1) {
output[(k * batch + i) * height + j] =
input[(i * height + j) * width + k];
}
}
}
}, 0, batch, 1, 0, height, 1, 0, width, 1);
} else if (dst_dims == std::vector<int>{2, 1, 0}) {
thread_pool->Compute3D([=](index_t start, index_t end, index_t step,
index_t start0, index_t end0, index_t step0,
index_t start1, index_t end1, index_t step1) {
for (int i = start; i < end; i += step) {
for (int j = start0; j < end0; j += step0) {
for (int k = start1; k < end1; k += step1) {
output[(k * height + j) * batch + i] =
input[(i * height + j) * width + k];
}
}
}
}, 0, batch, 1, 0, height, 1, 0, width, 1);
} else {
MACE_CHECK(false, "no need to transform.");
}
} else {
MACE_NOT_IMPLEMENTED;
}
......
......@@ -42,8 +42,9 @@ class TransposeOp<D, float> : public Operation {
Tensor *output = this->Output(0);
const std::vector<index_t> &input_shape = input->shape();
MACE_CHECK((input_shape.size() == 4 && dims_.size() == 4) ||
(input_shape.size() == 2 && dims_.size() == 2),
"rank should be 2 or 4");
(input_shape.size() == 3 && dims_.size() == 3) ||
(input_shape.size() == 2 && dims_.size() == 2),
"rank should be 2, 3 or 4");
std::vector<index_t> output_shape;
for (size_t i = 0; i < dims_.size(); ++i) {
output_shape.push_back(input_shape[dims_[i]]);
......
......@@ -101,6 +101,64 @@ TEST_F(TransposeOpTest, Rank2) {
*net.GetOutput("Output"));
}
namespace {
void Transpose3DTest(const std::vector<index_t> &input_shape,
const std::vector<float> &input_data,
const std::vector<int> &dest_dims,
const std::vector<index_t> &expected_shape,
const std::vector<float> &expected_data) {
// Construct graph
OpsTestNet net;
// Add input data
net.AddInputFromArray<DeviceType::CPU, float>("Input",
input_shape,
input_data);
OpDefBuilder("Transpose", "TransposeNCHWTest")
.Input("Input")
.Output("Output")
.AddIntsArg("dims", dest_dims)
.Finalize(net.NewOperatorDef());
// Run on cpu
net.RunOp();
net.AddInputFromArray<CPU, float>("ExpectedOutput", expected_shape,
expected_data);
ExpectTensorNear<float>(*net.GetOutput("ExpectedOutput"),
*net.GetOutput("Output"));
}
} // namespace
TEST_F(TransposeOpTest, Rank3) {
Transpose3DTest({2, 3, 2},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
{0, 2, 1},
{2, 2, 3},
{1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12});
Transpose3DTest({2, 3, 2},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
{1, 0, 2},
{3, 2, 2},
{1, 2, 7, 8, 3, 4, 9, 10, 5, 6, 11, 12});
Transpose3DTest({2, 3, 2},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
{1, 2, 0},
{3, 2, 2},
{1, 7, 2, 8, 3, 9, 4, 10, 5, 11, 6, 12});
Transpose3DTest({2, 3, 2},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
{2, 0, 1},
{2, 2, 3},
{1, 3, 5, 7, 9, 11, 2, 4, 6, 8, 10, 12});
Transpose3DTest({2, 3, 2},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
{2, 1, 0},
{2, 3, 2},
{1, 7, 3, 9, 5, 11, 2, 8, 4, 10, 6, 12});
}
} // namespace test
} // namespace ops
} // namespace mace
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册