提交 ac6fc00c 编写于 作者: L Liangliang He

Merge branch 'reshape' into 'master'

Reshape and Eltwise op

See merge request !237
...@@ -80,6 +80,8 @@ extern void Register_SpaceToBatchND(OperatorRegistry *op_registry); ...@@ -80,6 +80,8 @@ extern void Register_SpaceToBatchND(OperatorRegistry *op_registry);
extern void Register_MatMul(OperatorRegistry *op_registry); extern void Register_MatMul(OperatorRegistry *op_registry);
extern void Register_WinogradTransform(OperatorRegistry *op_registry); extern void Register_WinogradTransform(OperatorRegistry *op_registry);
extern void Register_WinogradInverseTransform(OperatorRegistry *op_registry); extern void Register_WinogradInverseTransform(OperatorRegistry *op_registry);
extern void Register_Reshape(OperatorRegistry *op_registry);
extern void Register_Eltwise(OperatorRegistry *op_registry);
OperatorRegistry::OperatorRegistry() { OperatorRegistry::OperatorRegistry() {
Register_Activation(this); Register_Activation(this);
...@@ -103,6 +105,8 @@ OperatorRegistry::OperatorRegistry() { ...@@ -103,6 +105,8 @@ OperatorRegistry::OperatorRegistry() {
Register_MatMul(this); Register_MatMul(this);
Register_WinogradTransform(this); Register_WinogradTransform(this);
Register_WinogradInverseTransform(this); Register_WinogradInverseTransform(this);
Register_Reshape(this);
Register_Eltwise(this);
} }
} // namespace mace } // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_KERNELS_ELTWISE_H_
#define MACE_KERNELS_ELTWISE_H_
#include "mace/core/future.h"
#include "mace/core/tensor.h"
#include "mace/core/runtime/opencl/cl2_header.h"
namespace mace {
namespace kernels {
enum EltwiseType{
PROD = 0,
SUM = 1,
MAX = 2,
MIN = 3,
};
struct EltwiseFunctorBase {
EltwiseFunctorBase(const EltwiseType type,
const std::vector<float> &coeff)
: type_(type), coeff_(coeff) {}
EltwiseType type_;
std::vector<float> coeff_;
};
template <DeviceType D, typename T>
struct EltwiseFunctor : EltwiseFunctorBase {
EltwiseFunctor(const EltwiseType type,
const std::vector<float> &coeff)
: EltwiseFunctorBase(type, coeff) {}
void operator()(const Tensor *input0,
const Tensor *input1,
Tensor *output,
StatsFuture *future) {
Tensor::MappingGuard input0_guard(input0);
Tensor::MappingGuard input1_guard(input1);
Tensor::MappingGuard output_guard(output);
const T *input0_ptr = input0->data<T>();
const T *input1_ptr = input1->data<T>();
T *output_ptr = output->mutable_data<T>();
const index_t size = input0->size();
switch (type_) {
case PROD:
#pragma omp parallel for
for(index_t i = 0; i < size; ++i) {
output_ptr[i] = input0_ptr[i] * input1_ptr[i];
}
break;
case SUM:
if (coeff_.empty()) {
#pragma omp parallel for
for (index_t i = 0; i < size; ++i) {
output_ptr[i] = input0_ptr[i] + input1_ptr[i];
}
} else {
#pragma omp parallel for
for (index_t i = 0; i < size; ++i) {
output_ptr[i] = coeff_[0] * input0_ptr[i] + coeff_[1] * input1_ptr[i];
}
}
break;
case MAX:
#pragma omp parallel for
for(index_t i = 0; i < size; ++i) {
output_ptr[i] = std::max<T>(input0_ptr[i], input1_ptr[i]);
}
break;
case MIN:
#pragma omp parallel for
for(index_t i = 0; i < size; ++i) {
output_ptr[i] = std::min<T>(input0_ptr[i], input1_ptr[i]);
}
break;
default:
LOG(FATAL) << "Eltwise op not support type " << type_;
}
}
};
template <typename T>
struct EltwiseFunctor<DeviceType::OPENCL, T>: EltwiseFunctorBase {
EltwiseFunctor(const EltwiseType type,
const std::vector<float> &coeff)
: EltwiseFunctorBase(type, coeff) {}
void operator()(const Tensor *input0,
const Tensor *input1,
Tensor *output,
StatsFuture *future);
cl::Kernel kernel_;
};
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_ELTWISE_H_
#include <common.h>
__kernel void eltwise(__read_only image2d_t input0, /* [c%4 * w * c/4, h * b] */
__read_only image2d_t input1,
#ifdef COEFF_SUM
__private const float coeff0,
__private const float coeff1,
#endif
__write_only image2d_t output) {
const int w = get_global_id(0);
const int hb = get_global_id(1);
DATA_TYPE4 in0 = READ_IMAGET(input0, SAMPLER, (int2)(w, hb));
DATA_TYPE4 in1 = READ_IMAGET(input1, SAMPLER, (int2)(w, hb));
DATA_TYPE4 out;
#if ELTWISE_TYPE == 0
out = in0 * in1;
#elif ELTWISE_TYPE == 1
#ifdef COEFF_SUM
out = mad(coeff0, in0, mad(coeff1, in1, 0));
#else
out = in0 + in1;
#endif
#elif ELTWISE_TYPE == 2
out = fmax(in0, in1);
#elif ELTWISE_TYPE == 3
out = fmin(in0, in1);
#endif
WRITE_IMAGET(output, (int2)(w, hb), out);
}
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/kernels/eltwise.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/kernels/opencl/helper.h"
#include "mace/utils/tuner.h"
namespace mace {
namespace kernels {
template <typename T>
void EltwiseFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input0,
const Tensor *input1,
Tensor *output,
StatsFuture *future) {
const index_t batch = input0->dim(0);
const index_t height = input0->dim(1);
const index_t width = input0->dim(2);
const index_t channels = input0->dim(3);
const index_t channel_blocks = RoundUpDiv4(channels);
const index_t width_pixels = channel_blocks * width;
const index_t batch_height_pixels = batch * height;
if (kernel_.get() == nullptr) {
auto runtime = OpenCLRuntime::Global();
std::set<std::string> built_options;
auto dt = DataTypeToEnum<T>::value;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("eltwise");
built_options.emplace("-Deltwise=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
built_options.emplace("-DELTWISE_TYPE=" + ToString(type_));
if (!coeff_.empty()) built_options.emplace("-DCOEFF_SUM");
kernel_ = runtime->BuildKernel("eltwise", kernel_name, built_options);
uint32_t idx = 0;
kernel_.setArg(idx++,
*(static_cast<const cl::Image2D *>(input0->buffer())));
kernel_.setArg(idx++,
*(static_cast<const cl::Image2D *>(input1->buffer())));
if (!coeff_.empty()) {
kernel_.setArg(idx++, coeff_[0]);
kernel_.setArg(idx++, coeff_[1]);
}
kernel_.setArg(idx++, *(static_cast<cl::Image2D *>(output->buffer())));
}
const uint32_t gws[2] = {
static_cast<uint32_t>(width_pixels),
static_cast<uint32_t>(batch_height_pixels)
};
const std::vector<uint32_t> lws = {64, 16, 1};
std::stringstream ss;
ss << "eltwise_opencl_kernel_"
<< output->dim(0) << "_"
<< output->dim(1) << "_"
<< output->dim(2) << "_"
<< output->dim(3);
TuningOrRun2DKernel(kernel_, ss.str(), gws, lws, future);
}
template struct EltwiseFunctor<DeviceType::OPENCL, float>;
template struct EltwiseFunctor<DeviceType::OPENCL, half>;
} // namespace kernels
} // namespace mace
...@@ -54,8 +54,8 @@ void WinogradTransformFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *i ...@@ -54,8 +54,8 @@ void WinogradTransformFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *i
kernel_.setArg(idx++, static_cast<uint32_t>(paddings[1] / 2)); kernel_.setArg(idx++, static_cast<uint32_t>(paddings[1] / 2));
} }
const uint32_t gws[2] = {static_cast<size_t>(out_width), const uint32_t gws[2] = {static_cast<uint32_t>(out_width),
static_cast<size_t>(RoundUpDiv4(input_tensor->dim(3)))}; static_cast<uint32_t>(RoundUpDiv4(input_tensor->dim(3)))};
const std::vector<uint32_t> lws = {128, 8, 1}; const std::vector<uint32_t> lws = {128, 8, 1};
std::stringstream ss; std::stringstream ss;
ss << "winograd_transform_kernel_" ss << "winograd_transform_kernel_"
...@@ -126,8 +126,8 @@ void WinogradInverseTransformFunctor<DeviceType::OPENCL, T>::operator()(const Te ...@@ -126,8 +126,8 @@ void WinogradInverseTransformFunctor<DeviceType::OPENCL, T>::operator()(const Te
kernel_.setArg(idx++, prelu_alpha_); kernel_.setArg(idx++, prelu_alpha_);
} }
const uint32_t gws[2] = {static_cast<size_t>(input_tensor->dim(2)), const uint32_t gws[2] = {static_cast<uint32_t>(input_tensor->dim(2)),
static_cast<size_t>(RoundUpDiv4(input_tensor->dim(1)))}; static_cast<uint32_t>(RoundUpDiv4(input_tensor->dim(1)))};
const std::vector<uint32_t> lws = {128, 8, 1}; const std::vector<uint32_t> lws = {128, 8, 1};
std::stringstream ss; std::stringstream ss;
......
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_KERNELS_RESHAPE_H_
#define MACE_KERNELS_RESHAPE_H_
#include "mace/core/future.h"
#include "mace/core/tensor.h"
#include "mace/core/runtime/opencl/cl2_header.h"
namespace mace {
namespace kernels {
template <DeviceType D, typename T>
struct ReshapeFunctor {
ReshapeFunctor() {}
void operator()(const Tensor *input,
const std::vector<index_t> &out_shape,
Tensor *output,
StatsFuture *future) {
output->Resize(out_shape);
// TODO copy on write to avoid this copy.
output->CopyBytes(input->raw_data(), input->size() * sizeof(T));
}
};
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_RESHAPE_H_
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/ops/eltwise.h"
namespace mace {
void Register_Eltwise(OperatorRegistry *op_registry) {
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Eltwise")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
.Build(),
EltwiseOp<DeviceType::CPU, float>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Eltwise")
.Device(DeviceType::OPENCL)
.TypeConstraint<float>("T")
.Build(),
EltwiseOp<DeviceType::OPENCL, float>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Eltwise")
.Device(DeviceType::OPENCL)
.TypeConstraint<half>("T")
.Build(),
EltwiseOp<DeviceType::OPENCL, half>);
}
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_OPS_RESHAPE_H_
#define MACE_OPS_RESHAPE_H_
#include "mace/core/operator.h"
#include "mace/kernels/eltwise.h"
namespace mace {
template <DeviceType D, typename T>
class EltwiseOp : public Operator<D, T> {
public:
EltwiseOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws),
functor_(static_cast<kernels::EltwiseType>(
OperatorBase::GetSingleArgument<int>(
"type", static_cast<int>(kernels::EltwiseType::SUM))),
OperatorBase::GetRepeatedArgument<float>("coeff")){}
bool Run(StatsFuture *future) override {
const Tensor *input0 = this->Input(0);
const Tensor *input1 = this->Input(1);
Tensor *output = this->Output(OUTPUT);
MACE_CHECK(input0->dim_size() == input1->dim_size()) << "Inputs of Eltwise op must be same shape";
for(int i = 0; i < input0->dim_size(); ++i) {
MACE_CHECK(input0->dim(i) == input1->dim(i)) << "Inputs of Eltwise op must be same shape";
}
output->ResizeLike(input0);
functor_(input0, input1, output, future);
return true;
}
private:
kernels::EltwiseFunctor<D, T> functor_;
private:
OP_OUTPUT_TAGS(OUTPUT);
};
} // namespace mace
#endif // MACE_OPS_RESHAPE_H_
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include <string>
#include "mace/core/operator.h"
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/ops_test_util.h"
#include "mace/kernels/eltwise.h"
namespace mace {
template <DeviceType D, typename T>
static void EltwiseBenchmark(int iters, kernels::EltwiseType type, int n, int h, int w, int c) {
mace::testing::StopTiming();
OpsTestNet net;
// Add input data
net.AddRandomInput<D, T>("Input0", {n, h, w, c});
net.AddRandomInput<D, T>("Input1", {n, h, w, c});
if (D == DeviceType::OPENCL) {
BufferToImage<D, half>(net, "Input0", "InputImg0", kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, half>(net, "Input1", "InputImg1", kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder("Eltwise", "EltwiseTest")
.Input("InputImg0")
.Input("InputImg1")
.AddIntArg("type", static_cast<int>(type))
.AddFloatsArg("coeff", {1.2, 2.1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Output("OutputImg")
.Finalize(net.NewOperatorDef());
} else {
OpDefBuilder("Eltwise", "EltwiseTest")
.Input("Input0")
.Input("Input1")
.AddIntArg("type", static_cast<int>(type))
.AddFloatsArg("coeff", {1.2, 2.1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Output("Output")
.Finalize(net.NewOperatorDef());
}
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
net.Sync();
}
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
net.Sync();
}
}
#define BM_ELTWISE_MACRO(ELT_TYPE, N, H, W, C, TYPE, DEVICE) \
static void BM_ELTWISE_##ELT_TYPE##_##N##_##H##_##W##_##C##_##TYPE##_##DEVICE( \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * H * W * C; \
mace::testing::ItemsProcessed(tot); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
EltwiseBenchmark<DEVICE, TYPE>(iters, static_cast<kernels::EltwiseType>(ELT_TYPE), N, H, W, C); \
} \
BENCHMARK(BM_ELTWISE_##ELT_TYPE##_##N##_##H##_##W##_##C##_##TYPE##_##DEVICE)
#define BM_ELTWISE(ELT_TYPE, N, H, W, C, ) \
BM_ELTWISE_MACRO(ELT_TYPE, N, H, W, C, float, CPU); \
BM_ELTWISE_MACRO(ELT_TYPE, N, H, W, C, float, OPENCL); \
BM_ELTWISE_MACRO(ELT_TYPE, N, H, W, C, half, OPENCL);
BM_ELTWISE(0, 1, 256, 256, 32);
BM_ELTWISE(0, 1, 128, 128, 32);
BM_ELTWISE(1, 1, 128, 128, 32);
BM_ELTWISE(2, 1, 128, 128, 32);
BM_ELTWISE(0, 1, 240, 240, 256);
BM_ELTWISE(1, 1, 240, 240, 256);
BM_ELTWISE(2, 1, 240, 240, 256);
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h"
#include "mace/kernels/eltwise.h"
namespace mace {
class EltwiseOpTest : public OpsTestBase {};
template<DeviceType D>
void Simple(const kernels::EltwiseType type,
const std::vector<index_t> &shape,
const std::vector<float> &input0,
const std::vector<float> &input1,
const std::vector<float> &output,
const std::vector<float> coeff = {}) {
// Construct graph
OpsTestNet net;
// Add input data
net.AddInputFromArray<D, float>("Input1", shape, input0);
net.AddInputFromArray<D, float>("Input2", shape, input1);
if (D == DeviceType::CPU) {
OpDefBuilder("Eltwise", "EltwiseTest")
.Input("Input1")
.Input("Input2")
.AddIntArg("type", static_cast<int>(type))
.AddFloatsArg("coeff", coeff)
.Output("Output")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
} else {
BufferToImage<D, half>(net, "Input1", "InputImg1", kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, half>(net, "Input2", "InputImg2", kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder("Eltwise", "EltwiseTest")
.Input("InputImg1")
.Input("InputImg2")
.AddIntArg("type", static_cast<int>(type))
.AddFloatsArg("coeff", coeff)
.Output("OutputImg")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
ImageToBuffer<D, float>(net, "OutputImg", "Output", kernels::BufferType::IN_OUT_CHANNEL);
}
auto expected = CreateTensor<float>(shape, output);
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-3);
}
TEST_F(EltwiseOpTest, CPUSimple) {
Simple<DeviceType::CPU>(kernels::EltwiseType::PROD,
{1, 1, 2, 3},
{1, 2, 3, 4, 5, 6},
{1, 2, 3, 4, 5, 6},
{1, 4, 9, 16, 25, 36});
Simple<DeviceType::CPU>(kernels::EltwiseType::SUM,
{1, 1, 2, 3},
{1, 2, 3, 4, 5, 6},
{1, 2, 3, 4, 5, 6},
{2, 4, 6, 8, 10, 12});
Simple<DeviceType::CPU>(kernels::EltwiseType::SUM,
{1, 1, 2, 3},
{1, 2, 3, 4, 5, 6},
{1, 2, 3, 4, 5, 6},
{3, 6, 9, 12, 15, 18},
{2, 1});
Simple<DeviceType::CPU>(kernels::EltwiseType::MAX,
{1, 1, 2, 3},
{1, 2, 3, 4, 5, 6},
{1, 1, 3, 3, 6, 6},
{1, 2, 3, 4, 6, 6});
Simple<DeviceType::CPU>(kernels::EltwiseType::MIN,
{1, 1, 2, 3},
{1, 2, 3, 4, 5, 6},
{1, 1, 3, 3, 6, 6},
{1, 1, 3, 3, 5, 6});
}
TEST_F(EltwiseOpTest, GPUSimple) {
Simple<DeviceType::OPENCL>(kernels::EltwiseType::PROD,
{1, 1, 2, 3},
{1, 2, 3, 4, 5, 6},
{1, 2, 3, 4, 5, 6},
{1, 4, 9, 16, 25, 36});
Simple<DeviceType::OPENCL>(kernels::EltwiseType::SUM,
{1, 1, 2, 3},
{1, 2, 3, 4, 5, 6},
{1, 2, 3, 4, 5, 6},
{2, 4, 6, 8, 10, 12});
Simple<DeviceType::OPENCL>(kernels::EltwiseType::SUM,
{1, 1, 2, 3},
{1, 2, 3, 4, 5, 6},
{1, 2, 3, 4, 5, 6},
{3, 6, 9, 12, 15, 18},
{2, 1});
Simple<DeviceType::OPENCL>(kernels::EltwiseType::MAX,
{1, 1, 2, 3},
{1, 2, 3, 4, 5, 6},
{1, 1, 3, 3, 6, 6},
{1, 2, 3, 4, 6, 6});
Simple<DeviceType::OPENCL>(kernels::EltwiseType::MIN,
{1, 1, 2, 3},
{1, 2, 3, 4, 5, 6},
{1, 1, 3, 3, 6, 6},
{1, 1, 3, 3, 5, 6});
}
template<DeviceType D, typename T>
void RandomTest(const kernels::EltwiseType type,
const std::vector<index_t> &shape) {
testing::internal::LogToStderr();
srand(time(NULL));
// Construct graph
OpsTestNet net;
// Add input data
net.AddRandomInput<D, float>("Input1", shape);
net.AddRandomInput<D, float>("Input2", shape);
OpDefBuilder("Eltwise", "EltwiseTest")
.Input("Input1")
.Input("Input2")
.AddIntArg("type", static_cast<int>(type))
.AddFloatsArg("coeff", {1.2, 2.1})
.Output("Output")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp();
BufferToImage<D, T>(net, "Input1", "InputImg1", kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, T>(net, "Input2", "InputImg2", kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder("Eltwise", "EltwiseTest")
.Input("InputImg1")
.Input("InputImg2")
.AddIntArg("type", static_cast<int>(type))
.AddFloatsArg("coeff", {1.2, 2.1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Output("OutputImg")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
ImageToBuffer<D, float>(net, "OutputImg", "OPENCLOutput", kernels::BufferType::IN_OUT_CHANNEL);
if (DataTypeToEnum<T>::value == DT_FLOAT) {
ExpectTensorNear<float>(*net.GetTensor("Output"), *net.GetOutput("OPENCLOutput"), 1e-3);
} else {
ExpectTensorNear<float>(*net.GetTensor("Output"), *net.GetOutput("OPENCLOutput"), 1e-1);
}
}
TEST_F(EltwiseOpTest, OPENCLRandomFloat) {
RandomTest<DeviceType::OPENCL, float>(kernels::EltwiseType::PROD,
{3, 23, 37, 19});
RandomTest<DeviceType::OPENCL, float>(kernels::EltwiseType::SUM,
{13, 32, 32, 64});
RandomTest<DeviceType::OPENCL, float>(kernels::EltwiseType::MAX,
{3, 32, 32, 64});
RandomTest<DeviceType::OPENCL, float>(kernels::EltwiseType::MIN,
{13, 32, 32, 64});
}
TEST_F(EltwiseOpTest, OPENCLRandomHalf) {
RandomTest<DeviceType::OPENCL, half>(kernels::EltwiseType::PROD,
{3, 23, 37, 19});
RandomTest<DeviceType::OPENCL, half>(kernels::EltwiseType::SUM,
{13, 32, 32, 64});
RandomTest<DeviceType::OPENCL, half>(kernels::EltwiseType::MAX,
{3, 32, 32, 64});
RandomTest<DeviceType::OPENCL, half>(kernels::EltwiseType::MIN,
{13, 32, 32, 64});
}
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/ops/reshape.h"
namespace mace {
void Register_Reshape(OperatorRegistry *op_registry) {
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Reshape")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
.Build(),
ReshapeOp<DeviceType::CPU, float>);
}
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_OPS_RESHAPE_H_
#define MACE_OPS_RESHAPE_H_
#include "mace/core/operator.h"
#include "mace/kernels/reshape.h"
namespace mace {
template <DeviceType D, typename T>
class ReshapeOp : public Operator<D, T> {
public:
ReshapeOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws),
shape_(OperatorBase::GetRepeatedArgument<int64_t>("shape")){}
bool Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
const index_t num_dims = shape_.size();
int unknown_idx = -1;
index_t product = 1;
std::vector<index_t> out_shape;
for (int i = 0; i < num_dims; ++i) {
if (shape_[i] == -1) {
MACE_CHECK(unknown_idx == -1) << "Only one input size may be -1";
unknown_idx = i;
out_shape.push_back(1);
} else if (shape_[i] < 0) {
VLOG(ERROR) << "Shape must be non-negative";
} else {
out_shape.push_back(shape_[i]);
product *= shape_[i];
}
}
if (unknown_idx != -1) {
MACE_CHECK(product != 0) << "Cannot infer shape if there is zero shape size.";
const index_t missing = input->size() / product;
MACE_CHECK(missing * product == input->size()) << "Input size not match reshaped tensor size";
out_shape[unknown_idx] = missing;
}
Tensor *output = this->Output(OUTPUT);
functor_(input, out_shape, output, future);
return true;
}
private:
std::vector<int64_t> shape_;
kernels::ReshapeFunctor<D, T> functor_;
private:
OP_INPUT_TAGS(INPUT);
OP_OUTPUT_TAGS(OUTPUT);
};
} // namespace mace
#endif // MACE_OPS_RESHAPE_H_
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "gmock/gmock.h"
#include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h"
using namespace mace;
class ReshapeTest : public OpsTestBase {};
void TestReshape(const std::vector<index_t> &org_shape,
const std::vector<int> &output_shape,
const std::vector<index_t> &res_shape) {
// Construct graph
OpsTestNet net;
OpDefBuilder("Reshape", "ReshapeTest")
.Input("Input")
.Output("Output")
.AddIntsArg("shape", output_shape)
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<DeviceType::CPU, float>("Input", org_shape);
// Run
net.RunOp();
auto input = net.GetTensor("Input");
auto output = net.GetTensor("Output");
EXPECT_THAT(output->shape(), ::testing::ContainerEq(res_shape));
const float *input_ptr = input->data<float>();
const float *output_ptr = output->data<float>();
const int size = output->size();
for (int i = 0; i < size; ++i) {
ASSERT_EQ(input_ptr[i], output_ptr[i]);
}
}
TEST_F(ReshapeTest, Simple) {
TestReshape({1, 2, 3, 4}, {1, 2, -1, 4}, {1, 2, 3, 4});
TestReshape({1, 2, 3, 4}, {1, 2, -1, 2}, {1, 2, 6, 2});
TestReshape({1, 2, 3, 4}, {1, -1, 3, 2}, {1, 4, 3, 2});
TestReshape({1, 2, 3, 4}, {2, 2, 3, 2}, {2, 2, 3, 2});
}
TEST_F(ReshapeTest, Complex) {
TestReshape({1, 2, 3, 4}, {-1}, {24});
TestReshape({1, 2, 3, 4}, {1, -1}, {1, 24});
TestReshape({1, 2, 3, 4}, {-1, 1}, {24, 1});
TestReshape({1, 2, 3, 4}, {1, 3, 8}, {1, 3, 8});
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册