diff --git a/mace/core/operator.cc b/mace/core/operator.cc index 10e74a8e6a5cf6449d98fdad573c510dc5603652..38a220eb4b2dcf1cf2091bc9a4f38742648cc977 100644 --- a/mace/core/operator.cc +++ b/mace/core/operator.cc @@ -63,7 +63,6 @@ std::unique_ptr OperatorRegistry::CreateOperator( } namespace ops { - // Keep in lexicographical order extern void Register_Activation(OperatorRegistry *op_registry); extern void Register_AddN(OperatorRegistry *op_registry); @@ -74,6 +73,7 @@ extern void Register_BufferToImage(OperatorRegistry *op_registry); extern void Register_ChannelShuffle(OperatorRegistry *op_registry); extern void Register_Concat(OperatorRegistry *op_registry); extern void Register_Conv2D(OperatorRegistry *op_registry); +extern void Register_DepthToSpace(OperatorRegistry *op_registry); extern void Register_DepthwiseConv2d(OperatorRegistry *op_registry); extern void Register_Eltwise(OperatorRegistry *op_registry); extern void Register_FoldedBatchNorm(OperatorRegistry *op_registry); @@ -85,11 +85,13 @@ extern void Register_MatMul(OperatorRegistry *op_registry); extern void Register_Pooling(OperatorRegistry *op_registry); extern void Register_Proposal(OperatorRegistry *op_registry); extern void Register_PSROIAlign(OperatorRegistry *op_registry); +extern void Register_ReOrganize(OperatorRegistry *op_registry); extern void Register_Reshape(OperatorRegistry *op_registry); extern void Register_ResizeBilinear(OperatorRegistry *op_registry); extern void Register_Slice(OperatorRegistry *op_registry); extern void Register_Softmax(OperatorRegistry *op_registry); extern void Register_SpaceToBatchND(OperatorRegistry *op_registry); +extern void Register_SpaceToDepth(OperatorRegistry *op_registry); extern void Register_WinogradInverseTransform(OperatorRegistry *op_registry); extern void Register_WinogradTransform(OperatorRegistry *op_registry); @@ -107,6 +109,7 @@ OperatorRegistry::OperatorRegistry() { ops::Register_ChannelShuffle(this); ops::Register_Concat(this); ops::Register_Conv2D(this); + ops::Register_DepthToSpace(this); ops::Register_DepthwiseConv2d(this); ops::Register_Eltwise(this); ops::Register_FoldedBatchNorm(this); @@ -118,11 +121,13 @@ OperatorRegistry::OperatorRegistry() { ops::Register_Pooling(this); ops::Register_Proposal(this); ops::Register_PSROIAlign(this); + ops::Register_ReOrganize(this); ops::Register_Reshape(this); ops::Register_ResizeBilinear(this); ops::Register_Slice(this); ops::Register_Softmax(this); ops::Register_SpaceToBatchND(this); + ops::Register_SpaceToDepth(this); ops::Register_WinogradInverseTransform(this); ops::Register_WinogradTransform(this); } diff --git a/mace/kernels/depth_to_space.h b/mace/kernels/depth_to_space.h new file mode 100644 index 0000000000000000000000000000000000000000..3f6577f32159309bba931eaef58011902ecc2045 --- /dev/null +++ b/mace/kernels/depth_to_space.h @@ -0,0 +1,119 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_KERNELS_DEPTH_TO_SPACE_H_ +#define MACE_KERNELS_DEPTH_TO_SPACE_H_ +#include + +#include "mace/core/future.h" +#include "mace/core/runtime/opencl/cl2_header.h" +#include "mace/core/tensor.h" +#include "mace/public/mace.h" + +namespace mace { +namespace kernels { + +template +struct DepthToSpaceOpFunctor { + explicit DepthToSpaceOpFunctor(const int block_size, bool d2s) + : block_size_(block_size), d2s_(d2s) {} + void operator()(const Tensor *input, Tensor *output, StatsFuture *future) { + const int batch_size = input->dim(0); + const int input_height = input->dim(1); + const int input_width = input->dim(2); + const int input_depth = input->dim(3); + + index_t output_depth, output_width, output_height; + + if (d2s_) { + output_depth = input_depth / (block_size_ * block_size_); + output_width = input_width * block_size_; + output_height = input_height * block_size_; + } else { + output_depth = input_depth * block_size_ * block_size_; + output_width = input_width / block_size_; + output_height = input_height / block_size_; + } + std::vector output_shape = {batch_size, output_height, + output_width, output_depth}; + + output->Resize(output_shape); + + Tensor::MappingGuard logits_guard(input); + Tensor::MappingGuard output_guard(output); + const T *input_ptr = input->data(); + T *output_ptr = output->mutable_data(); + + if (d2s_) { +#pragma omp parallel for + for (int b = 0; b < batch_size; ++b) { + for (int h = 0; h < output_height; ++h) { + const int in_h = h / block_size_; + const int offset_h = (h % block_size_); + for (int w = 0; w < output_width; ++w) { + const int in_w = w / block_size_; + const int offset_w = w % block_size_; + const int offset_d = + (offset_h * block_size_ + offset_w) * output_depth; + for (int d = 0; d < output_depth; ++d) { + const int in_d = d + offset_d; + const int o_index = + ((b * output_height + h) * output_width + w) * output_depth + + d; + const int i_index = + ((b * input_height + in_h) * input_width + in_w) * + input_depth + + in_d; + output_ptr[o_index] = input_ptr[i_index]; + } + } + } + } + } else { +#pragma omp parallel for + for (int b = 0; b < batch_size; ++b) { + for (int h = 0; h < input_height; ++h) { + const int out_h = h / block_size_; + const int offset_h = (h % block_size_); + for (int w = 0; w < input_width; ++w) { + const int out_w = w / block_size_; + const int offset_w = (w % block_size_); + const int offset_d = + (offset_h * block_size_ + offset_w) * input_depth; + for (int d = 0; d < input_depth; ++d) { + const int out_d = d + offset_d; + const int o_index = + ((b * output_height + out_h) * output_width + out_w) * + output_depth + + out_d; + const int i_index = + ((b * input_height + h) * input_width + w) * input_depth + d; + output_ptr[o_index] = input_ptr[i_index]; + } + } + } + } + } + } + + const int block_size_; + bool d2s_; +}; + +template +struct DepthToSpaceOpFunctor { + DepthToSpaceOpFunctor(const int block_size, bool d2s) + : block_size_(block_size), d2s_(d2s) {} + void operator()(const Tensor *input, Tensor *output, StatsFuture *future); + + cl::Kernel kernel_; + const int block_size_; + bool d2s_; + std::vector input_shape_; +}; + +} // namespace kernels +} // namespace mace + +#endif // MACE_KERNELS_DEPTH_TO_SPACE_H_ diff --git a/mace/kernels/opencl/cl/depth_to_space.cl b/mace/kernels/opencl/cl/depth_to_space.cl new file mode 100644 index 0000000000000000000000000000000000000000..824f82665542975da3b000d2e0b1865ceabf4a3c --- /dev/null +++ b/mace/kernels/opencl/cl/depth_to_space.cl @@ -0,0 +1,52 @@ +#include + +__kernel void depth_to_space(__read_only image2d_t input, + __private const int block_size, + __private const int output_depth, + __write_only image2d_t output) { + const int out_d = get_global_id(0); + const int out_w = get_global_id(1); + const int out_h = get_global_id(2); + const int output_width = get_global_size(1); + + const int out_pos = mad24(out_d, output_width, out_w); + + const int input_width = output_width / block_size; + + const int in_h = out_h / block_size; + const int offset_h = out_h % block_size; + const int in_w = out_w / block_size; + const int offset_w = out_w % block_size; + + const int offset_d = (offset_h * block_size + offset_w) * output_depth; + const int in_d = out_d + offset_d; + + const int in_pos = mad24(in_d, input_width, in_w); + + DATA_TYPE4 in_data = READ_IMAGET(input, SAMPLER, (int2)(in_pos, in_h)); + WRITE_IMAGET(output, (int2)(out_pos, out_h), in_data); +} + +__kernel void space_to_depth(__read_only image2d_t input, + __private const int block_size, + __private const int input_depth, + __write_only image2d_t output) { + const int d = get_global_id(0); + const int w = get_global_id(1); + const int h = get_global_id(2); + const int input_width = get_global_size(1); + const int in_pos = mad24(d, input_width, w); + const int output_width = input_width / block_size; + + const int out_h = h / block_size; + const int offset_h = h % block_size; + const int out_w = w / block_size; + const int offset_w = w % block_size; + const int offset_d = (offset_h * block_size + offset_w) * input_depth; + const int out_d = d + offset_d; + const int out_pos = mad24(out_d, output_width, out_w); + + DATA_TYPE4 in_data = READ_IMAGET(input, SAMPLER, (int2)(in_pos, h)); + + WRITE_IMAGET(output, (int2)(out_pos, out_h), in_data); +} diff --git a/mace/kernels/opencl/depth_to_space_opencl.cc b/mace/kernels/opencl/depth_to_space_opencl.cc new file mode 100644 index 0000000000000000000000000000000000000000..c39c1a342c837e7aef4e9b5da03e401b012fc5e2 --- /dev/null +++ b/mace/kernels/opencl/depth_to_space_opencl.cc @@ -0,0 +1,96 @@ +// +// Copyright (c) 2018 XiaoMi All rights reserved. +// + +#include "mace/kernels/depth_to_space.h" +#include "mace/core/runtime/opencl/cl2_header.h" +#include "mace/core/runtime/opencl/opencl_runtime.h" +#include "mace/kernels/opencl/helper.h" +#include "mace/utils/tuner.h" +#include "mace/utils/utils.h" + +namespace mace { +namespace kernels { + +template +void DepthToSpaceOpFunctor::operator()( + const Tensor *input, Tensor *output, StatsFuture *future) { + const index_t batch = input->dim(0); + const index_t input_height = input->dim(1); + const index_t input_width = input->dim(2); + const index_t input_depth = input->dim(3); + + int depth_blocks = 1; + const char *kernel_name = nullptr; + + index_t output_height, output_width, output_depth; + if (d2s_) { + output_height = input_height * block_size_; + output_width = input_width * block_size_; + output_depth = input_depth / (block_size_ * block_size_); + depth_blocks = RoundUpDiv4(output_depth); + kernel_name = "depth_to_space"; + } else { + output_height = input_height / block_size_; + output_width = input_width / block_size_; + output_depth = input_depth * block_size_ * block_size_; + depth_blocks = RoundUpDiv4(input_depth); + kernel_name = "space_to_depth"; + } + + std::vector output_shape = {batch, output_height, output_width, + output_depth}; + + std::vector image_shape; + CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, &image_shape); + output->ResizeImage(output_shape, image_shape); + + if (kernel_.get() == nullptr) { + auto runtime = OpenCLRuntime::Global(); + std::set built_options; + std::string obfuscated_kernel_name = MACE_OBFUSCATE_SYMBOL(kernel_name); + std::stringstream kernel_name_ss; + kernel_name_ss << "-D" << kernel_name << "=" << obfuscated_kernel_name; + built_options.emplace(kernel_name_ss.str()); + auto dt = DataTypeToEnum::value; + built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt)); + built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt)); + kernel_ = + runtime->BuildKernel("depth_to_space", kernel_name, built_options); + } + if (!IsVecEqual(input_shape_, input->shape())) { + uint32_t idx = 0; + kernel_.setArg(idx++, *(input->opencl_image())); + kernel_.setArg(idx++, block_size_); + kernel_.setArg(idx++, depth_blocks); + kernel_.setArg(idx++, *(output->opencl_image())); + input_shape_ = input->shape(); + } + + if (d2s_) { + const uint32_t gws[3] = {static_cast(depth_blocks), + static_cast(output_width), + static_cast(output_height * batch)}; + const std::vector lws = {8, 16, 8, 1}; + std::stringstream ss; + ss << "depth_to_space_opencl_kernel_" << output->dim(0) << "_" + << output->dim(1) << "_" << output->dim(2) << "_" << output->dim(3); + + TuningOrRun3DKernel(kernel_, ss.str(), gws, lws, future); + } else { + const uint32_t gws[3] = {static_cast(depth_blocks), + static_cast(input_width), + static_cast(input_height * batch)}; + const std::vector lws = {8, 16, 8, 1}; + std::stringstream ss; + ss << "space_to_depth_opencl_kernel_" << input->dim(0) << "_" + << input->dim(1) << "_" << input->dim(2) << "_" << input->dim(3); + TuningOrRun3DKernel(kernel_, ss.str(), gws, lws, future); + } +} + +template struct DepthToSpaceOpFunctor; +template struct DepthToSpaceOpFunctor; + +} // namespace kernels +} // namespace mace diff --git a/mace/kernels/reorganize.h b/mace/kernels/reorganize.h new file mode 100644 index 0000000000000000000000000000000000000000..68c772090d5db75c5cf609da23ea82f2ccc844eb --- /dev/null +++ b/mace/kernels/reorganize.h @@ -0,0 +1,84 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// +#ifndef MACE_KERNELS_REORGANIZE_H_ +#define MACE_KERNELS_REORGANIZE_H_ + +#include + +#include "mace/core/future.h" +#include "mace/core/runtime/opencl/cl2_header.h" +#include "mace/core/tensor.h" + +namespace mace { +namespace kernels { + +template +struct ReOrganizeFunctor { + void operator()(const Tensor *input, + const std::vector &out_shape, + Tensor *output, + StatsFuture *future) { + const bool w2c = out_shape[3] > input->dim(3); + + const index_t height = input->dim(1); + const index_t input_width = input->dim(2); + const index_t input_chan = input->dim(3); + const index_t output_width = output->dim(2); + const index_t output_chan = output->dim(3); + + const T *input_ptr = input->data(); + T *output_ptr = output->mutable_data(); + + if (w2c) { + MACE_CHECK((out_shape[3] % input->dim(3)) == 0); + const index_t multiplier = out_shape[3] / input->dim(3); +#pragma omp parallel for collapse(4) + for (index_t n = 0; n < out_shape[0]; ++n) { + for (index_t h = 0; h < out_shape[1]; ++h) { + for (index_t w = 0; w < out_shape[2]; ++w) { + for (index_t c = 0; c < out_shape[3]; ++c) { + const index_t out_offset = + ((n * height + h) * output_width + w) + * output_chan + c; + const index_t in_w_idx = w + (c % multiplier) * output_width; + const index_t in_chan_idx = c / multiplier; + const index_t in_offset = + ((n * height + h) * input_width + in_w_idx) + * input_chan + in_chan_idx; + output_ptr[out_offset] = input_ptr[in_offset]; + } + } + } + } + } else { + MACE_CHECK((input->dim(3) % out_shape[3]) == 0); + const index_t multiplier = input->dim(3) / out_shape[3]; + +#pragma omp parallel for collapse(4) + for (index_t n = 0; n < out_shape[0]; ++n) { + for (index_t h = 0; h < out_shape[1]; ++h) { + for (index_t w = 0; w < out_shape[2]; ++w) { + for (index_t c = 0; c < out_shape[3]; ++c) { + const index_t out_offset = + ((n * height + h) * output_width + w) + * output_chan + c; + const index_t in_w_idx = w % input_width; + const index_t in_chan_idx = w / input_width + c * multiplier; + const index_t in_offset = + ((n * height + h) * input_width + in_w_idx) + * input_chan + in_chan_idx; + output_ptr[out_offset] = input_ptr[in_offset]; + } + } + } + } + } + + } +}; + +} // namespace kernels +} // namespace mace + +#endif // MACE_KERNELS_REORGANIZE_H_ diff --git a/mace/ops/depth_to_space.cc b/mace/ops/depth_to_space.cc new file mode 100644 index 0000000000000000000000000000000000000000..a8c4ef55bdef9dfe2c4290f7cf4e3215a852e6fb --- /dev/null +++ b/mace/ops/depth_to_space.cc @@ -0,0 +1,31 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/ops/depth_to_space.h" + +namespace mace { +namespace ops { + +void Register_DepthToSpace(OperatorRegistry *op_registry) { + REGISTER_OPERATOR(op_registry, OpKeyBuilder("DepthToSpace") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + DepthToSpaceOp); + + REGISTER_OPERATOR(op_registry, OpKeyBuilder("DepthToSpace") + .Device(DeviceType::OPENCL) + .TypeConstraint("T") + .Build(), + DepthToSpaceOp); + + REGISTER_OPERATOR(op_registry, OpKeyBuilder("DepthToSpace") + .Device(DeviceType::OPENCL) + .TypeConstraint("T") + .Build(), + DepthToSpaceOp); +} + +} // namespace ops +} // namespace mace diff --git a/mace/ops/depth_to_space.h b/mace/ops/depth_to_space.h new file mode 100644 index 0000000000000000000000000000000000000000..78ff39191943f1cc7c215e219fcdec607d3e6718 --- /dev/null +++ b/mace/ops/depth_to_space.h @@ -0,0 +1,53 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_OPS_DEPTH_TO_SPACE_H_ +#define MACE_OPS_DEPTH_TO_SPACE_H_ + +#include +#include + +#include "mace/core/operator.h" +#include "mace/kernels/depth_to_space.h" + +namespace mace { +namespace ops { + +template +class DepthToSpaceOp : public Operator { + public: + DepthToSpaceOp(const OperatorDef &op_def, Workspace *ws) + : Operator(op_def, ws), + functor_(OperatorBase::GetSingleArgument("block_size", 1), true) {} + + bool Run(StatsFuture *future) override { + const Tensor *input = this->Input(INPUT); + Tensor *output = this->Output(OUTPUT); + MACE_CHECK(input->dim_size() == 4, "input dim should be 4"); + + const int block_size = + OperatorBase::GetSingleArgument("block_size", 1); + + int input_depth = input->dim(3); + MACE_CHECK(input_depth % (block_size * block_size) == 0, + "input depth should be dividable by block_size * block_size", + input->dim(3)); + MACE_CHECK((input_depth % 4) == 0, + "input channel should be dividable by 4"); + functor_(input, output, future); + return true; + } + + protected: + OP_INPUT_TAGS(INPUT); + OP_OUTPUT_TAGS(OUTPUT); + + private: + kernels::DepthToSpaceOpFunctor functor_; +}; + +} // namespace ops +} // namespace mace + +#endif // MACE_OPS_DEPTH_TO_SPACE_H_ diff --git a/mace/ops/depth_to_space_benchmark.cc b/mace/ops/depth_to_space_benchmark.cc new file mode 100644 index 0000000000000000000000000000000000000000..c90a8bd81c278dc5dfc3a2470097234c6dbb39f6 --- /dev/null +++ b/mace/ops/depth_to_space_benchmark.cc @@ -0,0 +1,74 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/core/operator.h" +#include "mace/core/testing/test_benchmark.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +template +static void DepthToSpace( + int iters, int batch, int channels, int height, int width, int block_size) { + mace::testing::StopTiming(); + + OpsTestNet net; + + // Add input data + net.AddRandomInput("Input", {batch, height, width, channels}); + + if (D == DeviceType::OPENCL) { + BufferToImage(&net, "Input", "InputImage", + kernels::BufferType::IN_OUT_CHANNEL); + + OpDefBuilder("DepthToSpace", "DepthToSpaceBM") + .Input("InputImage") + .Output("Output") + .AddIntArg("block_size", block_size) + .Finalize(net.NewOperatorDef()); + } else { + OpDefBuilder("DepthToSpace", "DepthToSpaceBM") + .Input("Input") + .Output("Output") + .Finalize(net.NewOperatorDef()); + } + + // Warm-up + for (int i = 0; i < 5; ++i) { + net.RunOp(D); + } + net.Sync(); + + mace::testing::StartTiming(); + while (iters--) { + net.RunOp(D); + } + net.Sync(); +} + +#define BM_DEPTH_TO_SPACE_MACRO(N, C, H, W, G, TYPE, DEVICE) \ + static void \ + BM_DEPTH_TO_SPACE_##N##_##C##_##H##_##W##_##G##_##TYPE##_##DEVICE( \ + int iters) { \ + const int64_t tot = static_cast(iters) * N * C * H * W; \ + mace::testing::MaccProcessed(tot); \ + mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ + DepthToSpace(iters, N, C, H, W, G); \ + } \ + BENCHMARK(BM_DEPTH_TO_SPACE_##N##_##C##_##H##_##W##_##G##_##TYPE##_##DEVICE) + +#define BM_DEPTH_TO_SPACE(N, C, H, W, G) \ + BM_DEPTH_TO_SPACE_MACRO(N, C, H, W, G, float, CPU); \ + BM_DEPTH_TO_SPACE_MACRO(N, C, H, W, G, float, OPENCL); \ + BM_DEPTH_TO_SPACE_MACRO(N, C, H, W, G, half, OPENCL); + +BM_DEPTH_TO_SPACE(1, 64, 64, 64, 4); +BM_DEPTH_TO_SPACE(1, 64, 128, 128, 4); +BM_DEPTH_TO_SPACE(1, 64, 256, 256, 4); + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/depth_to_space_test.cc b/mace/ops/depth_to_space_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ba31174d5362001d5484bec51130a0a0b1f3c018 --- /dev/null +++ b/mace/ops/depth_to_space_test.cc @@ -0,0 +1,177 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/core/operator.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +template +void RunDepthToSpace(const bool d2s, + const std::vector &input_shape, + const std::vector &input_data, + const int block_size, + const std::vector &expected_shape, + const std::vector &expected_data) { + OpsTestNet net; + net.AddInputFromArray("Input", input_shape, input_data); + const char *ops_name = (d2s) ? "DepthToSpace" : "SpaceToDepth"; + const char *ops_test_name = (d2s) ? "DepthToSpaceTest" : "SpaceToDepthTest"; + // Construct graph + if (D == DeviceType::CPU) { + OpDefBuilder(ops_name, ops_test_name) + .Input("Input") + .Output("Output") + .AddIntArg("block_size", block_size) + .Finalize(net.NewOperatorDef()); + + } else { + BufferToImage(&net, "Input", "InputImage", + kernels::BufferType::IN_OUT_CHANNEL); + OpDefBuilder(ops_name, ops_test_name) + .Input("InputImage") + .Output("OutputImage") + .AddIntArg("block_size", block_size) + .Finalize(net.NewOperatorDef()); + } + // Run + net.RunOp(D); + + if (D == DeviceType::OPENCL) { + ImageToBuffer(&net, "OutputImage", "Output", + kernels::BufferType::IN_OUT_CHANNEL); + } + auto expected = CreateTensor(expected_shape, expected_data); + ExpectTensorNear(*expected, *net.GetOutput("Output"), 0.001); +} + +class SpaceToDepthOpTest : public OpsTestBase {}; + +TEST_F(SpaceToDepthOpTest, Input2x4x4_B2_CPU) { + RunDepthToSpace(false, {1, 2, 4, 4}, + {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23, + 8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}, + 2, + {1, 1, 2, 16}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}); +} + +TEST_F(SpaceToDepthOpTest, Input2x4x4_B2_OPENCL) { + RunDepthToSpace(false, {1, 2, 4, 4}, + {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23, + 8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}, + 2, + {1, 1, 2, 16}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}); +} + +TEST_F(SpaceToDepthOpTest, Input2x2x4_B2_CPU) { + RunDepthToSpace(false, {1, 2, 2, 4}, + {1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16}, + 2, + {1, 1, 1, 16}, + {1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16}); +} + +TEST_F(SpaceToDepthOpTest, Input4x4x1_B2_OPENCL) { + RunDepthToSpace(false, {1, 2, 2, 4}, + {1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16}, + 2, + {1, 1, 1, 16}, + {1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16}); +} + +class DepthToSpaceOpTest : public OpsTestBase {}; + +TEST_F(DepthToSpaceOpTest, Input1x2x16_B2_CPU) { + RunDepthToSpace(true, {1, 1, 2, 16}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, + 2, + {1, 2, 4, 4}, + {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23, + 8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}); +} + +TEST_F(DepthToSpaceOpTest, Input1x2x16_B2_OPENCL) { + RunDepthToSpace(true, {1, 1, 2, 16}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, + 2, + {1, 2, 4, 4}, + {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23, + 8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}); +} + +TEST_F(DepthToSpaceOpTest, Input1x1x16_B2_CPU) { + RunDepthToSpace(true, {1, 1, 1, 16}, + {1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16}, + 2, + {1, 2, 2, 4}, + {1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16}); +} + +TEST_F(DepthToSpaceOpTest, Input1x1x16_B2_OPENCL) { + RunDepthToSpace(true, {1, 1, 1, 16}, + {1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16}, + 2, + {1, 2, 2, 4}, + {1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16}); +} + +/* +TEST_F(DepthToSpaceOpTest, Input2x2x3_B2_CPU) { + + RunDepthToSpace({1, 2, 2, 3}, + {1, 2, 3, 4, 5, 6, + 7, 8, 9, 10, 11, 12}, + 2, + {1, 1, 1, 12}, + {1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12}); +} + +TEST_F(DepthToSpaceOpTest, Input2x2x3_B2_OPENCL) { + RunDepthToSpace({1, 2, 2, 6}, + {1, 2, 3, 4, 5, 6, + 7, 8, 9, 10, 11, 12 + }, + 2, + {1, 1, 1, 12}, + {1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12}); +} + +TEST_F(DepthToSpaceOpTest, Input2x2x2_B2_CPU) { + + RunDepthToSpace({1, 2, 2, 2}, + {1, 10, 2, 20, 3, 30, 4, 40}, + 2, + {1, 1, 1, 8}, + {1, 10, 2, 20, 3, 30, 4, 40}); +} + +TEST_F(DepthToSpaceOpTest, Input2x2x2_B2_OPENCL) { + + RunDepthToSpace({1, 2, 2, 2}, + {1, 10, 2, 20, 3, 30, 4, 40}, + 2, + {1, 1, 1, 8}, + {1, 10, 2, 20, 3, 30, 4, 40}); +}*/ +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/proposal.h b/mace/ops/proposal.h index 06dcc8a1b02b030a82e8bf5508421f0342decc46..6bd1c15917b57e7334a604986ec40cb03471871c 100644 --- a/mace/ops/proposal.h +++ b/mace/ops/proposal.h @@ -16,12 +16,12 @@ class ProposalOp : public Operator { public: ProposalOp(const OperatorDef &operator_def, Workspace *ws) : Operator(operator_def, ws), - functor_(OperatorBase::GetSingleArgument("min_size", 0), - OperatorBase::GetSingleArgument("nms_thresh", 0), - OperatorBase::GetSingleArgument("pre_nms_top_n", 0), - OperatorBase::GetSingleArgument("post_nms_top_n", 0), + functor_(OperatorBase::GetSingleArgument("min_size", 16), + OperatorBase::GetSingleArgument("nms_thresh", 0.7), + OperatorBase::GetSingleArgument("pre_nms_top_n", 6000), + OperatorBase::GetSingleArgument("post_nms_top_n", 300), OperatorBase::GetSingleArgument("feat_stride", 0), - OperatorBase::GetSingleArgument("base_size", 16), + OperatorBase::GetSingleArgument("base_size", 12), OperatorBase::GetRepeatedArgument("scales"), OperatorBase::GetRepeatedArgument("ratios")) {} diff --git a/mace/ops/reorganize.cc b/mace/ops/reorganize.cc new file mode 100644 index 0000000000000000000000000000000000000000..794464cfb473005cb4dc76271bc470227191d104 --- /dev/null +++ b/mace/ops/reorganize.cc @@ -0,0 +1,19 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/ops/reorganize.h" + +namespace mace { +namespace ops { + +void Register_ReOrganize(OperatorRegistry *op_registry) { + REGISTER_OPERATOR(op_registry, OpKeyBuilder("ReOrganize") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + ReOrganizeOp); +} + +} // namespace ops +} // namespace mace diff --git a/mace/ops/reorganize.h b/mace/ops/reorganize.h new file mode 100644 index 0000000000000000000000000000000000000000..63b6110701a9c477982ef38f54489479ead89a1b --- /dev/null +++ b/mace/ops/reorganize.h @@ -0,0 +1,71 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_OPS_REORGANIZE_H_ +#define MACE_OPS_REORGANIZE_H_ + +#include + +#include "mace/core/operator.h" +#include "mace/kernels/reorganize.h" + +namespace mace { +namespace ops { + +template +class ReOrganizeOp : public Operator { + public: + ReOrganizeOp(const OperatorDef &op_def, Workspace *ws) + : Operator(op_def, ws), + shape_(OperatorBase::GetRepeatedArgument("shape")) {} + + bool Run(StatsFuture *future) override { + const Tensor *input = this->Input(INPUT); + const index_t num_dims = shape_.size(); + int unknown_idx = -1; + index_t product = 1; + std::vector out_shape; + + for (int i = 0; i < num_dims; ++i) { + if (shape_[i] == -1) { + MACE_CHECK(unknown_idx == -1) << "Only one input size may be -1"; + unknown_idx = i; + out_shape.push_back(1); + } else { + MACE_CHECK(shape_[i] >= 0) << "Shape must be non-negative: " + << shape_[i]; + out_shape.push_back(shape_[i]); + product *= shape_[i]; + } + } + + if (unknown_idx != -1) { + MACE_CHECK(product != 0) + << "Cannot infer shape if there is zero shape size."; + const index_t missing = input->size() / product; + MACE_CHECK(missing * product == input->size()) + << "Input size not match reshaped tensor size"; + out_shape[unknown_idx] = missing; + } + + Tensor *output = this->Output(OUTPUT); + output->Resize(out_shape); + + functor_(input, out_shape, output, future); + return true; + } + + private: + std::vector shape_; + kernels::ReOrganizeFunctor functor_; + + protected: + OP_INPUT_TAGS(INPUT); + OP_OUTPUT_TAGS(OUTPUT); +}; + +} // namespace ops +} // namespace mace + +#endif // MACE_OPS_REORGANIZE_H_ diff --git a/mace/ops/reorganize_test.cc b/mace/ops/reorganize_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..68e0886718d8728371878eff2eaa2e2b505d22d6 --- /dev/null +++ b/mace/ops/reorganize_test.cc @@ -0,0 +1,107 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "gmock/gmock.h" +#include "mace/core/operator.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +class ReOrganizeTest : public OpsTestBase {}; + +void TestReOrganize(const std::vector &input_shape, + const std::vector &input_data, + const std::vector &output_shape, + const std::vector &output_data) { + const std::vector out_shape(output_shape.begin(), output_shape.end()); + + // Construct graph + OpsTestNet net; + + OpDefBuilder("ReOrganize", "ReOrganizeTest") + .Input("Input") + .Output("Output") + .AddIntsArg("shape", out_shape) + .Finalize(net.NewOperatorDef()); + + // Add input data + net.AddInputFromArray("Input", + input_shape, input_data); + + // Run + net.RunOp(); + + auto output = net.GetTensor("Output"); + + EXPECT_THAT(output->shape(), ::testing::ContainerEq(output_shape)); + + const float *output_ptr = output->data(); + int size = output->size(); + for (int i = 0; i < size; ++i) { + ASSERT_EQ(output_data[i], output_ptr[i]) << "With Index " << i; + } + + // Reverse reorganzie + const std::vector in_shape(input_shape.begin(), input_shape.end()); + OpDefBuilder("ReOrganize", "ReOrganizeTest") + .Input("Input") + .Output("Output") + .AddIntsArg("shape", in_shape) + .Finalize(net.NewOperatorDef()); + + // Add input data + net.AddInputFromArray("Input", + output_shape, output_data); + + // Run + net.RunOp(); + + output = net.GetTensor("Output"); + + EXPECT_THAT(output->shape(), ::testing::ContainerEq(input_shape)); + + output_ptr = output->data(); + size = output->size(); + for (int i = 0; i < size; ++i) { + ASSERT_EQ(input_data[i], output_ptr[i]) << "With Index " << i; + } +} + +TEST_F(ReOrganizeTest, Simple) { + TestReOrganize({1, 1, 4, 6}, + {0, 4, 8, 12, 16, 20, + 1, 5, 9, 13, 17, 21, + 2, 6, 10, 14, 18, 22, + 3, 7, 11, 15, 19, 23}, + {1, 1, 8, 3}, + {0, 8, 16, 1, 9, 17, 2, 10, 18, 3, 11, 19, + 4, 12, 20, 5, 13, 21, 6, 14, 22, 7, 15, 23}); + TestReOrganize({1, 1, 5, 6}, + {0, 5, 10, 15, 20, 25, + 1, 6, 11, 16, 21, 26, + 2, 7, 12, 17, 22, 27, + 3, 8, 13, 18, 23, 28, + 4, 9, 14, 19, 24, 29}, + {1, 1, 10, 3}, + {0, 10, 20, 1, 11, 21, 2, 12, 22, 3, 13, 23, + 4, 14, 24, 5, 15, 25, 6, 16, 26, 7, 17, 27, + 8, 18, 28, 9, 19, 29}); +} + +TEST_F(ReOrganizeTest, Complex) { + TestReOrganize({1, 2, 2, 6}, + {0, 4, 8, 12, 16, 20, + 1, 5, 9, 13, 17, 21, + 2, 6, 10, 14, 18, 22, + 3, 7, 11, 15, 19, 23}, + {1, 2, 6, 2}, + {0, 12, 1, 13, 4, 16, 5, 17, 8, 20, 9, 21, + 2, 14, 3, 15, 6, 18, 7, 19, 10, 22, 11, 23}); +} + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/space_to_depth.cc b/mace/ops/space_to_depth.cc new file mode 100644 index 0000000000000000000000000000000000000000..55f1a13a4f80b5a88c1f318733f11db1abf2a872 --- /dev/null +++ b/mace/ops/space_to_depth.cc @@ -0,0 +1,31 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/ops/space_to_depth.h" + +namespace mace { +namespace ops { + +void Register_SpaceToDepth(OperatorRegistry *op_registry) { + REGISTER_OPERATOR(op_registry, OpKeyBuilder("SpaceToDepth") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + SpaceToDepthOp); + + REGISTER_OPERATOR(op_registry, OpKeyBuilder("SpaceToDepth") + .Device(DeviceType::OPENCL) + .TypeConstraint("T") + .Build(), + SpaceToDepthOp); + + REGISTER_OPERATOR(op_registry, OpKeyBuilder("SpaceToDepth") + .Device(DeviceType::OPENCL) + .TypeConstraint("T") + .Build(), + SpaceToDepthOp); +} + +} // namespace ops +} // namespace mace diff --git a/mace/ops/space_to_depth.h b/mace/ops/space_to_depth.h new file mode 100644 index 0000000000000000000000000000000000000000..517d8ccc8f8938214aefc50cfea091133d455466 --- /dev/null +++ b/mace/ops/space_to_depth.h @@ -0,0 +1,55 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_OPS_SPACE_TO_DEPTH_H_ +#define MACE_OPS_SPACE_TO_DEPTH_H_ + +#include +#include + +#include "mace/core/operator.h" +#include "mace/kernels/depth_to_space.h" + +namespace mace { +namespace ops { + +template +class SpaceToDepthOp : public Operator { + public: + SpaceToDepthOp(const OperatorDef &op_def, Workspace *ws) + : Operator(op_def, ws), + functor_(OperatorBase::GetSingleArgument("block_size", 1), false) { + } + + bool Run(StatsFuture *future) override { + const Tensor *input = this->Input(INPUT); + Tensor *output = this->Output(OUTPUT); + MACE_CHECK(input->dim_size() == 4, "input dim should be 4"); + const int block_size = + OperatorBase::GetSingleArgument("block_size", 1); + const int input_height = input->dim(1); + const int input_width = input->dim(2); + const int input_depth = input->dim(3); + MACE_CHECK((input_depth % 4) == 0, + "input channel should be dividable by 4"); + MACE_CHECK( + (input_width%block_size == 0)&&(input_height%block_size == 0), + "input width and height should be dividable by block_size", + input->dim(3)); + functor_(input, output, future); + return true; + } + + protected: + OP_INPUT_TAGS(INPUT); + OP_OUTPUT_TAGS(OUTPUT); + + private: + kernels::DepthToSpaceOpFunctor functor_; +}; + +} // namespace ops +} // namespace mace + +#endif // MACE_OPS_SPACE_TO_DEPTH_H_ diff --git a/mace/ops/space_to_depth_benchmark.cc b/mace/ops/space_to_depth_benchmark.cc new file mode 100644 index 0000000000000000000000000000000000000000..c97028c4c85cd792769f4fd69fc19ffe9a1280c0 --- /dev/null +++ b/mace/ops/space_to_depth_benchmark.cc @@ -0,0 +1,74 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/core/operator.h" +#include "mace/core/testing/test_benchmark.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +template +static void SpaceToDepth( + int iters, int batch, int channels, int height, int width, int block_size) { + mace::testing::StopTiming(); + + OpsTestNet net; + + // Add input data + net.AddRandomInput("Input", {batch, height, width, channels}); + + if (D == DeviceType::OPENCL) { + BufferToImage(&net, "Input", "InputImage", + kernels::BufferType::IN_OUT_CHANNEL); + + OpDefBuilder("SpaceToDepth", "SpaceToDepthBM") + .Input("InputImage") + .Output("Output") + .AddIntArg("block_size", block_size) + .Finalize(net.NewOperatorDef()); + } else { + OpDefBuilder("SpaceToDepth", "SpaceToDepthBM") + .Input("Input") + .Output("Output") + .Finalize(net.NewOperatorDef()); + } + + // Warm-up + for (int i = 0; i < 5; ++i) { + net.RunOp(D); + } + net.Sync(); + + mace::testing::StartTiming(); + while (iters--) { + net.RunOp(D); + } + net.Sync(); +} + +#define BM_SPACE_TO_DEPTH_MACRO(N, C, H, W, G, TYPE, DEVICE) \ + static void \ + BM_SPACE_TO_DEPTH_##N##_##C##_##H##_##W##_##G##_##TYPE##_##DEVICE( \ + int iters) { \ + const int64_t tot = static_cast(iters) * N * C * H * W; \ + mace::testing::MaccProcessed(tot); \ + mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ + SpaceToDepth(iters, N, C, H, W, G); \ + } \ + BENCHMARK(BM_SPACE_TO_DEPTH_##N##_##C##_##H##_##W##_##G##_##TYPE##_##DEVICE) + +#define BM_SPACE_TO_DEPTH(N, C, H, W, G) \ + BM_SPACE_TO_DEPTH_MACRO(N, C, H, W, G, float, CPU); \ + BM_SPACE_TO_DEPTH_MACRO(N, C, H, W, G, float, OPENCL); \ + BM_SPACE_TO_DEPTH_MACRO(N, C, H, W, G, half, OPENCL); + +BM_SPACE_TO_DEPTH(1, 64, 64, 64, 4); +BM_SPACE_TO_DEPTH(1, 64, 128, 128, 4); +BM_SPACE_TO_DEPTH(1, 64, 256, 256, 4); + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/proto/caffe.proto b/mace/proto/caffe.proto index 22764abc33fda32026bf436b685d79aa18ade460..cec617b99b8a5cde9e93e2bb14be0cab21794908 100644 --- a/mace/proto/caffe.proto +++ b/mace/proto/caffe.proto @@ -404,7 +404,9 @@ message LayerParameter { optional ParameterParameter parameter_param = 145; optional PoolingParameter pooling_param = 121; optional PowerParameter power_param = 122; + optional ProposalParameter proposal_param = 8266713; optional PReLUParameter prelu_param = 131; + optional PSROIAlignParameter psroi_align_param = 1490; optional PythonParameter python_param = 130; optional RecurrentParameter recurrent_param = 146; optional ReductionParameter reduction_param = 136; @@ -944,6 +946,19 @@ message PowerParameter { optional float shift = 3 [default = 0.0]; } +// Message that stores parameters used by ProposalLayer +message ProposalParameter { + optional uint32 feat_stride = 1 [default = 16]; + repeated uint32 scales = 2; + repeated float ratios = 3; +} + +message PSROIAlignParameter { + required float spatial_scale = 1; + required int32 output_dim = 2; // output channel number + required int32 group_size = 3; // number of groups to encode position-sensitive score maps +} + message PythonParameter { optional string module = 1; optional string layer = 2; diff --git a/mace/python/tools/caffe_converter_lib.py b/mace/python/tools/caffe_converter_lib.py index 5db1be84f8fd5ba79962f498de31551645477fb8..7c7cd9abd71cb8b4720f782ffc71835033c3e97c 100644 --- a/mace/python/tools/caffe_converter_lib.py +++ b/mace/python/tools/caffe_converter_lib.py @@ -784,21 +784,89 @@ class CaffeConverter(object): self.net_def.op.extend([op_def]) self.resolved_ops.add(op.name) + def convert_reshape(self, op): + op_def = self.CommonConvert(op, 'ReOrganize') + input_shape = op.parents[0].output_shape_map[op.layer.bottom[0]] + output_shape = input_shape + shape_param = np.asarray(op.layer.reshape_param.shape.dim)[[0, 3, 2, 1]] + print shape_param + for i in range(len(shape_param)): + if shape_param[i] != 0: + output_shape[i] = shape_param[i] + shape_arg = op_def.arg.add() + shape_arg.name = 'shape' + shape_arg.ints.extend(output_shape) + op.output_shape_map[op.layer.top[0]] = output_shape + self.add_output_shape(op_def, output_shape) + op_def.output.extend([op.name + ':0']) + self.net_def.op.extend([op_def]) + self.resolved_ops.add(op.name) + + def convert_proposal_op(self, op): + assert self.device == 'cpu' + op_def = self.CommonConvert(op, op.type) + if op.layer.HasField('proposal_param'): + proposal_param = op.layer.proposal_param + feat_stride_arg = op_def.arg.add() + feat_stride_arg.name = 'feat_stride' + feat_stride_arg.i = proposal_param.feat_stride + scales_arg = op_def.arg.add() + scales_arg.name = 'scales' + scales_arg.ints.extend(list(proposal_param.scales)) + ratios_arg = op_def.arg.add() + ratios_arg.name = 'ratios' + ratios_arg.floats.extend(list(proposal_param.ratios)) + output_shape = op.parents[0].output_shape_map[op.layer.bottom[0]] + op.output_shape_map[op.layer.top[0]] = output_shape + self.add_output_shape(op_def, output_shape) + op_def.output.extend([op.name + ':0']) + self.net_def.op.extend([op_def]) + self.resolved_ops.add(op.name) + + def convert_psroi_align(self, op): + assert self.device == 'cpu' + op_def = self.CommonConvert(op, op.type) + if op.layer.HasField('psroi_align_param'): + psroi_align_param = op.layer.psroi_align_param + spatial_scale_arg = op_def.arg.add() + spatial_scale_arg.name = 'spatial_scale' + spatial_scale_arg.f = psroi_align_param.spatial_scale + output_dim_arg = op_def.arg.add() + output_dim_arg.name = 'output_dim' + output_dim_arg.i = psroi_align_param.output_dim + group_size_arg = op_def.arg.add() + group_size_arg.name = 'group_size' + group_size_arg.i = psroi_align_param.group_size + output_shape = op.parents[0].output_shape_map[op.layer.bottom[0]] + op.output_shape_map[op.layer.top[0]] = output_shape + self.add_output_shape(op_def, output_shape) + op_def.output.extend([op.name + ':0']) + self.net_def.op.extend([op_def]) + self.resolved_ops.add(op.name) + def replace_in_out_name(self, input_names, output_names, is_single): in_names = set([input_name + ":0" for input_name in input_names]) out_names = set([output_name + ":0" for output_name in output_names]) if is_single: for op in self.net_def.op: - if len(op.input) > 0 and op.input[0] in in_names: - op.input[0] = MACE_INPUT_NODE_NAME + ':0' - if len(op.output) > 0 and op.output[0] in out_names: - op.output[0] = MACE_OUTPUT_NODE_NAME + ':0' + for i in range(len(op.input)): + if op.input[i] in in_names: + op.input[i] = MACE_INPUT_NODE_NAME + ':0' + for i in range(len(op.output)): + if op.output[i] in out_names: + op.output[i] = MACE_OUTPUT_NODE_NAME + ':0' else: for op in self.net_def.op: - if len(op.input) > 0 and op.input[0] in in_names: - op.input[0] = MACE_INPUT_NODE_NAME + '_' + op.input[0] - if len(op.output) > 0 and op.output[0] in out_names: - op.output[0] = MACE_OUTPUT_NODE_NAME + '_' + op.output[0] + for i in range(len(op.input)): + if op.input[i] in in_names: + op.input[i] = MACE_INPUT_NODE_NAME + '_' + op.input[i] + if op.input[i] in out_names: + op.input[i] = MACE_OUTPUT_NODE_NAME + '_' + op.input[i] + for i in range(len(op.output)): + if op.output[i] in in_names: + op.output[i] = MACE_INPUT_NODE_NAME + '_' + op.output[i] + if op.output[i] in out_names: + op.output[i] = MACE_OUTPUT_NODE_NAME + '_' + op.output[i] def add_input_op_shape(self, input_nodes, input_shapes): assert len(input_nodes) == len(input_shapes) @@ -843,10 +911,16 @@ class CaffeConverter(object): self.convert_concat(op) elif op.type == 'Eltwise': self.convert_eltwise(op) - elif op.type in ['Softmax']: - self.convert_normal_op(op) elif op.type == 'Slice': self.convert_slice(op) + elif op.type == 'Reshape': + self.convert_reshape(op) + elif op.type == 'Proposal': + self.convert_proposal_op(op) + elif op.type == 'PSROIAlign': + self.convert_psroi_align(op) + elif op.type in ['Softmax']: + self.convert_normal_op(op) else: raise Exception('Unknown Op: %s, type: %s' % (op.name, op.type)) diff --git a/tools/env.sh b/tools/env.sh index f48787a8956ac79349379207b054b0e7c4723e5f..e61180e1dd32a0fca24c886e9aab2cf2a5542c53 100644 --- a/tools/env.sh +++ b/tools/env.sh @@ -2,8 +2,6 @@ LIBMACE_TAG=`git describe --abbrev=0 --tags` MACE_SOURCE_DIR=`/bin/pwd` -INPUT_FILE_NAME="model_input" -OUTPUT_FILE_NAME="model_out" PHONE_DATA_DIR="/data/local/tmp/mace_run" KERNEL_DIR="${PHONE_DATA_DIR}/cl/" CODEGEN_DIR=${MACE_SOURCE_DIR}/mace/codegen diff --git a/tools/mace_tools.py b/tools/mace_tools.py index 5afe2ee168dd9e1eb69ced19a5653d64204d5d34..c9a22f6472e33f8b8245cee9da5796c32d5d5e1d 100644 --- a/tools/mace_tools.py +++ b/tools/mace_tools.py @@ -14,6 +14,7 @@ import subprocess import sys import urllib import yaml +import re import adb_tools @@ -64,13 +65,37 @@ def clear_env(target_soc): command = "bash tools/clear_env.sh {}".format(target_soc) run_command(command) +def input_file_name(input_name): + return os.environ['INPUT_FILE_NAME'] + '_' + \ + re.sub('[^0-9a-zA-Z]+', '_', input_name) -def generate_random_input(target_soc, model_output_dir): +def generate_random_input(target_soc, model_output_dir, + input_names, input_files): generate_data_or_not = True command = "bash tools/validate_tools.sh {} {} {}".format( target_soc, model_output_dir, int(generate_data_or_not)) run_command(command) + input_name_list = [] + input_file_list = [] + if isinstance(input_names, list): + input_name_list.extend(input_names) + else: + input_name_list.append(input_names) + if isinstance(input_files, list): + input_file_list.extend(input_files) + else: + input_file_list.append(input_files) + assert len(input_file_list) == len(input_name_list) + for i in range(len(input_file_list)): + if input_file_list[i] is not None: + dst_input_file = model_output_dir + '/' + input_file_name(input_name_list[i]) + if input_file_list[i].startswith("http://") or \ + input_file_list[i].startswith("https://"): + urllib.urlretrieve(input_file_list[i], dst_input_file) + else: + print 'Copy input data:', dst_input_file + shutil.copy(input_file_list[i], dst_input_file) def generate_model_code(): command = "bash tools/generate_model_code.sh" @@ -215,6 +240,13 @@ def parse_args(): help="SoCs to build, comma seperated list (getprop ro.board.platform)") return parser.parse_known_args() +def set_environment(configs): + os.environ["EMBED_MODEL_DATA"] = str(configs["embed_model_data"]) + os.environ["VLOG_LEVEL"] = str(configs["vlog_level"]) + os.environ["PROJECT_NAME"] = os.path.splitext(os.path.basename( + FLAGS.config))[0] + os.environ['INPUT_FILE_NAME'] = "model_input" + os.environ['OUTPUT_FILE_NAME'] = "model_out" def main(unused_args): configs = parse_model_configs() @@ -223,10 +255,7 @@ def main(unused_args): FLAGS.round = 1 FLAGS.restart_round = 1 - os.environ["EMBED_MODEL_DATA"] = str(configs["embed_model_data"]) - os.environ["VLOG_LEVEL"] = str(configs["vlog_level"]) - os.environ["PROJECT_NAME"] = os.path.splitext(os.path.basename( - FLAGS.config))[0] + set_environment(configs) if FLAGS.mode == "build" or FLAGS.mode == "all": # Remove previous output dirs @@ -266,6 +295,7 @@ def main(unused_args): skip_validation = configs["models"][model_name].get( "skip_validation", 0) model_config = configs["models"][model_name] + input_file_list = model_config.get("input_files", []) for key in model_config: if key in ['input_nodes', 'output_nodes'] and isinstance( model_config[key], list): @@ -310,7 +340,8 @@ def main(unused_args): if FLAGS.mode == "build" or FLAGS.mode == "run" or FLAGS.mode == "validate"\ or FLAGS.mode == "benchmark" or FLAGS.mode == "all": - generate_random_input(target_soc, model_output_dir) + generate_random_input(target_soc, model_output_dir, + model_config['input_nodes'], input_file_list) if FLAGS.mode == "build" or FLAGS.mode == "all": generate_model_code() @@ -336,7 +367,7 @@ def main(unused_args): if FLAGS.mode == "throughput_test": merged_lib_file = FLAGS.output_dir + "/%s/%s/libmace_%s.%s.a" % \ (os.environ["PROJECT_NAME"], target_abi, os.environ["PROJECT_NAME"], target_soc) - generate_random_input(target_soc, FLAGS.output_dir) + generate_random_input(target_soc, FLAGS.output_dir, [], []) for model_name in configs["models"]: runtime = configs["models"][model_name]["runtime"] os.environ["%s_MODEL_TAG" % runtime.upper()] = model_name diff --git a/tools/validate.py b/tools/validate.py index 4aaceacdc3fe8c5b52108e40e850c8d737bad208..d46284dcfc01067cbd2641877592c107cce8f460 100644 --- a/tools/validate.py +++ b/tools/validate.py @@ -97,14 +97,17 @@ def validate_caffe_model(input_names, input_shapes, output_names, output_shapes) input_value = load_data(FLAGS.input_file + "_" + input_names[i]) input_value = input_value.reshape(input_shapes[i]).transpose((0, 3, 1, 2)) input_blob_name = input_names[i] - if input_names[i] in net.top_names: - input_blob_name = net.top_names[input_names[i]][0] + try: + if input_names[i] in net.top_names: + input_blob_name = net.top_names[input_names[i]][0] + except ValueError: + pass net.blobs[input_blob_name].data[0] = input_value net.forward() for i in range(len(output_names)): - value = net.blobs[net.top_names[output_names[i]][0]].data[0] + value = net.blobs[net.top_names[output_names[i]][0]].data out_shape = output_shapes[i] out_shape[1], out_shape[2], out_shape[3] = out_shape[3], out_shape[1], out_shape[2] value = value.reshape(out_shape).transpose((0, 2, 3, 1))