diff --git a/mace/kernels/opencl/cl/resize_bicubic.cl b/mace/kernels/opencl/cl/resize_bicubic.cl deleted file mode 100644 index 42213094f97054baaf5586bbdcacac10cd87ef25..0000000000000000000000000000000000000000 --- a/mace/kernels/opencl/cl/resize_bicubic.cl +++ /dev/null @@ -1,130 +0,0 @@ -#include -//#include - -const int kTableSize = (1 << 10); - -inline float ComputeCoeffs(int i) { - const float A = -0.75; - float x = (i / 2) * 1.0 / kTableSize; - if (i % 2 == 0){ - float coeff = ((A + 2) * x - (A + 3)) * x * x + 1; - return coeff; - } - else { - x += 1.0; - float coeff = ((A * x - 5 * A) * x + 8 * A) * x - 4 * A; - return coeff; - } -} - -#define BOUND(val, limit) min(limit - 1, max(0, val)) - -__kernel void resize_bicubic_nocache(KERNEL_ERROR_PARAMS - GLOBAL_WORK_GROUP_SIZE_DIM3 - __read_only image2d_t input, - __write_only image2d_t output, - __private const float height_scale, - __private const float width_scale, - __private const int in_height, - __private const int in_width, - __private const int out_height) { - - const int ch_blk = get_global_id(0); - const int w = get_global_id(1); - const int hb = get_global_id(2); - -#ifndef NON_UNIFORM_WORK_GROUP - if (ch_blk >= global_size_dim0 || w >= global_size_dim1 - || hb >= global_size_dim2) { - return; - } - const int ch_blks = global_size_dim0; - const int out_width = global_size_dim1; -#else - const int ch_blks = get_global_size(0); - const int out_width = get_global_size(1); -#endif - - const int b = hb / out_height; - const int h = hb % out_height; - - const float h_in = h * height_scale; - const float w_in = w * width_scale; - - const int in_w_offset = mul24(ch_blk, in_width); - const int in_h_offset = mul24(b, in_height); - - const int h_in_loc = height_scale * h; - const float h_delta = height_scale * h - h_in_loc; - const int h_offset = h_delta * kTableSize + 0.5; - - const int w_in_loc = width_scale * w; - const float w_delta = width_scale * w - w_in_loc; - const int w_offset = w_delta * kTableSize + 0.5; - - float4 y_weights = {ComputeCoeffs(h_offset * 2 + 1), - ComputeCoeffs(h_offset * 2), - ComputeCoeffs((kTableSize - h_offset) * 2), - ComputeCoeffs((kTableSize - h_offset) * 2 + 1)}; - int4 y_indices = {BOUND(h_in_loc - 1, in_height), - BOUND(h_in_loc, in_height), - BOUND(h_in_loc + 1, in_height), - BOUND(h_in_loc + 2, in_height)}; - float4 x_weights = {ComputeCoeffs(w_offset * 2 + 1), - ComputeCoeffs(w_offset * 2), - ComputeCoeffs((kTableSize - w_offset) * 2), - ComputeCoeffs((kTableSize - w_offset) * 2 + 1)}; - int4 x_indices = {BOUND(w_in_loc - 1, in_width), - BOUND(w_in_loc, in_width), - BOUND(w_in_loc + 1, in_width), - BOUND(w_in_loc + 2, in_width)}; - - float4 coeffs0 = {0, 0, 0, 0}; - float4 coeffs1 = {0, 0, 0, 0}; - float4 coeffs2 = {0, 0, 0, 0}; - float4 coeffs3 = {0, 0, 0, 0}; - for (int i = 0; i < 4; ++i) { - int y_index = y_indices.s0; - if ( i == 1 ) { y_index = y_indices.s1; } - if ( i == 2 ) { y_index = y_indices.s2; } - if ( i == 3 ) { y_index = y_indices.s3; } - DATA_TYPE4 data0 = READ_IMAGET(input, SAMPLER, - (int2)(in_w_offset + x_indices.s0, in_h_offset + y_index)); - DATA_TYPE4 data1 = READ_IMAGET(input, SAMPLER, - (int2)(in_w_offset + x_indices.s1, in_h_offset + y_index)); - DATA_TYPE4 data2 = READ_IMAGET(input, SAMPLER, - (int2)(in_w_offset + x_indices.s2, in_h_offset + y_index)); - DATA_TYPE4 data3 = READ_IMAGET(input, SAMPLER, - (int2)(in_w_offset + x_indices.s3, in_h_offset + y_index)); - - float4 xw0 = { x_weights.s0, x_weights.s0, x_weights.s0, x_weights.s0 }; - float4 xw1 = { x_weights.s1, x_weights.s1, x_weights.s1, x_weights.s1 }; - float4 xw2 = { x_weights.s2, x_weights.s2, x_weights.s2, x_weights.s2 }; - float4 xw3 = { x_weights.s3, x_weights.s3, x_weights.s3, x_weights.s3 }; - float4 res = { 0, 0, 0, 0 }; - res = mad(xw0, data0, res); - res = mad(xw1, data1, res); - res = mad(xw2, data2, res); - res = mad(xw3, data3, res); - if ( i == 0 ) { coeffs0 = res; } - if ( i == 1 ) { coeffs1 = res; } - if ( i == 2 ) { coeffs2 = res; } - if ( i == 3 ) { coeffs3 = res; } - } - float4 yw0 = { y_weights.s0, y_weights.s0, y_weights.s0, y_weights.s0 }; - float4 yw1 = { y_weights.s1, y_weights.s1, y_weights.s1, y_weights.s1 }; - float4 yw2 = { y_weights.s2, y_weights.s2, y_weights.s2, y_weights.s2 }; - float4 yw3 = { y_weights.s3, y_weights.s3, y_weights.s3, y_weights.s3 }; - DATA_TYPE4 outdata = { 0, 0, 0, 0 }; - outdata = mad(yw0, coeffs0, outdata); - outdata = mad(yw1, coeffs1, outdata); - outdata = mad(yw2, coeffs2, outdata); - outdata = mad(yw3, coeffs3, outdata); - const int out_w_offset = mul24(ch_blk, out_width); - const int out_h_offset = mul24(b, out_height); - - WRITE_IMAGET(output, (int2)(out_w_offset + w, out_h_offset + h), outdata); -} - - - diff --git a/mace/kernels/opencl/resize_bicubic.cc b/mace/kernels/opencl/resize_bicubic.cc deleted file mode 100644 index 2b043794f26ade30c7e22b07d67b56570514d389..0000000000000000000000000000000000000000 --- a/mace/kernels/opencl/resize_bicubic.cc +++ /dev/null @@ -1,132 +0,0 @@ -// Copyright 2018 Xiaomi, Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "mace/kernels/resize_bicubic.h" -#include "mace/core/runtime/opencl/opencl_runtime.h" -#include "mace/core/tensor.h" -#include "mace/kernels/opencl/helper.h" -#include "mace/utils/tuner.h" -#include "mace/utils/utils.h" - -namespace mace { -namespace kernels { - -namespace { -std::vector LocalWS(const uint32_t *gws, const uint32_t kwg_size) { - std::vector lws(4, 0); - uint64_t cache_size = OpenCLRuntime::Global()->device_global_mem_cache_size(); - uint32_t base = std::max(cache_size / kBaseGPUMemCacheSize, 1); - lws[1] = std::min(gws[1], kwg_size); - if (lws[1] >= base) { - lws[0] = std::min(gws[0], base); - } else { - lws[0] = gws[0] / 8; - if (lws[0] == 0) { - lws[0] = gws[0]; - } - } - lws[0] = std::min(lws[0], kwg_size / lws[1]); - const uint32_t lws_size = lws[0] * lws[1]; - lws[2] = gws[2] / 8; - if (lws[2] == 0) { - lws[2] = gws[2]; - } - lws[2] = std::max(std::min(lws[2], kwg_size / lws_size), - 1); - return lws; -} - -} // namespace - -template -MaceStatus ResizeBicubicFunctor::operator()( - const Tensor *input, Tensor *output, StatsFuture *future) { - const index_t batch = input->dim(0); - const index_t in_height = input->dim(1); - const index_t in_width = input->dim(2); - const index_t channels = input->dim(3); - - const index_t channel_blocks = RoundUpDiv4(channels); - const index_t out_height = out_height_; - const index_t out_width = out_width_; - - const uint32_t gws[3] = {static_cast(channel_blocks), - static_cast(out_width), - static_cast(out_height * batch)}; - - auto runtime = OpenCLRuntime::Global(); - - if (kernel_.get() == nullptr) { - std::set built_options; - OUT_OF_RANGE_CONFIG(kernel_error_); - NON_UNIFORM_WG_CONFIG; - std::string kernel_name = MACE_OBFUSCATE_SYMBOL("resize_bicubic_nocache"); - built_options.emplace("-Dresize_bicubic_nocache=" + kernel_name); - auto dt = DataTypeToEnum::value; - built_options.emplace("-DDATA_TYPE=" + DtToUpCompatibleCLDt(dt)); - built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpCompatibleCLCMDDt(dt)); - MACE_RETURN_IF_ERROR( - runtime->BuildKernel("resize_bicubic", - kernel_name, - built_options, - &kernel_)); - - kwg_size_ = - static_cast(runtime->GetKernelMaxWorkGroupSize(kernel_)); - } - - if (!IsVecEqual(input_shape_, input->shape())) { - MACE_CHECK(out_height > 0 && out_width > 0); - std::vector output_shape{batch, out_height, out_width, channels}; - - std::vector output_image_shape; - CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, - &output_image_shape); - MACE_RETURN_IF_ERROR(output->ResizeImage(output_shape, output_image_shape)); - - float height_scale = - CalculateResizeScale(in_height, out_height, align_corners_); - float width_scale = - CalculateResizeScale(in_width, out_width, align_corners_); - - uint32_t idx = 0; - OUT_OF_RANGE_SET_ARG; - SET_3D_GWS_ARGS(kernel_); - kernel_.setArg(idx++, *(input->opencl_image())); - kernel_.setArg(idx++, *(output->opencl_image())); - kernel_.setArg(idx++, height_scale); - kernel_.setArg(idx++, width_scale); - kernel_.setArg(idx++, static_cast(in_height)); - kernel_.setArg(idx++, static_cast(in_width)); - kernel_.setArg(idx++, static_cast(out_height)); - - input_shape_ = input->shape(); - } - - const std::vector lws = LocalWS(gws, kwg_size_); - std::string tuning_key = - Concat("resize_bicubic_opencl_kernel", output->dim(0), output->dim(1), - output->dim(2), output->dim(3)); - MACE_RETURN_IF_ERROR(TuningOrRun3DKernel(kernel_, tuning_key, - gws, lws, future)); - - OUT_OF_RANGE_VALIDATION(kernel_error_); - return MACE_SUCCESS; -} - -template struct ResizeBicubicFunctor; -template struct ResizeBicubicFunctor; - -} // namespace kernels -} // namespace mace diff --git a/mace/kernels/resize_bicubic.h b/mace/kernels/resize_bicubic.h deleted file mode 100644 index 4bef1c90d433661153eb1735ef4670a41dc3455d..0000000000000000000000000000000000000000 --- a/mace/kernels/resize_bicubic.h +++ /dev/null @@ -1,224 +0,0 @@ -// Copyright 2018 Xiaomi, Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef MACE_KERNELS_RESIZE_BICUBIC_H_ -#define MACE_KERNELS_RESIZE_BICUBIC_H_ - -#include -#include -#include - -#include "mace/core/future.h" -#include "mace/core/tensor.h" -#include "mace/utils/logging.h" - -#ifdef MACE_ENABLE_OPENCL -#include "mace/core/runtime/opencl/cl2_header.h" -#endif // MACE_ENABLE_OPENCL - -namespace mace { -namespace kernels { - -static const int64_t kTableSize = (1 << 10); - -inline const float* InitCoeffsTable() { - // Allocate and initialize coefficients table using Bicubic - // convolution algorithm. - // https://en.wikipedia.org/wiki/Bicubic_interpolation - float* coeffs_tab = new float[(kTableSize + 1) * 2]; - static const double A = -0.75; - for (int i = 0; i <= kTableSize; ++i) { - float x = i * 1.0 / kTableSize; - coeffs_tab[i * 2] = ((A + 2) * x - (A + 3)) * x * x + 1; - x += 1.0; - coeffs_tab[i * 2 + 1] = ((A * x - 5 * A) * x + 8 * A) * x - 4 * A; - } - return coeffs_tab; -} - -inline const float* GetCoeffsTable() { - // Static so that we initialize it on first use - static const float* coeffs_tab = InitCoeffsTable(); - return coeffs_tab; -} - -inline int64_t Bound(int64_t val, int64_t limit) { - return std::min(limit - 1ll, std::max(0ll, val)); -} - -inline void GetWeightsAndIndices(float scale, int64_t out_loc, int64_t limit, - std::array* weights, - std::array* indices) { - const int64_t in_loc = scale * out_loc; - const float delta = scale * out_loc - in_loc; - const int64_t offset = lrintf(delta * kTableSize); - const float* coeffs_tab = GetCoeffsTable(); - *weights = {{coeffs_tab[offset * 2 + 1], coeffs_tab[offset * 2], - coeffs_tab[(kTableSize - offset) * 2], - coeffs_tab[(kTableSize - offset) * 2 + 1]}}; - *indices = {{Bound(in_loc - 1, limit), Bound(in_loc, limit), - Bound(in_loc + 1, limit), Bound(in_loc + 2, limit)}}; -} - -inline float Interpolate1D(const std::array& weights, - const std::array& values) { - return values[0] * weights[0] + values[1] * weights[1] + - values[2] * weights[2] + values[3] * weights[3]; -} - -inline float CalculateResizeScale(index_t in_size, - index_t out_size, - bool align_corners) { - return (align_corners && out_size > 1) - ? (in_size - 1) / static_cast(out_size - 1) - : in_size / static_cast(out_size); -} - -inline void ResizeImage(const float *images, - const index_t batch_size, - const index_t in_height, - const index_t in_width, - const index_t out_height, - const index_t out_width, - const index_t channels, - const float height_scale, - const float width_scale, - float *output) { - std::array coeff = {{0.0, 0.0, 0.0, 0.0}}; -#pragma omp parallel for collapse(2) - for (index_t b = 0; b < batch_size; ++b) { - for (index_t y = 0; y < out_height; ++y) { - std::array y_weights; - std::array y_indices; - GetWeightsAndIndices(height_scale, y, in_height, &y_weights, - &y_indices); - std::stringstream ss; - for (index_t x = 0; x < out_width; ++x) { - std::array x_weights; - std::array x_indices; - GetWeightsAndIndices(width_scale, x, in_width, &x_weights, - &x_indices); - - for (index_t c = 0; c < channels; ++c) { - // Use a 4x4 patch to compute the interpolated output value at - // (b, y, x, c). - const float *channel_input_ptr = - images + (b * channels + c) * in_height * in_width; - float *channel_output_ptr = - output + (b * channels + c) * out_height * out_width; - for (index_t i = 0; i < 4; ++i) { - const std::array values = { - {static_cast(channel_input_ptr - [y_indices[i] * in_width + x_indices[0]]), - static_cast(channel_input_ptr - [y_indices[i] * in_width + x_indices[1]]), - static_cast(channel_input_ptr - [y_indices[i] * in_width + x_indices[2]]), - static_cast(channel_input_ptr - [y_indices[i] * in_width + x_indices[3]])}}; - coeff[i] = Interpolate1D(x_weights, values); - } - channel_output_ptr[y * out_width + x] = - Interpolate1D(y_weights, coeff); - } - } - } - } -} - -struct ResizeBicubicFunctorBase { - ResizeBicubicFunctorBase(const std::vector &size, - bool align_corners) - : align_corners_(align_corners) { - MACE_CHECK(size.size() == 2); - out_height_ = size[0]; - out_width_ = size[1]; - } - - protected: - bool align_corners_; - index_t out_height_; - index_t out_width_; -}; - -template -struct ResizeBicubicFunctor; - -template<> -struct ResizeBicubicFunctor - : ResizeBicubicFunctorBase { - ResizeBicubicFunctor(const std::vector &size, bool align_corners) - : ResizeBicubicFunctorBase(size, align_corners) {} - - MaceStatus operator()(const Tensor *input, - Tensor *output, - StatsFuture *future) { - MACE_UNUSED(future); - const index_t batch = input->dim(0); - const index_t channels = input->dim(1); - const index_t in_height = input->dim(2); - const index_t in_width = input->dim(3); - - index_t out_height = out_height_; - index_t out_width = out_width_; - MACE_CHECK(out_height > 0 && out_width > 0); - std::vector out_shape{batch, channels, out_height, out_width}; - MACE_RETURN_IF_ERROR(output->Resize(out_shape)); - - Tensor::MappingGuard input_mapper(input); - Tensor::MappingGuard output_mapper(output); - const float *input_data = input->data(); - float *output_data = output->mutable_data(); - - if (out_height == in_height && out_width == in_width) { - std::copy(input_data, - input_data + batch * channels * in_height * in_width, - output_data); - return MACE_SUCCESS; - } - - float height_scale = - CalculateResizeScale(in_height, out_height, align_corners_); - float width_scale = - CalculateResizeScale(in_width, out_width, align_corners_); - - ResizeImage(input_data, batch, in_height, in_width, out_height, out_width, - channels, height_scale, width_scale, output_data); - - return MACE_SUCCESS; - } -}; - -#ifdef MACE_ENABLE_OPENCL -template -struct ResizeBicubicFunctor - : ResizeBicubicFunctorBase { - ResizeBicubicFunctor(const std::vector &size, bool align_corners) - : ResizeBicubicFunctorBase(size, align_corners) {} - - MaceStatus operator()(const Tensor *input, - Tensor *output, - StatsFuture *future); - - cl::Kernel kernel_; - uint32_t kwg_size_; - std::unique_ptr kernel_error_; - std::vector input_shape_; -}; -#endif // MACE_ENABLE_OPENCL - -} // namespace kernels -} // namespace mace - -#endif // MACE_KERNELS_RESIZE_BICUBIC_H_ diff --git a/mace/ops/ops_register.cc b/mace/ops/ops_register.cc index e881083500771f20b87353044d5aa62434017bd4..c318eb4417165ecfe6a2aa49339af3fa98093964 100644 --- a/mace/ops/ops_register.cc +++ b/mace/ops/ops_register.cc @@ -47,7 +47,6 @@ extern void Register_Proposal(OperatorRegistryBase *op_registry); extern void Register_Quantize(OperatorRegistryBase *op_registry); extern void Register_ReduceMean(OperatorRegistryBase *op_registry); extern void Register_Reshape(OperatorRegistryBase *op_registry); -extern void Register_ResizeBicubic(OperatorRegistryBase *op_registry); extern void Register_ResizeBilinear(OperatorRegistryBase *op_registry); extern void Register_ScalarMath(OperatorRegistryBase *op_registry); extern void Register_Shape(OperatorRegistryBase *op_registry); @@ -100,7 +99,6 @@ OperatorRegistry::OperatorRegistry() : OperatorRegistryBase() { ops::Register_Quantize(this); ops::Register_ReduceMean(this); ops::Register_Reshape(this); - ops::Register_ResizeBicubic(this); ops::Register_ResizeBilinear(this); ops::Register_ScalarMath(this); ops::Register_Shape(this); diff --git a/mace/ops/resize_bicubic.cc b/mace/ops/resize_bicubic.cc deleted file mode 100644 index 7a50522f54ed1005bb8c12c138a7158f4034e496..0000000000000000000000000000000000000000 --- a/mace/ops/resize_bicubic.cc +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright 2018 Xiaomi, Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "mace/ops/resize_bicubic.h" - -namespace mace { -namespace ops { - -void Register_ResizeBicubic(OperatorRegistryBase *op_registry) { - MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("ResizeBicubic") - .Device(DeviceType::CPU) - .TypeConstraint("T") - .Build(), - ResizeBicubicOp); - -#ifdef MACE_ENABLE_OPENCL - MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("ResizeBicubic") - .Device(DeviceType::GPU) - .TypeConstraint("T") - .Build(), - ResizeBicubicOp); - - MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("ResizeBicubic") - .Device(DeviceType::GPU) - .TypeConstraint("T") - .Build(), - ResizeBicubicOp); -#endif // MACE_ENABLE_OPENCL -} - -} // namespace ops -} // namespace mace diff --git a/mace/ops/resize_bicubic.h b/mace/ops/resize_bicubic.h deleted file mode 100644 index a83f3a310afc02ca3abd474b4481e16470f28953..0000000000000000000000000000000000000000 --- a/mace/ops/resize_bicubic.h +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright 2018 Xiaomi, Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef MACE_OPS_RESIZE_BICUBIC_H_ -#define MACE_OPS_RESIZE_BICUBIC_H_ - -#include "mace/core/operator.h" -#include "mace/kernels/resize_bicubic.h" - -namespace mace { -namespace ops { - -template -class ResizeBicubicOp : public Operator { - public: - ResizeBicubicOp(const OperatorDef &operator_def, Workspace *ws) - : Operator(operator_def, ws), - functor_(OperatorBase::GetRepeatedArgs("size", {-1, -1}), - OperatorBase::GetOptionalArg("align_corners", false)) {} - - MaceStatus Run(StatsFuture *future) override { - const Tensor *input = this->Input(0); - Tensor *output = this->Output(0); - - MACE_CHECK(input->dim_size() == 4, "input must be 4-dimensional.", - input->dim_size()); - - return functor_(input, output, future); - } - - private: - kernels::ResizeBicubicFunctor functor_; -}; - -} // namespace ops -} // namespace mace - -#endif // MACE_OPS_RESIZE_BICUBIC_H_ - diff --git a/mace/ops/resize_bicubic_benchmark.cc b/mace/ops/resize_bicubic_benchmark.cc deleted file mode 100644 index ba22f4fecdf49267f9f845a0879fe1f38e7faa0f..0000000000000000000000000000000000000000 --- a/mace/ops/resize_bicubic_benchmark.cc +++ /dev/null @@ -1,115 +0,0 @@ -// Copyright 2018 Xiaomi, Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include "mace/core/operator.h" -#include "mace/core/testing/test_benchmark.h" -#include "mace/ops/ops_test_util.h" - -namespace mace { -namespace ops { -namespace test { - -namespace { -template -void ResizeBicubicBenchmark(int iters, - int batch, - int channels, - int input_height, - int input_width, - int output_height, - int output_width) { - mace::testing::StopTiming(); - - OpsTestNet net; - - // Add input data - if (D == DeviceType::CPU) { - net.AddRandomInput("Input", - {batch, channels, input_height, input_width}); - } else if (D == DeviceType::GPU) { - net.AddRandomInput("Input", - {batch, input_height, input_width, channels}); - } else { - MACE_NOT_IMPLEMENTED; - } - net.AddInputFromArray("OutSize", {2}, - {output_height, output_width}); - - if (D == DeviceType::CPU) { - OpDefBuilder("ResizeBicubic", "ResizeBicubicBenchmark") - .Input("Input") - .Input("OutSize") - .Output("Output") - .AddIntsArg("size", {output_height, output_width}) - .AddIntArg("T", static_cast(DataTypeToEnum::value)) - .Finalize(net.NewOperatorDef()); - } else if (D == DeviceType::GPU) { - BufferToImage(&net, "Input", "InputImage", - kernels::BufferType::IN_OUT_CHANNEL); - OpDefBuilder("ResizeBicubic", "ResizeBicubicBenchmark") - .Input("InputImage") - .Input("OutSize") - .Output("OutputImage") - .AddIntsArg("size", {output_height, output_width}) - .AddIntArg("T", static_cast(DataTypeToEnum::value)) - .Finalize(net.NewOperatorDef()); - } else { - MACE_NOT_IMPLEMENTED; - } - - // Warm-up - for (int i = 0; i < 5; ++i) { - net.RunOp(D); - } - - mace::testing::StartTiming(); - while (iters--) { - net.RunOp(D); - } - net.Sync(); -} -} // namespace - -#define MACE_BM_RESIZE_BICUBIC_MACRO(N, C, H0, W0, H1, W1, TYPE, DEVICE) \ - static void \ - MACE_BM_RESIZE_BICUBIC_##N##_##C##_##H0##_##W0##_##H1##_##W1##_##TYPE##_\ - ##DEVICE( \ - int iters) { \ - const int64_t macc = static_cast(iters) * N * C * H1 * W1 * 3; \ - const int64_t tot = static_cast(iters) * N * C * H0 * W0; \ - mace::testing::MaccProcessed(macc); \ - mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ - ResizeBicubicBenchmark(iters, N, C, H0, W0, H1, W1); \ - } \ - MACE_BENCHMARK( \ - MACE_BM_RESIZE_BICUBIC_##N##_##C##_##H0##_##W0##_##H1##_##W1##_##TYPE##_\ - ##DEVICE) - -#define MACE_BM_RESIZE_BICUBIC(N, C, H0, W0, H1, W1) \ - MACE_BM_RESIZE_BICUBIC_MACRO(N, C, H0, W0, H1, W1, float, CPU); \ - MACE_BM_RESIZE_BICUBIC_MACRO(N, C, H0, W0, H1, W1, float, GPU); \ - MACE_BM_RESIZE_BICUBIC_MACRO(N, C, H0, W0, H1, W1, half, GPU); - -MACE_BM_RESIZE_BICUBIC(1, 128, 120, 120, 480, 480); -MACE_BM_RESIZE_BICUBIC(1, 256, 7, 7, 15, 15); -MACE_BM_RESIZE_BICUBIC(1, 256, 15, 15, 30, 30); -MACE_BM_RESIZE_BICUBIC(1, 128, 30, 30, 60, 60); -MACE_BM_RESIZE_BICUBIC(1, 128, 240, 240, 480, 480); -MACE_BM_RESIZE_BICUBIC(1, 3, 4032, 3016, 480, 480); -MACE_BM_RESIZE_BICUBIC(1, 3, 480, 480, 4032, 3016); - -} // namespace test -} // namespace ops -} // namespace mace diff --git a/mace/ops/resize_bicubic_test.cc b/mace/ops/resize_bicubic_test.cc deleted file mode 100644 index a834d98f387fa983fb056c3e0c61f550ab0595b4..0000000000000000000000000000000000000000 --- a/mace/ops/resize_bicubic_test.cc +++ /dev/null @@ -1,183 +0,0 @@ -// Copyright 2018 Xiaomi, Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include - -#include "mace/core/operator.h" -#include "mace/ops/ops_test_util.h" -#include "mace/ops/resize_bicubic.h" - -namespace mace { -namespace ops { -namespace test { - -class ResizeBicubicTest : public OpsTestBase {}; - -TEST_F(ResizeBicubicTest, CPUResizeBicubicWOAlignCorners) { - testing::internal::LogToStderr(); - // Construct graph - OpsTestNet net; - - // Add input data - std::vector input(24); - std::iota(begin(input), end(input), 0); - net.AddInputFromArray("Input", {1, 2, 4, 3}, input); - net.TransformDataFormat("Input", NHWC, "InputNCHW", - NCHW); - - OpDefBuilder("ResizeBicubic", "ResizeBicubicTest") - .Input("InputNCHW") - .Output("OutputNCHW") - .AddIntsArg("size", {1, 2}) - .Finalize(net.NewOperatorDef()); - - // Run - net.RunOp(); - net.TransformDataFormat("OutputNCHW", NCHW, "Output", - NHWC); - - // Check - auto expected = CreateTensor({1, 1, 2, 3}, {0, 1, 2, 6, 7, 8}); - - ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5); -} - -TEST_F(ResizeBicubicTest, CPUResizeBicubicWOAlignCorners1) { - testing::internal::LogToStderr(); - // Construct graph - OpsTestNet net; - - // Add input data - std::vector input(48); - std::iota(begin(input), end(input), 0); - net.AddInputFromArray("Input", {1, 4, 4, 3}, input); - net.TransformDataFormat("Input", NHWC, "InputNCHW", - NCHW); - - OpDefBuilder("ResizeBicubic", "ResizeBicubicTest") - .Input("InputNCHW") - .Output("OutputNCHW") - .AddIntsArg("size", {2, 3}) - .Finalize(net.NewOperatorDef()); - - // Run - net.RunOp(); - net.TransformDataFormat("OutputNCHW", NCHW, "Output", - NHWC); - - // Check - auto expected = CreateTensor({1, 2, 3, 3}, - {0., 1., 2., 4.110297, 5.110297, 6.110297, - 8.223037, 9.223036, 10.223037, 24., 25., 26., - 28.110298, 29.1103, 30.110298, 32.223038, 33.223038, 34.223038}); - - ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5); -} - -TEST_F(ResizeBicubicTest, ResizeBicubicWAlignCorners) { - testing::internal::LogToStderr(); - // Construct graph - OpsTestNet net; - - // Add input data - std::vector input(24); - std::iota(begin(input), end(input), 0); - net.AddInputFromArray("Input", {1, 2, 4, 3}, input); - net.TransformDataFormat("Input", NHWC, "InputNCHW", - NCHW); - - OpDefBuilder("ResizeBicubic", "ResizeBicubicTest") - .Input("InputNCHW") - .Output("OutputNCHW") - .AddIntArg("align_corners", 1) - .AddIntsArg("size", {1, 2}) - .Finalize(net.NewOperatorDef()); - - // Run - net.RunOp(); - net.TransformDataFormat("OutputNCHW", NCHW, "Output", - NHWC); - - // Check - auto expected = CreateTensor({1, 1, 2, 3}, {0, 1, 2, 9, 10, 11}); - - ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5); -} - -namespace { -template -void TestRandomResizeBicubic() { - testing::internal::LogToStderr(); - static unsigned int seed = time(NULL); - for (int round = 0; round < 10; ++round) { - int batch = 1 + rand_r(&seed) % 5; - int channels = 1 + rand_r(&seed) % 100; - int height = 1 + rand_r(&seed) % 100; - int width = 1 + rand_r(&seed) % 100; - int in_height = 1 + rand_r(&seed) % 100; - int in_width = 1 + rand_r(&seed) % 100; - int align_corners = rand_r(&seed) % 1; - - // Construct graph - OpsTestNet net; - // Add input data - net.AddRandomInput("Input", - {batch, in_height, in_width, channels}); - net.TransformDataFormat("Input", NHWC, "InputNCHW", - NCHW); - - OpDefBuilder("ResizeBicubic", "ResizeBicubicTest") - .Input("InputNCHW") - .Output("OutputNCHW") - .AddIntArg("align_corners", align_corners) - .AddIntsArg("size", {height, width}) - .Finalize(net.NewOperatorDef()); - // Run on CPU - net.RunOp(DeviceType::CPU); - net.TransformDataFormat("OutputNCHW", NCHW, - "Output", NHWC); - - Tensor expected; - expected.Copy(*net.GetOutput("Output")); - - if (D == DeviceType::GPU) { - BufferToImage(&net, "Input", "InputImage", - kernels::BufferType::IN_OUT_CHANNEL); - - OpDefBuilder("ResizeBicubic", "ResizeBicubicTest") - .Input("InputImage") - .Output("OutputImage") - .AddIntArg("align_corners", align_corners) - .AddIntsArg("size", {height, width}) - .Finalize(net.NewOperatorDef()); - // Run - net.RunOp(D); - - ImageToBuffer(&net, "OutputImage", "DeviceOutput", - kernels::BufferType::IN_OUT_CHANNEL); - } - // Check - ExpectTensorNear(expected, *net.GetOutput("DeviceOutput"), 1e-5, - 1e-4); - } -} -} // namespace - -TEST_F(ResizeBicubicTest, OPENCLRandomResizeBicubic) { - TestRandomResizeBicubic(); -} - -} // namespace test -} // namespace ops -} // namespace mace diff --git a/mace/python/tools/converter_tool/base_converter.py b/mace/python/tools/converter_tool/base_converter.py index c25fc0ed795fa2e060c491c5e5d2a8d1521efb17..99fac06f63600bb58d350f0be857e03ef83932f6 100644 --- a/mace/python/tools/converter_tool/base_converter.py +++ b/mace/python/tools/converter_tool/base_converter.py @@ -100,7 +100,6 @@ MaceSupportedOps = [ 'Quantize', 'ReduceMean', 'Reshape', - 'ResizeBicubic', 'ResizeBilinear', 'ScalarMath', 'Slice', diff --git a/mace/python/tools/converter_tool/tensorflow_converter.py b/mace/python/tools/converter_tool/tensorflow_converter.py index 13458769110021d09463894ff44640a7fcd18b1b..da9384fa93affdca98ad05983055e35a6771b0e3 100644 --- a/mace/python/tools/converter_tool/tensorflow_converter.py +++ b/mace/python/tools/converter_tool/tensorflow_converter.py @@ -80,7 +80,6 @@ TFSupportedOps = [ 'Shape', 'Transpose', 'Softmax', - 'ResizeBicubic', 'ResizeBilinear', 'Placeholder', 'SpaceToBatchND', @@ -179,7 +178,6 @@ class TensorflowConverter(base_converter.ConverterInterface): TFOpType.Squeeze.name: self.convert_squeeze, TFOpType.Transpose.name: self.convert_transpose, TFOpType.Softmax.name: self.convert_softmax, - TFOpType.ResizeBicubic.name: self.convert_resize_bicubic, TFOpType.ResizeBilinear.name: self.convert_resize_bilinear, TFOpType.Placeholder.name: self.convert_nop, TFOpType.SpaceToBatchND.name: self.convert_space_batch, @@ -536,20 +534,6 @@ class TensorflowConverter(base_converter.ConverterInterface): op = self.convert_general_op(tf_op) op.type = MaceOp.Softmax.name - def convert_resize_bicubic(self, tf_op): - op = self.convert_general_op(tf_op) - op.type = MaceOp.ResizeBicubic.name - del op.input[1:] - - size_arg = op.arg.add() - size_arg.name = MaceKeyword.mace_resize_size_str - size_value = tf_op.inputs[1].eval().astype(np.int32) - size_arg.ints.extend(size_value) - self._skip_tensor.add(tf_op.inputs[1].name) - align_corners_arg = op.arg.add() - align_corners_arg.name = MaceKeyword.mace_align_corners_str - align_corners_arg.i = tf_op.get_attr(tf_align_corners) - def convert_resize_bilinear(self, tf_op): op = self.convert_general_op(tf_op) op.type = MaceOp.ResizeBilinear.name diff --git a/repository/opencl-kernel/opencl_kernel_configure.bzl b/repository/opencl-kernel/opencl_kernel_configure.bzl index e8e75634ba1219cc748427cb8c4d6f7ae34946d0..0d1b9cf0ca9e7e72d383e9cf593f95e1a60c66ae 100644 --- a/repository/opencl-kernel/opencl_kernel_configure.bzl +++ b/repository/opencl-kernel/opencl_kernel_configure.bzl @@ -41,7 +41,6 @@ def _opencl_encrypt_kernel_impl(repository_ctx): unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/pad.cl")) unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/pooling.cl")) unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/reduce_mean.cl")) - unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/resize_bicubic.cl")) unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/resize_bilinear.cl")) unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/split.cl")) unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/softmax.cl"))