提交 4b6fa1c9 编写于 作者: L liuqi

Add reshape op (CPU version only)

上级 07f8ff18
......@@ -80,6 +80,7 @@ extern void Register_SpaceToBatchND(OperatorRegistry *op_registry);
extern void Register_MatMul(OperatorRegistry *op_registry);
extern void Register_WinogradTransform(OperatorRegistry *op_registry);
extern void Register_WinogradInverseTransform(OperatorRegistry *op_registry);
extern void Register_Reshape(OperatorRegistry *op_registry);
OperatorRegistry::OperatorRegistry() {
Register_Activation(this);
......@@ -103,6 +104,7 @@ OperatorRegistry::OperatorRegistry() {
Register_MatMul(this);
Register_WinogradTransform(this);
Register_WinogradInverseTransform(this);
Register_Reshape(this);
}
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_KERNELS_RESHAPE_H_
#define MACE_KERNELS_RESHAPE_H_
#include "mace/core/future.h"
#include "mace/core/tensor.h"
#include "mace/core/runtime/opencl/cl2_header.h"
namespace mace {
namespace kernels {
template <DeviceType D, typename T>
struct ReshapeFunctor {
ReshapeFunctor() {}
void operator()(const Tensor *input,
const std::vector<index_t> &out_shape,
Tensor *output,
StatsFuture *future) {
output->Resize(out_shape);
output->CopyBytes(input->raw_data(), input->size() * sizeof(T));
}
};
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_RESHAPE_H_
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/ops/reshape.h"
namespace mace {
void Register_Reshape(OperatorRegistry *op_registry) {
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Reshape")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
.Build(),
ReshapeOp<DeviceType::CPU, float>);
}
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_OPS_RESHAPE_H_
#define MACE_OPS_RESHAPE_H_
#include "mace/core/operator.h"
#include "mace/kernels/reshape.h"
namespace mace {
template <DeviceType D, typename T>
class ReshapeOp : public Operator<D, T> {
public:
ReshapeOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws),
shape_(OperatorBase::GetRepeatedArgument<int64_t>("shape")){}
bool Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
const index_t num_dims = shape_.size();
int unknown_idx = -1;
index_t product = 1;
std::vector<index_t> out_shape;
for (int i = 0; i < num_dims; ++i) {
if (shape_[i] == -1) {
MACE_CHECK(unknown_idx == -1) << "Only one input size may be -1";
unknown_idx = i;
out_shape.push_back(1);
} else if (shape_[i] < 0) {
VLOG(ERROR) << "Shape must be non-negative";
} else {
out_shape.push_back(shape_[i]);
product *= shape_[i];
}
}
if (unknown_idx != -1) {
MACE_CHECK(product != 0) << "Cannot infer shape if there is zero shape size.";
const index_t missing = input->size() / product;
MACE_CHECK(missing * product == input->size()) << "Input size not match reshaped tensor size";
out_shape[unknown_idx] = missing;
}
Tensor *output = this->Output(OUTPUT);
functor_(input, out_shape, output, future);
return true;
}
private:
std::vector<int64_t> shape_;
kernels::ReshapeFunctor<D, T> functor_;
private:
OP_INPUT_TAGS(INPUT);
OP_OUTPUT_TAGS(OUTPUT);
};
} // namespace mace
#endif // MACE_OPS_RESHAPE_H_
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "gmock/gmock.h"
#include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h"
using namespace mace;
class ReshapeTest : public OpsTestBase {};
void TestReshape(const std::vector<index_t> &org_shape,
const std::vector<int> &output_shape,
const std::vector<index_t> &res_shape) {
// Construct graph
OpsTestNet net;
OpDefBuilder("Reshape", "ReshapeTest")
.Input("Input")
.Output("Output")
.AddIntsArg("shape", output_shape)
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<DeviceType::CPU, float>("Input", org_shape);
// Run
net.RunOp();
auto input = net.GetTensor("Input");
auto output = net.GetTensor("Output");
EXPECT_THAT(output->shape(), ::testing::ContainerEq(res_shape));
const float *input_ptr = input->data<float>();
const float *output_ptr = output->data<float>();
const int size = output->size();
for (int i = 0; i < size; ++i) {
ASSERT_EQ(input_ptr[i], output_ptr[i]);
}
}
TEST_F(ReshapeTest, Simple) {
TestReshape({1, 2, 3, 4}, {1, 2, -1, 4}, {1, 2, 3, 4});
TestReshape({1, 2, 3, 4}, {1, 2, -1, 2}, {1, 2, 6, 2});
TestReshape({1, 2, 3, 4}, {1, -1, 3, 2}, {1, 4, 3, 2});
TestReshape({1, 2, 3, 4}, {2, 2, 3, 2}, {2, 2, 3, 2});
}
TEST_F(ReshapeTest, Complex) {
TestReshape({1, 2, 3, 4}, {-1}, {24});
TestReshape({1, 2, 3, 4}, {1, -1}, {1, 24});
TestReshape({1, 2, 3, 4}, {-1, 1}, {24, 1});
TestReshape({1, 2, 3, 4}, {1, 3, 8}, {1, 3, 8});
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册