diff --git a/mace/kernels/opencl/cl/resize_bicubic.cl b/mace/kernels/opencl/cl/resize_bicubic.cl new file mode 100644 index 0000000000000000000000000000000000000000..767ee3911c3c2fdbc62e8c2c485e6c790e9fdeda --- /dev/null +++ b/mace/kernels/opencl/cl/resize_bicubic.cl @@ -0,0 +1,107 @@ +#include + +inline float coeff_even(float i) { + float x = i / TABLE_SIZE; + return (1.25f * x - 2.25f) * x * x + 1.0f; +} + +inline float coeff_odd(float i) { + float x = i / TABLE_SIZE + 1.0f; + return ((-0.75f * x + 3.75f) * x - 6.0f) * x + 3.0f; +} + +__kernel void resize_bicubic_nocache(KERNEL_ERROR_PARAMS + GLOBAL_WORK_GROUP_SIZE_DIM3 + __read_only image2d_t input, + __write_only image2d_t output, + __private const float height_scale, + __private const float width_scale, + __private const int in_height, + __private const int in_width, + __private const int out_height) { + const int ch_blk = get_global_id(0); + const int w = get_global_id(1); + const int hb = get_global_id(2); + +#ifndef NON_UNIFORM_WORK_GROUP + if (ch_blk >= global_size_dim0 || w >= global_size_dim1 + || hb >= global_size_dim2) { + return; + } + const int ch_blks = global_size_dim0; + const int out_width = global_size_dim1; +#else + const int ch_blks = get_global_size(0); + const int out_width = get_global_size(1); +#endif + + const int b = hb / out_height; + const int h = hb % out_height; + + const float h_in = h * height_scale; + const float w_in = w * width_scale; + + const int in_w_offset = mul24(ch_blk, in_width); + const int in_h_offset = mul24(b, in_height); + + const int h_in_loc = (int)h_in; + const float h_delta = h_in - h_in_loc; + const int h_offset = h_delta * TABLE_SIZE + 0.5f; + + const int w_in_loc = (int)w_in; + const float w_delta = w_in - w_in_loc; + const int w_offset = w_delta * TABLE_SIZE + 0.5f; + + const float h_offset_l = h_offset; + const float h_offset_r = TABLE_SIZE - h_offset_l; + float4 y_weights = {coeff_odd(h_offset_l), coeff_even(h_offset_l), + coeff_even(h_offset_r), coeff_odd(h_offset_r)}; + int4 y_indices = {h_in_loc - 1, h_in_loc, h_in_loc + 1, h_in_loc + 2}; + y_indices = min(max(y_indices, 0), in_height - 1); + + const float w_offset_l = w_offset; + const float w_offset_r = TABLE_SIZE - w_offset_l; + float4 x_weights = {coeff_odd(w_offset_l), coeff_even(w_offset_l), + coeff_even(w_offset_r), coeff_odd(w_offset_r)}; + int4 x_indices = {w_in_loc - 1, w_in_loc, w_in_loc + 1, w_in_loc + 2}; + x_indices = min(max(x_indices, 0), in_width - 1); + + float4 coeffs0 = 0, coeffs1 = 0, coeffs2 = 0, coeffs3 = 0; + for (int i = 0; i < 4; ++i) { + int y_index = y_indices.s0; + if ( i == 1 ) { y_index = y_indices.s1; } + if ( i == 2 ) { y_index = y_indices.s2; } + if ( i == 3 ) { y_index = y_indices.s3; } + const int in_h_index = in_h_offset + y_index; + DATA_TYPE4 data0 = READ_IMAGET(input, SAMPLER, + (int2)(in_w_offset + x_indices.s0, in_h_index)); + DATA_TYPE4 data1 = READ_IMAGET(input, SAMPLER, + (int2)(in_w_offset + x_indices.s1, in_h_index)); + DATA_TYPE4 data2 = READ_IMAGET(input, SAMPLER, + (int2)(in_w_offset + x_indices.s2, in_h_index)); + DATA_TYPE4 data3 = READ_IMAGET(input, SAMPLER, + (int2)(in_w_offset + x_indices.s3, in_h_index)); + + float4 res = 0; + res = mad(data0, x_weights.s0, res); + res = mad(data1, x_weights.s1, res); + res = mad(data2, x_weights.s2, res); + res = mad(data3, x_weights.s3, res); + if ( i == 0 ) { coeffs0 = res; } + if ( i == 1 ) { coeffs1 = res; } + if ( i == 2 ) { coeffs2 = res; } + if ( i == 3 ) { coeffs3 = res; } + } + DATA_TYPE4 outdata = 0; + outdata = mad(coeffs0, y_weights.s0, outdata); + outdata = mad(coeffs1, y_weights.s1, outdata); + outdata = mad(coeffs2, y_weights.s2, outdata); + outdata = mad(coeffs3, y_weights.s3, outdata); + const int out_w_offset = mul24(ch_blk, out_width); + const int out_h_offset = mul24(b, out_height); + + WRITE_IMAGET(output, (int2)(out_w_offset + w, out_h_offset + h), outdata); +} + + + diff --git a/mace/kernels/opencl/resize_bicubic.cc b/mace/kernels/opencl/resize_bicubic.cc new file mode 100644 index 0000000000000000000000000000000000000000..f8a33383e99a2db4978f40a9c17ce9034df30218 --- /dev/null +++ b/mace/kernels/opencl/resize_bicubic.cc @@ -0,0 +1,133 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/kernels/resize_bicubic.h" +#include "mace/core/runtime/opencl/opencl_runtime.h" +#include "mace/core/tensor.h" +#include "mace/kernels/opencl/helper.h" +#include "mace/utils/tuner.h" +#include "mace/utils/utils.h" + +namespace mace { +namespace kernels { + +namespace { +std::vector LocalWS(const uint32_t *gws, const uint32_t kwg_size) { + std::vector lws(4, 0); + uint64_t cache_size = OpenCLRuntime::Global()->device_global_mem_cache_size(); + uint32_t base = std::max(cache_size / kBaseGPUMemCacheSize, 1); + lws[1] = std::min(gws[1], kwg_size); + if (lws[1] >= base) { + lws[0] = std::min(gws[0], base); + } else { + lws[0] = gws[0] / 8; + if (lws[0] == 0) { + lws[0] = gws[0]; + } + } + lws[0] = std::min(lws[0], kwg_size / lws[1]); + const uint32_t lws_size = lws[0] * lws[1]; + lws[2] = gws[2] / 8; + if (lws[2] == 0) { + lws[2] = gws[2]; + } + lws[2] = std::max(std::min(lws[2], kwg_size / lws_size), + 1); + return lws; +} + +} // namespace + +template +MaceStatus ResizeBicubicFunctor::operator()( + const Tensor *input, Tensor *output, StatsFuture *future) { + const index_t batch = input->dim(0); + const index_t in_height = input->dim(1); + const index_t in_width = input->dim(2); + const index_t channels = input->dim(3); + + const index_t channel_blocks = RoundUpDiv4(channels); + const index_t out_height = out_height_; + const index_t out_width = out_width_; + + const uint32_t gws[3] = {static_cast(channel_blocks), + static_cast(out_width), + static_cast(out_height * batch)}; + + auto runtime = OpenCLRuntime::Global(); + + if (kernel_.get() == nullptr) { + std::set built_options; + OUT_OF_RANGE_CONFIG(kernel_error_); + NON_UNIFORM_WG_CONFIG; + std::string kernel_name = MACE_OBFUSCATE_SYMBOL("resize_bicubic_nocache"); + built_options.emplace("-Dresize_bicubic_nocache=" + kernel_name); + auto dt = DataTypeToEnum::value; + built_options.emplace("-DDATA_TYPE=" + DtToUpCompatibleCLDt(dt)); + built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpCompatibleCLCMDDt(dt)); + built_options.emplace(MakeString("-DTABLE_SIZE=", kTableSize)); + MACE_RETURN_IF_ERROR( + runtime->BuildKernel("resize_bicubic", + kernel_name, + built_options, + &kernel_)); + + kwg_size_ = + static_cast(runtime->GetKernelMaxWorkGroupSize(kernel_)); + } + + if (!IsVecEqual(input_shape_, input->shape())) { + MACE_CHECK(out_height > 0 && out_width > 0); + std::vector output_shape{batch, out_height, out_width, channels}; + + std::vector output_image_shape; + CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, + &output_image_shape); + MACE_RETURN_IF_ERROR(output->ResizeImage(output_shape, output_image_shape)); + + float height_scale = + CalculateResizeScale(in_height, out_height, align_corners_); + float width_scale = + CalculateResizeScale(in_width, out_width, align_corners_); + + uint32_t idx = 0; + OUT_OF_RANGE_SET_ARG; + SET_3D_GWS_ARGS(kernel_); + kernel_.setArg(idx++, *(input->opencl_image())); + kernel_.setArg(idx++, *(output->opencl_image())); + kernel_.setArg(idx++, height_scale); + kernel_.setArg(idx++, width_scale); + kernel_.setArg(idx++, static_cast(in_height)); + kernel_.setArg(idx++, static_cast(in_width)); + kernel_.setArg(idx++, static_cast(out_height)); + + input_shape_ = input->shape(); + } + + const std::vector lws = LocalWS(gws, kwg_size_); + std::string tuning_key = + Concat("resize_bicubic_opencl_kernel", output->dim(0), output->dim(1), + output->dim(2), output->dim(3)); + MACE_RETURN_IF_ERROR(TuningOrRun3DKernel(kernel_, tuning_key, + gws, lws, future)); + + OUT_OF_RANGE_VALIDATION(kernel_error_); + return MACE_SUCCESS; +} + +template struct ResizeBicubicFunctor; +template struct ResizeBicubicFunctor; + +} // namespace kernels +} // namespace mace diff --git a/mace/kernels/resize_bicubic.h b/mace/kernels/resize_bicubic.h new file mode 100644 index 0000000000000000000000000000000000000000..b620b51d70822190d74e531d017a7be54c501d74 --- /dev/null +++ b/mace/kernels/resize_bicubic.h @@ -0,0 +1,224 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_KERNELS_RESIZE_BICUBIC_H_ +#define MACE_KERNELS_RESIZE_BICUBIC_H_ + +#include +#include +#include + +#include "mace/core/future.h" +#include "mace/core/tensor.h" +#include "mace/utils/logging.h" + +#ifdef MACE_ENABLE_OPENCL +#include "mace/core/runtime/opencl/cl2_header.h" +#endif // MACE_ENABLE_OPENCL + +namespace mace { +namespace kernels { + +static const int64_t kTableSize = (1 << 10); + +inline const float *InitCoeffsTable() { + // Allocate and initialize coefficients table using Bicubic + // convolution algorithm. + // https://en.wikipedia.org/wiki/Bicubic_interpolation + float *coeffs_tab = new float[(kTableSize + 1) * 2]; + static const double A = -0.75; + for (int i = 0; i <= kTableSize; ++i) { + float x = i * 1.0 / kTableSize; + coeffs_tab[i * 2] = ((A + 2) * x - (A + 3)) * x * x + 1; + x += 1.0; + coeffs_tab[i * 2 + 1] = ((A * x - 5 * A) * x + 8 * A) * x - 4 * A; + } + return coeffs_tab; +} + +inline const float *GetCoeffsTable() { + // Static so that we initialize it on first use + static const float *coeffs_tab = InitCoeffsTable(); + return coeffs_tab; +} + +inline int64_t Bound(int64_t val, int64_t limit) { + return std::min(limit - 1ll, std::max(0ll, val)); +} + +inline void GetWeightsAndIndices(float scale, int64_t out_loc, int64_t limit, + std::vector *weights, + std::vector *indices) { + const int64_t in_loc = scale * out_loc; + const float delta = scale * out_loc - in_loc; + const int64_t offset = lrintf(delta * kTableSize); + const float *coeffs_tab = GetCoeffsTable(); + *weights = {coeffs_tab[offset * 2 + 1], + coeffs_tab[offset * 2], + coeffs_tab[(kTableSize - offset) * 2], + coeffs_tab[(kTableSize - offset) * 2 + 1]}; + *indices = {Bound(in_loc - 1, limit), Bound(in_loc, limit), + Bound(in_loc + 1, limit), Bound(in_loc + 2, limit)}; +} + +inline float Interpolate1D(const std::vector &weights, + const std::vector &values) { + return values[0] * weights[0] + values[1] * weights[1] + + values[2] * weights[2] + values[3] * weights[3]; +} + +inline float CalculateResizeScale(index_t in_size, + index_t out_size, + bool align_corners) { + return (align_corners && out_size > 1) + ? (in_size - 1) / static_cast(out_size - 1) + : in_size / static_cast(out_size); +} + +inline void ResizeImage(const float *images, + const index_t batch_size, + const index_t in_height, + const index_t in_width, + const index_t out_height, + const index_t out_width, + const index_t channels, + const float height_scale, + const float width_scale, + float *output) { +#pragma omp parallel for collapse(2) + for (index_t b = 0; b < batch_size; ++b) { + for (index_t y = 0; y < out_height; ++y) { + std::vector y_weights; + std::vector y_indices; + GetWeightsAndIndices(height_scale, y, in_height, &y_weights, + &y_indices); + for (index_t x = 0; x < out_width; ++x) { + std::vector x_weights; + std::vector x_indices; + GetWeightsAndIndices(width_scale, x, in_width, &x_weights, + &x_indices); + + for (index_t c = 0; c < channels; ++c) { + // Use a 4x4 patch to compute the interpolated output value at + // (b, y, x, c). + const float *channel_input_ptr = + images + (b * channels + c) * in_height * in_width; + float *channel_output_ptr = + output + (b * channels + c) * out_height * out_width; + std::vector coeff(4, 0.0); + for (index_t i = 0; i < 4; ++i) { + const std::vector values = { + static_cast(channel_input_ptr + [y_indices[i] * in_width + x_indices[0]]), + static_cast(channel_input_ptr + [y_indices[i] * in_width + x_indices[1]]), + static_cast(channel_input_ptr + [y_indices[i] * in_width + x_indices[2]]), + static_cast(channel_input_ptr + [y_indices[i] * in_width + x_indices[3]])}; + coeff[i] = Interpolate1D(x_weights, values); + } + channel_output_ptr[y * out_width + x] = + Interpolate1D(y_weights, coeff); + } + } + } + } +} + +struct ResizeBicubicFunctorBase { + ResizeBicubicFunctorBase(const std::vector &size, + bool align_corners) + : align_corners_(align_corners) { + MACE_CHECK(size.size() == 2); + out_height_ = size[0]; + out_width_ = size[1]; + } + + protected: + bool align_corners_; + index_t out_height_; + index_t out_width_; +}; + +template +struct ResizeBicubicFunctor; + +template<> +struct ResizeBicubicFunctor + : ResizeBicubicFunctorBase { + ResizeBicubicFunctor(const std::vector &size, bool align_corners) + : ResizeBicubicFunctorBase(size, align_corners) {} + + MaceStatus operator()(const Tensor *input, + Tensor *output, + StatsFuture *future) { + MACE_UNUSED(future); + const index_t batch = input->dim(0); + const index_t channels = input->dim(1); + const index_t in_height = input->dim(2); + const index_t in_width = input->dim(3); + + index_t out_height = out_height_; + index_t out_width = out_width_; + MACE_CHECK(out_height > 0 && out_width > 0); + std::vector out_shape{batch, channels, out_height, out_width}; + MACE_RETURN_IF_ERROR(output->Resize(out_shape)); + + Tensor::MappingGuard input_mapper(input); + Tensor::MappingGuard output_mapper(output); + const float *input_data = input->data(); + float *output_data = output->mutable_data(); + + if (out_height == in_height && out_width == in_width) { + std::copy(input_data, + input_data + batch * channels * in_height * in_width, + output_data); + return MACE_SUCCESS; + } + + float height_scale = + CalculateResizeScale(in_height, out_height, align_corners_); + float width_scale = + CalculateResizeScale(in_width, out_width, align_corners_); + + ResizeImage(input_data, batch, in_height, in_width, out_height, out_width, + channels, height_scale, width_scale, output_data); + + return MACE_SUCCESS; + } +}; + +#ifdef MACE_ENABLE_OPENCL +template +struct ResizeBicubicFunctor + : ResizeBicubicFunctorBase { + ResizeBicubicFunctor(const std::vector &size, bool align_corners) + : ResizeBicubicFunctorBase(size, align_corners) {} + + MaceStatus operator()(const Tensor *input, + Tensor *output, + StatsFuture *future); + + cl::Kernel kernel_; + uint32_t kwg_size_; + std::unique_ptr kernel_error_; + std::vector input_shape_; +}; +#endif // MACE_ENABLE_OPENCL + +} // namespace kernels +} // namespace mace + +#endif // MACE_KERNELS_RESIZE_BICUBIC_H_ diff --git a/mace/ops/ops_register.cc b/mace/ops/ops_register.cc index a798015380b4d933320c85b2e48b801ea6c793be..a2aa5e478d9ce6a1ff19b8249c7f9ff2391fa956 100644 --- a/mace/ops/ops_register.cc +++ b/mace/ops/ops_register.cc @@ -48,6 +48,7 @@ extern void Register_Proposal(OperatorRegistryBase *op_registry); extern void Register_Quantize(OperatorRegistryBase *op_registry); extern void Register_ReduceMean(OperatorRegistryBase *op_registry); extern void Register_Reshape(OperatorRegistryBase *op_registry); +extern void Register_ResizeBicubic(OperatorRegistryBase *op_registry); extern void Register_ResizeBilinear(OperatorRegistryBase *op_registry); extern void Register_ScalarMath(OperatorRegistryBase *op_registry); extern void Register_Shape(OperatorRegistryBase *op_registry); @@ -101,6 +102,7 @@ OperatorRegistry::OperatorRegistry() : OperatorRegistryBase() { ops::Register_Quantize(this); ops::Register_ReduceMean(this); ops::Register_Reshape(this); + ops::Register_ResizeBicubic(this); ops::Register_ResizeBilinear(this); ops::Register_ScalarMath(this); ops::Register_Shape(this); diff --git a/mace/ops/resize_bicubic.cc b/mace/ops/resize_bicubic.cc new file mode 100644 index 0000000000000000000000000000000000000000..7a50522f54ed1005bb8c12c138a7158f4034e496 --- /dev/null +++ b/mace/ops/resize_bicubic.cc @@ -0,0 +1,43 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/ops/resize_bicubic.h" + +namespace mace { +namespace ops { + +void Register_ResizeBicubic(OperatorRegistryBase *op_registry) { + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("ResizeBicubic") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + ResizeBicubicOp); + +#ifdef MACE_ENABLE_OPENCL + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("ResizeBicubic") + .Device(DeviceType::GPU) + .TypeConstraint("T") + .Build(), + ResizeBicubicOp); + + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("ResizeBicubic") + .Device(DeviceType::GPU) + .TypeConstraint("T") + .Build(), + ResizeBicubicOp); +#endif // MACE_ENABLE_OPENCL +} + +} // namespace ops +} // namespace mace diff --git a/mace/ops/resize_bicubic.h b/mace/ops/resize_bicubic.h new file mode 100644 index 0000000000000000000000000000000000000000..a83f3a310afc02ca3abd474b4481e16470f28953 --- /dev/null +++ b/mace/ops/resize_bicubic.h @@ -0,0 +1,50 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_OPS_RESIZE_BICUBIC_H_ +#define MACE_OPS_RESIZE_BICUBIC_H_ + +#include "mace/core/operator.h" +#include "mace/kernels/resize_bicubic.h" + +namespace mace { +namespace ops { + +template +class ResizeBicubicOp : public Operator { + public: + ResizeBicubicOp(const OperatorDef &operator_def, Workspace *ws) + : Operator(operator_def, ws), + functor_(OperatorBase::GetRepeatedArgs("size", {-1, -1}), + OperatorBase::GetOptionalArg("align_corners", false)) {} + + MaceStatus Run(StatsFuture *future) override { + const Tensor *input = this->Input(0); + Tensor *output = this->Output(0); + + MACE_CHECK(input->dim_size() == 4, "input must be 4-dimensional.", + input->dim_size()); + + return functor_(input, output, future); + } + + private: + kernels::ResizeBicubicFunctor functor_; +}; + +} // namespace ops +} // namespace mace + +#endif // MACE_OPS_RESIZE_BICUBIC_H_ + diff --git a/mace/ops/resize_bicubic_benchmark.cc b/mace/ops/resize_bicubic_benchmark.cc new file mode 100644 index 0000000000000000000000000000000000000000..ba22f4fecdf49267f9f845a0879fe1f38e7faa0f --- /dev/null +++ b/mace/ops/resize_bicubic_benchmark.cc @@ -0,0 +1,115 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "mace/core/operator.h" +#include "mace/core/testing/test_benchmark.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +namespace { +template +void ResizeBicubicBenchmark(int iters, + int batch, + int channels, + int input_height, + int input_width, + int output_height, + int output_width) { + mace::testing::StopTiming(); + + OpsTestNet net; + + // Add input data + if (D == DeviceType::CPU) { + net.AddRandomInput("Input", + {batch, channels, input_height, input_width}); + } else if (D == DeviceType::GPU) { + net.AddRandomInput("Input", + {batch, input_height, input_width, channels}); + } else { + MACE_NOT_IMPLEMENTED; + } + net.AddInputFromArray("OutSize", {2}, + {output_height, output_width}); + + if (D == DeviceType::CPU) { + OpDefBuilder("ResizeBicubic", "ResizeBicubicBenchmark") + .Input("Input") + .Input("OutSize") + .Output("Output") + .AddIntsArg("size", {output_height, output_width}) + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Finalize(net.NewOperatorDef()); + } else if (D == DeviceType::GPU) { + BufferToImage(&net, "Input", "InputImage", + kernels::BufferType::IN_OUT_CHANNEL); + OpDefBuilder("ResizeBicubic", "ResizeBicubicBenchmark") + .Input("InputImage") + .Input("OutSize") + .Output("OutputImage") + .AddIntsArg("size", {output_height, output_width}) + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Finalize(net.NewOperatorDef()); + } else { + MACE_NOT_IMPLEMENTED; + } + + // Warm-up + for (int i = 0; i < 5; ++i) { + net.RunOp(D); + } + + mace::testing::StartTiming(); + while (iters--) { + net.RunOp(D); + } + net.Sync(); +} +} // namespace + +#define MACE_BM_RESIZE_BICUBIC_MACRO(N, C, H0, W0, H1, W1, TYPE, DEVICE) \ + static void \ + MACE_BM_RESIZE_BICUBIC_##N##_##C##_##H0##_##W0##_##H1##_##W1##_##TYPE##_\ + ##DEVICE( \ + int iters) { \ + const int64_t macc = static_cast(iters) * N * C * H1 * W1 * 3; \ + const int64_t tot = static_cast(iters) * N * C * H0 * W0; \ + mace::testing::MaccProcessed(macc); \ + mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ + ResizeBicubicBenchmark(iters, N, C, H0, W0, H1, W1); \ + } \ + MACE_BENCHMARK( \ + MACE_BM_RESIZE_BICUBIC_##N##_##C##_##H0##_##W0##_##H1##_##W1##_##TYPE##_\ + ##DEVICE) + +#define MACE_BM_RESIZE_BICUBIC(N, C, H0, W0, H1, W1) \ + MACE_BM_RESIZE_BICUBIC_MACRO(N, C, H0, W0, H1, W1, float, CPU); \ + MACE_BM_RESIZE_BICUBIC_MACRO(N, C, H0, W0, H1, W1, float, GPU); \ + MACE_BM_RESIZE_BICUBIC_MACRO(N, C, H0, W0, H1, W1, half, GPU); + +MACE_BM_RESIZE_BICUBIC(1, 128, 120, 120, 480, 480); +MACE_BM_RESIZE_BICUBIC(1, 256, 7, 7, 15, 15); +MACE_BM_RESIZE_BICUBIC(1, 256, 15, 15, 30, 30); +MACE_BM_RESIZE_BICUBIC(1, 128, 30, 30, 60, 60); +MACE_BM_RESIZE_BICUBIC(1, 128, 240, 240, 480, 480); +MACE_BM_RESIZE_BICUBIC(1, 3, 4032, 3016, 480, 480); +MACE_BM_RESIZE_BICUBIC(1, 3, 480, 480, 4032, 3016); + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/resize_bicubic_test.cc b/mace/ops/resize_bicubic_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ad4669f7ca1939ba5bf8b56966cc8c12f62f0b18 --- /dev/null +++ b/mace/ops/resize_bicubic_test.cc @@ -0,0 +1,184 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "mace/core/operator.h" +#include "mace/ops/ops_test_util.h" +#include "mace/ops/resize_bicubic.h" + +namespace mace { +namespace ops { +namespace test { + +class ResizeBicubicTest : public OpsTestBase {}; + +TEST_F(ResizeBicubicTest, CPUResizeBicubicWOAlignCorners) { + testing::internal::LogToStderr(); + // Construct graph + OpsTestNet net; + + // Add input data + std::vector input(24); + std::iota(begin(input), end(input), 0); + net.AddInputFromArray("Input", {1, 2, 4, 3}, input); + net.TransformDataFormat("Input", NHWC, "InputNCHW", + NCHW); + + OpDefBuilder("ResizeBicubic", "ResizeBicubicTest") + .Input("InputNCHW") + .Output("OutputNCHW") + .AddIntsArg("size", {1, 2}) + .Finalize(net.NewOperatorDef()); + + // Run + net.RunOp(); + net.TransformDataFormat("OutputNCHW", NCHW, "Output", + NHWC); + + // Check + auto expected = CreateTensor({1, 1, 2, 3}, {0, 1, 2, 6, 7, 8}); + + ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5); +} + +TEST_F(ResizeBicubicTest, CPUResizeBicubicWOAlignCornersFloat) { + testing::internal::LogToStderr(); + // Construct graph + OpsTestNet net; + + // Add input data + std::vector input(48); + std::iota(begin(input), end(input), 0); + net.AddInputFromArray("Input", {1, 4, 4, 3}, input); + net.TransformDataFormat("Input", NHWC, "InputNCHW", + NCHW); + + OpDefBuilder("ResizeBicubic", "ResizeBicubicTest") + .Input("InputNCHW") + .Output("OutputNCHW") + .AddIntsArg("size", {2, 3}) + .Finalize(net.NewOperatorDef()); + + // Run + net.RunOp(); + net.TransformDataFormat("OutputNCHW", NCHW, "Output", + NHWC); + + // Check + auto expected = CreateTensor({1, 2, 3, 3}, + {0., 1., 2., 4.110297, 5.110297, 6.110297, + 8.223037, 9.223036, 10.223037, 24., 25., 26., + 28.110298, 29.1103, 30.110298, 32.223038, 33.223038, 34.223038}); + + ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5); +} + +TEST_F(ResizeBicubicTest, ResizeBicubicWAlignCorners) { + testing::internal::LogToStderr(); + // Construct graph + OpsTestNet net; + + // Add input data + std::vector input(24); + std::iota(begin(input), end(input), 0); + net.AddInputFromArray("Input", {1, 2, 4, 3}, input); + net.TransformDataFormat("Input", NHWC, "InputNCHW", + NCHW); + + OpDefBuilder("ResizeBicubic", "ResizeBicubicTest") + .Input("InputNCHW") + .Output("OutputNCHW") + .AddIntArg("align_corners", 1) + .AddIntsArg("size", {1, 2}) + .Finalize(net.NewOperatorDef()); + + // Run + net.RunOp(); + net.TransformDataFormat("OutputNCHW", NCHW, "Output", + NHWC); + + // Check + auto expected = CreateTensor({1, 1, 2, 3}, {0, 1, 2, 9, 10, 11}); + + ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5); +} + +namespace { +template +void TestRandomResizeBicubic() { + testing::internal::LogToStderr(); + static unsigned int seed = time(NULL); + for (int round = 0; round < 10; ++round) { + int batch = 1 + rand_r(&seed) % 5; + int channels = 1 + rand_r(&seed) % 100; + int height = 1 + rand_r(&seed) % 100; + int width = 1 + rand_r(&seed) % 100; + int in_height = 1 + rand_r(&seed) % 100; + int in_width = 1 + rand_r(&seed) % 100; + int align_corners = rand_r(&seed) % 1; + + // Construct graph + OpsTestNet net; + // Add input data + net.AddRandomInput("Input", + {batch, in_height, in_width, channels}, + true, true); + net.TransformDataFormat("Input", NHWC, "InputNCHW", + NCHW); + + OpDefBuilder("ResizeBicubic", "ResizeBicubicTest") + .Input("InputNCHW") + .Output("OutputNCHW") + .AddIntArg("align_corners", align_corners) + .AddIntsArg("size", {height, width}) + .Finalize(net.NewOperatorDef()); + // Run on CPU + net.RunOp(DeviceType::CPU); + net.TransformDataFormat("OutputNCHW", NCHW, + "Output", NHWC); + + Tensor expected; + expected.Copy(*net.GetOutput("Output")); + + if (D == DeviceType::GPU) { + BufferToImage(&net, "Input", "InputImage", + kernels::BufferType::IN_OUT_CHANNEL); + + OpDefBuilder("ResizeBicubic", "ResizeBicubicTest") + .Input("InputImage") + .Output("OutputImage") + .AddIntArg("align_corners", align_corners) + .AddIntsArg("size", {height, width}) + .Finalize(net.NewOperatorDef()); + // Run + net.RunOp(D); + + ImageToBuffer(&net, "OutputImage", "DeviceOutput", + kernels::BufferType::IN_OUT_CHANNEL); + } + // Check + ExpectTensorNear(expected, *net.GetOutput("DeviceOutput"), 1e-2, + 1e-2); + } +} +} // namespace + +TEST_F(ResizeBicubicTest, OPENCLRandomResizeBicubic) { + TestRandomResizeBicubic(); +} + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/python/tools/converter_tool/base_converter.py b/mace/python/tools/converter_tool/base_converter.py index 7f873dda2b7ac68a5db162fb5e12073ce54638a2..e200d3ec3238b86143b8ced861085d282c3abbd9 100644 --- a/mace/python/tools/converter_tool/base_converter.py +++ b/mace/python/tools/converter_tool/base_converter.py @@ -101,6 +101,7 @@ MaceSupportedOps = [ 'Quantize', 'ReduceMean', 'Reshape', + 'ResizeBicubic', 'ResizeBilinear', 'ScalarMath', 'Slice', diff --git a/mace/python/tools/converter_tool/tensorflow_converter.py b/mace/python/tools/converter_tool/tensorflow_converter.py index 732abdcee5ef44ad5326556a5678ef2b21f9b852..640f983e96d5abae5faefcabe39808ed52be02e5 100644 --- a/mace/python/tools/converter_tool/tensorflow_converter.py +++ b/mace/python/tools/converter_tool/tensorflow_converter.py @@ -81,6 +81,7 @@ TFSupportedOps = [ 'Shape', 'Transpose', 'Softmax', + 'ResizeBicubic', 'ResizeBilinear', 'Placeholder', 'SpaceToBatchND', @@ -181,6 +182,7 @@ class TensorflowConverter(base_converter.ConverterInterface): TFOpType.Squeeze.name: self.convert_squeeze, TFOpType.Transpose.name: self.convert_transpose, TFOpType.Softmax.name: self.convert_softmax, + TFOpType.ResizeBicubic.name: self.convert_resize_bicubic, TFOpType.ResizeBilinear.name: self.convert_resize_bilinear, TFOpType.Placeholder.name: self.convert_nop, TFOpType.SpaceToBatchND.name: self.convert_space_batch, @@ -537,6 +539,20 @@ class TensorflowConverter(base_converter.ConverterInterface): op = self.convert_general_op(tf_op) op.type = MaceOp.Softmax.name + def convert_resize_bicubic(self, tf_op): + op = self.convert_general_op(tf_op) + op.type = MaceOp.ResizeBicubic.name + del op.input[1:] + + size_arg = op.arg.add() + size_arg.name = MaceKeyword.mace_resize_size_str + size_value = tf_op.inputs[1].eval().astype(np.int32) + size_arg.ints.extend(size_value) + self._skip_tensor.add(tf_op.inputs[1].name) + align_corners_arg = op.arg.add() + align_corners_arg.name = MaceKeyword.mace_align_corners_str + align_corners_arg.i = tf_op.get_attr(tf_align_corners) + def convert_resize_bilinear(self, tf_op): op = self.convert_general_op(tf_op) op.type = MaceOp.ResizeBilinear.name diff --git a/repository/opencl-kernel/opencl_kernel_configure.bzl b/repository/opencl-kernel/opencl_kernel_configure.bzl index 0d1b9cf0ca9e7e72d383e9cf593f95e1a60c66ae..e8e75634ba1219cc748427cb8c4d6f7ae34946d0 100644 --- a/repository/opencl-kernel/opencl_kernel_configure.bzl +++ b/repository/opencl-kernel/opencl_kernel_configure.bzl @@ -41,6 +41,7 @@ def _opencl_encrypt_kernel_impl(repository_ctx): unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/pad.cl")) unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/pooling.cl")) unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/reduce_mean.cl")) + unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/resize_bicubic.cl")) unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/resize_bilinear.cl")) unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/split.cl")) unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/softmax.cl"))