提交 96452de2 编写于 作者: Y yejianwu

Merge branch 'master' of v9.git.n.xiaomi.com:deep-computing/mace into...

Merge branch 'master' of v9.git.n.xiaomi.com:deep-computing/mace into compatible_with_opencl_1.1_and_1.2
...@@ -63,7 +63,6 @@ std::unique_ptr<OperatorBase> OperatorRegistry::CreateOperator( ...@@ -63,7 +63,6 @@ std::unique_ptr<OperatorBase> OperatorRegistry::CreateOperator(
} }
namespace ops { namespace ops {
// Keep in lexicographical order // Keep in lexicographical order
extern void Register_Activation(OperatorRegistry *op_registry); extern void Register_Activation(OperatorRegistry *op_registry);
extern void Register_AddN(OperatorRegistry *op_registry); extern void Register_AddN(OperatorRegistry *op_registry);
...@@ -74,6 +73,7 @@ extern void Register_BufferToImage(OperatorRegistry *op_registry); ...@@ -74,6 +73,7 @@ extern void Register_BufferToImage(OperatorRegistry *op_registry);
extern void Register_ChannelShuffle(OperatorRegistry *op_registry); extern void Register_ChannelShuffle(OperatorRegistry *op_registry);
extern void Register_Concat(OperatorRegistry *op_registry); extern void Register_Concat(OperatorRegistry *op_registry);
extern void Register_Conv2D(OperatorRegistry *op_registry); extern void Register_Conv2D(OperatorRegistry *op_registry);
extern void Register_DepthToSpace(OperatorRegistry *op_registry);
extern void Register_DepthwiseConv2d(OperatorRegistry *op_registry); extern void Register_DepthwiseConv2d(OperatorRegistry *op_registry);
extern void Register_Eltwise(OperatorRegistry *op_registry); extern void Register_Eltwise(OperatorRegistry *op_registry);
extern void Register_FoldedBatchNorm(OperatorRegistry *op_registry); extern void Register_FoldedBatchNorm(OperatorRegistry *op_registry);
...@@ -85,11 +85,13 @@ extern void Register_MatMul(OperatorRegistry *op_registry); ...@@ -85,11 +85,13 @@ extern void Register_MatMul(OperatorRegistry *op_registry);
extern void Register_Pooling(OperatorRegistry *op_registry); extern void Register_Pooling(OperatorRegistry *op_registry);
extern void Register_Proposal(OperatorRegistry *op_registry); extern void Register_Proposal(OperatorRegistry *op_registry);
extern void Register_PSROIAlign(OperatorRegistry *op_registry); extern void Register_PSROIAlign(OperatorRegistry *op_registry);
extern void Register_ReOrganize(OperatorRegistry *op_registry);
extern void Register_Reshape(OperatorRegistry *op_registry); extern void Register_Reshape(OperatorRegistry *op_registry);
extern void Register_ResizeBilinear(OperatorRegistry *op_registry); extern void Register_ResizeBilinear(OperatorRegistry *op_registry);
extern void Register_Slice(OperatorRegistry *op_registry); extern void Register_Slice(OperatorRegistry *op_registry);
extern void Register_Softmax(OperatorRegistry *op_registry); extern void Register_Softmax(OperatorRegistry *op_registry);
extern void Register_SpaceToBatchND(OperatorRegistry *op_registry); extern void Register_SpaceToBatchND(OperatorRegistry *op_registry);
extern void Register_SpaceToDepth(OperatorRegistry *op_registry);
extern void Register_WinogradInverseTransform(OperatorRegistry *op_registry); extern void Register_WinogradInverseTransform(OperatorRegistry *op_registry);
extern void Register_WinogradTransform(OperatorRegistry *op_registry); extern void Register_WinogradTransform(OperatorRegistry *op_registry);
...@@ -107,6 +109,7 @@ OperatorRegistry::OperatorRegistry() { ...@@ -107,6 +109,7 @@ OperatorRegistry::OperatorRegistry() {
ops::Register_ChannelShuffle(this); ops::Register_ChannelShuffle(this);
ops::Register_Concat(this); ops::Register_Concat(this);
ops::Register_Conv2D(this); ops::Register_Conv2D(this);
ops::Register_DepthToSpace(this);
ops::Register_DepthwiseConv2d(this); ops::Register_DepthwiseConv2d(this);
ops::Register_Eltwise(this); ops::Register_Eltwise(this);
ops::Register_FoldedBatchNorm(this); ops::Register_FoldedBatchNorm(this);
...@@ -118,11 +121,13 @@ OperatorRegistry::OperatorRegistry() { ...@@ -118,11 +121,13 @@ OperatorRegistry::OperatorRegistry() {
ops::Register_Pooling(this); ops::Register_Pooling(this);
ops::Register_Proposal(this); ops::Register_Proposal(this);
ops::Register_PSROIAlign(this); ops::Register_PSROIAlign(this);
ops::Register_ReOrganize(this);
ops::Register_Reshape(this); ops::Register_Reshape(this);
ops::Register_ResizeBilinear(this); ops::Register_ResizeBilinear(this);
ops::Register_Slice(this); ops::Register_Slice(this);
ops::Register_Softmax(this); ops::Register_Softmax(this);
ops::Register_SpaceToBatchND(this); ops::Register_SpaceToBatchND(this);
ops::Register_SpaceToDepth(this);
ops::Register_WinogradInverseTransform(this); ops::Register_WinogradInverseTransform(this);
ops::Register_WinogradTransform(this); ops::Register_WinogradTransform(this);
} }
......
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_KERNELS_DEPTH_TO_SPACE_H_
#define MACE_KERNELS_DEPTH_TO_SPACE_H_
#include <vector>
#include "mace/core/future.h"
#include "mace/core/runtime/opencl/cl2_header.h"
#include "mace/core/tensor.h"
#include "mace/public/mace.h"
namespace mace {
namespace kernels {
template <DeviceType D, typename T>
struct DepthToSpaceOpFunctor {
explicit DepthToSpaceOpFunctor(const int block_size, bool d2s)
: block_size_(block_size), d2s_(d2s) {}
void operator()(const Tensor *input, Tensor *output, StatsFuture *future) {
const int batch_size = input->dim(0);
const int input_height = input->dim(1);
const int input_width = input->dim(2);
const int input_depth = input->dim(3);
index_t output_depth, output_width, output_height;
if (d2s_) {
output_depth = input_depth / (block_size_ * block_size_);
output_width = input_width * block_size_;
output_height = input_height * block_size_;
} else {
output_depth = input_depth * block_size_ * block_size_;
output_width = input_width / block_size_;
output_height = input_height / block_size_;
}
std::vector<index_t> output_shape = {batch_size, output_height,
output_width, output_depth};
output->Resize(output_shape);
Tensor::MappingGuard logits_guard(input);
Tensor::MappingGuard output_guard(output);
const T *input_ptr = input->data<T>();
T *output_ptr = output->mutable_data<T>();
if (d2s_) {
#pragma omp parallel for
for (int b = 0; b < batch_size; ++b) {
for (int h = 0; h < output_height; ++h) {
const int in_h = h / block_size_;
const int offset_h = (h % block_size_);
for (int w = 0; w < output_width; ++w) {
const int in_w = w / block_size_;
const int offset_w = w % block_size_;
const int offset_d =
(offset_h * block_size_ + offset_w) * output_depth;
for (int d = 0; d < output_depth; ++d) {
const int in_d = d + offset_d;
const int o_index =
((b * output_height + h) * output_width + w) * output_depth +
d;
const int i_index =
((b * input_height + in_h) * input_width + in_w) *
input_depth +
in_d;
output_ptr[o_index] = input_ptr[i_index];
}
}
}
}
} else {
#pragma omp parallel for
for (int b = 0; b < batch_size; ++b) {
for (int h = 0; h < input_height; ++h) {
const int out_h = h / block_size_;
const int offset_h = (h % block_size_);
for (int w = 0; w < input_width; ++w) {
const int out_w = w / block_size_;
const int offset_w = (w % block_size_);
const int offset_d =
(offset_h * block_size_ + offset_w) * input_depth;
for (int d = 0; d < input_depth; ++d) {
const int out_d = d + offset_d;
const int o_index =
((b * output_height + out_h) * output_width + out_w) *
output_depth +
out_d;
const int i_index =
((b * input_height + h) * input_width + w) * input_depth + d;
output_ptr[o_index] = input_ptr[i_index];
}
}
}
}
}
}
const int block_size_;
bool d2s_;
};
template <typename T>
struct DepthToSpaceOpFunctor<DeviceType::OPENCL, T> {
DepthToSpaceOpFunctor(const int block_size, bool d2s)
: block_size_(block_size), d2s_(d2s) {}
void operator()(const Tensor *input, Tensor *output, StatsFuture *future);
cl::Kernel kernel_;
const int block_size_;
bool d2s_;
std::vector<index_t> input_shape_;
};
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_DEPTH_TO_SPACE_H_
#include <common.h>
__kernel void depth_to_space(__read_only image2d_t input,
__private const int block_size,
__private const int output_depth,
__write_only image2d_t output) {
const int out_d = get_global_id(0);
const int out_w = get_global_id(1);
const int out_h = get_global_id(2);
const int output_width = get_global_size(1);
const int out_pos = mad24(out_d, output_width, out_w);
const int input_width = output_width / block_size;
const int in_h = out_h / block_size;
const int offset_h = out_h % block_size;
const int in_w = out_w / block_size;
const int offset_w = out_w % block_size;
const int offset_d = (offset_h * block_size + offset_w) * output_depth;
const int in_d = out_d + offset_d;
const int in_pos = mad24(in_d, input_width, in_w);
DATA_TYPE4 in_data = READ_IMAGET(input, SAMPLER, (int2)(in_pos, in_h));
WRITE_IMAGET(output, (int2)(out_pos, out_h), in_data);
}
__kernel void space_to_depth(__read_only image2d_t input,
__private const int block_size,
__private const int input_depth,
__write_only image2d_t output) {
const int d = get_global_id(0);
const int w = get_global_id(1);
const int h = get_global_id(2);
const int input_width = get_global_size(1);
const int in_pos = mad24(d, input_width, w);
const int output_width = input_width / block_size;
const int out_h = h / block_size;
const int offset_h = h % block_size;
const int out_w = w / block_size;
const int offset_w = w % block_size;
const int offset_d = (offset_h * block_size + offset_w) * input_depth;
const int out_d = d + offset_d;
const int out_pos = mad24(out_d, output_width, out_w);
DATA_TYPE4 in_data = READ_IMAGET(input, SAMPLER, (int2)(in_pos, h));
WRITE_IMAGET(output, (int2)(out_pos, out_h), in_data);
}
//
// Copyright (c) 2018 XiaoMi All rights reserved.
//
#include "mace/kernels/depth_to_space.h"
#include "mace/core/runtime/opencl/cl2_header.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/kernels/opencl/helper.h"
#include "mace/utils/tuner.h"
#include "mace/utils/utils.h"
namespace mace {
namespace kernels {
template <typename T>
void DepthToSpaceOpFunctor<DeviceType::OPENCL, T>::operator()(
const Tensor *input, Tensor *output, StatsFuture *future) {
const index_t batch = input->dim(0);
const index_t input_height = input->dim(1);
const index_t input_width = input->dim(2);
const index_t input_depth = input->dim(3);
int depth_blocks = 1;
const char *kernel_name = nullptr;
index_t output_height, output_width, output_depth;
if (d2s_) {
output_height = input_height * block_size_;
output_width = input_width * block_size_;
output_depth = input_depth / (block_size_ * block_size_);
depth_blocks = RoundUpDiv4(output_depth);
kernel_name = "depth_to_space";
} else {
output_height = input_height / block_size_;
output_width = input_width / block_size_;
output_depth = input_depth * block_size_ * block_size_;
depth_blocks = RoundUpDiv4(input_depth);
kernel_name = "space_to_depth";
}
std::vector<index_t> output_shape = {batch, output_height, output_width,
output_depth};
std::vector<size_t> image_shape;
CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, &image_shape);
output->ResizeImage(output_shape, image_shape);
if (kernel_.get() == nullptr) {
auto runtime = OpenCLRuntime::Global();
std::set<std::string> built_options;
std::string obfuscated_kernel_name = MACE_OBFUSCATE_SYMBOL(kernel_name);
std::stringstream kernel_name_ss;
kernel_name_ss << "-D" << kernel_name << "=" << obfuscated_kernel_name;
built_options.emplace(kernel_name_ss.str());
auto dt = DataTypeToEnum<T>::value;
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
kernel_ =
runtime->BuildKernel("depth_to_space", kernel_name, built_options);
}
if (!IsVecEqual(input_shape_, input->shape())) {
uint32_t idx = 0;
kernel_.setArg(idx++, *(input->opencl_image()));
kernel_.setArg(idx++, block_size_);
kernel_.setArg(idx++, depth_blocks);
kernel_.setArg(idx++, *(output->opencl_image()));
input_shape_ = input->shape();
}
if (d2s_) {
const uint32_t gws[3] = {static_cast<uint32_t>(depth_blocks),
static_cast<uint32_t>(output_width),
static_cast<uint32_t>(output_height * batch)};
const std::vector<uint32_t> lws = {8, 16, 8, 1};
std::stringstream ss;
ss << "depth_to_space_opencl_kernel_" << output->dim(0) << "_"
<< output->dim(1) << "_" << output->dim(2) << "_" << output->dim(3);
TuningOrRun3DKernel(kernel_, ss.str(), gws, lws, future);
} else {
const uint32_t gws[3] = {static_cast<uint32_t>(depth_blocks),
static_cast<uint32_t>(input_width),
static_cast<uint32_t>(input_height * batch)};
const std::vector<uint32_t> lws = {8, 16, 8, 1};
std::stringstream ss;
ss << "space_to_depth_opencl_kernel_" << input->dim(0) << "_"
<< input->dim(1) << "_" << input->dim(2) << "_" << input->dim(3);
TuningOrRun3DKernel(kernel_, ss.str(), gws, lws, future);
}
}
template struct DepthToSpaceOpFunctor<DeviceType::OPENCL, float>;
template struct DepthToSpaceOpFunctor<DeviceType::OPENCL, half>;
} // namespace kernels
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_KERNELS_REORGANIZE_H_
#define MACE_KERNELS_REORGANIZE_H_
#include <vector>
#include "mace/core/future.h"
#include "mace/core/runtime/opencl/cl2_header.h"
#include "mace/core/tensor.h"
namespace mace {
namespace kernels {
template <DeviceType D, typename T>
struct ReOrganizeFunctor {
void operator()(const Tensor *input,
const std::vector<index_t> &out_shape,
Tensor *output,
StatsFuture *future) {
const bool w2c = out_shape[3] > input->dim(3);
const index_t height = input->dim(1);
const index_t input_width = input->dim(2);
const index_t input_chan = input->dim(3);
const index_t output_width = output->dim(2);
const index_t output_chan = output->dim(3);
const T *input_ptr = input->data<T>();
T *output_ptr = output->mutable_data<T>();
if (w2c) {
MACE_CHECK((out_shape[3] % input->dim(3)) == 0);
const index_t multiplier = out_shape[3] / input->dim(3);
#pragma omp parallel for collapse(4)
for (index_t n = 0; n < out_shape[0]; ++n) {
for (index_t h = 0; h < out_shape[1]; ++h) {
for (index_t w = 0; w < out_shape[2]; ++w) {
for (index_t c = 0; c < out_shape[3]; ++c) {
const index_t out_offset =
((n * height + h) * output_width + w)
* output_chan + c;
const index_t in_w_idx = w + (c % multiplier) * output_width;
const index_t in_chan_idx = c / multiplier;
const index_t in_offset =
((n * height + h) * input_width + in_w_idx)
* input_chan + in_chan_idx;
output_ptr[out_offset] = input_ptr[in_offset];
}
}
}
}
} else {
MACE_CHECK((input->dim(3) % out_shape[3]) == 0);
const index_t multiplier = input->dim(3) / out_shape[3];
#pragma omp parallel for collapse(4)
for (index_t n = 0; n < out_shape[0]; ++n) {
for (index_t h = 0; h < out_shape[1]; ++h) {
for (index_t w = 0; w < out_shape[2]; ++w) {
for (index_t c = 0; c < out_shape[3]; ++c) {
const index_t out_offset =
((n * height + h) * output_width + w)
* output_chan + c;
const index_t in_w_idx = w % input_width;
const index_t in_chan_idx = w / input_width + c * multiplier;
const index_t in_offset =
((n * height + h) * input_width + in_w_idx)
* input_chan + in_chan_idx;
output_ptr[out_offset] = input_ptr[in_offset];
}
}
}
}
}
}
};
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_REORGANIZE_H_
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/ops/depth_to_space.h"
namespace mace {
namespace ops {
void Register_DepthToSpace(OperatorRegistry *op_registry) {
REGISTER_OPERATOR(op_registry, OpKeyBuilder("DepthToSpace")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
.Build(),
DepthToSpaceOp<DeviceType::CPU, float>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("DepthToSpace")
.Device(DeviceType::OPENCL)
.TypeConstraint<float>("T")
.Build(),
DepthToSpaceOp<DeviceType::OPENCL, float>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("DepthToSpace")
.Device(DeviceType::OPENCL)
.TypeConstraint<half>("T")
.Build(),
DepthToSpaceOp<DeviceType::OPENCL, half>);
}
} // namespace ops
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_OPS_DEPTH_TO_SPACE_H_
#define MACE_OPS_DEPTH_TO_SPACE_H_
#include <memory>
#include <vector>
#include "mace/core/operator.h"
#include "mace/kernels/depth_to_space.h"
namespace mace {
namespace ops {
template <DeviceType D, typename T>
class DepthToSpaceOp : public Operator<D, T> {
public:
DepthToSpaceOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws),
functor_(OperatorBase::GetSingleArgument<int>("block_size", 1), true) {}
bool Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
Tensor *output = this->Output(OUTPUT);
MACE_CHECK(input->dim_size() == 4, "input dim should be 4");
const int block_size =
OperatorBase::GetSingleArgument<int>("block_size", 1);
int input_depth = input->dim(3);
MACE_CHECK(input_depth % (block_size * block_size) == 0,
"input depth should be dividable by block_size * block_size",
input->dim(3));
MACE_CHECK((input_depth % 4) == 0,
"input channel should be dividable by 4");
functor_(input, output, future);
return true;
}
protected:
OP_INPUT_TAGS(INPUT);
OP_OUTPUT_TAGS(OUTPUT);
private:
kernels::DepthToSpaceOpFunctor<D, T> functor_;
};
} // namespace ops
} // namespace mace
#endif // MACE_OPS_DEPTH_TO_SPACE_H_
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/core/operator.h"
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
template <DeviceType D, typename T>
static void DepthToSpace(
int iters, int batch, int channels, int height, int width, int block_size) {
mace::testing::StopTiming();
OpsTestNet net;
// Add input data
net.AddRandomInput<D, float>("Input", {batch, height, width, channels});
if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder("DepthToSpace", "DepthToSpaceBM")
.Input("InputImage")
.Output("Output")
.AddIntArg("block_size", block_size)
.Finalize(net.NewOperatorDef());
} else {
OpDefBuilder("DepthToSpace", "DepthToSpaceBM")
.Input("Input")
.Output("Output")
.Finalize(net.NewOperatorDef());
}
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
net.Sync();
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
}
net.Sync();
}
#define BM_DEPTH_TO_SPACE_MACRO(N, C, H, W, G, TYPE, DEVICE) \
static void \
BM_DEPTH_TO_SPACE_##N##_##C##_##H##_##W##_##G##_##TYPE##_##DEVICE( \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::MaccProcessed(tot); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
DepthToSpace<DEVICE, TYPE>(iters, N, C, H, W, G); \
} \
BENCHMARK(BM_DEPTH_TO_SPACE_##N##_##C##_##H##_##W##_##G##_##TYPE##_##DEVICE)
#define BM_DEPTH_TO_SPACE(N, C, H, W, G) \
BM_DEPTH_TO_SPACE_MACRO(N, C, H, W, G, float, CPU); \
BM_DEPTH_TO_SPACE_MACRO(N, C, H, W, G, float, OPENCL); \
BM_DEPTH_TO_SPACE_MACRO(N, C, H, W, G, half, OPENCL);
BM_DEPTH_TO_SPACE(1, 64, 64, 64, 4);
BM_DEPTH_TO_SPACE(1, 64, 128, 128, 4);
BM_DEPTH_TO_SPACE(1, 64, 256, 256, 4);
} // namespace test
} // namespace ops
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
template <DeviceType D>
void RunDepthToSpace(const bool d2s,
const std::vector<index_t> &input_shape,
const std::vector<float> &input_data,
const int block_size,
const std::vector<index_t> &expected_shape,
const std::vector<float> &expected_data) {
OpsTestNet net;
net.AddInputFromArray<D, float>("Input", input_shape, input_data);
const char *ops_name = (d2s) ? "DepthToSpace" : "SpaceToDepth";
const char *ops_test_name = (d2s) ? "DepthToSpaceTest" : "SpaceToDepthTest";
// Construct graph
if (D == DeviceType::CPU) {
OpDefBuilder(ops_name, ops_test_name)
.Input("Input")
.Output("Output")
.AddIntArg("block_size", block_size)
.Finalize(net.NewOperatorDef());
} else {
BufferToImage<D, float>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder(ops_name, ops_test_name)
.Input("InputImage")
.Output("OutputImage")
.AddIntArg("block_size", block_size)
.Finalize(net.NewOperatorDef());
}
// Run
net.RunOp(D);
if (D == DeviceType::OPENCL) {
ImageToBuffer<DeviceType::OPENCL, float>(&net, "OutputImage", "Output",
kernels::BufferType::IN_OUT_CHANNEL);
}
auto expected = CreateTensor<float>(expected_shape, expected_data);
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001);
}
class SpaceToDepthOpTest : public OpsTestBase {};
TEST_F(SpaceToDepthOpTest, Input2x4x4_B2_CPU) {
RunDepthToSpace<DeviceType::CPU>(false, {1, 2, 4, 4},
{0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23,
8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31},
2,
{1, 1, 2, 16},
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31});
}
TEST_F(SpaceToDepthOpTest, Input2x4x4_B2_OPENCL) {
RunDepthToSpace<DeviceType::OPENCL>(false, {1, 2, 4, 4},
{0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23,
8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31},
2,
{1, 1, 2, 16},
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31});
}
TEST_F(SpaceToDepthOpTest, Input2x2x4_B2_CPU) {
RunDepthToSpace<DeviceType::CPU>(false, {1, 2, 2, 4},
{1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16},
2,
{1, 1, 1, 16},
{1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16});
}
TEST_F(SpaceToDepthOpTest, Input4x4x1_B2_OPENCL) {
RunDepthToSpace<DeviceType::OPENCL>(false, {1, 2, 2, 4},
{1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16},
2,
{1, 1, 1, 16},
{1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16});
}
class DepthToSpaceOpTest : public OpsTestBase {};
TEST_F(DepthToSpaceOpTest, Input1x2x16_B2_CPU) {
RunDepthToSpace<DeviceType::CPU>(true, {1, 1, 2, 16},
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31},
2,
{1, 2, 4, 4},
{0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23,
8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31});
}
TEST_F(DepthToSpaceOpTest, Input1x2x16_B2_OPENCL) {
RunDepthToSpace<DeviceType::OPENCL>(true, {1, 1, 2, 16},
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31},
2,
{1, 2, 4, 4},
{0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23,
8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31});
}
TEST_F(DepthToSpaceOpTest, Input1x1x16_B2_CPU) {
RunDepthToSpace<DeviceType::CPU>(true, {1, 1, 1, 16},
{1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16},
2,
{1, 2, 2, 4},
{1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16});
}
TEST_F(DepthToSpaceOpTest, Input1x1x16_B2_OPENCL) {
RunDepthToSpace<DeviceType::OPENCL>(true, {1, 1, 1, 16},
{1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16},
2,
{1, 2, 2, 4},
{1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16});
}
/*
TEST_F(DepthToSpaceOpTest, Input2x2x3_B2_CPU) {
RunDepthToSpace<DeviceType::CPU>({1, 2, 2, 3},
{1, 2, 3, 4, 5, 6,
7, 8, 9, 10, 11, 12},
2,
{1, 1, 1, 12},
{1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12});
}
TEST_F(DepthToSpaceOpTest, Input2x2x3_B2_OPENCL) {
RunDepthToSpace<DeviceType::OPENCL>({1, 2, 2, 6},
{1, 2, 3, 4, 5, 6,
7, 8, 9, 10, 11, 12
},
2,
{1, 1, 1, 12},
{1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12});
}
TEST_F(DepthToSpaceOpTest, Input2x2x2_B2_CPU) {
RunDepthToSpace<DeviceType::CPU>({1, 2, 2, 2},
{1, 10, 2, 20, 3, 30, 4, 40},
2,
{1, 1, 1, 8},
{1, 10, 2, 20, 3, 30, 4, 40});
}
TEST_F(DepthToSpaceOpTest, Input2x2x2_B2_OPENCL) {
RunDepthToSpace<DeviceType::OPENCL>({1, 2, 2, 2},
{1, 10, 2, 20, 3, 30, 4, 40},
2,
{1, 1, 1, 8},
{1, 10, 2, 20, 3, 30, 4, 40});
}*/
} // namespace test
} // namespace ops
} // namespace mace
...@@ -16,12 +16,12 @@ class ProposalOp : public Operator<D, T> { ...@@ -16,12 +16,12 @@ class ProposalOp : public Operator<D, T> {
public: public:
ProposalOp(const OperatorDef &operator_def, Workspace *ws) ProposalOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws), : Operator<D, T>(operator_def, ws),
functor_(OperatorBase::GetSingleArgument<int>("min_size", 0), functor_(OperatorBase::GetSingleArgument<int>("min_size", 16),
OperatorBase::GetSingleArgument<float>("nms_thresh", 0), OperatorBase::GetSingleArgument<float>("nms_thresh", 0.7),
OperatorBase::GetSingleArgument<int>("pre_nms_top_n", 0), OperatorBase::GetSingleArgument<int>("pre_nms_top_n", 6000),
OperatorBase::GetSingleArgument<int>("post_nms_top_n", 0), OperatorBase::GetSingleArgument<int>("post_nms_top_n", 300),
OperatorBase::GetSingleArgument<int>("feat_stride", 0), OperatorBase::GetSingleArgument<int>("feat_stride", 0),
OperatorBase::GetSingleArgument<int>("base_size", 16), OperatorBase::GetSingleArgument<int>("base_size", 12),
OperatorBase::GetRepeatedArgument<int>("scales"), OperatorBase::GetRepeatedArgument<int>("scales"),
OperatorBase::GetRepeatedArgument<float>("ratios")) {} OperatorBase::GetRepeatedArgument<float>("ratios")) {}
......
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/ops/reorganize.h"
namespace mace {
namespace ops {
void Register_ReOrganize(OperatorRegistry *op_registry) {
REGISTER_OPERATOR(op_registry, OpKeyBuilder("ReOrganize")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
.Build(),
ReOrganizeOp<DeviceType::CPU, float>);
}
} // namespace ops
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_OPS_REORGANIZE_H_
#define MACE_OPS_REORGANIZE_H_
#include <vector>
#include "mace/core/operator.h"
#include "mace/kernels/reorganize.h"
namespace mace {
namespace ops {
template <DeviceType D, typename T>
class ReOrganizeOp : public Operator<D, T> {
public:
ReOrganizeOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws),
shape_(OperatorBase::GetRepeatedArgument<int64_t>("shape")) {}
bool Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
const index_t num_dims = shape_.size();
int unknown_idx = -1;
index_t product = 1;
std::vector<index_t> out_shape;
for (int i = 0; i < num_dims; ++i) {
if (shape_[i] == -1) {
MACE_CHECK(unknown_idx == -1) << "Only one input size may be -1";
unknown_idx = i;
out_shape.push_back(1);
} else {
MACE_CHECK(shape_[i] >= 0) << "Shape must be non-negative: "
<< shape_[i];
out_shape.push_back(shape_[i]);
product *= shape_[i];
}
}
if (unknown_idx != -1) {
MACE_CHECK(product != 0)
<< "Cannot infer shape if there is zero shape size.";
const index_t missing = input->size() / product;
MACE_CHECK(missing * product == input->size())
<< "Input size not match reshaped tensor size";
out_shape[unknown_idx] = missing;
}
Tensor *output = this->Output(OUTPUT);
output->Resize(out_shape);
functor_(input, out_shape, output, future);
return true;
}
private:
std::vector<int64_t> shape_;
kernels::ReOrganizeFunctor<D, T> functor_;
protected:
OP_INPUT_TAGS(INPUT);
OP_OUTPUT_TAGS(OUTPUT);
};
} // namespace ops
} // namespace mace
#endif // MACE_OPS_REORGANIZE_H_
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "gmock/gmock.h"
#include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class ReOrganizeTest : public OpsTestBase {};
void TestReOrganize(const std::vector<index_t> &input_shape,
const std::vector<float> &input_data,
const std::vector<index_t> &output_shape,
const std::vector<float> &output_data) {
const std::vector<int> out_shape(output_shape.begin(), output_shape.end());
// Construct graph
OpsTestNet net;
OpDefBuilder("ReOrganize", "ReOrganizeTest")
.Input("Input")
.Output("Output")
.AddIntsArg("shape", out_shape)
.Finalize(net.NewOperatorDef());
// Add input data
net.AddInputFromArray<DeviceType::CPU, float>("Input",
input_shape, input_data);
// Run
net.RunOp();
auto output = net.GetTensor("Output");
EXPECT_THAT(output->shape(), ::testing::ContainerEq(output_shape));
const float *output_ptr = output->data<float>();
int size = output->size();
for (int i = 0; i < size; ++i) {
ASSERT_EQ(output_data[i], output_ptr[i]) << "With Index " << i;
}
// Reverse reorganzie
const std::vector<int> in_shape(input_shape.begin(), input_shape.end());
OpDefBuilder("ReOrganize", "ReOrganizeTest")
.Input("Input")
.Output("Output")
.AddIntsArg("shape", in_shape)
.Finalize(net.NewOperatorDef());
// Add input data
net.AddInputFromArray<DeviceType::CPU, float>("Input",
output_shape, output_data);
// Run
net.RunOp();
output = net.GetTensor("Output");
EXPECT_THAT(output->shape(), ::testing::ContainerEq(input_shape));
output_ptr = output->data<float>();
size = output->size();
for (int i = 0; i < size; ++i) {
ASSERT_EQ(input_data[i], output_ptr[i]) << "With Index " << i;
}
}
TEST_F(ReOrganizeTest, Simple) {
TestReOrganize({1, 1, 4, 6},
{0, 4, 8, 12, 16, 20,
1, 5, 9, 13, 17, 21,
2, 6, 10, 14, 18, 22,
3, 7, 11, 15, 19, 23},
{1, 1, 8, 3},
{0, 8, 16, 1, 9, 17, 2, 10, 18, 3, 11, 19,
4, 12, 20, 5, 13, 21, 6, 14, 22, 7, 15, 23});
TestReOrganize({1, 1, 5, 6},
{0, 5, 10, 15, 20, 25,
1, 6, 11, 16, 21, 26,
2, 7, 12, 17, 22, 27,
3, 8, 13, 18, 23, 28,
4, 9, 14, 19, 24, 29},
{1, 1, 10, 3},
{0, 10, 20, 1, 11, 21, 2, 12, 22, 3, 13, 23,
4, 14, 24, 5, 15, 25, 6, 16, 26, 7, 17, 27,
8, 18, 28, 9, 19, 29});
}
TEST_F(ReOrganizeTest, Complex) {
TestReOrganize({1, 2, 2, 6},
{0, 4, 8, 12, 16, 20,
1, 5, 9, 13, 17, 21,
2, 6, 10, 14, 18, 22,
3, 7, 11, 15, 19, 23},
{1, 2, 6, 2},
{0, 12, 1, 13, 4, 16, 5, 17, 8, 20, 9, 21,
2, 14, 3, 15, 6, 18, 7, 19, 10, 22, 11, 23});
}
} // namespace test
} // namespace ops
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/ops/space_to_depth.h"
namespace mace {
namespace ops {
void Register_SpaceToDepth(OperatorRegistry *op_registry) {
REGISTER_OPERATOR(op_registry, OpKeyBuilder("SpaceToDepth")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
.Build(),
SpaceToDepthOp<DeviceType::CPU, float>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("SpaceToDepth")
.Device(DeviceType::OPENCL)
.TypeConstraint<float>("T")
.Build(),
SpaceToDepthOp<DeviceType::OPENCL, float>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("SpaceToDepth")
.Device(DeviceType::OPENCL)
.TypeConstraint<half>("T")
.Build(),
SpaceToDepthOp<DeviceType::OPENCL, half>);
}
} // namespace ops
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_OPS_SPACE_TO_DEPTH_H_
#define MACE_OPS_SPACE_TO_DEPTH_H_
#include <memory>
#include <vector>
#include "mace/core/operator.h"
#include "mace/kernels/depth_to_space.h"
namespace mace {
namespace ops {
template <DeviceType D, typename T>
class SpaceToDepthOp : public Operator<D, T> {
public:
SpaceToDepthOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws),
functor_(OperatorBase::GetSingleArgument<int>("block_size", 1), false) {
}
bool Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
Tensor *output = this->Output(OUTPUT);
MACE_CHECK(input->dim_size() == 4, "input dim should be 4");
const int block_size =
OperatorBase::GetSingleArgument<int>("block_size", 1);
const int input_height = input->dim(1);
const int input_width = input->dim(2);
const int input_depth = input->dim(3);
MACE_CHECK((input_depth % 4) == 0,
"input channel should be dividable by 4");
MACE_CHECK(
(input_width%block_size == 0)&&(input_height%block_size == 0),
"input width and height should be dividable by block_size",
input->dim(3));
functor_(input, output, future);
return true;
}
protected:
OP_INPUT_TAGS(INPUT);
OP_OUTPUT_TAGS(OUTPUT);
private:
kernels::DepthToSpaceOpFunctor<D, T> functor_;
};
} // namespace ops
} // namespace mace
#endif // MACE_OPS_SPACE_TO_DEPTH_H_
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/core/operator.h"
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
template <DeviceType D, typename T>
static void SpaceToDepth(
int iters, int batch, int channels, int height, int width, int block_size) {
mace::testing::StopTiming();
OpsTestNet net;
// Add input data
net.AddRandomInput<D, float>("Input", {batch, height, width, channels});
if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder("SpaceToDepth", "SpaceToDepthBM")
.Input("InputImage")
.Output("Output")
.AddIntArg("block_size", block_size)
.Finalize(net.NewOperatorDef());
} else {
OpDefBuilder("SpaceToDepth", "SpaceToDepthBM")
.Input("Input")
.Output("Output")
.Finalize(net.NewOperatorDef());
}
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
net.Sync();
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
}
net.Sync();
}
#define BM_SPACE_TO_DEPTH_MACRO(N, C, H, W, G, TYPE, DEVICE) \
static void \
BM_SPACE_TO_DEPTH_##N##_##C##_##H##_##W##_##G##_##TYPE##_##DEVICE( \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::MaccProcessed(tot); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
SpaceToDepth<DEVICE, TYPE>(iters, N, C, H, W, G); \
} \
BENCHMARK(BM_SPACE_TO_DEPTH_##N##_##C##_##H##_##W##_##G##_##TYPE##_##DEVICE)
#define BM_SPACE_TO_DEPTH(N, C, H, W, G) \
BM_SPACE_TO_DEPTH_MACRO(N, C, H, W, G, float, CPU); \
BM_SPACE_TO_DEPTH_MACRO(N, C, H, W, G, float, OPENCL); \
BM_SPACE_TO_DEPTH_MACRO(N, C, H, W, G, half, OPENCL);
BM_SPACE_TO_DEPTH(1, 64, 64, 64, 4);
BM_SPACE_TO_DEPTH(1, 64, 128, 128, 4);
BM_SPACE_TO_DEPTH(1, 64, 256, 256, 4);
} // namespace test
} // namespace ops
} // namespace mace
...@@ -404,7 +404,9 @@ message LayerParameter { ...@@ -404,7 +404,9 @@ message LayerParameter {
optional ParameterParameter parameter_param = 145; optional ParameterParameter parameter_param = 145;
optional PoolingParameter pooling_param = 121; optional PoolingParameter pooling_param = 121;
optional PowerParameter power_param = 122; optional PowerParameter power_param = 122;
optional ProposalParameter proposal_param = 8266713;
optional PReLUParameter prelu_param = 131; optional PReLUParameter prelu_param = 131;
optional PSROIAlignParameter psroi_align_param = 1490;
optional PythonParameter python_param = 130; optional PythonParameter python_param = 130;
optional RecurrentParameter recurrent_param = 146; optional RecurrentParameter recurrent_param = 146;
optional ReductionParameter reduction_param = 136; optional ReductionParameter reduction_param = 136;
...@@ -944,6 +946,19 @@ message PowerParameter { ...@@ -944,6 +946,19 @@ message PowerParameter {
optional float shift = 3 [default = 0.0]; optional float shift = 3 [default = 0.0];
} }
// Message that stores parameters used by ProposalLayer
message ProposalParameter {
optional uint32 feat_stride = 1 [default = 16];
repeated uint32 scales = 2;
repeated float ratios = 3;
}
message PSROIAlignParameter {
required float spatial_scale = 1;
required int32 output_dim = 2; // output channel number
required int32 group_size = 3; // number of groups to encode position-sensitive score maps
}
message PythonParameter { message PythonParameter {
optional string module = 1; optional string module = 1;
optional string layer = 2; optional string layer = 2;
......
...@@ -784,21 +784,89 @@ class CaffeConverter(object): ...@@ -784,21 +784,89 @@ class CaffeConverter(object):
self.net_def.op.extend([op_def]) self.net_def.op.extend([op_def])
self.resolved_ops.add(op.name) self.resolved_ops.add(op.name)
def convert_reshape(self, op):
op_def = self.CommonConvert(op, 'ReOrganize')
input_shape = op.parents[0].output_shape_map[op.layer.bottom[0]]
output_shape = input_shape
shape_param = np.asarray(op.layer.reshape_param.shape.dim)[[0, 3, 2, 1]]
print shape_param
for i in range(len(shape_param)):
if shape_param[i] != 0:
output_shape[i] = shape_param[i]
shape_arg = op_def.arg.add()
shape_arg.name = 'shape'
shape_arg.ints.extend(output_shape)
op.output_shape_map[op.layer.top[0]] = output_shape
self.add_output_shape(op_def, output_shape)
op_def.output.extend([op.name + ':0'])
self.net_def.op.extend([op_def])
self.resolved_ops.add(op.name)
def convert_proposal_op(self, op):
assert self.device == 'cpu'
op_def = self.CommonConvert(op, op.type)
if op.layer.HasField('proposal_param'):
proposal_param = op.layer.proposal_param
feat_stride_arg = op_def.arg.add()
feat_stride_arg.name = 'feat_stride'
feat_stride_arg.i = proposal_param.feat_stride
scales_arg = op_def.arg.add()
scales_arg.name = 'scales'
scales_arg.ints.extend(list(proposal_param.scales))
ratios_arg = op_def.arg.add()
ratios_arg.name = 'ratios'
ratios_arg.floats.extend(list(proposal_param.ratios))
output_shape = op.parents[0].output_shape_map[op.layer.bottom[0]]
op.output_shape_map[op.layer.top[0]] = output_shape
self.add_output_shape(op_def, output_shape)
op_def.output.extend([op.name + ':0'])
self.net_def.op.extend([op_def])
self.resolved_ops.add(op.name)
def convert_psroi_align(self, op):
assert self.device == 'cpu'
op_def = self.CommonConvert(op, op.type)
if op.layer.HasField('psroi_align_param'):
psroi_align_param = op.layer.psroi_align_param
spatial_scale_arg = op_def.arg.add()
spatial_scale_arg.name = 'spatial_scale'
spatial_scale_arg.f = psroi_align_param.spatial_scale
output_dim_arg = op_def.arg.add()
output_dim_arg.name = 'output_dim'
output_dim_arg.i = psroi_align_param.output_dim
group_size_arg = op_def.arg.add()
group_size_arg.name = 'group_size'
group_size_arg.i = psroi_align_param.group_size
output_shape = op.parents[0].output_shape_map[op.layer.bottom[0]]
op.output_shape_map[op.layer.top[0]] = output_shape
self.add_output_shape(op_def, output_shape)
op_def.output.extend([op.name + ':0'])
self.net_def.op.extend([op_def])
self.resolved_ops.add(op.name)
def replace_in_out_name(self, input_names, output_names, is_single): def replace_in_out_name(self, input_names, output_names, is_single):
in_names = set([input_name + ":0" for input_name in input_names]) in_names = set([input_name + ":0" for input_name in input_names])
out_names = set([output_name + ":0" for output_name in output_names]) out_names = set([output_name + ":0" for output_name in output_names])
if is_single: if is_single:
for op in self.net_def.op: for op in self.net_def.op:
if len(op.input) > 0 and op.input[0] in in_names: for i in range(len(op.input)):
op.input[0] = MACE_INPUT_NODE_NAME + ':0' if op.input[i] in in_names:
if len(op.output) > 0 and op.output[0] in out_names: op.input[i] = MACE_INPUT_NODE_NAME + ':0'
op.output[0] = MACE_OUTPUT_NODE_NAME + ':0' for i in range(len(op.output)):
if op.output[i] in out_names:
op.output[i] = MACE_OUTPUT_NODE_NAME + ':0'
else: else:
for op in self.net_def.op: for op in self.net_def.op:
if len(op.input) > 0 and op.input[0] in in_names: for i in range(len(op.input)):
op.input[0] = MACE_INPUT_NODE_NAME + '_' + op.input[0] if op.input[i] in in_names:
if len(op.output) > 0 and op.output[0] in out_names: op.input[i] = MACE_INPUT_NODE_NAME + '_' + op.input[i]
op.output[0] = MACE_OUTPUT_NODE_NAME + '_' + op.output[0] if op.input[i] in out_names:
op.input[i] = MACE_OUTPUT_NODE_NAME + '_' + op.input[i]
for i in range(len(op.output)):
if op.output[i] in in_names:
op.output[i] = MACE_INPUT_NODE_NAME + '_' + op.output[i]
if op.output[i] in out_names:
op.output[i] = MACE_OUTPUT_NODE_NAME + '_' + op.output[i]
def add_input_op_shape(self, input_nodes, input_shapes): def add_input_op_shape(self, input_nodes, input_shapes):
assert len(input_nodes) == len(input_shapes) assert len(input_nodes) == len(input_shapes)
...@@ -843,10 +911,16 @@ class CaffeConverter(object): ...@@ -843,10 +911,16 @@ class CaffeConverter(object):
self.convert_concat(op) self.convert_concat(op)
elif op.type == 'Eltwise': elif op.type == 'Eltwise':
self.convert_eltwise(op) self.convert_eltwise(op)
elif op.type in ['Softmax']:
self.convert_normal_op(op)
elif op.type == 'Slice': elif op.type == 'Slice':
self.convert_slice(op) self.convert_slice(op)
elif op.type == 'Reshape':
self.convert_reshape(op)
elif op.type == 'Proposal':
self.convert_proposal_op(op)
elif op.type == 'PSROIAlign':
self.convert_psroi_align(op)
elif op.type in ['Softmax']:
self.convert_normal_op(op)
else: else:
raise Exception('Unknown Op: %s, type: %s' % (op.name, op.type)) raise Exception('Unknown Op: %s, type: %s' % (op.name, op.type))
......
...@@ -2,8 +2,6 @@ ...@@ -2,8 +2,6 @@
LIBMACE_TAG=`git describe --abbrev=0 --tags` LIBMACE_TAG=`git describe --abbrev=0 --tags`
MACE_SOURCE_DIR=`/bin/pwd` MACE_SOURCE_DIR=`/bin/pwd`
INPUT_FILE_NAME="model_input"
OUTPUT_FILE_NAME="model_out"
PHONE_DATA_DIR="/data/local/tmp/mace_run" PHONE_DATA_DIR="/data/local/tmp/mace_run"
KERNEL_DIR="${PHONE_DATA_DIR}/cl/" KERNEL_DIR="${PHONE_DATA_DIR}/cl/"
CODEGEN_DIR=${MACE_SOURCE_DIR}/mace/codegen CODEGEN_DIR=${MACE_SOURCE_DIR}/mace/codegen
......
...@@ -14,6 +14,7 @@ import subprocess ...@@ -14,6 +14,7 @@ import subprocess
import sys import sys
import urllib import urllib
import yaml import yaml
import re
import adb_tools import adb_tools
...@@ -64,13 +65,37 @@ def clear_env(target_soc): ...@@ -64,13 +65,37 @@ def clear_env(target_soc):
command = "bash tools/clear_env.sh {}".format(target_soc) command = "bash tools/clear_env.sh {}".format(target_soc)
run_command(command) run_command(command)
def input_file_name(input_name):
return os.environ['INPUT_FILE_NAME'] + '_' + \
re.sub('[^0-9a-zA-Z]+', '_', input_name)
def generate_random_input(target_soc, model_output_dir): def generate_random_input(target_soc, model_output_dir,
input_names, input_files):
generate_data_or_not = True generate_data_or_not = True
command = "bash tools/validate_tools.sh {} {} {}".format( command = "bash tools/validate_tools.sh {} {} {}".format(
target_soc, model_output_dir, int(generate_data_or_not)) target_soc, model_output_dir, int(generate_data_or_not))
run_command(command) run_command(command)
input_name_list = []
input_file_list = []
if isinstance(input_names, list):
input_name_list.extend(input_names)
else:
input_name_list.append(input_names)
if isinstance(input_files, list):
input_file_list.extend(input_files)
else:
input_file_list.append(input_files)
assert len(input_file_list) == len(input_name_list)
for i in range(len(input_file_list)):
if input_file_list[i] is not None:
dst_input_file = model_output_dir + '/' + input_file_name(input_name_list[i])
if input_file_list[i].startswith("http://") or \
input_file_list[i].startswith("https://"):
urllib.urlretrieve(input_file_list[i], dst_input_file)
else:
print 'Copy input data:', dst_input_file
shutil.copy(input_file_list[i], dst_input_file)
def generate_model_code(): def generate_model_code():
command = "bash tools/generate_model_code.sh" command = "bash tools/generate_model_code.sh"
...@@ -215,6 +240,13 @@ def parse_args(): ...@@ -215,6 +240,13 @@ def parse_args():
help="SoCs to build, comma seperated list (getprop ro.board.platform)") help="SoCs to build, comma seperated list (getprop ro.board.platform)")
return parser.parse_known_args() return parser.parse_known_args()
def set_environment(configs):
os.environ["EMBED_MODEL_DATA"] = str(configs["embed_model_data"])
os.environ["VLOG_LEVEL"] = str(configs["vlog_level"])
os.environ["PROJECT_NAME"] = os.path.splitext(os.path.basename(
FLAGS.config))[0]
os.environ['INPUT_FILE_NAME'] = "model_input"
os.environ['OUTPUT_FILE_NAME'] = "model_out"
def main(unused_args): def main(unused_args):
configs = parse_model_configs() configs = parse_model_configs()
...@@ -223,10 +255,7 @@ def main(unused_args): ...@@ -223,10 +255,7 @@ def main(unused_args):
FLAGS.round = 1 FLAGS.round = 1
FLAGS.restart_round = 1 FLAGS.restart_round = 1
os.environ["EMBED_MODEL_DATA"] = str(configs["embed_model_data"]) set_environment(configs)
os.environ["VLOG_LEVEL"] = str(configs["vlog_level"])
os.environ["PROJECT_NAME"] = os.path.splitext(os.path.basename(
FLAGS.config))[0]
if FLAGS.mode == "build" or FLAGS.mode == "all": if FLAGS.mode == "build" or FLAGS.mode == "all":
# Remove previous output dirs # Remove previous output dirs
...@@ -266,6 +295,7 @@ def main(unused_args): ...@@ -266,6 +295,7 @@ def main(unused_args):
skip_validation = configs["models"][model_name].get( skip_validation = configs["models"][model_name].get(
"skip_validation", 0) "skip_validation", 0)
model_config = configs["models"][model_name] model_config = configs["models"][model_name]
input_file_list = model_config.get("input_files", [])
for key in model_config: for key in model_config:
if key in ['input_nodes', 'output_nodes'] and isinstance( if key in ['input_nodes', 'output_nodes'] and isinstance(
model_config[key], list): model_config[key], list):
...@@ -310,7 +340,8 @@ def main(unused_args): ...@@ -310,7 +340,8 @@ def main(unused_args):
if FLAGS.mode == "build" or FLAGS.mode == "run" or FLAGS.mode == "validate"\ if FLAGS.mode == "build" or FLAGS.mode == "run" or FLAGS.mode == "validate"\
or FLAGS.mode == "benchmark" or FLAGS.mode == "all": or FLAGS.mode == "benchmark" or FLAGS.mode == "all":
generate_random_input(target_soc, model_output_dir) generate_random_input(target_soc, model_output_dir,
model_config['input_nodes'], input_file_list)
if FLAGS.mode == "build" or FLAGS.mode == "all": if FLAGS.mode == "build" or FLAGS.mode == "all":
generate_model_code() generate_model_code()
...@@ -336,7 +367,7 @@ def main(unused_args): ...@@ -336,7 +367,7 @@ def main(unused_args):
if FLAGS.mode == "throughput_test": if FLAGS.mode == "throughput_test":
merged_lib_file = FLAGS.output_dir + "/%s/%s/libmace_%s.%s.a" % \ merged_lib_file = FLAGS.output_dir + "/%s/%s/libmace_%s.%s.a" % \
(os.environ["PROJECT_NAME"], target_abi, os.environ["PROJECT_NAME"], target_soc) (os.environ["PROJECT_NAME"], target_abi, os.environ["PROJECT_NAME"], target_soc)
generate_random_input(target_soc, FLAGS.output_dir) generate_random_input(target_soc, FLAGS.output_dir, [], [])
for model_name in configs["models"]: for model_name in configs["models"]:
runtime = configs["models"][model_name]["runtime"] runtime = configs["models"][model_name]["runtime"]
os.environ["%s_MODEL_TAG" % runtime.upper()] = model_name os.environ["%s_MODEL_TAG" % runtime.upper()] = model_name
......
...@@ -97,14 +97,17 @@ def validate_caffe_model(input_names, input_shapes, output_names, output_shapes) ...@@ -97,14 +97,17 @@ def validate_caffe_model(input_names, input_shapes, output_names, output_shapes)
input_value = load_data(FLAGS.input_file + "_" + input_names[i]) input_value = load_data(FLAGS.input_file + "_" + input_names[i])
input_value = input_value.reshape(input_shapes[i]).transpose((0, 3, 1, 2)) input_value = input_value.reshape(input_shapes[i]).transpose((0, 3, 1, 2))
input_blob_name = input_names[i] input_blob_name = input_names[i]
if input_names[i] in net.top_names: try:
input_blob_name = net.top_names[input_names[i]][0] if input_names[i] in net.top_names:
input_blob_name = net.top_names[input_names[i]][0]
except ValueError:
pass
net.blobs[input_blob_name].data[0] = input_value net.blobs[input_blob_name].data[0] = input_value
net.forward() net.forward()
for i in range(len(output_names)): for i in range(len(output_names)):
value = net.blobs[net.top_names[output_names[i]][0]].data[0] value = net.blobs[net.top_names[output_names[i]][0]].data
out_shape = output_shapes[i] out_shape = output_shapes[i]
out_shape[1], out_shape[2], out_shape[3] = out_shape[3], out_shape[1], out_shape[2] out_shape[1], out_shape[2], out_shape[3] = out_shape[3], out_shape[1], out_shape[2]
value = value.reshape(out_shape).transpose((0, 2, 3, 1)) value = value.reshape(out_shape).transpose((0, 2, 3, 1))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册