diff --git a/mace/core/opencl_allocator.cc b/mace/core/opencl_allocator.cc index 75004f75276a54c47f82626eefa818dcef3941da..0c4cf8f0f87069d20650622c578308983d61560b 100644 --- a/mace/core/opencl_allocator.cc +++ b/mace/core/opencl_allocator.cc @@ -54,10 +54,11 @@ void *OpenCLAllocator::NewImage(const std::vector &image_shape, cl_int error; cl::Image2D *cl_image = new cl::Image2D(OpenCLRuntime::Get()->context(), - CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR , + CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, img_format, image_shape[0], image_shape[1], 0, nullptr, &error); + MACE_CHECK(error == CL_SUCCESS); return cl_image; } diff --git a/mace/kernels/opencl/addn.cc b/mace/kernels/opencl/addn.cc index 6c5106db25c0f12cb625b6e5e0c80c0497541804..9f1ed60afc91ad4f2dfdcd13aa6eebf8fd2839b6 100644 --- a/mace/kernels/opencl/addn.cc +++ b/mace/kernels/opencl/addn.cc @@ -17,7 +17,7 @@ static void Add2(const Tensor *input0, const Tensor *input1, Tensor *output) { auto runtime = OpenCLRuntime::Get(); std::set built_options; - built_options.emplace("-DDATA_TYPE=" + DataTypeToCLType(output->dtype())); + built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(output->dtype())); auto addn_kernel = runtime->BuildKernel("addn", "add2", built_options); const uint32_t lws = runtime->GetKernelMaxWorkGroupSize(addn_kernel); diff --git a/mace/kernels/opencl/batch_norm_opencl.cc b/mace/kernels/opencl/batch_norm_opencl.cc index 271ef9da07f5dd4cac421bbecc317c8125c65589..e9dc00b9d5d93fe4792ea15cc25e40791afff6e2 100644 --- a/mace/kernels/opencl/batch_norm_opencl.cc +++ b/mace/kernels/opencl/batch_norm_opencl.cc @@ -35,8 +35,8 @@ void BatchNormFunctor::operator()( auto runtime = OpenCLRuntime::Get(); std::set built_options; - built_options.emplace("-DDATA_TYPE=" + DataTypeToCLType(input->dtype())); - built_options.emplace("-DCMD_DATA_TYPE=" + DataTypeToOPENCLCMDDataType(input->dtype())); + built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(input->dtype())); + built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(input->dtype())); auto bm_kernel = runtime->BuildKernel("batch_norm", "batch_norm", built_options); const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(bm_kernel); diff --git a/mace/kernels/opencl/buffer_to_image.cc b/mace/kernels/opencl/buffer_to_image.cc index 61faa995ce86792a302068af11aed7b784b2834f..f3af3d22622bd5e893347d958da76dbec71a450a 100644 --- a/mace/kernels/opencl/buffer_to_image.cc +++ b/mace/kernels/opencl/buffer_to_image.cc @@ -24,8 +24,13 @@ void BufferToImageFunctor::operator()(Tensor *buffer, } std::set built_options; - built_options.emplace("-DDATA_TYPE=" + DataTypeToCLType(DataTypeToEnum::value)); - built_options.emplace("-DCMD_DATA_TYPE=" + DataTypeToOPENCLCMDDataType(DataTypeToEnum::value)); + if (buffer->dtype() == image->dtype()) { + built_options.emplace("-DDATA_TYPE=" + DtToCLDt(DataTypeToEnum::value)); + built_options.emplace("-DCMD_DATA_TYPE=" + DtToCLCMDDt(DataTypeToEnum::value)); + } else { + built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(DataTypeToEnum::value)); + built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(DataTypeToEnum::value)); + } auto runtime = OpenCLRuntime::Get(); string kernel_name; switch (type) { diff --git a/mace/kernels/opencl/cl/resize_bilinear.cl b/mace/kernels/opencl/cl/resize_bilinear.cl index f34e63cbf07b1a360957fcf5eaf74661ec22b8c1..efb769d27b7ab7836d0681c2b84775047942805a 100644 --- a/mace/kernels/opencl/cl/resize_bilinear.cl +++ b/mace/kernels/opencl/cl/resize_bilinear.cl @@ -1,18 +1,19 @@ #include -// Supported data type: half/float -__kernel void resize_bilinear_nocache(__global const DATA_TYPE *input, /* n * c, h, w */ - __global DATA_TYPE *output /* n * c, h, w */, +__kernel void resize_bilinear_nocache(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */ + __write_only image2d_t output, __private const float height_scale, __private const float width_scale, __private const int in_height, - __private const int in_width) { - const int c = get_global_id(0); - const int h = get_global_id(1); - const int w = get_global_id(2); - const int channels = get_global_size(0); - const int height = get_global_size(1); - const int width = get_global_size(2); + __private const int in_width, + __private const int out_height) { + const int ch_blk = get_global_id(0); + const int ch_blks = get_global_size(0); + const int w = get_global_id(1); + const int out_width = get_global_size(1); + const int hb = get_global_id(2); + const int b = hb / out_height; + const int h = hb % out_height; const float h_in = h * height_scale; const float w_in = w * width_scale; @@ -24,16 +25,26 @@ __kernel void resize_bilinear_nocache(__global const DATA_TYPE *input, /* n * c, const float h_lerp = h_in - h_lower; const float w_lerp = w_in - w_lower; - const DATA_TYPE *input_base = input + c * in_height * in_width; - DATA_TYPE *output_base = output + c * height * width; + const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + const int in_w_offset = ch_blk * in_width; + const int in_h_offset = b * in_height; - DATA_TYPE top_left = input_base[h_lower * in_width + w_lower]; - DATA_TYPE top_right = input_base[h_lower * in_width + w_upper]; - DATA_TYPE bottom_left = input_base[h_upper * in_width + w_lower]; - DATA_TYPE bottom_right = input_base[h_upper * in_width + w_upper]; + DATA_TYPE4 top_left = READ_IMAGET(input, sampler, + (int2)(in_w_offset + w_lower, in_h_offset + h_lower)); + DATA_TYPE4 top_right = READ_IMAGET(input, sampler, + (int2)(in_w_offset + w_upper, in_h_offset + h_lower)); + DATA_TYPE4 bottom_left = READ_IMAGET(input, sampler, + (int2)(in_w_offset + w_lower, in_h_offset + h_upper)); + DATA_TYPE4 bottom_right = READ_IMAGET(input, sampler, + (int2)(in_w_offset + w_upper, in_h_offset + h_upper)); - const DATA_TYPE top = top_left + (top_right - top_left) * w_lerp; - const DATA_TYPE bottom = bottom_left + (bottom_right - bottom_left) * w_lerp; - output_base[h * width + w] = top + (bottom - top) * h_lerp; + DATA_TYPE4 top = top_left + (top_right - top_left) * w_lerp; + DATA_TYPE4 bottom = bottom_left + (bottom_right - bottom_left) * w_lerp; + + DATA_TYPE4 out = top + (bottom - top) * h_lerp; + + const int out_w_offset = ch_blk * out_width; + const int out_h_offset = b * out_height; + WRITE_IMAGET(output, (int2)(out_w_offset + w, out_h_offset + h), out); } diff --git a/mace/kernels/opencl/conv_2d_opencl_1x1.cc b/mace/kernels/opencl/conv_2d_opencl_1x1.cc index c3a17c7bd77b760b3e9dfe31c9a3158e3348db58..d759689c6dc1ee8ffbfa98f2a4a58577a50c4271 100644 --- a/mace/kernels/opencl/conv_2d_opencl_1x1.cc +++ b/mace/kernels/opencl/conv_2d_opencl_1x1.cc @@ -34,8 +34,8 @@ void Conv1x1(const Tensor *input, MACE_CHECK(input_batch == batch); std::set built_options; - built_options.emplace("-DDATA_TYPE=" + DataTypeToCLType(dt)); - built_options.emplace("-DCMD_DATA_TYPE=" + DataTypeToOPENCLCMDDataType(dt)); + built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt)); + built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt)); built_options.emplace("-DSTRIDE=" + ToString(stride)); if (bias != nullptr) { built_options.emplace("-DBIAS"); diff --git a/mace/kernels/opencl/conv_2d_opencl_3x3.cc b/mace/kernels/opencl/conv_2d_opencl_3x3.cc index e29c4d92f7ee52ad6db3b9714e1fe94749a4c3d4..24bf90a1178961665ec1cf65935809d8409987bd 100644 --- a/mace/kernels/opencl/conv_2d_opencl_3x3.cc +++ b/mace/kernels/opencl/conv_2d_opencl_3x3.cc @@ -26,8 +26,8 @@ static void Conv2d3x3S12(const Tensor *input, const Tensor *filter, const index_t width_blocks = RoundUpDiv(width); std::set built_options; - built_options.emplace("-DDATA_TYPE=" + DataTypeToCLType(dt)); - built_options.emplace("-DCMD_DATA_TYPE=" + DataTypeToOPENCLCMDDataType(dt)); + built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt)); + built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt)); built_options.emplace(bias != nullptr ? "-DBIAS" : ""); built_options.emplace("-DSTRIDE=" + ToString(stride)); if (fused_relu) { diff --git a/mace/kernels/opencl/depthwise_conv_opencl_3x3.cc b/mace/kernels/opencl/depthwise_conv_opencl_3x3.cc index 60ce2a829a78a0a0439dd1e287c61f2dee4b490b..1402131df164cb0d1ba348617b3988e78f71c574 100644 --- a/mace/kernels/opencl/depthwise_conv_opencl_3x3.cc +++ b/mace/kernels/opencl/depthwise_conv_opencl_3x3.cc @@ -32,7 +32,7 @@ static void InnerDepthwiseConvOpenclK3x3S12(const Tensor *input, auto runtime = OpenCLRuntime::Get(); std::set built_options; - built_options.emplace("-DDATA_TYPE=" + DataTypeToCLType(input->dtype())); + built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(input->dtype())); built_options.emplace(stride == 1 ? "-DSTRIDE_1" : ""); built_options.emplace(bias != nullptr ? "-DBIAS" : ""); auto conv_kernel = runtime->BuildKernel("depthwise_conv_3x3", "depthwise_conv_3x3", built_options); diff --git a/mace/kernels/opencl/helper.cc b/mace/kernels/opencl/helper.cc index 4f4d1c56147df61da58a5a6478f1958e2b289a39..2c1dc264bd5ac1ddaeeaf47ea54a6e8b9e32e13a 100644 --- a/mace/kernels/opencl/helper.cc +++ b/mace/kernels/opencl/helper.cc @@ -54,34 +54,42 @@ void CalImage2DShape(const std::vector &shape, /* NHWC */ } -std::string DataTypeToCLType(const DataType dt) { +std::string DtToCLDt(const DataType dt) { + switch (dt) { + case DT_FLOAT: + return "float"; + case DT_HALF: + return "half"; + default: + LOG(FATAL) << "Unsupported data type"; + return ""; + } +} + +std::string DtToCLCMDDt(const DataType dt) { + switch (dt) { + case DT_FLOAT: + return "f"; + case DT_HALF: + return "h"; + default: + LOG(FATAL) << "Not supported data type for opencl cmd data type"; + return ""; + } +} + +std::string DtToUpstreamCLDt(const DataType dt) { switch (dt) { case DT_FLOAT: case DT_HALF: return "float"; - case DT_UINT8: - return "uchar"; - case DT_INT8: - return "char"; - case DT_DOUBLE: - return "double"; - case DT_INT32: - return "int"; - case DT_UINT32: - return "int"; - case DT_UINT16: - return "ushort"; - case DT_INT16: - return "short"; - case DT_INT64: - return "long"; default: LOG(FATAL) << "Unsupported data type"; return ""; } } -std::string DataTypeToOPENCLCMDDataType(const DataType dt) { +std::string DtToUpstreamCLCMDDt(const DataType dt) { switch (dt) { case DT_FLOAT: case DT_HALF: diff --git a/mace/kernels/opencl/helper.h b/mace/kernels/opencl/helper.h index 1ad94aa5d2545f059ec785c0b4ec36a87155fb49..70d74e5886c61a50c0a5fb684d02ecc6e00403cd 100644 --- a/mace/kernels/opencl/helper.h +++ b/mace/kernels/opencl/helper.h @@ -19,10 +19,13 @@ void CalImage2DShape(const std::vector &shape, /* NHWC */ const BufferType type, std::vector &image_shape); -std::string DataTypeToOPENCLCMDDataType(const DataType dt); +std::string DtToCLCMDDt(const DataType dt); -std::string DataTypeToCLType(const DataType dt); +std::string DtToUpstreamCLCMDDt(const DataType dt); +std::string DtToCLDt(const DataType dt); + +std::string DtToUpstreamCLDt(const DataType dt); } // namespace kernels } // namespace mace diff --git a/mace/kernels/opencl/pooling_opencl.cc b/mace/kernels/opencl/pooling_opencl.cc index 0aaa89ae2c649583dddafaffbcce428d4ffc94fd..fb9216f767dfd2770a6ccfc405283e51dea2ffe5 100644 --- a/mace/kernels/opencl/pooling_opencl.cc +++ b/mace/kernels/opencl/pooling_opencl.cc @@ -32,7 +32,7 @@ static void Pooling3(const Tensor *input, auto runtime = OpenCLRuntime::Get(); std::set built_options; - built_options.emplace("-DDATA_TYPE=" + DataTypeToCLType(input->dtype())); + built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(input->dtype())); built_options.emplace(stride[0] == 1 ? "-DSTRIDE_1" : ""); auto pooling_kernel = runtime->BuildKernel("pooling", "pooling3", built_options); @@ -80,7 +80,7 @@ static void PoolingN(const Tensor *input, auto runtime = OpenCLRuntime::Get(); std::set built_options; - built_options.emplace("-DDATA_TYPE=" + DataTypeToCLType(input->dtype())); + built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(input->dtype())); auto pooling_kernel = runtime->BuildKernel("pooling", "poolingn", built_options); const uint32_t lws[3] = {1, 8, 128}; diff --git a/mace/kernels/opencl/relu_opencl.cc b/mace/kernels/opencl/relu_opencl.cc index 1149b965a2fc91c5394c97b7028d872b827dc125..e7f527a5380a8f965d3781335f4b2a580fdcd3e7 100644 --- a/mace/kernels/opencl/relu_opencl.cc +++ b/mace/kernels/opencl/relu_opencl.cc @@ -23,7 +23,7 @@ void ReluFunctor::operator()(const Tensor *input, auto program = runtime->program(); std::set built_options; - built_options.emplace("-DDATA_TYPE=" + DataTypeToCLType(input->dtype())); + built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(input->dtype())); if (max_limit_ < 0) { auto relu_kernel = runtime->BuildKernel("relu", "relu", built_options); const uint32_t lws = runtime->GetKernelMaxWorkGroupSize(relu_kernel); diff --git a/mace/kernels/opencl/resize_bilinear_opencl.cc b/mace/kernels/opencl/resize_bilinear_opencl.cc index 7b77afea0fdd3aed146b22d736cacc5c6c165e79..27dd8e62b96422c368e324d249900b5e8d5f7767 100644 --- a/mace/kernels/opencl/resize_bilinear_opencl.cc +++ b/mace/kernels/opencl/resize_bilinear_opencl.cc @@ -6,24 +6,33 @@ #include "mace/core/tensor.h" #include "mace/kernels/resize_bilinear.h" #include "mace/kernels/opencl/helper.h" +#include "mace/utils/utils.h" namespace mace { namespace kernels { -template <> -void ResizeBilinearFunctor::operator()( +template +void ResizeBilinearFunctor::operator()( const Tensor *input, const Tensor *resize_dims, Tensor *output) { const index_t batch = input->dim(0); - const index_t channels = input->dim(1); - const index_t in_height = input->dim(2); - const index_t in_width = input->dim(3); + const index_t in_height = input->dim(1); + const index_t in_width = input->dim(2); + const index_t channels = input->dim(3); + + const index_t channel_blocks = RoundUpDiv4(channels); index_t out_height; index_t out_width; GetOutputSize(resize_dims, &out_height, &out_width); MACE_CHECK(out_height > 0 && out_width > 0); - std::vector out_shape {batch, channels, out_height, out_width}; - output->Resize(out_shape); + std::vector output_shape {batch, out_height, out_width, channels}; + if (input->is_image()) { + std::vector output_image_shape; + CalImage2DShape(output_shape, BufferType::IN_OUT, output_image_shape); + output->ResizeImage(output_shape, output_image_shape); + } else { + output->Resize(output_shape); + } float height_scale = CalculateResizeScale(in_height, out_height, align_corners_); @@ -31,29 +40,37 @@ void ResizeBilinearFunctor::operator()( auto runtime = OpenCLRuntime::Get(); std::set built_options; - built_options.emplace("-DDATA_TYPE=" + DataTypeToCLType(input->dtype())); + auto dt = DataTypeToEnum::value; + built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt)); + built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt)); auto rb_kernel = runtime->BuildKernel("resize_bilinear", "resize_bilinear_nocache", built_options); const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(rb_kernel); + uint32_t idx = 0; - rb_kernel.setArg(idx++, *(static_cast(input->buffer()))); - rb_kernel.setArg(idx++, *(static_cast(output->buffer()))); + rb_kernel.setArg(idx++, *(static_cast(input->buffer()))); + rb_kernel.setArg(idx++, *(static_cast(output->buffer()))); rb_kernel.setArg(idx++, height_scale); rb_kernel.setArg(idx++, width_scale); - rb_kernel.setArg(idx++, static_cast(in_height)); - rb_kernel.setArg(idx++, static_cast(in_width)); + rb_kernel.setArg(idx++, static_cast(in_height)); + rb_kernel.setArg(idx++, static_cast(in_width)); + rb_kernel.setArg(idx++, static_cast(out_height)); auto command_queue = runtime->command_queue(); cl_int error = command_queue.enqueueNDRangeKernel( rb_kernel, cl::NullRange, - cl::NDRange(static_cast(batch * channels), - static_cast(out_height), static_cast(out_width)), - // TODO (heliangliang) tuning and fix when kwg_size < devisor - cl::NDRange(1, 16, kwg_size / 16), - NULL, OpenCLRuntime::Get()->GetDefaultEvent()); + cl::NDRange(static_cast(channel_blocks), + static_cast(out_width), + static_cast(out_height * batch)), + // TODO tuning + cl::NDRange(1, static_cast(out_width > kwg_size ? kwg_size : out_width), 1), + nullptr, OpenCLRuntime::Get()->GetDefaultEvent()); MACE_CHECK(error == CL_SUCCESS, error); } +template struct ResizeBilinearFunctor; +template struct ResizeBilinearFunctor; + } // namespace kernels } // namespace mace diff --git a/mace/kernels/opencl/space_to_batch_opecl.cc b/mace/kernels/opencl/space_to_batch_opecl.cc index 2716501c880fcd4fb2232e292b9396e27cfff2f3..72590be5e87ca1c5b721972855b8869e397df82c 100644 --- a/mace/kernels/opencl/space_to_batch_opecl.cc +++ b/mace/kernels/opencl/space_to_batch_opecl.cc @@ -20,7 +20,7 @@ void SpaceToBatchFunctor::operator()(Tensor *space_te Tensor *batch_tensor) { auto runtime = OpenCLRuntime::Get(); std::set built_options; - built_options.emplace("-DDATA_TYPE=" + DataTypeToCLType(space_tensor->dtype())); + built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(space_tensor->dtype())); auto s2b_kernel = runtime->BuildKernel("space_to_batch", "space_to_batch", built_options); uint32_t idx = 0; diff --git a/mace/kernels/resize_bilinear.h b/mace/kernels/resize_bilinear.h index 59bb2505c9c379c1b0700d7a515a880a704d72db..27415ebdd8e61ff904360d1c520aab8ecf2b7591 100644 --- a/mace/kernels/resize_bilinear.h +++ b/mace/kernels/resize_bilinear.h @@ -61,63 +61,90 @@ void ResizeImage(const T *images, const index_t channels, const std::vector &xs_vec, const std::vector &ys, - float *output) { - const index_t in_channel_size = in_height * in_width; - const index_t in_batch_num_values = channels * in_channel_size; - const index_t out_channel_size = out_height * out_width; - const index_t out_batch_num_values = channels * out_channel_size; + T *output) { + const index_t in_batch_num_values = channels * in_height * in_width; + const index_t out_batch_num_values = channels * out_height * out_width; const CachedInterpolation *xs = xs_vec.data(); -#pragma omp parallel for collapse(2) +#pragma omp parallel for for (index_t b = 0; b < batch_size; ++b) { - for (index_t c = 0; c < channels; ++c) { - const T *input_ptr = - images + in_batch_num_values * b + in_channel_size * c; - float *output_ptr = - output + out_batch_num_values * b + out_channel_size * c; - for (index_t y = 0; y < out_height; ++y) { - const T *ys_input_lower_ptr = input_ptr + ys[y].lower * in_width; - const T *ys_input_upper_ptr = input_ptr + ys[y].upper * in_width; - const float ys_lerp = ys[y].lerp; - for (index_t x = 0; x < out_width; ++x) { - auto xs_lower = xs[x].lower; - auto xs_upper = xs[x].upper; - auto xs_lerp = xs[x].lerp; - - const float top_left = ys_input_lower_ptr[xs_lower]; - const float top_right = ys_input_lower_ptr[xs_upper]; - const float bottom_left = ys_input_upper_ptr[xs_lower]; - const float bottom_right = ys_input_upper_ptr[xs_upper]; - - output_ptr[x] = ComputeLerp(top_left, top_right, bottom_left, - bottom_right, xs_lerp, ys_lerp); + const T *batch_input_ptr = images + in_batch_num_values * b;; + T *batch_output_ptr = output + out_batch_num_values * b; + + for (index_t y = 0; y < out_height; ++y) { + const T *y_lower_input_ptr = + batch_input_ptr + ys[y].lower * in_width * channels; + const T *y_upper_input_ptr = + batch_input_ptr + ys[y].upper * in_width * channels; + T *y_output_ptr = batch_output_ptr + y * out_width * channels; + const float ys_lerp = ys[y].lerp; + + for (index_t x = 0; x < out_width; ++x) { + const float xs_lerp = xs[x].lerp; + const T *top_left_ptr = y_lower_input_ptr + xs[x].lower * channels; + const T *top_right_ptr = y_lower_input_ptr + xs[x].upper * channels; + const T *bottom_left_ptr = y_upper_input_ptr + xs[x].lower * channels; + const T *bottom_right_ptr = y_upper_input_ptr + xs[x].upper * channels; + T *output_ptr = y_output_ptr + x * channels; + + for (index_t c = 0; c < channels; ++c) { + const T top_left = top_left_ptr[c]; + const T top_right = top_right_ptr[c]; + const T bottom_left = bottom_left_ptr[c]; + const T bottom_right = bottom_right_ptr[c]; + + output_ptr[c] = ComputeLerp(top_left, top_right, bottom_left, + bottom_right, xs_lerp, ys_lerp); } - output_ptr += out_width; } } } } } +struct ResizeBilinearFunctorBase { + ResizeBilinearFunctorBase(const std::vector &size, + bool align_corners) + : align_corners_(align_corners), size_(size) {} + + protected: + void GetOutputSize(const Tensor *resize_dims, + index_t *out_height, + index_t *out_width) { + if (size_[0] < 0 || size_[1] < 0) { + MACE_CHECK(resize_dims != nullptr && resize_dims->dim_size() == 1); + Tensor::MappingGuard resize_dims_mapper(resize_dims); + auto dims_data = resize_dims->data(); + *out_height = dims_data[0]; + *out_width = dims_data[1]; + } else { + *out_height = size_[0]; + *out_width = size_[1]; + } + } + + bool align_corners_; + std::vector size_; +}; + template -class ResizeBilinearFunctor { - public: +struct ResizeBilinearFunctor : ResizeBilinearFunctorBase { ResizeBilinearFunctor(const std::vector &size, bool align_corners) - : align_corners_(align_corners), size_(size) {} + : ResizeBilinearFunctorBase(size, align_corners) {} void operator()(const Tensor *input, const Tensor *resize_dims, Tensor *output) { const index_t batch = input->dim(0); - const index_t channels = input->dim(1); - const index_t in_height = input->dim(2); - const index_t in_width = input->dim(3); + const index_t in_height = input->dim(1); + const index_t in_width = input->dim(2); + const index_t channels = input->dim(3); index_t out_height; index_t out_width; GetOutputSize(resize_dims, &out_height, &out_width); MACE_CHECK(out_height > 0 && out_width > 0); - std::vector out_shape{batch, channels, out_height, out_width}; + std::vector out_shape{batch, out_height, out_width, channels}; output->Resize(out_shape); Tensor::MappingGuard input_mapper(input); @@ -146,32 +173,18 @@ class ResizeBilinearFunctor { ResizeImage(input_data, batch, in_height, in_width, out_height, out_width, channels, xs, ys, output_data); } +}; - protected: - void GetOutputSize(const Tensor *resize_dims, - index_t *out_height, - index_t *out_width) { - if (size_[0] < 0 || size_[1] < 0) { - MACE_CHECK(resize_dims != nullptr && resize_dims->dim_size() == 1); - Tensor::MappingGuard resize_dims_mapper(resize_dims); - auto dims_data = resize_dims->data(); - *out_height = dims_data[0]; - *out_width = dims_data[1]; - } else { - *out_height = size_[0]; - *out_width = size_[1]; - } - } +template +struct ResizeBilinearFunctor : ResizeBilinearFunctorBase { + ResizeBilinearFunctor(const std::vector &size, bool align_corners) + : ResizeBilinearFunctorBase(size, align_corners) {} - private: - bool align_corners_; - std::vector size_; + void operator()(const Tensor *input, + const Tensor *resize_dims, + Tensor *output); }; -template <> -void ResizeBilinearFunctor::operator()( - const Tensor *input, const Tensor *resize_dims, Tensor *output); - } // namespace kernels } // namespace mace diff --git a/mace/ops/buffer_to_image_test.cc b/mace/ops/buffer_to_image_test.cc index 7bd667ca3988320529a702224e3045a99ca38de8..3836a7ae90291dbbfb80da20cf78a1bb1c79d87e 100644 --- a/mace/ops/buffer_to_image_test.cc +++ b/mace/ops/buffer_to_image_test.cc @@ -43,7 +43,7 @@ TEST(BufferToImageTest, ArgSmall) { } TEST(BufferToImageTest, ArgHalfSmall) { - TestBidirectionTransform(kernels::ARGUMENT, {1}); + TestBidirectionTransform(kernels::ARGUMENT, {11}); } TEST(BufferToImageTest, ArgMedia) { @@ -97,3 +97,37 @@ TEST(BufferToImageTest, Filter3x3Meida) { TEST(BufferToImageTest, Filter3x3Large) { TestBidirectionTransform(kernels::FILTER, {3, 3, 128, 256}); } + +template +void TestDiffTypeBidirectionTransform(const int type, const std::vector &input_shape) { + OpsTestNet net; + OpDefBuilder("BufferToImage", "BufferToImageTest") + .Input("Input") + .Output("B2IOutput") + .AddIntArg("buffer_type", type) + .AddIntArg("T", DataTypeToEnum::value) + .Finalize(net.NewOperatorDef()); + + // Add input data + net.AddRandomInput("Input", input_shape); + + // Run + net.RunOp(D); + + OpDefBuilder("ImageToBuffer", "ImageToBufferTest") + .Input("B2IOutput") + .Output("I2BOutput") + .AddIntArg("buffer_type", type) + .AddIntArg("T", DataTypeToEnum::value) + .Finalize(net.NewOperatorDef()); + + // Run + net.RunOp(D); + + // Check + ExpectTensorNear(*net.GetOutput("Input"), *net.GetOutput("I2BOutput"), 1e-2); +} + +TEST(BufferToImageTest, ArgFloatToHalfSmall) { + TestDiffTypeBidirectionTransform(kernels::ARGUMENT, {11}); +} diff --git a/mace/ops/resize_bilinear.cc b/mace/ops/resize_bilinear.cc index c3510f688311bbb0210150759ea359c4e7ef6883..8eae71819537a99cc08454e1585844f7d77f52e3 100644 --- a/mace/ops/resize_bilinear.cc +++ b/mace/ops/resize_bilinear.cc @@ -23,4 +23,9 @@ REGISTER_OPENCL_OPERATOR(OpKeyBuilder("ResizeBilinear") .Build(), ResizeBilinearOp); +REGISTER_OPENCL_OPERATOR(OpKeyBuilder("ResizeBilinear") + .TypeConstraint("T") + .Build(), + ResizeBilinearOp); + } // namespace mace diff --git a/mace/ops/resize_bilinear_benchmark.cc b/mace/ops/resize_bilinear_benchmark.cc index 8429fd6bee0f8617e98268cd4ce97be43935a44c..d9453908c11bff15ad8ee3c996af03523d6fb7d1 100644 --- a/mace/ops/resize_bilinear_benchmark.cc +++ b/mace/ops/resize_bilinear_benchmark.cc @@ -19,18 +19,30 @@ static void ResizeBilinearBenchmark(int iters, mace::testing::StopTiming(); OpsTestNet net; - OpDefBuilder("ResizeBilinear", "ResizeBilinearBenchmark") - .Input("Input") - .Input("OutSize") - .Output("Output") - .AddIntsArg("size", {output_height, output_width}) - .Finalize(net.NewOperatorDef()); // Add input data net.AddRandomInput("Input", - {batch, channels, input_height, input_width}); + {batch, input_height, input_width, channels}); net.AddInputFromArray("OutSize", {2}, {output_height, output_width}); + if (D == DeviceType::OPENCL) { + BufferToImage(net, "Input", "InputImage", kernels::BufferType::IN_OUT); + OpDefBuilder("ResizeBilinear", "ResizeBilinearBenchmark") + .Input("InputImage") + .Input("OutSize") + .Output("OutputImage") + .AddIntsArg("size", {output_height, output_width}) + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Finalize(net.NewOperatorDef()); + } else { + OpDefBuilder("ResizeBilinear", "ResizeBilinearBenchmark") + .Input("Input") + .Input("OutSize") + .Output("Output") + .AddIntsArg("size", {output_height, output_width}) + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Finalize(net.NewOperatorDef()); + } // Warm-up for (int i = 0; i < 5; ++i) { @@ -58,9 +70,12 @@ static void ResizeBilinearBenchmark(int iters, #define BM_RESIZE_BILINEAR(N, C, H0, W0, H1, W1, TYPE) \ BM_RESIZE_BILINEAR_MACRO(N, C, H0, W0, H1, W1, TYPE, CPU); \ - BM_RESIZE_BILINEAR_MACRO(N, C, H0, W0, H1, W1, TYPE, NEON); \ BM_RESIZE_BILINEAR_MACRO(N, C, H0, W0, H1, W1, TYPE, OPENCL); +// SNPE 835 GPU: 6870us +BM_RESIZE_BILINEAR(1, 128, 120, 120, 480, 480, half); +BM_RESIZE_BILINEAR(1, 128, 120, 120, 480, 480, float); + BM_RESIZE_BILINEAR(1, 256, 7, 7, 15, 15, float); BM_RESIZE_BILINEAR(1, 256, 15, 15, 30, 30, float); BM_RESIZE_BILINEAR(1, 128, 30, 30, 60, 60, float); diff --git a/mace/ops/resize_bilinear_test.cc b/mace/ops/resize_bilinear_test.cc index 7b7cee9d97da3afd98e80ff710815f06cf1d8eef..3e50c3b4c15133238fb2e7b937430dc8d13dffdd 100644 --- a/mace/ops/resize_bilinear_test.cc +++ b/mace/ops/resize_bilinear_test.cc @@ -23,14 +23,14 @@ TEST_F(ResizeBilinearTest, CPUResizeBilinearWOAlignCorners) { // Add input data vector input(24); std::iota(begin(input), end(input), 0); - net.AddInputFromArray("Input", {1, 3, 2, 4}, input); + net.AddInputFromArray("Input", {1, 2, 4, 3}, input); net.AddInputFromArray("OutSize", {2}, {1, 2}); // Run net.RunOp(); // Check - auto expected = CreateTensor({1, 3, 1, 2}, {0, 2, 8, 10, 16, 18}); + auto expected = CreateTensor({1, 1, 2, 3}, {0, 1, 2, 6, 7, 8}); ExpectTensorNear(*expected, *net.GetOutput("Output"), 0.001); } @@ -49,14 +49,14 @@ TEST_F(ResizeBilinearTest, ResizeBilinearWAlignCorners) { // Add input data vector input(24); std::iota(begin(input), end(input), 0); - net.AddInputFromArray("Input", {1, 3, 2, 4}, input); + net.AddInputFromArray("Input", {1, 2, 4, 3}, input); net.AddInputFromArray("OutSize", {2}, {1, 2}); // Run net.RunOp(); // Check - auto expected = CreateTensor({1, 3, 1, 2}, {0, 3, 8, 11, 16, 19}); + auto expected = CreateTensor({1, 1, 2, 3}, {0, 1, 2, 9, 10, 11}); ExpectTensorNear(*expected, *net.GetOutput("Output"), 0.001); } @@ -65,6 +65,7 @@ template void TestRandomResizeBilinear() { srand(time(nullptr)); testing::internal::LogToStderr(); + for (int round = 0; round < 10; ++round) { int batch = 1 + rand() % 5; int channels = 1 + rand() % 100; @@ -72,39 +73,54 @@ void TestRandomResizeBilinear() { int width = 1 + rand() % 100; int in_height = 1 + rand() % 100; int in_width = 1 + rand() % 100; + int align_corners = rand() % 1; // Construct graph OpsTestNet net; - OpDefBuilder("ResizeBilinear", "ResizeBilinearTest") - .Input("Input") - .Input("OutSize") - .Output("Output") - .AddIntArg("align_corners", 1) - .AddIntsArg("size", {height, width}) - .Finalize(net.NewOperatorDef()); - // Add input data net.AddRandomInput("Input", - {batch, channels, in_height, in_width}); + {batch, in_height, in_width, channels}); net.AddInputFromArray("OutSize", {2}, {height, width}); - // Run - net.RunOp(D); - Tensor actual; - actual.Copy(*net.GetOutput("Output")); - + OpDefBuilder("ResizeBilinear", "ResizeBilinearTest") + .Input("Input") + .Input("OutSize") + .Output("Output") + .AddIntArg("align_corners", align_corners) + .AddIntsArg("size", {height, width}) + .Finalize(net.NewOperatorDef()); // Run on CPU net.RunOp(DeviceType::CPU); - Tensor *expected = net.GetOutput("Output"); + Tensor expected; + expected.Copy(*net.GetOutput("Output")); + + if (D == DeviceType::OPENCL) { + BufferToImage(net, "Input", "InputImage", kernels::BufferType::IN_OUT); + + OpDefBuilder("ResizeBilinear", "ResizeBilinearTest") + .Input("InputImage") + .Input("OutSize") + .Output("OutputImage") + .AddIntArg("align_corners", align_corners) + .AddIntsArg("size", {height, width}) + .Finalize(net.NewOperatorDef()); + // Run + net.RunOp(D); + ImageToBuffer(net, "OutputImage", "DeviceOutput", kernels::BufferType::IN_OUT); + } else { + // TODO support NEON + } // Check - ExpectTensorNear(*expected, actual, 0.001); + ExpectTensorNear(expected, *net.GetOutput("DeviceOutput"), 0.001); } } +/* TEST_F(ResizeBilinearTest, NEONRandomResizeBilinear) { TestRandomResizeBilinear(); } +*/ TEST_F(ResizeBilinearTest, OPENCLRandomResizeBilinear) { TestRandomResizeBilinear(); diff --git a/mace/python/tools/tf_ops_stats.py b/mace/python/tools/tf_ops_stats.py index d1016affa93b7142600dc4802eaf128788c5185a..d60487a96434bf1fbda63f0bb456a973e4c07b9b 100644 --- a/mace/python/tools/tf_ops_stats.py +++ b/mace/python/tools/tf_ops_stats.py @@ -92,6 +92,7 @@ def main(unused_args): size = tensor_values[input_name] break key = '%s(size=%s, align_corners=%s)' % (op.type, size, align_corners) + print(key) hist_inc(stats, key) elif op.type in ['AvgPool', 'MaxPool']: padding = op.get_attr('padding')