diff --git a/mace/core/mace.cc b/mace/core/mace.cc index 5248318164b6ec34dff05a12c3ec54bd87c4c916..557ef20fe16abbdb13b0ffbecc6ca2cdca9a52e0 100644 --- a/mace/core/mace.cc +++ b/mace/core/mace.cc @@ -153,7 +153,9 @@ void OperatorDef::CopyFrom(const OperatorDef &from) { output_type_.resize(from_data_type.size()); std::copy(from_data_type.begin(), from_data_type.end(), output_type_.begin()); - mem_id_ = from.mem_id(); + auto mem_ids = from.mem_id(); + mem_id_.resize(mem_ids.size()); + std::copy(mem_ids.begin(), mem_ids.end(), mem_id_.begin()); // nnlib node_id_ = from.node_id(); @@ -186,13 +188,11 @@ void OperatorDef::set_type(const std::string &type_) { } bool OperatorDef::has_type() const { return (has_bits_ & 0x00000002u) != 0; } void OperatorDef::set_has_type() { has_bits_ |= 0x00000002u; } -int OperatorDef::mem_id() const { return mem_id_; } -void OperatorDef::set_mem_id(const int mem_id) { - set_has_mem_id(); - mem_id_ = mem_id; +const std::vector &OperatorDef::mem_id() const { return mem_id_; } +void OperatorDef::set_mem_id(const std::vector &value) { + mem_id_.resize(value.size()); + std::copy(value.begin(), value.end(), mem_id_.begin()); } -bool OperatorDef::has_mem_id() const { return (has_bits_ & 0x00000004u) != 0; } -void OperatorDef::set_has_mem_id() { has_bits_ |= 0x00000004u; } uint32_t OperatorDef::node_id() const { return node_id_; } void OperatorDef::set_node_id(uint32_t node_id) { node_id_ = node_id; } uint32_t OperatorDef::op_id() const { return op_id_; } diff --git a/mace/core/operator.cc b/mace/core/operator.cc index 826f424bfde34a85657bfb55f49b46ce9b2a0822..c670d9aa729dc575a204eacf3789fd56675df4a9 100644 --- a/mace/core/operator.cc +++ b/mace/core/operator.cc @@ -83,6 +83,7 @@ extern void Register_WinogradInverseTransform(OperatorRegistry *op_registry); extern void Register_Reshape(OperatorRegistry *op_registry); extern void Register_Eltwise(OperatorRegistry *op_registry); extern void Register_FullyConnected(OperatorRegistry *op_registry); +extern void Register_Slice(OperatorRegistry *op_registry); OperatorRegistry::OperatorRegistry() { Register_Activation(this); @@ -109,6 +110,7 @@ OperatorRegistry::OperatorRegistry() { Register_Reshape(this); Register_Eltwise(this); Register_FullyConnected(this); + Register_Slice(this); } } // namespace mace diff --git a/mace/core/workspace.cc b/mace/core/workspace.cc index 1cfa1802f07ab1bf6e44d8e86518ef7667575263..2cb5e237ab7c81e72c01df0f4850a9d3c5583389 100644 --- a/mace/core/workspace.cc +++ b/mace/core/workspace.cc @@ -116,7 +116,7 @@ void Workspace::CreateImageOutputTensor(const NetDef &net_def) { // As DSP may have different data output type for each op, // we stick to the same concept. for (auto &op : net_def.op()) { - if (op.has_mem_id()) { + if (! op.mem_id().empty()){ const DataType op_dtype = static_cast( ArgumentHelper::GetSingleArgument( op, "T", static_cast(DT_FLOAT))); @@ -135,18 +135,20 @@ void Workspace::CreateImageOutputTensor(const NetDef &net_def) { } VLOG(3) << "Preallocate image to tensors"; for (auto &op : net_def.op()) { - if (op.has_mem_id()) { - std::unique_ptr tensor( - new Tensor(preallocated_allocator_.GetBuffer(op.mem_id()), dtype)); - tensor->SetSourceOpName(op.name()); - VLOG(3) - << "Tensor: " << op.name() << "(" << op.type() << ")" - << "; Mem: " << op.mem_id() << "; Image shape: " - << dynamic_cast(tensor->UnderlyingBuffer())->image_shape()[0] - << ", " - << dynamic_cast(tensor->UnderlyingBuffer()) - ->image_shape()[1]; - tensor_map_[op.output(0)] = std::move(tensor); + if (!op.mem_id().empty()) { + auto mem_ids = op.mem_id(); + int count = mem_ids.size(); + for (int i = 0; i < count; ++i) { + std::unique_ptr tensor + (new Tensor(preallocated_allocator_.GetBuffer(mem_ids[i]), dtype)); + tensor->SetSourceOpName(op.name()); + VLOG(3) << "Tensor: " << op.name() << "(" << op.type() << ")" << "; Mem: " + << mem_ids[i] << "; Image shape: " + << dynamic_cast(tensor->UnderlyingBuffer())->image_shape()[0] + << ", " + << dynamic_cast(tensor->UnderlyingBuffer())->image_shape()[1]; + tensor_map_[op.output(i)] = std::move(tensor); + } } } } diff --git a/mace/kernels/opencl/cl/slice.cl b/mace/kernels/opencl/cl/slice.cl new file mode 100644 index 0000000000000000000000000000000000000000..d8d45bcbcfa4fd6416ab6ea417841e379082af50 --- /dev/null +++ b/mace/kernels/opencl/cl/slice.cl @@ -0,0 +1,15 @@ +#include + +__kernel void slice(__read_only image2d_t input, + __private const int chan_blk_offset, + __write_only image2d_t output) { + const int chan_blk_idx = get_global_id(0); + const int width_idx = get_global_id(1); + const int width = get_global_size(1); + const int hb_idx = get_global_id(2); + DATA_TYPE4 data = READ_IMAGET(input, SAMPLER, + (int2)(mad24(chan_blk_idx + chan_blk_offset, + width, width_idx), hb_idx)); + WRITE_IMAGET(output, + (int2)(mad24(chan_blk_idx, width, width_idx), hb_idx), data); +} diff --git a/mace/kernels/opencl/concat.cc b/mace/kernels/opencl/concat.cc index 9cd508bdfde411d2a55c97712d00c2066b0f82c6..119ec7cd61a99f915ff7ec29443839ea2923d3a4 100644 --- a/mace/kernels/opencl/concat.cc +++ b/mace/kernels/opencl/concat.cc @@ -72,8 +72,6 @@ static void ConcatN(cl::Kernel *kernel, const index_t width = output->dim(2); const index_t channel = output->dim(3); - const int channel_blk = RoundUpDiv4(channel); - if (kernel->get() == nullptr) { auto runtime = OpenCLRuntime::Global(); std::set built_options; diff --git a/mace/kernels/opencl/slice.cc b/mace/kernels/opencl/slice.cc new file mode 100644 index 0000000000000000000000000000000000000000..63efc555dbf8a743e3fc6881a06e0202480bbd16 --- /dev/null +++ b/mace/kernels/opencl/slice.cc @@ -0,0 +1,73 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/kernels/slice.h" +#include "mace/core/runtime/opencl/opencl_runtime.h" +#include "mace/kernels/opencl/helper.h" +#include "mace/utils/tuner.h" + +namespace mace { +namespace kernels { + +template +void SliceFunctor::operator()( + const Tensor *input, + const std::vector &output_list, + StatsFuture *future) { + const index_t input_channels = input->dim(3); + const size_t outputs_count = output_list.size(); + const index_t output_channels = input_channels / outputs_count; + MACE_CHECK(output_channels % 4 == 0) + << "output channels of slice op must be divisible by 4"; + std::vector output_shape({input->dim(0), input->dim(1), + input->dim(2), output_channels}); + + std::vector image_shape; + CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, image_shape); + for (size_t i= 0; i < outputs_count; ++i) { + output_list[i]->ResizeImage(output_shape, image_shape); + } + + if (kernel_.get() == nullptr) { + auto runtime = OpenCLRuntime::Global(); + std::set built_options; + std::string kernel_name = MACE_OBFUSCATE_SYMBOL("slice"); + built_options.emplace("-Dslice=" + kernel_name); + built_options.emplace("-DDATA_TYPE=" + DtToCLDt(DataTypeToEnum::value)); + built_options.emplace("-DCMD_DATA_TYPE=" + + DtToCLCMDDt(DataTypeToEnum::value)); + kernel_ = runtime->BuildKernel("slice", kernel_name, built_options); + } + const index_t channel_blk = RoundUpDiv4(output_channels); + + const uint32_t gws[3] = { + static_cast(channel_blk), + static_cast(input->dim(2)), + static_cast(input->dim(0) * input->dim(1)), + }; + const std::vector lws = {8, 16, 8, 1}; + std::stringstream ss; + ss << "slice_opencl_kernel_" + << input->dim(0) << "_" + << input->dim(1) << "_" + << input->dim(2) << "_" + << input_channels << "_" + << outputs_count; + for (int i = 0; i < outputs_count; ++i) { + uint32_t idx = 0; + kernel_.setArg(idx++, *(input->opencl_image())); + kernel_.setArg(idx++, static_cast(channel_blk * i)); + kernel_.setArg(idx++, *(output_list[i]->opencl_image())); + + TuningOrRun3DKernel(kernel_, ss.str(), gws, lws, future); + } +} + +template +struct SliceFunctor; +template +struct SliceFunctor; + +} // namespace kernels +} // namespace mace diff --git a/mace/kernels/slice.h b/mace/kernels/slice.h new file mode 100644 index 0000000000000000000000000000000000000000..b08ea7ef4fcd1e235375952085e9965c7f897334 --- /dev/null +++ b/mace/kernels/slice.h @@ -0,0 +1,70 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_KERNELS_SLICE_H_ +#define MACE_KERNELS_SLICE_H_ + +#include "mace/core/future.h" +#include "mace/core/runtime/opencl/cl2_header.h" +#include "mace/core/tensor.h" +#include "mace/core/types.h" +#include "mace/public/mace.h" + +namespace mace { +namespace kernels { + +template +struct SliceFunctor { + + void operator()(const Tensor *input, + const std::vector &output_list, + StatsFuture *future) { + const index_t outer_size = input->dim(0) * input->dim(1) * input->dim(2); + const index_t input_channels = input->dim(3); + const size_t outputs_count = output_list.size(); + const index_t output_channels = input_channels / outputs_count; + std::vector output_ptrs(output_list.size(), nullptr); + + std::vector output_shape({input->dim(0), input->dim(1), + input->dim(2), output_channels}); + + for (size_t i= 0; i < outputs_count; ++i) { + output_list[i]->Resize(output_shape); + output_ptrs[i] = output_list[i]->mutable_data(); + } + const T *input_ptr = input->data(); + +#pragma omp parallel for + for (int outer_idx = 0; outer_idx < outer_size; ++outer_idx) { + int input_idx = outer_idx * input_channels; + int output_idx = outer_idx * output_channels; + for (size_t i = 0; i < outputs_count; ++i) { + if (DataTypeCanUseMemcpy(DataTypeToEnum::v())) { + memcpy(output_ptrs[i]+output_idx, input_ptr+input_idx, + output_channels * sizeof(T)); + } else { + for (index_t k = 0; k < output_channels; ++k) { + *(output_ptrs[i] + output_idx + k) = *(input_ptr + input_idx + k); + } + } + input_idx += output_channels; + } + } + } +}; + +template +struct SliceFunctor { + + void operator()(const Tensor *input, + const std::vector &output_list, + StatsFuture *future); + cl::Kernel kernel_; + +}; + +} // namepsace kernels +} // namespace mace + +#endif // MACE_KERNELS_SLICE_H_ diff --git a/mace/ops/concat.cc b/mace/ops/concat.cc index 71be2fc3eb1100c36a23907aeccbec41c2dba899..361fce51cf0ce7ecfb60da65e5b17791a6c4067d 100644 --- a/mace/ops/concat.cc +++ b/mace/ops/concat.cc @@ -12,11 +12,6 @@ void Register_Concat(OperatorRegistry *op_registry) { .TypeConstraint("T") .Build(), ConcatOp); - REGISTER_OPERATOR(op_registry, OpKeyBuilder("Concat") - .Device(DeviceType::CPU) - .TypeConstraint("T") - .Build(), - ConcatOp); REGISTER_OPERATOR(op_registry, OpKeyBuilder("Concat") .Device(DeviceType::OPENCL) diff --git a/mace/ops/slice.cc b/mace/ops/slice.cc new file mode 100644 index 0000000000000000000000000000000000000000..6de3da403fca90031c76597c31126f742bf8ba5f --- /dev/null +++ b/mace/ops/slice.cc @@ -0,0 +1,28 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/ops/slice.h" + +namespace mace { + +void Register_Slice(OperatorRegistry *op_registry) { + REGISTER_OPERATOR(op_registry, OpKeyBuilder("Slice") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + SliceOp); + + REGISTER_OPERATOR(op_registry, OpKeyBuilder("Slice") + .Device(DeviceType::OPENCL) + .TypeConstraint("T") + .Build(), + SliceOp); + REGISTER_OPERATOR(op_registry, OpKeyBuilder("Slice") + .Device(DeviceType::OPENCL) + .TypeConstraint("T") + .Build(), + SliceOp); +} + +} // namespace mace diff --git a/mace/ops/slice.h b/mace/ops/slice.h new file mode 100644 index 0000000000000000000000000000000000000000..41106a6339bc21ff9ce4b176bb02cdb12b1299c8 --- /dev/null +++ b/mace/ops/slice.h @@ -0,0 +1,37 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_OPS_SLICE_H_ +#define MACE_OPS_SLICE_H_ + +#include "mace/core/operator.h" +#include "mace/kernels/slice.h" +namespace mace { + +template +class SliceOp : public Operator { + public: + SliceOp(const OperatorDef &op_def, Workspace *ws) + : Operator(op_def, ws) {} + + bool Run(StatsFuture *future) override { + MACE_CHECK(this->OutputSize() >= 2) << "There must be at least two outputs for slicing"; + const Tensor *input = this->Input(INPUT); + const std::vector output_list = this->Outputs(); + MACE_CHECK((input->dim(3) % this->OutputSize()) == 0) << "Outputs do not split input equally."; + + functor_(input, output_list, future); + return true; + } + + private: + kernels::SliceFunctor functor_; + + private: + OP_INPUT_TAGS(INPUT); +}; + +} // namespace mace + +#endif // MACE_OPS_SLICE_H_ diff --git a/mace/ops/slice_benchmark.cc b/mace/ops/slice_benchmark.cc new file mode 100644 index 0000000000000000000000000000000000000000..5273f592c71084434b265affd4ef5760cf6aef37 --- /dev/null +++ b/mace/ops/slice_benchmark.cc @@ -0,0 +1,79 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/core/operator.h" +#include "mace/core/testing/test_benchmark.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +template +static void BMSliceHelper(int iters, + const std::vector &input_shape, + const index_t num_outputs) { + mace::testing::StopTiming(); + + // Construct graph + OpsTestNet net; + + const index_t input_size = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies()); + std::vector input_data(input_size); + GenerateRandomRealTypeData(input_shape, input_data); + net.AddInputFromArray("Input", input_shape, input_data); + + if (D == DeviceType::OPENCL) { + BufferToImage(net, "Input", "InputImage", + kernels::BufferType::IN_OUT_CHANNEL); + + auto builder = OpDefBuilder("Slice", "SliceTest"); + builder.Input("InputImage"); + for (int i = 0; i < num_outputs; ++i) { + builder = builder.Output(MakeString("OutputImage", i)); + } + builder + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Finalize(net.NewOperatorDef()); + } else { + auto builder = OpDefBuilder("Slice", "SliceTest"); + builder.Input("Input"); + for (int i = 0; i < num_outputs; ++i) { + builder = builder.Output(MakeString("Output", i)); + } + builder.Finalize(net.NewOperatorDef()); + } + + // Warm-up + for (int i = 0; i < 2; ++i) { + net.RunOp(D); + net.Sync(); + } + + mace::testing::StartTiming(); + while (iters--) { + net.RunOp(D); + net.Sync(); + } +} + +#define BM_SLICE_MACRO(N, H, W, C, NO, TYPE, DEVICE) \ + static void BM_SLICE_##N##_##H##_##W##_##C##_##NO##_##TYPE##_##DEVICE(int iters) { \ + const int64_t tot = static_cast(iters) * N * H * W * C; \ + mace::testing::MaccProcessed(tot); \ + mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ + BMSliceHelper(iters, {N, H, W, C}, NO); \ + } \ + BENCHMARK(BM_SLICE_##N##_##H##_##W##_##C##_##NO##_##TYPE##_##DEVICE) + +#define BM_SLICE(N, H, W, C, NO) \ + BM_SLICE_MACRO(N, H, W, C, NO, float, CPU); \ + BM_SLICE_MACRO(N, H, W, C, NO, float, OPENCL); \ + BM_SLICE_MACRO(N, H, W, C, NO, half, OPENCL); + +BM_SLICE(1, 32, 32, 32, 2); +BM_SLICE(1, 32, 32, 128, 2); +BM_SLICE(1, 32, 32, 256, 2); +BM_SLICE(1, 128, 128, 32, 2); +BM_SLICE(1, 128, 128, 128, 2); + + +} // namespace mace diff --git a/mace/ops/slice_test.cc b/mace/ops/slice_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ff65fffa78ff2e8507944e53e100fbc6b31df506 --- /dev/null +++ b/mace/ops/slice_test.cc @@ -0,0 +1,99 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/ops/slice.h" +#include "mace/ops/ops_test_util.h" +#include "gmock/gmock.h" + +using namespace mace; + +class SliceOpTest : public OpsTestBase {}; + +template +void RandomTest(const int num_outputs) { + srand(time(nullptr)); + const index_t output_channels = 4 * (1 + rand() % 10); + const index_t input_channels = num_outputs * output_channels; + const index_t batch = 3 + (rand() % 10); + const index_t height = 13 + (rand() % 10); + const index_t width = 17 + (rand() % 10); + + // Construct graph + OpsTestNet net; + + std::vector input_shape({batch, height, width, input_channels}); + const index_t input_size = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies()); + std::vector input_data(input_size); + GenerateRandomRealTypeData(input_shape, input_data); + net.AddInputFromArray("Input", input_shape, input_data); + + if (D == DeviceType::OPENCL) { + BufferToImage(net, "Input", "InputImage", + kernels::BufferType::IN_OUT_CHANNEL); + + auto builder = OpDefBuilder("Slice", "SliceTest"); + builder.Input("InputImage"); + for (int i = 0; i < num_outputs; ++i) { + builder = builder.Output(MakeString("OutputImage", i)); + } + builder + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Finalize(net.NewOperatorDef()); + } else { + auto builder = OpDefBuilder("Slice", "SliceTest"); + builder.Input("Input"); + for (int i = 0; i < num_outputs; ++i) { + builder = builder.Output(MakeString("Output", i)); + } + builder.Finalize(net.NewOperatorDef()); + + } + + // Run + net.RunOp(D); + + if (D == DeviceType::OPENCL) { + for (int i = 0; i < num_outputs; ++i) { + ImageToBuffer(net, MakeString("OutputImage", i), MakeString("Output", i), + kernels::BufferType::IN_OUT_CHANNEL); + } + } + + // Check + std::vector expected_shape({batch, height, width, output_channels}); + const index_t outer_size = std::accumulate(expected_shape.begin(), expected_shape.end() - 1, + 1, std::multiplies()); + const float *input_ptr = input_data.data(); + const float *output_ptr; + for (int i = 0; i < num_outputs; ++i) { + auto output = net.GetOutput(MakeString("Output", i).c_str()); + EXPECT_THAT(output->shape(), ::testing::ContainerEq(expected_shape)); + Tensor::MappingGuard output_mapper(output); + output_ptr = output->data(); + for (int outer_idx = 0; outer_idx < outer_size; ++outer_idx) { + const int idx = outer_idx * input_channels + i * output_channels; + for (int j = 0; j < output_channels; ++j) { + ASSERT_NEAR(*output_ptr++, input_ptr[idx + j], 1e-2) << "with output " << i << " index " << idx + j; + } + } + } +} + +TEST_F(SliceOpTest, CPU) { + RandomTest(2); + RandomTest(4); + RandomTest(11); +} + +TEST_F(SliceOpTest, OPENCLFloat) { + RandomTest(2); + RandomTest(4); + RandomTest(11); +} + +TEST_F(SliceOpTest, OPENCLHalf) { + RandomTest(2); + RandomTest(4); + RandomTest(11); +} diff --git a/mace/public/mace.h b/mace/public/mace.h index 5d4ad299e7e842d07a801c6d074b948150e6f484..3c8fc778a709d228d369e3321ca730dc8f086f94 100644 --- a/mace/public/mace.h +++ b/mace/public/mace.h @@ -174,9 +174,8 @@ class OperatorDef { const std::string &type() const; void set_type(const std::string &type_); bool has_type() const; - int mem_id() const; - void set_mem_id(const int mem_id); - bool has_mem_id() const; + const std::vector &mem_id() const; + void set_mem_id(const std::vector &value); uint32_t node_id() const; void set_node_id(uint32_t node_id); uint32_t op_id() const; @@ -220,7 +219,7 @@ class OperatorDef { std::vector output_shape_; std::vector output_type_; - int mem_id_; + std::vector mem_id_; // nnlib uint32_t node_id_;