提交 e396c388 编写于 作者: L Liangliang He

Change resize bilinear from buffer to image2d

上级 48a038ca
......@@ -54,10 +54,11 @@ void *OpenCLAllocator::NewImage(const std::vector<size_t> &image_shape,
cl_int error;
cl::Image2D *cl_image =
new cl::Image2D(OpenCLRuntime::Get()->context(),
CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR ,
CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR,
img_format,
image_shape[0], image_shape[1],
0, nullptr, &error);
MACE_CHECK(error == CL_SUCCESS);
return cl_image;
}
......
#include <common.h>
// Supported data type: half/float
__kernel void resize_bilinear_nocache(__global const DATA_TYPE *input, /* n * c, h, w */
__global DATA_TYPE *output /* n * c, h, w */,
__kernel void resize_bilinear_nocache(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */
__write_only image2d_t output,
__private const float height_scale,
__private const float width_scale,
__private const int in_height,
__private const int in_width) {
const int c = get_global_id(0);
const int h = get_global_id(1);
const int w = get_global_id(2);
const int channels = get_global_size(0);
const int height = get_global_size(1);
const int width = get_global_size(2);
__private const int in_width,
__private const int out_height) {
const int ch_blk = get_global_id(0);
const int ch_blks = get_global_size(0);
const int w = get_global_id(1);
const int out_width = get_global_size(1);
const int hb = get_global_id(2);
const int b = hb / out_height;
const int h = hb % out_height;
const float h_in = h * height_scale;
const float w_in = w * width_scale;
......@@ -24,16 +25,26 @@ __kernel void resize_bilinear_nocache(__global const DATA_TYPE *input, /* n * c,
const float h_lerp = h_in - h_lower;
const float w_lerp = w_in - w_lower;
const DATA_TYPE *input_base = input + c * in_height * in_width;
DATA_TYPE *output_base = output + c * height * width;
const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
const int in_w_offset = ch_blk * in_width;
const int in_h_offset = b * in_height;
DATA_TYPE top_left = input_base[h_lower * in_width + w_lower];
DATA_TYPE top_right = input_base[h_lower * in_width + w_upper];
DATA_TYPE bottom_left = input_base[h_upper * in_width + w_lower];
DATA_TYPE bottom_right = input_base[h_upper * in_width + w_upper];
DATA_TYPE4 top_left = READ_IMAGET(input, sampler,
(int2)(in_w_offset + w_lower, in_h_offset + h_lower));
DATA_TYPE4 top_right = READ_IMAGET(input, sampler,
(int2)(in_w_offset + w_upper, in_h_offset + h_lower));
DATA_TYPE4 bottom_left = READ_IMAGET(input, sampler,
(int2)(in_w_offset + w_lower, in_h_offset + h_upper));
DATA_TYPE4 bottom_right = READ_IMAGET(input, sampler,
(int2)(in_w_offset + w_upper, in_h_offset + h_upper));
const DATA_TYPE top = top_left + (top_right - top_left) * w_lerp;
const DATA_TYPE bottom = bottom_left + (bottom_right - bottom_left) * w_lerp;
output_base[h * width + w] = top + (bottom - top) * h_lerp;
DATA_TYPE4 top = top_left + (top_right - top_left) * w_lerp;
DATA_TYPE4 bottom = bottom_left + (bottom_right - bottom_left) * w_lerp;
DATA_TYPE4 out = top + (bottom - top) * h_lerp;
const int out_w_offset = ch_blk * out_width;
const int out_h_offset = b * out_height;
WRITE_IMAGET(output, (int2)(out_w_offset + w, out_h_offset + h), out);
}
......@@ -6,24 +6,33 @@
#include "mace/core/tensor.h"
#include "mace/kernels/resize_bilinear.h"
#include "mace/kernels/opencl/helper.h"
#include "mace/utils/utils.h"
namespace mace {
namespace kernels {
template <>
void ResizeBilinearFunctor<DeviceType::OPENCL, float>::operator()(
template <typename T>
void ResizeBilinearFunctor<DeviceType::OPENCL, T>::operator()(
const Tensor *input, const Tensor *resize_dims, Tensor *output) {
const index_t batch = input->dim(0);
const index_t channels = input->dim(1);
const index_t in_height = input->dim(2);
const index_t in_width = input->dim(3);
const index_t in_height = input->dim(1);
const index_t in_width = input->dim(2);
const index_t channels = input->dim(3);
const index_t channel_blocks = RoundUpDiv4(channels);
index_t out_height;
index_t out_width;
GetOutputSize(resize_dims, &out_height, &out_width);
MACE_CHECK(out_height > 0 && out_width > 0);
std::vector<index_t> out_shape {batch, channels, out_height, out_width};
output->Resize(out_shape);
std::vector<index_t> output_shape {batch, out_height, out_width, channels};
if (input->is_image()) {
std::vector<size_t> output_image_shape;
CalImage2DShape(output_shape, BufferType::IN_OUT, output_image_shape);
output->ResizeImage(output_shape, output_image_shape);
} else {
output->Resize(output_shape);
}
float height_scale =
CalculateResizeScale(in_height, out_height, align_corners_);
......@@ -32,28 +41,35 @@ void ResizeBilinearFunctor<DeviceType::OPENCL, float>::operator()(
auto runtime = OpenCLRuntime::Get();
std::set<std::string> built_options;
built_options.emplace("-DDATA_TYPE=" + DataTypeToCLType(input->dtype()));
built_options.emplace("-DCMD_DATA_TYPE=" + DataTypeToOPENCLCMDDataType(input->dtype()));
auto rb_kernel = runtime->BuildKernel("resize_bilinear", "resize_bilinear_nocache", built_options);
const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(rb_kernel);
uint32_t idx = 0;
rb_kernel.setArg(idx++, *(static_cast<const cl::Buffer *>(input->buffer())));
rb_kernel.setArg(idx++, *(static_cast<cl::Buffer *>(output->buffer())));
rb_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(input->buffer())));
rb_kernel.setArg(idx++, *(static_cast<cl::Image2D *>(output->buffer())));
rb_kernel.setArg(idx++, height_scale);
rb_kernel.setArg(idx++, width_scale);
rb_kernel.setArg(idx++, static_cast<int>(in_height));
rb_kernel.setArg(idx++, static_cast<int>(in_width));
rb_kernel.setArg(idx++, static_cast<int32_t>(in_height));
rb_kernel.setArg(idx++, static_cast<int32_t>(in_width));
rb_kernel.setArg(idx++, static_cast<int32_t>(out_height));
auto command_queue = runtime->command_queue();
cl_int error = command_queue.enqueueNDRangeKernel(
rb_kernel, cl::NullRange,
cl::NDRange(static_cast<int>(batch * channels),
static_cast<int>(out_height), static_cast<int>(out_width)),
// TODO (heliangliang) tuning and fix when kwg_size < devisor
cl::NDRange(1, 16, kwg_size / 16),
NULL, OpenCLRuntime::Get()->GetDefaultEvent());
cl::NDRange(static_cast<int32_t>(channel_blocks),
static_cast<int32_t>(out_width),
static_cast<int32_t>(out_height * batch)),
// TODO tuning
cl::NDRange(1, static_cast<int32_t>(out_width > kwg_size ? kwg_size : out_width), 1),
nullptr, OpenCLRuntime::Get()->GetDefaultEvent());
MACE_CHECK(error == CL_SUCCESS, error);
}
template struct ResizeBilinearFunctor<DeviceType::OPENCL, float>;
template struct ResizeBilinearFunctor<DeviceType::OPENCL, half>;
} // namespace kernels
} // namespace mace
......@@ -61,63 +61,90 @@ void ResizeImage(const T *images,
const index_t channels,
const std::vector<CachedInterpolation> &xs_vec,
const std::vector<CachedInterpolation> &ys,
float *output) {
const index_t in_channel_size = in_height * in_width;
const index_t in_batch_num_values = channels * in_channel_size;
const index_t out_channel_size = out_height * out_width;
const index_t out_batch_num_values = channels * out_channel_size;
T *output) {
const index_t in_batch_num_values = channels * in_height * in_width;
const index_t out_batch_num_values = channels * out_height * out_width;
const CachedInterpolation *xs = xs_vec.data();
#pragma omp parallel for collapse(2)
#pragma omp parallel for
for (index_t b = 0; b < batch_size; ++b) {
for (index_t c = 0; c < channels; ++c) {
const T *input_ptr =
images + in_batch_num_values * b + in_channel_size * c;
float *output_ptr =
output + out_batch_num_values * b + out_channel_size * c;
for (index_t y = 0; y < out_height; ++y) {
const T *ys_input_lower_ptr = input_ptr + ys[y].lower * in_width;
const T *ys_input_upper_ptr = input_ptr + ys[y].upper * in_width;
const float ys_lerp = ys[y].lerp;
for (index_t x = 0; x < out_width; ++x) {
auto xs_lower = xs[x].lower;
auto xs_upper = xs[x].upper;
auto xs_lerp = xs[x].lerp;
const float top_left = ys_input_lower_ptr[xs_lower];
const float top_right = ys_input_lower_ptr[xs_upper];
const float bottom_left = ys_input_upper_ptr[xs_lower];
const float bottom_right = ys_input_upper_ptr[xs_upper];
output_ptr[x] = ComputeLerp(top_left, top_right, bottom_left,
bottom_right, xs_lerp, ys_lerp);
const T *batch_input_ptr = images + in_batch_num_values * b;;
T *batch_output_ptr = output + out_batch_num_values * b;
for (index_t y = 0; y < out_height; ++y) {
const T *y_lower_input_ptr =
batch_input_ptr + ys[y].lower * in_width * channels;
const T *y_upper_input_ptr =
batch_input_ptr + ys[y].upper * in_width * channels;
T *y_output_ptr = batch_output_ptr + y * out_width * channels;
const float ys_lerp = ys[y].lerp;
for (index_t x = 0; x < out_width; ++x) {
const float xs_lerp = xs[x].lerp;
const T *top_left_ptr = y_lower_input_ptr + xs[x].lower * channels;
const T *top_right_ptr = y_lower_input_ptr + xs[x].upper * channels;
const T *bottom_left_ptr = y_upper_input_ptr + xs[x].lower * channels;
const T *bottom_right_ptr = y_upper_input_ptr + xs[x].upper * channels;
T *output_ptr = y_output_ptr + x * channels;
for (index_t c = 0; c < channels; ++c) {
const T top_left = top_left_ptr[c];
const T top_right = top_right_ptr[c];
const T bottom_left = bottom_left_ptr[c];
const T bottom_right = bottom_right_ptr[c];
output_ptr[c] = ComputeLerp(top_left, top_right, bottom_left,
bottom_right, xs_lerp, ys_lerp);
}
output_ptr += out_width;
}
}
}
}
}
struct ResizeBilinearFunctorBase {
ResizeBilinearFunctorBase(const std::vector<index_t> &size,
bool align_corners)
: align_corners_(align_corners), size_(size) {}
protected:
void GetOutputSize(const Tensor *resize_dims,
index_t *out_height,
index_t *out_width) {
if (size_[0] < 0 || size_[1] < 0) {
MACE_CHECK(resize_dims != nullptr && resize_dims->dim_size() == 1);
Tensor::MappingGuard resize_dims_mapper(resize_dims);
auto dims_data = resize_dims->data<int32_t>();
*out_height = dims_data[0];
*out_width = dims_data[1];
} else {
*out_height = size_[0];
*out_width = size_[1];
}
}
bool align_corners_;
std::vector<index_t> size_;
};
template <DeviceType D, typename T>
class ResizeBilinearFunctor {
public:
struct ResizeBilinearFunctor : ResizeBilinearFunctorBase {
ResizeBilinearFunctor(const std::vector<index_t> &size, bool align_corners)
: align_corners_(align_corners), size_(size) {}
: ResizeBilinearFunctorBase(size, align_corners) {}
void operator()(const Tensor *input,
const Tensor *resize_dims,
Tensor *output) {
const index_t batch = input->dim(0);
const index_t channels = input->dim(1);
const index_t in_height = input->dim(2);
const index_t in_width = input->dim(3);
const index_t in_height = input->dim(1);
const index_t in_width = input->dim(2);
const index_t channels = input->dim(3);
index_t out_height;
index_t out_width;
GetOutputSize(resize_dims, &out_height, &out_width);
MACE_CHECK(out_height > 0 && out_width > 0);
std::vector<index_t> out_shape{batch, channels, out_height, out_width};
std::vector<index_t> out_shape{batch, out_height, out_width, channels};
output->Resize(out_shape);
Tensor::MappingGuard input_mapper(input);
......@@ -146,32 +173,18 @@ class ResizeBilinearFunctor {
ResizeImage(input_data, batch, in_height, in_width, out_height, out_width,
channels, xs, ys, output_data);
}
};
protected:
void GetOutputSize(const Tensor *resize_dims,
index_t *out_height,
index_t *out_width) {
if (size_[0] < 0 || size_[1] < 0) {
MACE_CHECK(resize_dims != nullptr && resize_dims->dim_size() == 1);
Tensor::MappingGuard resize_dims_mapper(resize_dims);
auto dims_data = resize_dims->data<int32_t>();
*out_height = dims_data[0];
*out_width = dims_data[1];
} else {
*out_height = size_[0];
*out_width = size_[1];
}
}
template<typename T>
struct ResizeBilinearFunctor<DeviceType::OPENCL, T> : ResizeBilinearFunctorBase {
ResizeBilinearFunctor(const std::vector<index_t> &size, bool align_corners)
: ResizeBilinearFunctorBase(size, align_corners) {}
private:
bool align_corners_;
std::vector<index_t> size_;
void operator()(const Tensor *input,
const Tensor *resize_dims,
Tensor *output);
};
template <>
void ResizeBilinearFunctor<DeviceType::OPENCL, float>::operator()(
const Tensor *input, const Tensor *resize_dims, Tensor *output);
} // namespace kernels
} // namespace mace
......
......@@ -23,4 +23,9 @@ REGISTER_OPENCL_OPERATOR(OpKeyBuilder("ResizeBilinear")
.Build(),
ResizeBilinearOp<DeviceType::OPENCL, float>);
REGISTER_OPENCL_OPERATOR(OpKeyBuilder("ResizeBilinear")
.TypeConstraint<half>("T")
.Build(),
ResizeBilinearOp<DeviceType::OPENCL, half>);
} // namespace mace
......@@ -19,18 +19,30 @@ static void ResizeBilinearBenchmark(int iters,
mace::testing::StopTiming();
OpsTestNet net;
OpDefBuilder("ResizeBilinear", "ResizeBilinearBenchmark")
.Input("Input")
.Input("OutSize")
.Output("Output")
.AddIntsArg("size", {output_height, output_width})
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<D, float>("Input",
{batch, channels, input_height, input_width});
{batch, input_height, input_width, channels});
net.AddInputFromArray<D, index_t>("OutSize", {2},
{output_height, output_width});
if (D == DeviceType::OPENCL) {
BufferToImage<D, T>(net, "Input", "InputImage", kernels::BufferType::IN_OUT);
OpDefBuilder("ResizeBilinear", "ResizeBilinearBenchmark")
.Input("InputImage")
.Input("OutSize")
.Output("OutputImage")
.AddIntsArg("size", {output_height, output_width})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
} else {
OpDefBuilder("ResizeBilinear", "ResizeBilinearBenchmark")
.Input("Input")
.Input("OutSize")
.Output("Output")
.AddIntsArg("size", {output_height, output_width})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
}
// Warm-up
for (int i = 0; i < 5; ++i) {
......@@ -58,9 +70,12 @@ static void ResizeBilinearBenchmark(int iters,
#define BM_RESIZE_BILINEAR(N, C, H0, W0, H1, W1, TYPE) \
BM_RESIZE_BILINEAR_MACRO(N, C, H0, W0, H1, W1, TYPE, CPU); \
BM_RESIZE_BILINEAR_MACRO(N, C, H0, W0, H1, W1, TYPE, NEON); \
BM_RESIZE_BILINEAR_MACRO(N, C, H0, W0, H1, W1, TYPE, OPENCL);
// SNPE 835 GPU: 6870us
BM_RESIZE_BILINEAR(1, 128, 120, 120, 480, 480, half);
BM_RESIZE_BILINEAR(1, 128, 120, 120, 480, 480, float);
BM_RESIZE_BILINEAR(1, 256, 7, 7, 15, 15, float);
BM_RESIZE_BILINEAR(1, 256, 15, 15, 30, 30, float);
BM_RESIZE_BILINEAR(1, 128, 30, 30, 60, 60, float);
......
......@@ -23,14 +23,14 @@ TEST_F(ResizeBilinearTest, CPUResizeBilinearWOAlignCorners) {
// Add input data
vector<float> input(24);
std::iota(begin(input), end(input), 0);
net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 3, 2, 4}, input);
net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 2, 4, 3}, input);
net.AddInputFromArray<DeviceType::CPU, int>("OutSize", {2}, {1, 2});
// Run
net.RunOp();
// Check
auto expected = CreateTensor<float>({1, 3, 1, 2}, {0, 2, 8, 10, 16, 18});
auto expected = CreateTensor<float>({1, 1, 2, 3}, {0, 1, 2, 6, 7, 8});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001);
}
......@@ -49,14 +49,14 @@ TEST_F(ResizeBilinearTest, ResizeBilinearWAlignCorners) {
// Add input data
vector<float> input(24);
std::iota(begin(input), end(input), 0);
net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 3, 2, 4}, input);
net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 2, 4, 3}, input);
net.AddInputFromArray<DeviceType::CPU, int>("OutSize", {2}, {1, 2});
// Run
net.RunOp();
// Check
auto expected = CreateTensor<float>({1, 3, 1, 2}, {0, 3, 8, 11, 16, 19});
auto expected = CreateTensor<float>({1, 1, 2, 3}, {0, 1, 2, 9, 10, 11});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001);
}
......@@ -65,6 +65,7 @@ template <DeviceType D>
void TestRandomResizeBilinear() {
srand(time(nullptr));
testing::internal::LogToStderr();
for (int round = 0; round < 10; ++round) {
int batch = 1 + rand() % 5;
int channels = 1 + rand() % 100;
......@@ -72,39 +73,54 @@ void TestRandomResizeBilinear() {
int width = 1 + rand() % 100;
int in_height = 1 + rand() % 100;
int in_width = 1 + rand() % 100;
int align_corners = rand() % 1;
// Construct graph
OpsTestNet net;
OpDefBuilder("ResizeBilinear", "ResizeBilinearTest")
.Input("Input")
.Input("OutSize")
.Output("Output")
.AddIntArg("align_corners", 1)
.AddIntsArg("size", {height, width})
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<D, float>("Input",
{batch, channels, in_height, in_width});
{batch, in_height, in_width, channels});
net.AddInputFromArray<D, int>("OutSize", {2}, {height, width});
// Run
net.RunOp(D);
Tensor actual;
actual.Copy(*net.GetOutput("Output"));
OpDefBuilder("ResizeBilinear", "ResizeBilinearTest")
.Input("Input")
.Input("OutSize")
.Output("Output")
.AddIntArg("align_corners", align_corners)
.AddIntsArg("size", {height, width})
.Finalize(net.NewOperatorDef());
// Run on CPU
net.RunOp(DeviceType::CPU);
Tensor *expected = net.GetOutput("Output");
Tensor expected;
expected.Copy(*net.GetOutput("Output"));
if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(net, "Input", "InputImage", kernels::BufferType::IN_OUT);
OpDefBuilder("ResizeBilinear", "ResizeBilinearTest")
.Input("InputImage")
.Input("OutSize")
.Output("OutputImage")
.AddIntArg("align_corners", align_corners)
.AddIntsArg("size", {height, width})
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
ImageToBuffer<D, float>(net, "OutputImage", "DeviceOutput", kernels::BufferType::IN_OUT);
} else {
// TODO support NEON
}
// Check
ExpectTensorNear<float>(*expected, actual, 0.001);
ExpectTensorNear<float>(expected, *net.GetOutput("DeviceOutput"), 0.001);
}
}
/*
TEST_F(ResizeBilinearTest, NEONRandomResizeBilinear) {
TestRandomResizeBilinear<DeviceType::NEON>();
}
*/
TEST_F(ResizeBilinearTest, OPENCLRandomResizeBilinear) {
TestRandomResizeBilinear<DeviceType::OPENCL>();
......
......@@ -92,6 +92,7 @@ def main(unused_args):
size = tensor_values[input_name]
break
key = '%s(size=%s, align_corners=%s)' % (op.type, size, align_corners)
print(key)
hist_inc(stats, key)
elif op.type in ['AvgPool', 'MaxPool']:
padding = op.get_attr('padding')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册