diff --git a/mace/core/operator.cc b/mace/core/operator.cc index 029c99f1667d6ddac71b778e76672710674d0787..c56ad46a4246fdf1eeac31ae23d939963ea0d03e 100644 --- a/mace/core/operator.cc +++ b/mace/core/operator.cc @@ -83,6 +83,7 @@ extern void Register_FusedConv2D(OperatorRegistry *op_registry); extern void Register_GlobalAvgPooling(OperatorRegistry *op_registry); extern void Register_ImageToBuffer(OperatorRegistry *op_registry); extern void Register_MatMul(OperatorRegistry *op_registry); +extern void Register_Pad(OperatorRegistry *op_registry); extern void Register_Pooling(OperatorRegistry *op_registry); extern void Register_Proposal(OperatorRegistry *op_registry); extern void Register_PSROIAlign(OperatorRegistry *op_registry); @@ -119,6 +120,7 @@ OperatorRegistry::OperatorRegistry() { ops::Register_GlobalAvgPooling(this); ops::Register_ImageToBuffer(this); ops::Register_MatMul(this); + ops::Register_Pad(this); ops::Register_Pooling(this); ops::Register_Proposal(this); ops::Register_PSROIAlign(this); diff --git a/mace/kernels/opencl/cl/pad.cl b/mace/kernels/opencl/cl/pad.cl new file mode 100644 index 0000000000000000000000000000000000000000..1ccaa29af28df2bbf329f6716511d736d2624964 --- /dev/null +++ b/mace/kernels/opencl/cl/pad.cl @@ -0,0 +1,47 @@ +#include + +__kernel void pad(KERNEL_ERROR_PARAMS + GLOBAL_WORK_GROUP_SIZE_DIM3 + __read_only image2d_t input, + __write_only image2d_t output, + __private const float constant_value, + __private const int input_height, + __private const int input_width, + __private const int output_height, + __private const int height_padding, + __private const int width_padding) { + const int chan_blk_idx = get_global_id(0); + const int width_idx = get_global_id(1); + const int hb_idx = get_global_id(2); + const int batch_idx = hb_idx / output_height; + const int height_idx = hb_idx % output_height; + const int input_padded_height = input_height + height_padding; + const int input_padded_width = input_width + width_padding; + +#ifndef NON_UNIFORM_WORK_GROUP + if (chan_blk_idx >= global_size_dim0 || width_idx >= global_size_dim1 + || hb_idx >= global_size_dim2) { + return; + } + const int width = global_size_dim1; +#else + const int width = get_global_size(1); +#endif + + + DATA_TYPE4 data = constant_value; + if ((height_padding <= height_idx && height_idx < input_padded_height) && + (width_padding <= width_idx && width_idx < input_padded_width)) { + const int in_hb_idx = mad24(batch_idx, input_height, + height_idx - height_padding); + data = READ_IMAGET(input, + SAMPLER, + (int2)(mad24(chan_blk_idx, input_width, + width_idx - width_padding), + in_hb_idx)); + } + + const int pos = mad24(chan_blk_idx, width, width_idx); + + WRITE_IMAGET(output, (int2)(pos, hb_idx), data); +} diff --git a/mace/kernels/opencl/pad.cc b/mace/kernels/opencl/pad.cc new file mode 100644 index 0000000000000000000000000000000000000000..09d40c8441fa6803c5e1323c2d68ab6664f95d58 --- /dev/null +++ b/mace/kernels/opencl/pad.cc @@ -0,0 +1,113 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/kernels/pad.h" +#include "mace/core/runtime/opencl/opencl_runtime.h" +#include "mace/kernels/opencl/helper.h" +#include "mace/utils/tuner.h" + +namespace mace { +namespace kernels { + +template +void PadFunctor::operator()( + const Tensor *input, + Tensor *output, + StatsFuture *future) { + MACE_CHECK(this->paddings_.size() == (input->dim_size() * 2)); + MACE_CHECK((this->paddings_[0] == 0) && (this->paddings_[1] == 0) + && (this->paddings_[6] == 0) && (this->paddings_[7] == 0)) + << "Mace only support height/width dimension now"; + auto input_shape = input->shape(); + std::vector + output_shape = {input_shape[0] + this->paddings_[0] + this->paddings_[1], + input_shape[1] + this->paddings_[2] + this->paddings_[3], + input_shape[2] + this->paddings_[4] + this->paddings_[5], + input_shape[3] + this->paddings_[6] + this->paddings_[7]}; + + std::vector image_shape; + CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, &image_shape); + output->ResizeImage(output_shape, image_shape); + + const index_t batch = output->dim(0); + const index_t height = output->dim(1); + const index_t width = output->dim(2); + const index_t channels = output->dim(3); + + const index_t channel_blocks = RoundUpDiv4(channels); + + auto runtime = OpenCLRuntime::Global(); + + if (kernel_.get() == nullptr) { + std::set built_options; + std::string kernel_name = MACE_OBFUSCATE_SYMBOL("pad"); + built_options.emplace("-Dpad=" + kernel_name); + auto dt = DataTypeToEnum::value; + built_options.emplace("-DDATA_TYPE=" + DtToCLDt(dt)); + built_options.emplace("-DCMD_DATA_TYPE=" + DtToCLCMDDt(dt)); + if (runtime->IsOutOfRangeCheckEnabled()) { + built_options.emplace("-DOUT_OF_RANGE_CHECK"); + kernel_error_ = std::move(std::unique_ptr( + new Buffer(GetDeviceAllocator(DeviceType::OPENCL), 1))); + kernel_error_->Map(nullptr); + *(kernel_error_->mutable_data()) = 0; + kernel_error_->UnMap(); + } + if (runtime->IsNonUniformWorkgroupsSupported()) { + built_options.emplace("-DNON_UNIFORM_WORK_GROUP"); + } + kernel_ = runtime->BuildKernel("pad", kernel_name, built_options); + + kwg_size_ = + static_cast(runtime->GetKernelMaxWorkGroupSize(kernel_)); + } + + const uint32_t gws[3] = {static_cast(channel_blocks), + static_cast(width), + static_cast(height * batch)}; + + if (!IsVecEqual(input_shape_, input->shape())) { + int idx = 0; + if (runtime->IsOutOfRangeCheckEnabled()) { + kernel_.setArg(idx++, + *(static_cast(kernel_error_->buffer()))); + } + if (!runtime->IsNonUniformWorkgroupsSupported()) { + kernel_.setArg(idx++, gws[0]); + kernel_.setArg(idx++, gws[1]); + kernel_.setArg(idx++, gws[2]); + } + kernel_.setArg(idx++, *(input->opencl_image())); + kernel_.setArg(idx++, *(output->opencl_image())); + kernel_.setArg(idx++, this->constant_value_); + kernel_.setArg(idx++, static_cast(input_shape[1])); + kernel_.setArg(idx++, static_cast(input_shape[2])); + kernel_.setArg(idx++, static_cast(output_shape[1])); + kernel_.setArg(idx++, this->paddings_[2]); + kernel_.setArg(idx++, this->paddings_[4]); + + input_shape_ = input->shape(); + } + + const std::vector lws = {8, kwg_size_ / 64, 8, 1}; + std::string tuning_key = + Concat("pad", output->dim(0), output->dim(1), output->dim(2), + output->dim(3)); + TuningOrRun3DKernel(kernel_, tuning_key, gws, lws, future); + + if (runtime->IsOutOfRangeCheckEnabled()) { + kernel_error_->Map(nullptr); + char *kerror_code = kernel_error_->mutable_data(); + MACE_CHECK(*kerror_code == 0) << "Kernel error code: " << *kerror_code; + kernel_error_->UnMap(); + } +} + +template +struct PadFunctor; +template +struct PadFunctor; + +} // namespace kernels +} // namespace mace diff --git a/mace/kernels/pad.h b/mace/kernels/pad.h new file mode 100644 index 0000000000000000000000000000000000000000..6fbb1c7663388dcb7dcd9845c07274747f5f0165 --- /dev/null +++ b/mace/kernels/pad.h @@ -0,0 +1,90 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// +#ifndef MACE_KERNELS_PAD_H_ +#define MACE_KERNELS_PAD_H_ + +#include +#include +#include + +#include "mace/core/future.h" +#include "mace/core/runtime/opencl/cl2_header.h" +#include "mace/core/tensor.h" + +namespace mace { +namespace kernels { + +struct PadFunctorBase { + PadFunctorBase(const std::vector &paddings, + const float constant_value) + : paddings_(paddings), constant_value_(constant_value) {} + + std::vector paddings_; + float constant_value_; +}; + +template +struct PadFunctor : public PadFunctorBase { + PadFunctor(const std::vector &paddings, + const float constant_value) + : PadFunctorBase(paddings, constant_value) {} + + void operator()(const Tensor *input, + Tensor *output, + StatsFuture *future) { + MACE_CHECK(this->paddings_.size() == (input->dim_size() * 2)); + auto input_shape = input->shape(); + output->Resize({input_shape[0] + this->paddings_[0] + this->paddings_[1], + input_shape[1] + this->paddings_[2] + this->paddings_[3], + input_shape[2] + this->paddings_[4] + this->paddings_[5], + input_shape[3] + this->paddings_[6] + this->paddings_[7]}); + + Tensor::MappingGuard input_guard(input); + Tensor::MappingGuard output_guard(output); + auto input_ptr = input->data(); + T *output_ptr = output->mutable_data(); + std::fill(output_ptr, output_ptr + output->size(), this->constant_value_); + + const index_t batch = input->dim(0); + const index_t height = input->dim(1); + const index_t width = input->dim(2); + const index_t channel = input->dim(3); +#pragma omp parallel for collapse(3) + for (index_t b = 0; b < batch; ++b) { + for (index_t h = 0; h < height; ++h) { + for (index_t w = 0; w < width; ++w) { + const index_t in_offset = (((b * height + h) * width) + w) * channel; + const index_t out_offset = (((b + this->paddings_[0]) * output->dim(1) + + (h + this->paddings_[2])) * output->dim(2) + + (w + this->paddings_[4])) * output->dim(3) + + this->paddings_[6]; + memcpy(output_ptr + out_offset, + input_ptr + in_offset, + channel * sizeof(T)); + } + } + } + } +}; + +template +struct PadFunctor : PadFunctorBase { + PadFunctor(const std::vector &paddings, + const float constant_value) + : PadFunctorBase(paddings, constant_value) {} + + void operator()(const Tensor *input, + Tensor *output, + StatsFuture *future); + + cl::Kernel kernel_; + uint32_t kwg_size_; + std::unique_ptr kernel_error_; + std::vector input_shape_; +}; + +} // namespace kernels +} // namespace mace + +#endif // MACE_KERNELS_PAD_H_ diff --git a/mace/ops/pad.cc b/mace/ops/pad.cc new file mode 100644 index 0000000000000000000000000000000000000000..67f6608a9b57afaf000c7231a2a2f73e7b6297c6 --- /dev/null +++ b/mace/ops/pad.cc @@ -0,0 +1,30 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/ops/pad.h" + +namespace mace { +namespace ops { + +void Register_Pad(OperatorRegistry *op_registry) { + REGISTER_OPERATOR(op_registry, OpKeyBuilder("Pad") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + PadOp); + + REGISTER_OPERATOR(op_registry, OpKeyBuilder("Pad") + .Device(DeviceType::OPENCL) + .TypeConstraint("T") + .Build(), + PadOp); + REGISTER_OPERATOR(op_registry, OpKeyBuilder("Pad") + .Device(DeviceType::OPENCL) + .TypeConstraint("T") + .Build(), + PadOp); +} + +} // namespace ops +} // namespace mace diff --git a/mace/ops/pad.h b/mace/ops/pad.h new file mode 100644 index 0000000000000000000000000000000000000000..1ee3319043578382d49b9d8f7642cf2d2afca9b5 --- /dev/null +++ b/mace/ops/pad.h @@ -0,0 +1,39 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_OPS_PAD_H_ +#define MACE_OPS_PAD_H_ + +#include + +#include "mace/core/operator.h" +#include "mace/kernels/pad.h" + +namespace mace { +namespace ops { + +template +class PadOp : public Operator { + public: + PadOp(const OperatorDef &operator_def, Workspace *ws) + : Operator(operator_def, ws), + functor_(OperatorBase::GetRepeatedArgument("paddings"), + OperatorBase::GetSingleArgument("constant_value", 0.0)) + {} + + bool Run(StatsFuture *future) override { + const Tensor *input_tensor = this->Input(0); + Tensor *output_tensor = this->Output(0); + functor_(input_tensor, output_tensor, future); + return true; + } + + private: + kernels::PadFunctor functor_; +}; + +} // namespace ops +} // namespace mace + +#endif // MACE_OPS_PAD_H_ diff --git a/mace/ops/pad_benchmark.cc b/mace/ops/pad_benchmark.cc new file mode 100644 index 0000000000000000000000000000000000000000..947c7aa8d83e7dc55271f6985f8c1a38ddc2e050 --- /dev/null +++ b/mace/ops/pad_benchmark.cc @@ -0,0 +1,78 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/core/operator.h" +#include "mace/core/runtime/opencl/opencl_runtime.h" +#include "mace/core/testing/test_benchmark.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +template +static void Pad(int iters, int batch, int height, + int width, int channels, int pad) { + mace::testing::StopTiming(); + + OpsTestNet net; + + // Add input data + net.AddRandomInput("Input", {batch, height, width, channels}); + + const std::vector paddings = {0, 0, pad, pad, pad, pad, 0, 0}; + if (D == DeviceType::OPENCL) { + BufferToImage(&net, "Input", "InputImage", + kernels::BufferType::IN_OUT_CHANNEL); + OpDefBuilder("Pad", "PadTest") + .Input("InputImage") + .Output("OutputImage") + .AddIntsArg("paddings", paddings) + .AddFloatArg("constant_value", 1.0) + .Finalize(net.NewOperatorDef()); + } else { + OpDefBuilder("Pad", "PadTest") + .Input("Input") + .Output("Output") + .AddIntsArg("paddings", paddings) + .AddFloatArg("constant_value", 1.0) + .Finalize(net.NewOperatorDef()); + } + + // Warm-up + for (int i = 0; i < 5; ++i) { + net.RunOp(D); + } + net.Sync(); + + mace::testing::StartTiming(); + while (iters--) { + net.Run(); + } + net.Sync(); +} + +#define BM_PAD_MACRO(N, H, W, C, PAD, TYPE, DEVICE) \ + static void BM_PAD_##N##_##H##_##W##_##C##_##PAD##_##TYPE##_##DEVICE( \ + int iters) { \ + const int64_t tot = static_cast(iters) * N * C * H * W; \ + mace::testing::MaccProcessed(tot); \ + mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ + Pad(iters, N, H, W, C, PAD); \ + } \ + BENCHMARK(BM_PAD_##N##_##H##_##W##_##C##_##PAD##_##TYPE##_##DEVICE) + +#define BM_PAD(N, H, W, C, PAD) \ + BM_PAD_MACRO(N, H, W, C, PAD, float, CPU); \ + BM_PAD_MACRO(N, H, W, C, PAD, float, OPENCL); \ + BM_PAD_MACRO(N, H, W, C, PAD, half, OPENCL); + +BM_PAD(1, 512, 512, 1, 2); +BM_PAD(1, 112, 112, 64, 1); +BM_PAD(1, 256, 256, 32, 2); +BM_PAD(1, 512, 512, 16, 2); + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/pad_test.cc b/mace/ops/pad_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4aac54bbc340fe409920097ef22a77c18d327d5b --- /dev/null +++ b/mace/ops/pad_test.cc @@ -0,0 +1,158 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/core/operator.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +class PadTest : public OpsTestBase {}; + +template +void Simple() { + // Construct graph + OpsTestNet net; + + // Add input data + net.AddRepeatedInput("Input", {1, 2, 3, 1}, 2); + if (D == DeviceType::OPENCL) { + BufferToImage(&net, "Input", "InputImage", + kernels::BufferType::IN_OUT_CHANNEL); + OpDefBuilder("Pad", "PadTest") + .Input("InputImage") + .Output("OutputImage") + .AddIntsArg("paddings", {0, 0, 1, 2, 1, 2, 0, 0}) + .AddFloatArg("constant_value", 1.0) + .Finalize(net.NewOperatorDef()); + + // Run + net.RunOp(D); + + ImageToBuffer(&net, "OutputImage", "Output", + kernels::BufferType::IN_OUT_CHANNEL); + } else { + OpDefBuilder("Pad", "PadTest") + .Input("Input") + .Output("Output") + .AddIntsArg("paddings", {0, 0, 1, 2, 1, 2, 0, 0}) + .AddFloatArg("constant_value", 1.0) + .Finalize(net.NewOperatorDef()); + + // Run + net.RunOp(); + } + + auto output = net.GetTensor("Output"); + + auto expected = CreateTensor({1, 5, 6, 1}, + { + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 2, 2, 2, 1.0, 1.0, + 1.0, 2, 2, 2, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + }); + ExpectTensorNear(*expected, *output, 1e-5); +} + +TEST_F(PadTest, SimpleCPU) { + Simple(); +} + +TEST_F(PadTest, SimpleGPU) { + Simple(); +} + +TEST_F(PadTest, ComplexCPU) { + // Construct graph + OpsTestNet net; + + // Add input data + net.AddRepeatedInput("Input", {1, 1, 1, 2}, 2); + OpDefBuilder("Pad", "PadTest") + .Input("Input") + .Output("Output") + .AddIntsArg("paddings", {0, 0, 1, 1, 1, 1, 1, 1}) + .AddFloatArg("constant_value", 1.0) + .Finalize(net.NewOperatorDef()); + + // Run + net.RunOp(); + + auto output = net.GetTensor("Output"); + + auto expected = CreateTensor( + {1, 3, 3, 4}, + { + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + }); + ExpectTensorNear(*expected, *output, 1e-5); +} + +template +void Complex(const std::vector &input_shape, + const std::vector &paddings) { + // Construct graph + OpsTestNet net; + + // Add input data + net.AddRandomInput("Input", input_shape); + + OpDefBuilder("Pad", "PadTest") + .Input("Input") + .Output("Output") + .AddIntsArg("paddings", paddings) + .AddFloatArg("constant_value", 1.0) + .Finalize(net.NewOperatorDef()); + + // Run + net.RunOp(); + + Tensor expected; + expected.Copy(*net.GetOutput("Output")); + + BufferToImage(&net, "Input", "InputImage", + kernels::BufferType::IN_OUT_CHANNEL); + OpDefBuilder("Pad", "PadTest") + .Input("InputImage") + .Output("OutputImage") + .AddIntsArg("paddings", paddings) + .AddFloatArg("constant_value", 1.0) + .Finalize(net.NewOperatorDef()); + + // Run + net.RunOp(DeviceType::OPENCL); + + ImageToBuffer(&net, "OutputImage", "OpenCLOutput", + kernels::BufferType::IN_OUT_CHANNEL); + + auto output = net.GetTensor("OpenCLOutput"); + + if (DataTypeToEnum::value == DT_HALF) { + ExpectTensorNear(expected, *output, 1e-1); + } else { + ExpectTensorNear(expected, *output, 1e-5); + } +} + +TEST_F(PadTest, ComplexFloat) { + Complex({1, 32, 32, 4}, {0, 0, 2, 2, 1, 1, 0, 0}); + Complex({1, 31, 37, 16}, {0, 0, 2, 0, 1, 0, 0, 0}); + Complex({1, 128, 128, 32}, {0, 0, 0, 1, 0, 2, 0, 0}); +} + +TEST_F(PadTest, ComplexHalf) { + Complex({1, 32, 32, 4}, {0, 0, 2, 2, 1, 1, 0, 0}); + Complex({1, 31, 37, 16}, {0, 0, 2, 0, 1, 0, 0, 0}); + Complex({1, 128, 128, 32}, {0, 0, 0, 1, 0, 2, 0, 0}); +} + +} // namespace test +} // namespace ops +} // namespace mace + diff --git a/mace/python/tools/tf_converter_lib.py b/mace/python/tools/tf_converter_lib.py index a721e0fd2b849bea84a9c0e07a557078c5a6e217..56b3f04d07fbdfaae284236f601110e41b6c9dba 100644 --- a/mace/python/tools/tf_converter_lib.py +++ b/mace/python/tools/tf_converter_lib.py @@ -992,6 +992,29 @@ class TFConverter(object): self.add_output_shape([shape], op_def) self.resolved_ops[reshape_op.name] = 1 + def convert_pad(self, op): + op_def = self.net_def.op.add() + arg = op_def.arg.add() + arg.name = 'T' + arg.i = self.dt + op_def.name = op.name + op_def.type = "Pad" + op_def.input.extend([op.inputs[0].name]) + op_def.output.extend([output.name for output in op.outputs]) + paddings_arg = op_def.arg.add() + paddings_arg.name = 'paddings' + paddings_arg.ints.extend( + get_input_tensor(op, 1).eval().astype(np.int32).flat) + self.unused_tensor.add(get_input_tensor(op, 1).name) + if len(op.inputs) == 3: + constant_value_arg = op_def.arg.add() + constant_value_arg.name = 'constant_value' + constant_value_arg.i = \ + get_input_tensor(op, 2).eval().astype(np.int32).flat[0] + self.unused_tensor.add(get_input_tensor(op, 2).name) + self.add_output_shape(op.outputs, op_def) + self.resolved_ops[op.name] = 1 + def convert_normal_op(self, op): op_def = self.net_def.op.add() arg = op_def.arg.add() @@ -1084,6 +1107,8 @@ class TFConverter(object): else: raise Exception('Unknown Op: %s, type: %s' % (op.name, op.type)) + elif op.type == 'Pad': + self.convert_pad(op) # elif op.type in ['']: # self.convert_normal_op(op) else: