提交 1dcd83b1 编写于 作者: 李寅

Add transpose 2D

上级 c384a6e2
......@@ -37,31 +37,44 @@ struct TransposeFunctor {
const T *input_data = input->data<T>();
T *output_data = output->mutable_data<T>();
std::vector<index_t>
in_stride{input_shape[1] * input_shape[2] * input_shape[3],
input_shape[2] * input_shape[3], input_shape[3], 1};
std::vector<index_t>
out_stride{output_shape[1] * output_shape[2] * output_shape[3],
output_shape[2] * output_shape[3], output_shape[3], 1};
if (input->dim_size() == 2) {
MACE_CHECK(dims_[0] == 1 && dims_[1] == 0, "no need transform");
index_t stride_i = input_shape[0];
index_t stride_j = input_shape[1];
for (int i = 0; i < input_shape[0]; ++i) {
for (int j = 0; j < input_shape[1]; ++j) {
output_data[j * stride_i + i] = input_data[i * stride_j + j];
}
}
} else if (input->dim_size() == 4) {
std::vector<index_t>
in_stride{input_shape[1] * input_shape[2] * input_shape[3],
input_shape[2] * input_shape[3], input_shape[3], 1};
std::vector<index_t>
out_stride{output_shape[1] * output_shape[2] * output_shape[3],
output_shape[2] * output_shape[3], output_shape[3], 1};
std::vector<index_t> idim(4, 0);
std::vector<index_t> odim(4, 0);
for (odim[0] = 0; odim[0] < output_shape[0]; ++odim[0]) {
for (odim[1] = 0; odim[1] < output_shape[1]; ++odim[1]) {
for (odim[2] = 0; odim[2] < output_shape[2]; ++odim[2]) {
for (odim[3] = 0; odim[3] < output_shape[3]; ++odim[3]) {
idim[dims_[0]] = odim[0];
idim[dims_[1]] = odim[1];
idim[dims_[2]] = odim[2];
idim[dims_[3]] = odim[3];
std::vector<index_t> idim(4, 0);
std::vector<index_t> odim(4, 0);
for (odim[0] = 0; odim[0] < output_shape[0]; ++odim[0]) {
for (odim[1] = 0; odim[1] < output_shape[1]; ++odim[1]) {
for (odim[2] = 0; odim[2] < output_shape[2]; ++odim[2]) {
for (odim[3] = 0; odim[3] < output_shape[3]; ++odim[3]) {
idim[dims_[0]] = odim[0];
idim[dims_[1]] = odim[1];
idim[dims_[2]] = odim[2];
idim[dims_[3]] = odim[3];
output_data[odim[0] * out_stride[0] + odim[1] * out_stride[1]
+ odim[2] * out_stride[2] + odim[3]] =
input_data[idim[0] * in_stride[0] + idim[1] * in_stride[1]
+ idim[2] * in_stride[2] + idim[3]];
output_data[odim[0] * out_stride[0] + odim[1] * out_stride[1]
+ odim[2] * out_stride[2] + odim[3]] =
input_data[idim[0] * in_stride[0] + idim[1] * in_stride[1]
+ idim[2] * in_stride[2] + idim[3]];
}
}
}
}
} else {
MACE_NOT_IMPLEMENTED;
}
}
......
......@@ -28,16 +28,16 @@ class TransposeOp : public Operator<D, T> {
public:
TransposeOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws),
dims_(OperatorBase::GetRepeatedArgument<int>(
"dims")),
dims_(OperatorBase::GetRepeatedArgument<int>("dims")),
functor_(dims_) {}
bool Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
Tensor *output = this->Output(OUTPUT);
const std::vector<index_t> &input_shape = input->shape();
MACE_CHECK(input_shape.size() == 4 && dims_.size() == 4,
"rank should be 4");
MACE_CHECK(input_shape.size() == 4 && dims_.size() == 4
|| input_shape.size() == 2 && dims_.size() == 2,
"rank should be 2 or 4");
std::vector<index_t> output_shape;
for (int i = 0; i < dims_.size(); ++i) {
output_shape.push_back(input_shape[dims_[i]]);
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <vector>
#include "mace/core/operator.h"
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
namespace {
template<DeviceType D, typename T>
void TransposeBenchmark(int iters,
std::vector<index_t> shape,
std::vector<int> dims) {
mace::testing::StopTiming();
OpsTestNet net;
// Add input data
net.AddRandomInput<D, float>("Input", shape);
OpDefBuilder("Transpose", "TransposeBM")
.Input("Input")
.Output("Output")
.AddIntsArg("dims", dims)
.Finalize(net.NewOperatorDef());
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
net.Sync();
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
}
net.Sync();
}
} // namespace
#define BM_TRANSPOSE2D_MACRO(H, W, TYPE, DEVICE) \
static void BM_TRANSPOSE2D_##H##_##W##_##TYPE##_##DEVICE( \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * H * W; \
mace::testing::MaccProcessed(tot); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
TransposeBenchmark<DEVICE, TYPE>(iters, {H, W}, {1, 0}); \
} \
BENCHMARK(BM_TRANSPOSE2D_##H##_##W##_##TYPE##_##DEVICE)
#define BM_TRANSPOSE2D(H, W) \
BM_TRANSPOSE2D_MACRO(H, W, float, CPU);
#define BM_TRANSPOSE4D_MACRO(N, C, H, W, D0, D1, D2, D3, TYPE, DEVICE) \
static void \
BM_TRANSPOSE4D_##N##_##C##_##H##_##W##_##D0##D1##D2##D3##_##TYPE##_##DEVICE( \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::MaccProcessed(tot); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
TransposeBenchmark<DEVICE, TYPE>(iters, {N, C, H, W}, {D0, D1, D2, D3}); \
} \
BENCHMARK( \
BM_TRANSPOSE4D_##N##_##C##_##H##_##W##_##D0##D1##D2##D3##_##TYPE##_##DEVICE)
#define BM_TRANSPOSE4D(N, C, H, W, D0, D1, D2, D3) \
BM_TRANSPOSE4D_MACRO(N, C, H, W, D0, D1, D2, D3, float, CPU);
BM_TRANSPOSE4D(1, 64, 64, 512, 0, 3, 1, 2);
BM_TRANSPOSE4D(1, 512, 64, 64, 0, 2, 3, 1);
BM_TRANSPOSE2D(128, 128);
BM_TRANSPOSE2D(512, 512);
} // namespace test
} // namespace ops
} // namespace mace
......@@ -49,6 +49,29 @@ TEST_F(TransposeOpTest, NCHW) {
TransposeNCHWTest({1, 64, 48, 128});
}
TEST_F(TransposeOpTest, Rank2) {
// Construct graph
OpsTestNet net;
// Add input data
net.AddInputFromArray<CPU, float>("Input", {2, 3}, {1, 2, 3, 4, 5, 6});
OpDefBuilder("Transpose", "TransposeNCHWTest")
.Input("Input")
.Output("Output")
.AddIntsArg("dims", {1, 0})
.Finalize(net.NewOperatorDef());
// Run on cpu
net.RunOp();
net.AddInputFromArray<CPU, float>("ExpectedOutput",
{3, 2},
{1, 4, 2, 5, 3, 6});
ExpectTensorNear<float>(*net.GetOutput("ExpectedOutput"),
*net.GetOutput("Output"));
}
} // namespace test
} // namespace ops
} // namespace mace
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册