提交 d0014e47 编写于 作者: 刘托 提交者: 赵奇可

Merge branch 'resizebicubic' into 'master'

Resizebicubic

See merge request !766
#include <common.h>
inline float coeff_even(float i) {
float x = i / TABLE_SIZE;
return (1.25f * x - 2.25f) * x * x + 1.0f;
}
inline float coeff_odd(float i) {
float x = i / TABLE_SIZE + 1.0f;
return ((-0.75f * x + 3.75f) * x - 6.0f) * x + 3.0f;
}
__kernel void resize_bicubic_nocache(KERNEL_ERROR_PARAMS
GLOBAL_WORK_GROUP_SIZE_DIM3
__read_only image2d_t input,
__write_only image2d_t output,
__private const float height_scale,
__private const float width_scale,
__private const int in_height,
__private const int in_width,
__private const int out_height) {
const int ch_blk = get_global_id(0);
const int w = get_global_id(1);
const int hb = get_global_id(2);
#ifndef NON_UNIFORM_WORK_GROUP
if (ch_blk >= global_size_dim0 || w >= global_size_dim1
|| hb >= global_size_dim2) {
return;
}
const int ch_blks = global_size_dim0;
const int out_width = global_size_dim1;
#else
const int ch_blks = get_global_size(0);
const int out_width = get_global_size(1);
#endif
const int b = hb / out_height;
const int h = hb % out_height;
const float h_in = h * height_scale;
const float w_in = w * width_scale;
const int in_w_offset = mul24(ch_blk, in_width);
const int in_h_offset = mul24(b, in_height);
const int h_in_loc = (int)h_in;
const float h_delta = h_in - h_in_loc;
const int h_offset = h_delta * TABLE_SIZE + 0.5f;
const int w_in_loc = (int)w_in;
const float w_delta = w_in - w_in_loc;
const int w_offset = w_delta * TABLE_SIZE + 0.5f;
const float h_offset_l = h_offset;
const float h_offset_r = TABLE_SIZE - h_offset_l;
float4 y_weights = {coeff_odd(h_offset_l), coeff_even(h_offset_l),
coeff_even(h_offset_r), coeff_odd(h_offset_r)};
int4 y_indices = {h_in_loc - 1, h_in_loc, h_in_loc + 1, h_in_loc + 2};
y_indices = min(max(y_indices, 0), in_height - 1);
const float w_offset_l = w_offset;
const float w_offset_r = TABLE_SIZE - w_offset_l;
float4 x_weights = {coeff_odd(w_offset_l), coeff_even(w_offset_l),
coeff_even(w_offset_r), coeff_odd(w_offset_r)};
int4 x_indices = {w_in_loc - 1, w_in_loc, w_in_loc + 1, w_in_loc + 2};
x_indices = min(max(x_indices, 0), in_width - 1);
float4 coeffs0 = 0, coeffs1 = 0, coeffs2 = 0, coeffs3 = 0;
for (int i = 0; i < 4; ++i) {
int y_index = y_indices.s0;
if ( i == 1 ) { y_index = y_indices.s1; }
if ( i == 2 ) { y_index = y_indices.s2; }
if ( i == 3 ) { y_index = y_indices.s3; }
const int in_h_index = in_h_offset + y_index;
DATA_TYPE4 data0 = READ_IMAGET(input, SAMPLER,
(int2)(in_w_offset + x_indices.s0, in_h_index));
DATA_TYPE4 data1 = READ_IMAGET(input, SAMPLER,
(int2)(in_w_offset + x_indices.s1, in_h_index));
DATA_TYPE4 data2 = READ_IMAGET(input, SAMPLER,
(int2)(in_w_offset + x_indices.s2, in_h_index));
DATA_TYPE4 data3 = READ_IMAGET(input, SAMPLER,
(int2)(in_w_offset + x_indices.s3, in_h_index));
float4 res = 0;
res = mad(data0, x_weights.s0, res);
res = mad(data1, x_weights.s1, res);
res = mad(data2, x_weights.s2, res);
res = mad(data3, x_weights.s3, res);
if ( i == 0 ) { coeffs0 = res; }
if ( i == 1 ) { coeffs1 = res; }
if ( i == 2 ) { coeffs2 = res; }
if ( i == 3 ) { coeffs3 = res; }
}
DATA_TYPE4 outdata = 0;
outdata = mad(coeffs0, y_weights.s0, outdata);
outdata = mad(coeffs1, y_weights.s1, outdata);
outdata = mad(coeffs2, y_weights.s2, outdata);
outdata = mad(coeffs3, y_weights.s3, outdata);
const int out_w_offset = mul24(ch_blk, out_width);
const int out_h_offset = mul24(b, out_height);
WRITE_IMAGET(output, (int2)(out_w_offset + w, out_h_offset + h), outdata);
}
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/kernels/resize_bicubic.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/core/tensor.h"
#include "mace/kernels/opencl/helper.h"
#include "mace/utils/tuner.h"
#include "mace/utils/utils.h"
namespace mace {
namespace kernels {
namespace {
std::vector<uint32_t> LocalWS(const uint32_t *gws, const uint32_t kwg_size) {
std::vector<uint32_t> lws(4, 0);
uint64_t cache_size = OpenCLRuntime::Global()->device_global_mem_cache_size();
uint32_t base = std::max<uint32_t>(cache_size / kBaseGPUMemCacheSize, 1);
lws[1] = std::min<uint32_t>(gws[1], kwg_size);
if (lws[1] >= base) {
lws[0] = std::min<uint32_t>(gws[0], base);
} else {
lws[0] = gws[0] / 8;
if (lws[0] == 0) {
lws[0] = gws[0];
}
}
lws[0] = std::min<uint32_t>(lws[0], kwg_size / lws[1]);
const uint32_t lws_size = lws[0] * lws[1];
lws[2] = gws[2] / 8;
if (lws[2] == 0) {
lws[2] = gws[2];
}
lws[2] = std::max<uint32_t>(std::min<uint32_t>(lws[2], kwg_size / lws_size),
1);
return lws;
}
} // namespace
template <typename T>
MaceStatus ResizeBicubicFunctor<DeviceType::GPU, T>::operator()(
const Tensor *input, Tensor *output, StatsFuture *future) {
const index_t batch = input->dim(0);
const index_t in_height = input->dim(1);
const index_t in_width = input->dim(2);
const index_t channels = input->dim(3);
const index_t channel_blocks = RoundUpDiv4(channels);
const index_t out_height = out_height_;
const index_t out_width = out_width_;
const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks),
static_cast<uint32_t>(out_width),
static_cast<uint32_t>(out_height * batch)};
auto runtime = OpenCLRuntime::Global();
if (kernel_.get() == nullptr) {
std::set<std::string> built_options;
OUT_OF_RANGE_CONFIG(kernel_error_);
NON_UNIFORM_WG_CONFIG;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("resize_bicubic_nocache");
built_options.emplace("-Dresize_bicubic_nocache=" + kernel_name);
auto dt = DataTypeToEnum<T>::value;
built_options.emplace("-DDATA_TYPE=" + DtToUpCompatibleCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpCompatibleCLCMDDt(dt));
built_options.emplace(MakeString("-DTABLE_SIZE=", kTableSize));
MACE_RETURN_IF_ERROR(
runtime->BuildKernel("resize_bicubic",
kernel_name,
built_options,
&kernel_));
kwg_size_ =
static_cast<uint32_t>(runtime->GetKernelMaxWorkGroupSize(kernel_));
}
if (!IsVecEqual(input_shape_, input->shape())) {
MACE_CHECK(out_height > 0 && out_width > 0);
std::vector<index_t> output_shape{batch, out_height, out_width, channels};
std::vector<size_t> output_image_shape;
CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL,
&output_image_shape);
MACE_RETURN_IF_ERROR(output->ResizeImage(output_shape, output_image_shape));
float height_scale =
CalculateResizeScale(in_height, out_height, align_corners_);
float width_scale =
CalculateResizeScale(in_width, out_width, align_corners_);
uint32_t idx = 0;
OUT_OF_RANGE_SET_ARG;
SET_3D_GWS_ARGS(kernel_);
kernel_.setArg(idx++, *(input->opencl_image()));
kernel_.setArg(idx++, *(output->opencl_image()));
kernel_.setArg(idx++, height_scale);
kernel_.setArg(idx++, width_scale);
kernel_.setArg(idx++, static_cast<int32_t>(in_height));
kernel_.setArg(idx++, static_cast<int32_t>(in_width));
kernel_.setArg(idx++, static_cast<int32_t>(out_height));
input_shape_ = input->shape();
}
const std::vector<uint32_t> lws = LocalWS(gws, kwg_size_);
std::string tuning_key =
Concat("resize_bicubic_opencl_kernel", output->dim(0), output->dim(1),
output->dim(2), output->dim(3));
MACE_RETURN_IF_ERROR(TuningOrRun3DKernel(kernel_, tuning_key,
gws, lws, future));
OUT_OF_RANGE_VALIDATION(kernel_error_);
return MACE_SUCCESS;
}
template struct ResizeBicubicFunctor<DeviceType::GPU, float>;
template struct ResizeBicubicFunctor<DeviceType::GPU, half>;
} // namespace kernels
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_KERNELS_RESIZE_BICUBIC_H_
#define MACE_KERNELS_RESIZE_BICUBIC_H_
#include <algorithm>
#include <memory>
#include <vector>
#include "mace/core/future.h"
#include "mace/core/tensor.h"
#include "mace/utils/logging.h"
#ifdef MACE_ENABLE_OPENCL
#include "mace/core/runtime/opencl/cl2_header.h"
#endif // MACE_ENABLE_OPENCL
namespace mace {
namespace kernels {
static const int64_t kTableSize = (1 << 10);
inline const float *InitCoeffsTable() {
// Allocate and initialize coefficients table using Bicubic
// convolution algorithm.
// https://en.wikipedia.org/wiki/Bicubic_interpolation
float *coeffs_tab = new float[(kTableSize + 1) * 2];
static const double A = -0.75;
for (int i = 0; i <= kTableSize; ++i) {
float x = i * 1.0 / kTableSize;
coeffs_tab[i * 2] = ((A + 2) * x - (A + 3)) * x * x + 1;
x += 1.0;
coeffs_tab[i * 2 + 1] = ((A * x - 5 * A) * x + 8 * A) * x - 4 * A;
}
return coeffs_tab;
}
inline const float *GetCoeffsTable() {
// Static so that we initialize it on first use
static const float *coeffs_tab = InitCoeffsTable();
return coeffs_tab;
}
inline int64_t Bound(int64_t val, int64_t limit) {
return std::min<int64_t>(limit - 1ll, std::max<int64_t>(0ll, val));
}
inline void GetWeightsAndIndices(float scale, int64_t out_loc, int64_t limit,
std::vector<float> *weights,
std::vector<int64_t> *indices) {
const int64_t in_loc = scale * out_loc;
const float delta = scale * out_loc - in_loc;
const int64_t offset = lrintf(delta * kTableSize);
const float *coeffs_tab = GetCoeffsTable();
*weights = {coeffs_tab[offset * 2 + 1],
coeffs_tab[offset * 2],
coeffs_tab[(kTableSize - offset) * 2],
coeffs_tab[(kTableSize - offset) * 2 + 1]};
*indices = {Bound(in_loc - 1, limit), Bound(in_loc, limit),
Bound(in_loc + 1, limit), Bound(in_loc + 2, limit)};
}
inline float Interpolate1D(const std::vector<float> &weights,
const std::vector<float> &values) {
return values[0] * weights[0] + values[1] * weights[1] +
values[2] * weights[2] + values[3] * weights[3];
}
inline float CalculateResizeScale(index_t in_size,
index_t out_size,
bool align_corners) {
return (align_corners && out_size > 1)
? (in_size - 1) / static_cast<float>(out_size - 1)
: in_size / static_cast<float>(out_size);
}
inline void ResizeImage(const float *images,
const index_t batch_size,
const index_t in_height,
const index_t in_width,
const index_t out_height,
const index_t out_width,
const index_t channels,
const float height_scale,
const float width_scale,
float *output) {
#pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch_size; ++b) {
for (index_t y = 0; y < out_height; ++y) {
std::vector<float> y_weights;
std::vector<index_t> y_indices;
GetWeightsAndIndices(height_scale, y, in_height, &y_weights,
&y_indices);
for (index_t x = 0; x < out_width; ++x) {
std::vector<float> x_weights;
std::vector<index_t> x_indices;
GetWeightsAndIndices(width_scale, x, in_width, &x_weights,
&x_indices);
for (index_t c = 0; c < channels; ++c) {
// Use a 4x4 patch to compute the interpolated output value at
// (b, y, x, c).
const float *channel_input_ptr =
images + (b * channels + c) * in_height * in_width;
float *channel_output_ptr =
output + (b * channels + c) * out_height * out_width;
std::vector<float> coeff(4, 0.0);
for (index_t i = 0; i < 4; ++i) {
const std::vector<float> values = {
static_cast<float>(channel_input_ptr
[y_indices[i] * in_width + x_indices[0]]),
static_cast<float>(channel_input_ptr
[y_indices[i] * in_width + x_indices[1]]),
static_cast<float>(channel_input_ptr
[y_indices[i] * in_width + x_indices[2]]),
static_cast<float>(channel_input_ptr
[y_indices[i] * in_width + x_indices[3]])};
coeff[i] = Interpolate1D(x_weights, values);
}
channel_output_ptr[y * out_width + x] =
Interpolate1D(y_weights, coeff);
}
}
}
}
}
struct ResizeBicubicFunctorBase {
ResizeBicubicFunctorBase(const std::vector<index_t> &size,
bool align_corners)
: align_corners_(align_corners) {
MACE_CHECK(size.size() == 2);
out_height_ = size[0];
out_width_ = size[1];
}
protected:
bool align_corners_;
index_t out_height_;
index_t out_width_;
};
template<DeviceType D, typename T>
struct ResizeBicubicFunctor;
template<>
struct ResizeBicubicFunctor<DeviceType::CPU, float>
: ResizeBicubicFunctorBase {
ResizeBicubicFunctor(const std::vector<index_t> &size, bool align_corners)
: ResizeBicubicFunctorBase(size, align_corners) {}
MaceStatus operator()(const Tensor *input,
Tensor *output,
StatsFuture *future) {
MACE_UNUSED(future);
const index_t batch = input->dim(0);
const index_t channels = input->dim(1);
const index_t in_height = input->dim(2);
const index_t in_width = input->dim(3);
index_t out_height = out_height_;
index_t out_width = out_width_;
MACE_CHECK(out_height > 0 && out_width > 0);
std::vector<index_t> out_shape{batch, channels, out_height, out_width};
MACE_RETURN_IF_ERROR(output->Resize(out_shape));
Tensor::MappingGuard input_mapper(input);
Tensor::MappingGuard output_mapper(output);
const float *input_data = input->data<float>();
float *output_data = output->mutable_data<float>();
if (out_height == in_height && out_width == in_width) {
std::copy(input_data,
input_data + batch * channels * in_height * in_width,
output_data);
return MACE_SUCCESS;
}
float height_scale =
CalculateResizeScale(in_height, out_height, align_corners_);
float width_scale =
CalculateResizeScale(in_width, out_width, align_corners_);
ResizeImage(input_data, batch, in_height, in_width, out_height, out_width,
channels, height_scale, width_scale, output_data);
return MACE_SUCCESS;
}
};
#ifdef MACE_ENABLE_OPENCL
template<typename T>
struct ResizeBicubicFunctor<DeviceType::GPU, T>
: ResizeBicubicFunctorBase {
ResizeBicubicFunctor(const std::vector<index_t> &size, bool align_corners)
: ResizeBicubicFunctorBase(size, align_corners) {}
MaceStatus operator()(const Tensor *input,
Tensor *output,
StatsFuture *future);
cl::Kernel kernel_;
uint32_t kwg_size_;
std::unique_ptr<BufferBase> kernel_error_;
std::vector<index_t> input_shape_;
};
#endif // MACE_ENABLE_OPENCL
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_RESIZE_BICUBIC_H_
......@@ -48,6 +48,7 @@ extern void Register_Proposal(OperatorRegistryBase *op_registry);
extern void Register_Quantize(OperatorRegistryBase *op_registry);
extern void Register_ReduceMean(OperatorRegistryBase *op_registry);
extern void Register_Reshape(OperatorRegistryBase *op_registry);
extern void Register_ResizeBicubic(OperatorRegistryBase *op_registry);
extern void Register_ResizeBilinear(OperatorRegistryBase *op_registry);
extern void Register_ScalarMath(OperatorRegistryBase *op_registry);
extern void Register_Shape(OperatorRegistryBase *op_registry);
......@@ -101,6 +102,7 @@ OperatorRegistry::OperatorRegistry() : OperatorRegistryBase() {
ops::Register_Quantize(this);
ops::Register_ReduceMean(this);
ops::Register_Reshape(this);
ops::Register_ResizeBicubic(this);
ops::Register_ResizeBilinear(this);
ops::Register_ScalarMath(this);
ops::Register_Shape(this);
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/resize_bicubic.h"
namespace mace {
namespace ops {
void Register_ResizeBicubic(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("ResizeBicubic")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
.Build(),
ResizeBicubicOp<DeviceType::CPU, float>);
#ifdef MACE_ENABLE_OPENCL
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("ResizeBicubic")
.Device(DeviceType::GPU)
.TypeConstraint<float>("T")
.Build(),
ResizeBicubicOp<DeviceType::GPU, float>);
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("ResizeBicubic")
.Device(DeviceType::GPU)
.TypeConstraint<half>("T")
.Build(),
ResizeBicubicOp<DeviceType::GPU, half>);
#endif // MACE_ENABLE_OPENCL
}
} // namespace ops
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_RESIZE_BICUBIC_H_
#define MACE_OPS_RESIZE_BICUBIC_H_
#include "mace/core/operator.h"
#include "mace/kernels/resize_bicubic.h"
namespace mace {
namespace ops {
template <DeviceType D, class T>
class ResizeBicubicOp : public Operator<D, T> {
public:
ResizeBicubicOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws),
functor_(OperatorBase::GetRepeatedArgs<index_t>("size", {-1, -1}),
OperatorBase::GetOptionalArg<bool>("align_corners", false)) {}
MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(0);
Tensor *output = this->Output(0);
MACE_CHECK(input->dim_size() == 4, "input must be 4-dimensional.",
input->dim_size());
return functor_(input, output, future);
}
private:
kernels::ResizeBicubicFunctor<D, T> functor_;
};
} // namespace ops
} // namespace mace
#endif // MACE_OPS_RESIZE_BICUBIC_H_
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "mace/core/operator.h"
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
namespace {
template <DeviceType D, typename T>
void ResizeBicubicBenchmark(int iters,
int batch,
int channels,
int input_height,
int input_width,
int output_height,
int output_width) {
mace::testing::StopTiming();
OpsTestNet net;
// Add input data
if (D == DeviceType::CPU) {
net.AddRandomInput<D, float>("Input",
{batch, channels, input_height, input_width});
} else if (D == DeviceType::GPU) {
net.AddRandomInput<D, float>("Input",
{batch, input_height, input_width, channels});
} else {
MACE_NOT_IMPLEMENTED;
}
net.AddInputFromArray<D, int>("OutSize", {2},
{output_height, output_width});
if (D == DeviceType::CPU) {
OpDefBuilder("ResizeBicubic", "ResizeBicubicBenchmark")
.Input("Input")
.Input("OutSize")
.Output("Output")
.AddIntsArg("size", {output_height, output_width})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
} else if (D == DeviceType::GPU) {
BufferToImage<D, T>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder("ResizeBicubic", "ResizeBicubicBenchmark")
.Input("InputImage")
.Input("OutSize")
.Output("OutputImage")
.AddIntsArg("size", {output_height, output_width})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
} else {
MACE_NOT_IMPLEMENTED;
}
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
}
net.Sync();
}
} // namespace
#define MACE_BM_RESIZE_BICUBIC_MACRO(N, C, H0, W0, H1, W1, TYPE, DEVICE) \
static void \
MACE_BM_RESIZE_BICUBIC_##N##_##C##_##H0##_##W0##_##H1##_##W1##_##TYPE##_\
##DEVICE( \
int iters) { \
const int64_t macc = static_cast<int64_t>(iters) * N * C * H1 * W1 * 3; \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H0 * W0; \
mace::testing::MaccProcessed(macc); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
ResizeBicubicBenchmark<DEVICE, TYPE>(iters, N, C, H0, W0, H1, W1); \
} \
MACE_BENCHMARK( \
MACE_BM_RESIZE_BICUBIC_##N##_##C##_##H0##_##W0##_##H1##_##W1##_##TYPE##_\
##DEVICE)
#define MACE_BM_RESIZE_BICUBIC(N, C, H0, W0, H1, W1) \
MACE_BM_RESIZE_BICUBIC_MACRO(N, C, H0, W0, H1, W1, float, CPU); \
MACE_BM_RESIZE_BICUBIC_MACRO(N, C, H0, W0, H1, W1, float, GPU); \
MACE_BM_RESIZE_BICUBIC_MACRO(N, C, H0, W0, H1, W1, half, GPU);
MACE_BM_RESIZE_BICUBIC(1, 128, 120, 120, 480, 480);
MACE_BM_RESIZE_BICUBIC(1, 256, 7, 7, 15, 15);
MACE_BM_RESIZE_BICUBIC(1, 256, 15, 15, 30, 30);
MACE_BM_RESIZE_BICUBIC(1, 128, 30, 30, 60, 60);
MACE_BM_RESIZE_BICUBIC(1, 128, 240, 240, 480, 480);
MACE_BM_RESIZE_BICUBIC(1, 3, 4032, 3016, 480, 480);
MACE_BM_RESIZE_BICUBIC(1, 3, 480, 480, 4032, 3016);
} // namespace test
} // namespace ops
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h"
#include "mace/ops/resize_bicubic.h"
namespace mace {
namespace ops {
namespace test {
class ResizeBicubicTest : public OpsTestBase {};
TEST_F(ResizeBicubicTest, CPUResizeBicubicWOAlignCorners) {
testing::internal::LogToStderr();
// Construct graph
OpsTestNet net;
// Add input data
std::vector<float> input(24);
std::iota(begin(input), end(input), 0);
net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 2, 4, 3}, input);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW",
NCHW);
OpDefBuilder("ResizeBicubic", "ResizeBicubicTest")
.Input("InputNCHW")
.Output("OutputNCHW")
.AddIntsArg("size", {1, 2})
.Finalize(net.NewOperatorDef());
// Run
net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output",
NHWC);
// Check
auto expected = CreateTensor<float>({1, 1, 2, 3}, {0, 1, 2, 6, 7, 8});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5);
}
TEST_F(ResizeBicubicTest, CPUResizeBicubicWOAlignCornersFloat) {
testing::internal::LogToStderr();
// Construct graph
OpsTestNet net;
// Add input data
std::vector<float> input(48);
std::iota(begin(input), end(input), 0);
net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 4, 4, 3}, input);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW",
NCHW);
OpDefBuilder("ResizeBicubic", "ResizeBicubicTest")
.Input("InputNCHW")
.Output("OutputNCHW")
.AddIntsArg("size", {2, 3})
.Finalize(net.NewOperatorDef());
// Run
net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output",
NHWC);
// Check
auto expected = CreateTensor<float>({1, 2, 3, 3},
{0., 1., 2., 4.110297, 5.110297, 6.110297,
8.223037, 9.223036, 10.223037, 24., 25., 26.,
28.110298, 29.1103, 30.110298, 32.223038, 33.223038, 34.223038});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5);
}
TEST_F(ResizeBicubicTest, ResizeBicubicWAlignCorners) {
testing::internal::LogToStderr();
// Construct graph
OpsTestNet net;
// Add input data
std::vector<float> input(24);
std::iota(begin(input), end(input), 0);
net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 2, 4, 3}, input);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW",
NCHW);
OpDefBuilder("ResizeBicubic", "ResizeBicubicTest")
.Input("InputNCHW")
.Output("OutputNCHW")
.AddIntArg("align_corners", 1)
.AddIntsArg("size", {1, 2})
.Finalize(net.NewOperatorDef());
// Run
net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output",
NHWC);
// Check
auto expected = CreateTensor<float>({1, 1, 2, 3}, {0, 1, 2, 9, 10, 11});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5);
}
namespace {
template <DeviceType D>
void TestRandomResizeBicubic() {
testing::internal::LogToStderr();
static unsigned int seed = time(NULL);
for (int round = 0; round < 10; ++round) {
int batch = 1 + rand_r(&seed) % 5;
int channels = 1 + rand_r(&seed) % 100;
int height = 1 + rand_r(&seed) % 100;
int width = 1 + rand_r(&seed) % 100;
int in_height = 1 + rand_r(&seed) % 100;
int in_width = 1 + rand_r(&seed) % 100;
int align_corners = rand_r(&seed) % 1;
// Construct graph
OpsTestNet net;
// Add input data
net.AddRandomInput<D, float>("Input",
{batch, in_height, in_width, channels},
true, true);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW",
NCHW);
OpDefBuilder("ResizeBicubic", "ResizeBicubicTest")
.Input("InputNCHW")
.Output("OutputNCHW")
.AddIntArg("align_corners", align_corners)
.AddIntsArg("size", {height, width})
.Finalize(net.NewOperatorDef());
// Run on CPU
net.RunOp(DeviceType::CPU);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW,
"Output", NHWC);
Tensor expected;
expected.Copy(*net.GetOutput("Output"));
if (D == DeviceType::GPU) {
BufferToImage<D, float>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder("ResizeBicubic", "ResizeBicubicTest")
.Input("InputImage")
.Output("OutputImage")
.AddIntArg("align_corners", align_corners)
.AddIntsArg("size", {height, width})
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
ImageToBuffer<D, float>(&net, "OutputImage", "DeviceOutput",
kernels::BufferType::IN_OUT_CHANNEL);
}
// Check
ExpectTensorNear<float>(expected, *net.GetOutput("DeviceOutput"), 1e-2,
1e-2);
}
}
} // namespace
TEST_F(ResizeBicubicTest, OPENCLRandomResizeBicubic) {
TestRandomResizeBicubic<DeviceType::GPU>();
}
} // namespace test
} // namespace ops
} // namespace mace
......@@ -101,6 +101,7 @@ MaceSupportedOps = [
'Quantize',
'ReduceMean',
'Reshape',
'ResizeBicubic',
'ResizeBilinear',
'ScalarMath',
'Slice',
......
......@@ -81,6 +81,7 @@ TFSupportedOps = [
'Shape',
'Transpose',
'Softmax',
'ResizeBicubic',
'ResizeBilinear',
'Placeholder',
'SpaceToBatchND',
......@@ -181,6 +182,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
TFOpType.Squeeze.name: self.convert_squeeze,
TFOpType.Transpose.name: self.convert_transpose,
TFOpType.Softmax.name: self.convert_softmax,
TFOpType.ResizeBicubic.name: self.convert_resize_bicubic,
TFOpType.ResizeBilinear.name: self.convert_resize_bilinear,
TFOpType.Placeholder.name: self.convert_nop,
TFOpType.SpaceToBatchND.name: self.convert_space_batch,
......@@ -537,6 +539,20 @@ class TensorflowConverter(base_converter.ConverterInterface):
op = self.convert_general_op(tf_op)
op.type = MaceOp.Softmax.name
def convert_resize_bicubic(self, tf_op):
op = self.convert_general_op(tf_op)
op.type = MaceOp.ResizeBicubic.name
del op.input[1:]
size_arg = op.arg.add()
size_arg.name = MaceKeyword.mace_resize_size_str
size_value = tf_op.inputs[1].eval().astype(np.int32)
size_arg.ints.extend(size_value)
self._skip_tensor.add(tf_op.inputs[1].name)
align_corners_arg = op.arg.add()
align_corners_arg.name = MaceKeyword.mace_align_corners_str
align_corners_arg.i = tf_op.get_attr(tf_align_corners)
def convert_resize_bilinear(self, tf_op):
op = self.convert_general_op(tf_op)
op.type = MaceOp.ResizeBilinear.name
......
......@@ -41,6 +41,7 @@ def _opencl_encrypt_kernel_impl(repository_ctx):
unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/pad.cl"))
unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/pooling.cl"))
unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/reduce_mean.cl"))
unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/resize_bicubic.cl"))
unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/resize_bilinear.cl"))
unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/split.cl"))
unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/softmax.cl"))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册