提交 0e58b7b6 编写于 作者: L liuqi

Add reorganize op corresponding to caffe reshape layer.

上级 9985cbf4
......@@ -85,6 +85,7 @@ extern void Register_MatMul(OperatorRegistry *op_registry);
extern void Register_Pooling(OperatorRegistry *op_registry);
extern void Register_Proposal(OperatorRegistry *op_registry);
extern void Register_PSROIAlign(OperatorRegistry *op_registry);
extern void Register_ReOrganize(OperatorRegistry *op_registry);
extern void Register_Reshape(OperatorRegistry *op_registry);
extern void Register_ResizeBilinear(OperatorRegistry *op_registry);
extern void Register_Slice(OperatorRegistry *op_registry);
......@@ -118,6 +119,7 @@ OperatorRegistry::OperatorRegistry() {
ops::Register_Pooling(this);
ops::Register_Proposal(this);
ops::Register_PSROIAlign(this);
ops::Register_ReOrganize(this);
ops::Register_Reshape(this);
ops::Register_ResizeBilinear(this);
ops::Register_Slice(this);
......
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_KERNELS_REORGANIZE_H_
#define MACE_KERNELS_REORGANIZE_H_
#include <vector>
#include "mace/core/future.h"
#include "mace/core/runtime/opencl/cl2_header.h"
#include "mace/core/tensor.h"
namespace mace {
namespace kernels {
template <DeviceType D, typename T>
struct ReOrganizeFunctor {
void operator()(const Tensor *input,
const std::vector<index_t> &out_shape,
Tensor *output,
StatsFuture *future) {
const bool w2c = out_shape[3] > input->dim(3);
const index_t height = input->dim(1);
const index_t input_width = input->dim(2);
const index_t input_chan = input->dim(3);
const index_t output_width = output->dim(2);
const index_t output_chan = output->dim(3);
const T *input_ptr = input->data<T>();
T *output_ptr = output->mutable_data<T>();
if (w2c) {
MACE_CHECK((out_shape[3] % input->dim(3)) == 0);
const index_t multiplier = out_shape[3] / input->dim(3);
#pragma omp parallel for collapse(4)
for (index_t n = 0; n < out_shape[0]; ++n) {
for (index_t h = 0; h < out_shape[1]; ++h) {
for (index_t w = 0; w < out_shape[2]; ++w) {
for (index_t c = 0; c < out_shape[3]; ++c) {
const index_t out_offset =
((n * height + h) * output_width + w)
* output_chan + c;
const index_t in_w_idx = w + (c % multiplier) * output_width;
const index_t in_chan_idx = c / multiplier;
const index_t in_offset =
((n * height + h) * input_width + in_w_idx)
* input_chan + in_chan_idx;
output_ptr[out_offset] = input_ptr[in_offset];
}
}
}
}
} else {
MACE_CHECK((input->dim(3) % out_shape[3]) == 0);
const index_t multiplier = input->dim(3) / out_shape[3];
#pragma omp parallel for collapse(4)
for (index_t n = 0; n < out_shape[0]; ++n) {
for (index_t h = 0; h < out_shape[1]; ++h) {
for (index_t w = 0; w < out_shape[2]; ++w) {
for (index_t c = 0; c < out_shape[3]; ++c) {
const index_t out_offset =
((n * height + h) * output_width + w)
* output_chan + c;
const index_t in_w_idx = w % input_width;
const index_t in_chan_idx = w / input_width + c * multiplier;
const index_t in_offset =
((n * height + h) * input_width + in_w_idx)
* input_chan + in_chan_idx;
output_ptr[out_offset] = input_ptr[in_offset];
}
}
}
}
}
}
};
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_REORGANIZE_H_
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/ops/reorganize.h"
namespace mace {
namespace ops {
void Register_ReOrganize(OperatorRegistry *op_registry) {
REGISTER_OPERATOR(op_registry, OpKeyBuilder("ReOrganize")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
.Build(),
ReOrganizeOp<DeviceType::CPU, float>);
}
} // namespace ops
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_OPS_REORGANIZE_H_
#define MACE_OPS_REORGANIZE_H_
#include <vector>
#include "mace/core/operator.h"
#include "mace/kernels/reorganize.h"
namespace mace {
namespace ops {
template <DeviceType D, typename T>
class ReOrganizeOp : public Operator<D, T> {
public:
ReOrganizeOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws),
shape_(OperatorBase::GetRepeatedArgument<int64_t>("shape")) {}
bool Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
const index_t num_dims = shape_.size();
int unknown_idx = -1;
index_t product = 1;
std::vector<index_t> out_shape;
for (int i = 0; i < num_dims; ++i) {
if (shape_[i] == -1) {
MACE_CHECK(unknown_idx == -1) << "Only one input size may be -1";
unknown_idx = i;
out_shape.push_back(1);
} else {
MACE_CHECK(shape_[i] >= 0) << "Shape must be non-negative: "
<< shape_[i];
out_shape.push_back(shape_[i]);
product *= shape_[i];
}
}
if (unknown_idx != -1) {
MACE_CHECK(product != 0)
<< "Cannot infer shape if there is zero shape size.";
const index_t missing = input->size() / product;
MACE_CHECK(missing * product == input->size())
<< "Input size not match reshaped tensor size";
out_shape[unknown_idx] = missing;
}
Tensor *output = this->Output(OUTPUT);
output->Resize(out_shape);
functor_(input, out_shape, output, future);
return true;
}
private:
std::vector<int64_t> shape_;
kernels::ReOrganizeFunctor<D, T> functor_;
protected:
OP_INPUT_TAGS(INPUT);
OP_OUTPUT_TAGS(OUTPUT);
};
} // namespace ops
} // namespace mace
#endif // MACE_OPS_REORGANIZE_H_
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "gmock/gmock.h"
#include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class ReOrganizeTest : public OpsTestBase {};
void TestReOrganize(const std::vector<index_t> &input_shape,
const std::vector<float> &input_data,
const std::vector<index_t> &output_shape,
const std::vector<float> &output_data) {
const std::vector<int> out_shape(output_shape.begin(), output_shape.end());
// Construct graph
OpsTestNet net;
OpDefBuilder("ReOrganize", "ReOrganizeTest")
.Input("Input")
.Output("Output")
.AddIntsArg("shape", out_shape)
.Finalize(net.NewOperatorDef());
// Add input data
net.AddInputFromArray<DeviceType::CPU, float>("Input",
input_shape, input_data);
// Run
net.RunOp();
auto output = net.GetTensor("Output");
EXPECT_THAT(output->shape(), ::testing::ContainerEq(output_shape));
const float *output_ptr = output->data<float>();
int size = output->size();
for (int i = 0; i < size; ++i) {
ASSERT_EQ(output_data[i], output_ptr[i]) << "With Index " << i;
}
// Reverse reorganzie
const std::vector<int> in_shape(input_shape.begin(), input_shape.end());
OpDefBuilder("ReOrganize", "ReOrganizeTest")
.Input("Input")
.Output("Output")
.AddIntsArg("shape", in_shape)
.Finalize(net.NewOperatorDef());
// Add input data
net.AddInputFromArray<DeviceType::CPU, float>("Input",
output_shape, output_data);
// Run
net.RunOp();
output = net.GetTensor("Output");
EXPECT_THAT(output->shape(), ::testing::ContainerEq(input_shape));
output_ptr = output->data<float>();
size = output->size();
for (int i = 0; i < size; ++i) {
ASSERT_EQ(input_data[i], output_ptr[i]) << "With Index " << i;
}
}
TEST_F(ReOrganizeTest, Simple) {
TestReOrganize({1, 1, 4, 6},
{0, 4, 8, 12, 16, 20,
1, 5, 9, 13, 17, 21,
2, 6, 10, 14, 18, 22,
3, 7, 11, 15, 19, 23},
{1, 1, 8, 3},
{0, 8, 16, 1, 9, 17, 2, 10, 18, 3, 11, 19,
4, 12, 20, 5, 13, 21, 6, 14, 22, 7, 15, 23});
TestReOrganize({1, 1, 5, 6},
{0, 5, 10, 15, 20, 25,
1, 6, 11, 16, 21, 26,
2, 7, 12, 17, 22, 27,
3, 8, 13, 18, 23, 28,
4, 9, 14, 19, 24, 29},
{1, 1, 10, 3},
{0, 10, 20, 1, 11, 21, 2, 12, 22, 3, 13, 23,
4, 14, 24, 5, 15, 25, 6, 16, 26, 7, 17, 27,
8, 18, 28, 9, 19, 29});
}
TEST_F(ReOrganizeTest, Complex) {
TestReOrganize({1, 2, 2, 6},
{0, 4, 8, 12, 16, 20,
1, 5, 9, 13, 17, 21,
2, 6, 10, 14, 18, 22,
3, 7, 11, 15, 19, 23},
{1, 2, 6, 2},
{0, 12, 1, 13, 4, 16, 5, 17, 8, 20, 9, 21,
2, 14, 3, 15, 6, 18, 7, 19, 10, 22, 11, 23});
}
} // namespace test
} // namespace ops
} // namespace mace
......@@ -785,10 +785,10 @@ class CaffeConverter(object):
self.resolved_ops.add(op.name)
def convert_reshape(self, op):
op_def = self.CommonConvert(op, op.type)
op_def = self.CommonConvert(op, 'ReOrganize')
input_shape = op.parents[0].output_shape_map[op.layer.bottom[0]]
output_shape = input_shape
shape_param = np.asarray(op.layer.reshape_param.shape.dim)[[0, 2, 3, 1]]
shape_param = np.asarray(op.layer.reshape_param.shape.dim)[[0, 3, 2, 1]]
print shape_param
for i in range(len(shape_param)):
if shape_param[i] != 0:
......
......@@ -97,14 +97,17 @@ def validate_caffe_model(input_names, input_shapes, output_names, output_shapes)
input_value = load_data(FLAGS.input_file + "_" + input_names[i])
input_value = input_value.reshape(input_shapes[i]).transpose((0, 3, 1, 2))
input_blob_name = input_names[i]
if input_names[i] in net.top_names:
input_blob_name = net.top_names[input_names[i]][0]
try:
if input_names[i] in net.top_names:
input_blob_name = net.top_names[input_names[i]][0]
except ValueError:
pass
net.blobs[input_blob_name].data[0] = input_value
net.forward()
for i in range(len(output_names)):
value = net.blobs[net.top_names[output_names[i]][0]].data[0]
value = net.blobs[net.top_names[output_names[i]][0]].data
out_shape = output_shapes[i]
out_shape[1], out_shape[2], out_shape[3] = out_shape[3], out_shape[1], out_shape[2]
value = value.reshape(out_shape).transpose((0, 2, 3, 1))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册