提交 1998fd46 编写于 作者: L Liangliang He

Merge branch 'resize' into 'master'

Resize

See merge request !93
...@@ -27,18 +27,18 @@ struct Conv2dFunctor { ...@@ -27,18 +27,18 @@ struct Conv2dFunctor {
MACE_CHECK_NOTNULL(filter); MACE_CHECK_NOTNULL(filter);
MACE_CHECK_NOTNULL(output); MACE_CHECK_NOTNULL(output);
index_t batch = output->shape()[0]; index_t batch = output->dim(0);
index_t channels = output->shape()[1]; index_t channels = output->dim(1);
index_t height = output->shape()[2]; index_t height = output->dim(2);
index_t width = output->shape()[3]; index_t width = output->dim(3);
index_t input_batch = input->shape()[0]; index_t input_batch = input->dim(0);
index_t input_channels = input->shape()[1]; index_t input_channels = input->dim(1);
index_t input_height = input->shape()[2]; index_t input_height = input->dim(2);
index_t input_width = input->shape()[3]; index_t input_width = input->dim(3);
index_t kernel_h = filter->shape()[2]; index_t kernel_h = filter->dim(2);
index_t kernel_w = filter->shape()[3]; index_t kernel_w = filter->dim(3);
int stride_h = strides_[0]; int stride_h = strides_[0];
int stride_w = strides_[1]; int stride_w = strides_[1];
......
...@@ -61,8 +61,8 @@ void Conv2dFunctor<DeviceType::NEON, float>::operator()(const Tensor *input, ...@@ -61,8 +61,8 @@ void Conv2dFunctor<DeviceType::NEON, float>::operator()(const Tensor *input,
{nullptr, nullptr}, {nullptr, nullptr},
{Conv2dNeonK5x5S1, nullptr}}; {Conv2dNeonK5x5S1, nullptr}};
// not implement yet // not implement yet
index_t kernel_h = filter->shape()[2]; index_t kernel_h = filter->dim(2);
index_t kernel_w = filter->shape()[3]; index_t kernel_w = filter->dim(3);
if (kernel_h != kernel_w || kernel_h > 5 || strides_[0] != strides_[1] || if (kernel_h != kernel_w || kernel_h > 5 || strides_[0] != strides_[1] ||
strides_[0] > 2 || dilations_[0] != 1 || dilations_[1] != 1 || strides_[0] > 2 || dilations_[0] != 1 || dilations_[1] != 1 ||
selector[kernel_h - 1][strides_[0] - 1] == nullptr) { selector[kernel_h - 1][strides_[0] - 1] == nullptr) {
......
__kernel void resize_bilinear_nocache(__global const float *input, /* n * c, h, w */
__global float *output /* n * c, h, w */,
__private const float height_scale,
__private const float width_scale,
__private const int in_height,
__private const int in_width) {
const int c = get_global_id(0);
const int h = get_global_id(1);
const int w = get_global_id(2);
const int channels = get_global_size(0);
const int height = get_global_size(1);
const int width = get_global_size(2);
const float h_in = h * height_scale;
const float w_in = w * width_scale;
const int h_lower = max(0, (int) floor(h_in));
const int h_upper = min(in_height - 1, h_lower + 1);
const int w_lower = max(0, (int) floor(w_in));
const int w_upper = min(in_width - 1, w_lower + 1);
const float h_lerp = h_in - h_lower;
const float w_lerp = w_in - w_lower;
const float *input_base = input + c * in_height * in_width;
float *output_base = output + c * height * width;
float top_left = input_base[h_lower * in_width + w_lower];
float top_right = input_base[h_lower * in_width + w_upper];
float bottom_left = input_base[h_upper * in_width + w_lower];
float bottom_right = input_base[h_upper * in_width + w_upper];
const float top = top_left + (top_right - top_left) * w_lerp;
const float bottom = bottom_left + (bottom_right - bottom_left) * w_lerp;
output_base[h * width + w] = top + (bottom - top) * h_lerp;
}
...@@ -30,8 +30,8 @@ void Conv2dFunctor<DeviceType::OPENCL, float>::operator()(const Tensor *input, ...@@ -30,8 +30,8 @@ void Conv2dFunctor<DeviceType::OPENCL, float>::operator()(const Tensor *input,
{nullptr, nullptr}, {nullptr, nullptr},
{nullptr, nullptr}}; {nullptr, nullptr}};
index_t kernel_h = filter->shape()[2]; index_t kernel_h = filter->dim(2);
index_t kernel_w = filter->shape()[3]; index_t kernel_w = filter->dim(3);
if (kernel_h != kernel_w || kernel_h > 5 || strides_[0] != strides_[1] || if (kernel_h != kernel_w || kernel_h > 5 || strides_[0] != strides_[1] ||
strides_[0] > 2 || dilations_[0] != 1 || dilations_[1] != 1 || strides_[0] > 2 || dilations_[0] != 1 || dilations_[1] != 1 ||
selector[kernel_h - 1][strides_[0] - 1] == nullptr) { selector[kernel_h - 1][strides_[0] - 1] == nullptr) {
......
...@@ -15,11 +15,11 @@ void Conv1x1Naive(const Tensor *input, ...@@ -15,11 +15,11 @@ void Conv1x1Naive(const Tensor *input,
const Tensor *filter, const Tensor *filter,
const Tensor *bias, const Tensor *bias,
Tensor *output) { Tensor *output) {
const index_t batch = output->shape()[0]; const index_t batch = output->dim(0);
const index_t channels = output->shape()[1]; const index_t channels = output->dim(1);
const index_t height = output->shape()[2]; const index_t height = output->dim(2);
const index_t width = output->shape()[3]; const index_t width = output->dim(3);
const index_t input_channels = input->shape()[1]; const index_t input_channels = input->dim(1);
auto runtime = OpenCLRuntime::Get(); auto runtime = OpenCLRuntime::Get();
auto program = runtime->program(); auto program = runtime->program();
...@@ -46,11 +46,11 @@ void Conv1x1V2(const Tensor *input, ...@@ -46,11 +46,11 @@ void Conv1x1V2(const Tensor *input,
const Tensor *filter, const Tensor *filter,
const Tensor *bias, const Tensor *bias,
Tensor *output) { Tensor *output) {
const index_t batch = output->shape()[0]; const index_t batch = output->dim(0);
const index_t channels = output->shape()[1]; const index_t channels = output->dim(1);
const index_t height = output->shape()[2]; const index_t height = output->dim(2);
const index_t width = output->shape()[3]; const index_t width = output->dim(3);
const index_t input_channels = input->shape()[1]; const index_t input_channels = input->dim(1);
auto runtime = OpenCLRuntime::Get(); auto runtime = OpenCLRuntime::Get();
auto program = runtime->program(); auto program = runtime->program();
...@@ -88,11 +88,11 @@ void Conv1x1V3(const Tensor *input, ...@@ -88,11 +88,11 @@ void Conv1x1V3(const Tensor *input,
const Tensor *filter, const Tensor *filter,
const Tensor *bias, const Tensor *bias,
Tensor *output) { Tensor *output) {
const index_t batch = output->shape()[0]; const index_t batch = output->dim(0);
const index_t channels = output->shape()[1]; const index_t channels = output->dim(1);
const index_t height = output->shape()[2]; const index_t height = output->dim(2);
const index_t width = output->shape()[3]; const index_t width = output->dim(3);
const index_t input_channels = input->shape()[1]; const index_t input_channels = input->dim(1);
auto runtime = OpenCLRuntime::Get(); auto runtime = OpenCLRuntime::Get();
auto program = runtime->program(); auto program = runtime->program();
...@@ -174,13 +174,13 @@ extern void Conv2dOpenclK1x1S1(const Tensor *input, ...@@ -174,13 +174,13 @@ extern void Conv2dOpenclK1x1S1(const Tensor *input,
const Tensor *filter, const Tensor *filter,
const Tensor *bias, const Tensor *bias,
Tensor *output) { Tensor *output) {
const index_t batch = output->shape()[0]; const index_t batch = output->dim(0);
const index_t height = output->shape()[2]; const index_t height = output->dim(2);
const index_t width = output->shape()[3]; const index_t width = output->dim(3);
const index_t input_batch = input->shape()[0]; const index_t input_batch = input->dim(0);
const index_t input_height = input->shape()[2]; const index_t input_height = input->dim(2);
const index_t input_width = input->shape()[3]; const index_t input_width = input->dim(3);
MACE_CHECK(input_batch == batch && input_height == height && MACE_CHECK(input_batch == batch && input_height == height &&
input_width == width); input_width == width);
......
...@@ -11,9 +11,9 @@ namespace kernels { ...@@ -11,9 +11,9 @@ namespace kernels {
static void InnerConv2dK3x3S12(const Tensor *input, const Tensor *filter, static void InnerConv2dK3x3S12(const Tensor *input, const Tensor *filter,
const Tensor *bias, const uint32_t stride, Tensor *output) { const Tensor *bias, const uint32_t stride, Tensor *output) {
const index_t channels = output->shape()[1]; const index_t channels = output->dim(1);
const index_t height = output->shape()[2]; const index_t height = output->dim(2);
const index_t width = output->shape()[3]; const index_t width = output->dim(3);
MACE_CHECK(input->dim(0) == output->dim(0)); MACE_CHECK(input->dim(0) == output->dim(0));
......
...@@ -27,8 +27,8 @@ void DepthwiseConv2dFunctor<DeviceType::OPENCL, float>::operator()(const Tensor ...@@ -27,8 +27,8 @@ void DepthwiseConv2dFunctor<DeviceType::OPENCL, float>::operator()(const Tensor
{nullptr, nullptr}, {nullptr, nullptr},
{nullptr, nullptr}}; {nullptr, nullptr}};
index_t kernel_h = filter->shape()[2]; index_t kernel_h = filter->dim(2);
index_t kernel_w = filter->shape()[3]; index_t kernel_w = filter->dim(3);
if (kernel_h != kernel_w || kernel_h > 5 || strides_[0] != strides_[1] || if (kernel_h != kernel_w || kernel_h > 5 || strides_[0] != strides_[1] ||
strides_[0] > 2 || dilations_[0] != 1 || dilations_[1] != 1 || strides_[0] > 2 || dilations_[0] != 1 || dilations_[1] != 1 ||
selector[kernel_h - 1][strides_[0] - 1] == nullptr) { selector[kernel_h - 1][strides_[0] - 1] == nullptr) {
......
...@@ -2,15 +2,59 @@ ...@@ -2,15 +2,59 @@
// Copyright (c) 2017 XiaoMi All rights reserved. // Copyright (c) 2017 XiaoMi All rights reserved.
// //
#include "mace/kernels/resize_bilinear.h" #include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/kernels/resize_bilinear.h"
namespace mace { namespace mace {
namespace kernels { namespace kernels {
template <> template <>
void ResizeBilinearFunctor<DeviceType::OPENCL, float>::operator()( void ResizeBilinearFunctor<DeviceType::OPENCL, float>::operator()(
const Tensor *input, const Tensor *resize_dims, Tensor *output) {} const Tensor *input, const Tensor *resize_dims, Tensor *output) {
const index_t batch = input->dim(0);
const index_t channels = input->dim(1);
const index_t in_height = input->dim(2);
const index_t in_width = input->dim(3);
index_t out_height;
index_t out_width;
{
MACE_CHECK(resize_dims->dim_size() == 1);
Tensor::MappingGuard resize_dims_mapper(resize_dims);
auto dims_data = resize_dims->data<index_t>();
out_height = dims_data[0];
out_width = dims_data[1];
}
std::vector<index_t> out_shape{batch, channels, out_height, out_width};
output->Resize(out_shape);
float height_scale =
CalculateResizeScale(in_height, out_height, align_corners_);
float width_scale = CalculateResizeScale(in_width, out_width, align_corners_);
auto runtime = OpenCLRuntime::Get();
auto program = runtime->program();
auto rb_kernel = cl::Kernel(program, "resize_bilinear_nocache");
const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(rb_kernel);
uint32_t idx = 0;
rb_kernel.setArg(idx++, *(static_cast<const cl::Buffer *>(input->buffer())));
rb_kernel.setArg(idx++, *(static_cast<cl::Buffer *>(output->buffer())));
rb_kernel.setArg(idx++, static_cast<float>(height_scale));
rb_kernel.setArg(idx++, static_cast<float>(width_scale));
rb_kernel.setArg(idx++, static_cast<int>(in_height));
rb_kernel.setArg(idx++, static_cast<int>(in_width));
auto command_queue = runtime->command_queue();
cl_int error = command_queue.enqueueNDRangeKernel(
rb_kernel, cl::NullRange,
cl::NDRange(static_cast<int>(batch * channels),
static_cast<int>(out_height), static_cast<int>(out_width)),
cl::NDRange(1, 16, kwg_size / 16));
MACE_CHECK(error == CL_SUCCESS, error);
}
} // namespace kernels } // namespace kernels
} // namespace mace } // namespace mace
...@@ -127,6 +127,8 @@ struct ResizeBilinearFunctor { ...@@ -127,6 +127,8 @@ struct ResizeBilinearFunctor {
vector<index_t> out_shape{n, channels, out_height, out_width}; vector<index_t> out_shape{n, channels, out_height, out_width};
output->Resize(out_shape); output->Resize(out_shape);
Tensor::MappingGuard input_mapper(input);
Tensor::MappingGuard output_mapper(output);
const T *input_data = input->data<T>(); const T *input_data = input->data<T>();
T *output_data = output->mutable_data<T>(); T *output_data = output->mutable_data<T>();
......
...@@ -28,7 +28,6 @@ class ResizeBilinearOp : public Operator<D, T> { ...@@ -28,7 +28,6 @@ class ResizeBilinearOp : public Operator<D, T> {
MACE_CHECK(resize_dims->dim_size() == 1, MACE_CHECK(resize_dims->dim_size() == 1,
"resize dim must be 2-dimensional.", resize_dims->dim_size()); "resize dim must be 2-dimensional.", resize_dims->dim_size());
functor_(input, resize_dims, output); functor_(input, resize_dims, output);
return true; return true;
} }
......
...@@ -26,10 +26,10 @@ static void ResizeBilinearBenchmark(int iters, ...@@ -26,10 +26,10 @@ static void ResizeBilinearBenchmark(int iters,
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Add input data // Add input data
net.AddRandomInput<DeviceType::CPU, float>( net.AddRandomInput<D, float>("Input",
"Input", {batch, channels, input_height, input_width}); {batch, channels, input_height, input_width});
net.AddInputFromArray<DeviceType::CPU, index_t>( net.AddInputFromArray<D, index_t>("OutSize", {2},
"OutSize", {2}, {output_height, output_width}); {output_height, output_width});
// Warm-up // Warm-up
for (int i = 0; i < 5; ++i) { for (int i = 0; i < 5; ++i) {
...@@ -40,6 +40,7 @@ static void ResizeBilinearBenchmark(int iters, ...@@ -40,6 +40,7 @@ static void ResizeBilinearBenchmark(int iters,
while (iters--) { while (iters--) {
net.RunOp(D); net.RunOp(D);
} }
net.Sync();
} }
#define BM_RESIZE_BILINEAR_MACRO(N, C, H0, W0, H1, W1, TYPE, DEVICE) \ #define BM_RESIZE_BILINEAR_MACRO(N, C, H0, W0, H1, W1, TYPE, DEVICE) \
......
...@@ -10,7 +10,7 @@ using namespace mace; ...@@ -10,7 +10,7 @@ using namespace mace;
class ResizeBilinearTest : public OpsTestBase {}; class ResizeBilinearTest : public OpsTestBase {};
TEST_F(ResizeBilinearTest, ResizeBilinearWOAlignCorners) { TEST_F(ResizeBilinearTest, CPUResizeBilinearWOAlignCorners) {
testing::internal::LogToStderr(); testing::internal::LogToStderr();
// Construct graph // Construct graph
auto &net = test_net(); auto &net = test_net();
...@@ -60,3 +60,51 @@ TEST_F(ResizeBilinearTest, ResizeBilinearWAlignCorners) { ...@@ -60,3 +60,51 @@ TEST_F(ResizeBilinearTest, ResizeBilinearWAlignCorners) {
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001); ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001);
} }
template <DeviceType D>
void TestRandomResizeBilinear() {
srand(time(nullptr));
testing::internal::LogToStderr();
for (int round = 0; round < 10; ++round) {
index_t batch = 1 + rand() % 5;
index_t channels = 1 + rand() % 100;
index_t height = 1 + rand() % 100;
index_t width = 1 + rand() % 100;
index_t in_height = 1 + rand() % 100;
index_t in_width = 1 + rand() % 100;
// Construct graph
OpsTestNet net;
OpDefBuilder("ResizeBilinear", "ResizeBilinearTest")
.Input("Input")
.Input("OutSize")
.Output("Output")
.AddIntArg("align_corners", 1)
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<D, float>("Input",
{batch, channels, in_height, in_width});
net.AddInputFromArray<D, index_t>("OutSize", {2}, {height, width});
// Run
net.RunOp(D);
Tensor actual;
actual.Copy(*net.GetOutput("Output"));
// Run on CPU
net.RunOp(DeviceType::CPU);
Tensor *expected = net.GetOutput("Output");
// Check
ExpectTensorNear<float>(*expected, actual, 0.001);
}
}
TEST_F(ResizeBilinearTest, NEONRandomResizeBilinear) {
TestRandomResizeBilinear<DeviceType::NEON>();
}
TEST_F(ResizeBilinearTest, OPENCLRandomResizeBilinear) {
TestRandomResizeBilinear<DeviceType::OPENCL>();
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册