From 07927d5afbd64170176bd1374217aba96ffe412c Mon Sep 17 00:00:00 2001 From: liuqi Date: Wed, 28 Feb 2018 17:20:56 +0800 Subject: [PATCH] Add FC op. --- mace/core/operator.cc | 2 + mace/kernels/fully_connected.h | 97 ++++++++ mace/kernels/opencl/activation_opencl.cc | 2 +- mace/kernels/opencl/batch_norm_opencl.cc | 2 +- mace/kernels/opencl/buffer_to_image.cc | 5 + mace/kernels/opencl/cl/batch_norm.cl | 4 +- mace/kernels/opencl/cl/common.h | 4 +- mace/kernels/opencl/cl/conv_2d.cl | 4 +- mace/kernels/opencl/cl/conv_2d_1x1.cl | 4 +- mace/kernels/opencl/cl/conv_2d_3x3.cl | 4 +- mace/kernels/opencl/cl/depthwise_conv2d.cl | 4 +- mace/kernels/opencl/cl/fc.cl | 58 +++++ mace/kernels/opencl/cl/winograd_transform.cl | 4 +- mace/kernels/opencl/conv_2d_opencl_1x1.cc | 2 +- mace/kernels/opencl/conv_2d_opencl_3x3.cc | 2 +- mace/kernels/opencl/conv_2d_opencl_general.cc | 2 +- mace/kernels/opencl/depthwise_conv_opencl.cc | 2 +- mace/kernels/opencl/fully_connected_opencl.cc | 106 +++++++++ mace/kernels/opencl/helper.cc | 12 + mace/kernels/opencl/helper.h | 1 + mace/kernels/opencl/winograd_transform.cc | 2 +- mace/ops/fully_connected.cc | 29 +++ mace/ops/fully_connected.h | 49 ++++ mace/ops/fully_connected_benchmark.cc | 77 ++++++ mace/ops/fully_connected_test.cc | 220 ++++++++++++++++++ 25 files changed, 677 insertions(+), 21 deletions(-) create mode 100644 mace/kernels/fully_connected.h create mode 100644 mace/kernels/opencl/cl/fc.cl create mode 100644 mace/kernels/opencl/fully_connected_opencl.cc create mode 100644 mace/ops/fully_connected.cc create mode 100644 mace/ops/fully_connected.h create mode 100644 mace/ops/fully_connected_benchmark.cc create mode 100644 mace/ops/fully_connected_test.cc diff --git a/mace/core/operator.cc b/mace/core/operator.cc index e9a5f1b6..826f424b 100644 --- a/mace/core/operator.cc +++ b/mace/core/operator.cc @@ -82,6 +82,7 @@ extern void Register_WinogradTransform(OperatorRegistry *op_registry); extern void Register_WinogradInverseTransform(OperatorRegistry *op_registry); extern void Register_Reshape(OperatorRegistry *op_registry); extern void Register_Eltwise(OperatorRegistry *op_registry); +extern void Register_FullyConnected(OperatorRegistry *op_registry); OperatorRegistry::OperatorRegistry() { Register_Activation(this); @@ -107,6 +108,7 @@ OperatorRegistry::OperatorRegistry() { Register_WinogradInverseTransform(this); Register_Reshape(this); Register_Eltwise(this); + Register_FullyConnected(this); } } // namespace mace diff --git a/mace/kernels/fully_connected.h b/mace/kernels/fully_connected.h new file mode 100644 index 00000000..d95d8a48 --- /dev/null +++ b/mace/kernels/fully_connected.h @@ -0,0 +1,97 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_KERNELS_FULLY_CONNECTED_H_ +#define MACE_KERNELS_FULLY_CONNECTED_H_ + +#include "mace/core/future.h" +#include "mace/core/tensor.h" +#include "mace/core/runtime/opencl/cl2_header.h" +#include "mace/kernels/activation.h" + +namespace mace { +namespace kernels { + +struct FullyConnectedBase { + FullyConnectedBase(const ActivationType activation, + const float relux_max_limit, + const float prelu_alpha) + : activation_(activation), + relux_max_limit_(relux_max_limit), + prelu_alpha_(prelu_alpha) {} + + const ActivationType activation_; + const float relux_max_limit_; + const float prelu_alpha_; +}; + +template +struct FullyConnectedFunctor : FullyConnectedBase{ + FullyConnectedFunctor(const ActivationType activation, + const float relux_max_limit, + const float prelu_alpha) : + FullyConnectedBase(activation, relux_max_limit, prelu_alpha){} + + void operator()(const Tensor *input, + const Tensor *weight, + const Tensor *bias, + Tensor *output, + StatsFuture *future) { + + std::vector output_shape = {input->dim(0), 1, 1, weight->dim(0)}; + output->Resize(output_shape); + const index_t N = output->dim(0); + const index_t input_size = weight->dim(1); + const index_t output_size = weight->dim(0); + Tensor::MappingGuard guard_input(input); + Tensor::MappingGuard guard_weight(weight); + Tensor::MappingGuard guard_bias(bias); + Tensor::MappingGuard guard_output(output); + const T *input_ptr = input->data(); + const T *weight_ptr = weight->data(); + const T *bias_ptr = bias == nullptr ? nullptr : bias->data(); + T *output_ptr = output->mutable_data(); + +#pragma omp parallel for collapse(2) + for (int i = 0; i < N; ++i) { + for (int out_idx = 0; out_idx < output_size; ++out_idx) { + T sum = 0; + if (bias_ptr != nullptr) sum = bias_ptr[out_idx]; + index_t input_offset = i * input_size; + index_t weight_offset = out_idx * input_size; + for (int in_idx = 0; in_idx < input_size; ++in_idx) { + sum += input_ptr[input_offset] * weight_ptr[weight_offset]; + input_offset++; + weight_offset++; + } + output_ptr[i * output_size + out_idx] = sum; + } + } + + DoActivation(output_ptr, output_ptr, output->NumElements(), activation_, + relux_max_limit_, prelu_alpha_); + } +}; + + +template +struct FullyConnectedFunctor : FullyConnectedBase{ + FullyConnectedFunctor(const ActivationType activation, + const float relux_max_limit, + const float prelu_alpha) : + FullyConnectedBase(activation, relux_max_limit, prelu_alpha){} + + void operator()(const Tensor *input, + const Tensor *weight, + const Tensor *bias, + Tensor *output, + StatsFuture *future); + + cl::Kernel kernel_; +}; + +} // namespace kernels +} // namespace mace + +#endif // MACE_KERNELS_FULLY_CONNECTED_H_ diff --git a/mace/kernels/opencl/activation_opencl.cc b/mace/kernels/opencl/activation_opencl.cc index 935b3576..50ad3063 100644 --- a/mace/kernels/opencl/activation_opencl.cc +++ b/mace/kernels/opencl/activation_opencl.cc @@ -54,7 +54,7 @@ void ActivationFunctor::operator()(const Tensor *input, tuning_key_prefix = "sigmoid_opencl_kernel_"; built_options.emplace("-DUSE_SIGMOID"); break; - defeult: + default: LOG(FATAL) << "Unknown activation type: " << activation_; } kernel_ = diff --git a/mace/kernels/opencl/batch_norm_opencl.cc b/mace/kernels/opencl/batch_norm_opencl.cc index 3b63c7ac..d88fed51 100644 --- a/mace/kernels/opencl/batch_norm_opencl.cc +++ b/mace/kernels/opencl/batch_norm_opencl.cc @@ -59,7 +59,7 @@ void BatchNormFunctor::operator()(const Tensor *input, case SIGMOID: built_options.emplace("-DUSE_SIGMOID"); break; - defeult: + default: LOG(FATAL) << "Unknown activation type: " << activation_; } diff --git a/mace/kernels/opencl/buffer_to_image.cc b/mace/kernels/opencl/buffer_to_image.cc index 6a7696ae..f9d9e781 100644 --- a/mace/kernels/opencl/buffer_to_image.cc +++ b/mace/kernels/opencl/buffer_to_image.cc @@ -48,6 +48,7 @@ void BufferToImageFunctor::operator()(Tensor *buffer, kernel_name = i2b_ ? "arg_image_to_buffer" : "arg_buffer_to_image"; break; case IN_OUT_HEIGHT: + case WEIGHT_HEIGHT: kernel_name = i2b_ ? "in_out_height_image_to_buffer" : "in_out_height_buffer_to_image"; break; case IN_OUT_WIDTH: @@ -80,6 +81,10 @@ void BufferToImageFunctor::operator()(Tensor *buffer, b2f_kernel.setArg(idx++, *(static_cast(buffer->buffer()))); if (type == ARGUMENT) { b2f_kernel.setArg(idx++, static_cast(buffer->dim(0))); + } else if(type == WEIGHT_HEIGHT) { + b2f_kernel.setArg(idx++, static_cast(buffer->dim(0))); + b2f_kernel.setArg(idx++, static_cast(buffer->dim(1))); + b2f_kernel.setArg(idx++, 1); } else { b2f_kernel.setArg(idx++, static_cast(buffer->dim(1))); b2f_kernel.setArg(idx++, static_cast(buffer->dim(2))); diff --git a/mace/kernels/opencl/cl/batch_norm.cl b/mace/kernels/opencl/cl/batch_norm.cl index 995abc8c..99c00fab 100644 --- a/mace/kernels/opencl/cl/batch_norm.cl +++ b/mace/kernels/opencl/cl/batch_norm.cl @@ -9,8 +9,8 @@ __kernel void batch_norm(__read_only image2d_t input, __private const float epsilon, #endif __write_only image2d_t output, - __private const DATA_TYPE relux_max_limit, - __private const DATA_TYPE prelu_alpha) { + __private const float relux_max_limit, + __private const float prelu_alpha) { const int ch_blk = get_global_id(0); const int w = get_global_id(1); const int hb = get_global_id(2); diff --git a/mace/kernels/opencl/cl/common.h b/mace/kernels/opencl/cl/common.h index 792d2b49..13b20e05 100644 --- a/mace/kernels/opencl/cl/common.h +++ b/mace/kernels/opencl/cl/common.h @@ -22,8 +22,8 @@ __constant sampler_t SAMPLER = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | inline DATA_TYPE4 do_activation(DATA_TYPE4 in, - __private const DATA_TYPE relux_max_limit, - __private const DATA_TYPE prelu_alpha) { + __private const float relux_max_limit, + __private const float prelu_alpha) { DATA_TYPE4 out; #ifdef USE_RELU out = fmax(in, 0); diff --git a/mace/kernels/opencl/cl/conv_2d.cl b/mace/kernels/opencl/cl/conv_2d.cl index 522f28c7..35e17da8 100644 --- a/mace/kernels/opencl/cl/conv_2d.cl +++ b/mace/kernels/opencl/cl/conv_2d.cl @@ -6,8 +6,8 @@ __kernel void conv_2d(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */ __read_only image2d_t bias, /* cout%4 * cout/4 */ #endif __write_only image2d_t output, - __private const DATA_TYPE relux_max_limit, - __private const DATA_TYPE prelu_alpha, + __private const float relux_max_limit, + __private const float prelu_alpha, __private const int in_height, __private const int in_width, __private const int in_ch_blks, diff --git a/mace/kernels/opencl/cl/conv_2d_1x1.cl b/mace/kernels/opencl/cl/conv_2d_1x1.cl index de19cd77..0eecdb19 100644 --- a/mace/kernels/opencl/cl/conv_2d_1x1.cl +++ b/mace/kernels/opencl/cl/conv_2d_1x1.cl @@ -6,8 +6,8 @@ __kernel void conv_2d_1x1(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] __read_only image2d_t bias, /* cout%4 * cout/4 */ #endif __write_only image2d_t output, - __private const DATA_TYPE relux_max_limit, - __private const DATA_TYPE prelu_alpha, + __private const float relux_max_limit, + __private const float prelu_alpha, __private const int in_height, __private const int in_width, __private const int in_ch_blks, diff --git a/mace/kernels/opencl/cl/conv_2d_3x3.cl b/mace/kernels/opencl/cl/conv_2d_3x3.cl index 9403c905..d37ec7f1 100644 --- a/mace/kernels/opencl/cl/conv_2d_3x3.cl +++ b/mace/kernels/opencl/cl/conv_2d_3x3.cl @@ -6,8 +6,8 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] __read_only image2d_t bias, /* cout%4 * cout/4 */ #endif __write_only image2d_t output, - __private const DATA_TYPE relux_max_limit, - __private const DATA_TYPE prelu_alpha, + __private const float relux_max_limit, + __private const float prelu_alpha, __private const int in_height, __private const int in_width, __private const int in_ch_blks, diff --git a/mace/kernels/opencl/cl/depthwise_conv2d.cl b/mace/kernels/opencl/cl/depthwise_conv2d.cl index d9c94007..5ba07d73 100644 --- a/mace/kernels/opencl/cl/depthwise_conv2d.cl +++ b/mace/kernels/opencl/cl/depthwise_conv2d.cl @@ -7,8 +7,8 @@ __kernel void depthwise_conv2d(__read_only image2d_t input, /* [c%4 * w * c/4, h __read_only image2d_t bias, /* cout%4 * cout/4 */ #endif __write_only image2d_t output, - __private const DATA_TYPE relux_max_limit, - __private const DATA_TYPE prelu_alpha, + __private const float relux_max_limit, + __private const float prelu_alpha, __private const short in_height, __private const short in_width, __private const short in_ch_blks, diff --git a/mace/kernels/opencl/cl/fc.cl b/mace/kernels/opencl/cl/fc.cl new file mode 100644 index 00000000..ced7adf7 --- /dev/null +++ b/mace/kernels/opencl/cl/fc.cl @@ -0,0 +1,58 @@ +#include + +// output = weight * input + bias +__kernel void fc(__read_only image2d_t input, + __read_only image2d_t weight, +#ifdef BIAS + __read_only image2d_t bias, +#endif + __write_only image2d_t output, + __private const int input_height, + __private const int input_width, + __private const int input_channel, + __private const float relux_max_limit, + __private const float prelu_alpha) { + const int batch_idx = get_global_id(0); + const int out_blk_idx = get_global_id(1); + const int input_chan_blk = (input_channel + 3) >> 2; + + float4 input_value; + float4 w0, w1, w2, w3; + +#ifdef BIAS + DATA_TYPE4 result = READ_IMAGET(bias, SAMPLER, (int2)(out_blk_idx, 0)); +#else + DATA_TYPE4 result = (DATA_TYPE4)(0, 0, 0, 0); +#endif + + int2 input_coord = (int2)(0, mul24(batch_idx, input_height)); + int weight_x = 0; + for (short h_idx = 0; h_idx < input_height; ++h_idx) { + for (short w_idx = 0; w_idx < input_width; ++w_idx) { + input_coord.x = w_idx; + weight_x = (h_idx * input_width + w_idx) * input_channel; +#pragma unroll + for (short chan_idx = 0; chan_idx < input_chan_blk; ++chan_idx) { + input_value = READ_IMAGET(input, SAMPLER, input_coord); + + w0 = READ_IMAGET(weight, SAMPLER, (int2)(weight_x++, out_blk_idx)); + w1 = READ_IMAGET(weight, SAMPLER, (int2)(weight_x++, out_blk_idx)); + w2 = READ_IMAGET(weight, SAMPLER, (int2)(weight_x++, out_blk_idx)); + w3 = READ_IMAGET(weight, SAMPLER, (int2)(weight_x++, out_blk_idx)); + + result = mad(input_value.x, w0, result); + result = mad(input_value.y, w1, result); + result = mad(input_value.z, w2, result); + result = mad(input_value.w, w3, result); + + input_coord.x += input_width; + } + } + input_coord.y++; + } + +#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_PRELU) || defined(USE_TANH) || defined(USE_SIGMOID) + result = do_activation(result, relux_max_limit, prelu_alpha); +#endif + WRITE_IMAGET(output, (int2)(out_blk_idx, batch_idx), result); +} diff --git a/mace/kernels/opencl/cl/winograd_transform.cl b/mace/kernels/opencl/cl/winograd_transform.cl index daecd39f..e4b31598 100644 --- a/mace/kernels/opencl/cl/winograd_transform.cl +++ b/mace/kernels/opencl/cl/winograd_transform.cl @@ -115,8 +115,8 @@ __kernel void winograd_inverse_transform_2x2(__read_only image2d_t input, __private const int out_width, __private const int round_hw, __private const int round_w, - __private const DATA_TYPE relux_max_limit, - __private const DATA_TYPE prelu_alpha) { + __private const float relux_max_limit, + __private const float prelu_alpha) { const int width_idx = get_global_id(0); const int height_idx = get_global_id(1); const int out_channel = get_global_size(1); diff --git a/mace/kernels/opencl/conv_2d_opencl_1x1.cc b/mace/kernels/opencl/conv_2d_opencl_1x1.cc index 06d93ea6..aa4bbc6b 100644 --- a/mace/kernels/opencl/conv_2d_opencl_1x1.cc +++ b/mace/kernels/opencl/conv_2d_opencl_1x1.cc @@ -66,7 +66,7 @@ extern void Conv2dOpenclK1x1(cl::Kernel *kernel, case SIGMOID: built_options.emplace("-DUSE_SIGMOID"); break; - defeult: + default: LOG(FATAL) << "Unknown activation type: " << activation; } diff --git a/mace/kernels/opencl/conv_2d_opencl_3x3.cc b/mace/kernels/opencl/conv_2d_opencl_3x3.cc index 47fb5605..3a185faf 100644 --- a/mace/kernels/opencl/conv_2d_opencl_3x3.cc +++ b/mace/kernels/opencl/conv_2d_opencl_3x3.cc @@ -61,7 +61,7 @@ extern void Conv2dOpenclK3x3(cl::Kernel *kernel, case SIGMOID: built_options.emplace("-DUSE_SIGMOID"); break; - defeult: + default: LOG(FATAL) << "Unknown activation type: " << activation; } diff --git a/mace/kernels/opencl/conv_2d_opencl_general.cc b/mace/kernels/opencl/conv_2d_opencl_general.cc index 7414abc6..30a1a751 100644 --- a/mace/kernels/opencl/conv_2d_opencl_general.cc +++ b/mace/kernels/opencl/conv_2d_opencl_general.cc @@ -61,7 +61,7 @@ extern void Conv2dOpencl(cl::Kernel *kernel, case SIGMOID: built_options.emplace("-DUSE_SIGMOID"); break; - defeult: + default: LOG(FATAL) << "Unknown activation type: " << activation; } diff --git a/mace/kernels/opencl/depthwise_conv_opencl.cc b/mace/kernels/opencl/depthwise_conv_opencl.cc index e4b615b0..4365c02f 100644 --- a/mace/kernels/opencl/depthwise_conv_opencl.cc +++ b/mace/kernels/opencl/depthwise_conv_opencl.cc @@ -78,7 +78,7 @@ void DepthwiseConv2d(cl::Kernel *kernel, case SIGMOID: built_options.emplace("-DUSE_SIGMOID"); break; - defeult: + default: LOG(FATAL) << "Unknown activation type: " << activation; } diff --git a/mace/kernels/opencl/fully_connected_opencl.cc b/mace/kernels/opencl/fully_connected_opencl.cc new file mode 100644 index 00000000..06dc74e2 --- /dev/null +++ b/mace/kernels/opencl/fully_connected_opencl.cc @@ -0,0 +1,106 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/kernels/fully_connected.h" +#include "mace/core/runtime/opencl/opencl_runtime.h" +#include "mace/kernels/opencl/helper.h" +#include "mace/utils/tuner.h" + +namespace mace { +namespace kernels { + +template +void FullyConnectedFunctor::operator()( + const Tensor *input, + const Tensor *weight, + const Tensor *bias, + Tensor *output, + StatsFuture *future) { + + std::vector output_shape = {input->dim(0), 1, 1, weight->dim(0)}; + std::vector output_image_shape; + CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, output_image_shape); + output->ResizeImage(output_shape, output_image_shape); + + const index_t batch = output->dim(0); + const index_t output_size = output->dim(3); + + const index_t output_blocks = RoundUpDiv4(output_size); + + if (kernel_.get() == nullptr) { + auto runtime = OpenCLRuntime::Global(); + std::set built_options; + auto dt = DataTypeToEnum::value; + std::string kernel_name = MACE_OBFUSCATE_SYMBOL("fc"); + built_options.emplace("-Dfc=" + kernel_name); + built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt)); + built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt)); + if (bias != nullptr) { + built_options.emplace("-DBIAS"); + } + switch (activation_) { + case NOOP: + break; + case RELU: + built_options.emplace("-DUSE_RELU"); + break; + case RELUX: + built_options.emplace("-DUSE_RELUX"); + break; + case PRELU: + built_options.emplace("-DUSE_PRELU"); + break; + case TANH: + built_options.emplace("-DUSE_TANH"); + break; + case SIGMOID: + built_options.emplace("-DUSE_SIGMOID"); + break; + default: + LOG(FATAL) << "Unknown activation type: " << activation_; + } + kernel_ = runtime->BuildKernel("fc", kernel_name, built_options); + + uint32_t idx = 0; + kernel_.setArg(idx++, + *(static_cast(input->buffer()))); + kernel_.setArg(idx++, + *(static_cast(weight->buffer()))); + if (bias != nullptr) { + kernel_.setArg(idx++, + *(static_cast(bias->buffer()))); + } + kernel_.setArg(idx++, + *(static_cast(output->buffer()))); + kernel_.setArg(idx++, static_cast(input->dim(1))); + kernel_.setArg(idx++, static_cast(input->dim(2))); + kernel_.setArg(idx++, static_cast(input->dim(3))); + // FIXME handle flexable data type: half not supported + kernel_.setArg(idx++, relux_max_limit_); + kernel_.setArg(idx++, prelu_alpha_); + } + + const uint32_t gws[2] = { + static_cast(batch), + static_cast(output_blocks), + }; + const std::vector lws = {16, 64, 1}; + std::stringstream ss; + ss << "fc_opencl_kernel_" + << output->dim(0) << "_" + << output->dim(1) << "_" + << output->dim(2) << "_" + << output->dim(3); + TuningOrRun2DKernel(kernel_, ss.str(), gws, lws, future); + +}; + +template +struct FullyConnectedFunctor; + +template +struct FullyConnectedFunctor; + +} // namespace kernels +} // namespace mace diff --git a/mace/kernels/opencl/helper.cc b/mace/kernels/opencl/helper.cc index a9923a62..ee141adb 100644 --- a/mace/kernels/opencl/helper.cc +++ b/mace/kernels/opencl/helper.cc @@ -73,6 +73,15 @@ void CalInOutWidthImageShape(const std::vector &shape, /* NHWC */ image_shape[1] = shape[0] * shape[1]; } +// [W, (H + 3) / 4] +void CalWeightHeightImageShape(const std::vector &shape, /* HW */ + std::vector &image_shape) { + MACE_CHECK(shape.size() == 2); + image_shape.resize(2); + image_shape[0] = shape[1]; + image_shape[1] = RoundUpDiv4(shape[0]); +} + void CalImage2DShape(const std::vector &shape, /* NHWC */ const BufferType type, std::vector &image_shape) { @@ -98,6 +107,9 @@ void CalImage2DShape(const std::vector &shape, /* NHWC */ case WINOGRAD_FILTER: CalWinogradFilterImageShape(shape, image_shape); break; + case WEIGHT_HEIGHT: + CalWeightHeightImageShape(shape, image_shape); + break; default: LOG(FATAL) << "Mace not supported yet."; } diff --git a/mace/kernels/opencl/helper.h b/mace/kernels/opencl/helper.h index 89278592..cc68466f 100644 --- a/mace/kernels/opencl/helper.h +++ b/mace/kernels/opencl/helper.h @@ -24,6 +24,7 @@ enum BufferType { IN_OUT_WIDTH = 4, WINOGRAD_FILTER = 5, DW_CONV2D_FILTER = 6, + WEIGHT_HEIGHT = 7, }; void CalImage2DShape(const std::vector &shape, /* NHWC */ diff --git a/mace/kernels/opencl/winograd_transform.cc b/mace/kernels/opencl/winograd_transform.cc index 1dc949d6..bacbdb63 100644 --- a/mace/kernels/opencl/winograd_transform.cc +++ b/mace/kernels/opencl/winograd_transform.cc @@ -101,7 +101,7 @@ void WinogradInverseTransformFunctor::operator()(const Te case SIGMOID: built_options.emplace("-DUSE_SIGMOID"); break; - defeult: + default: LOG(FATAL) << "Unknown activation type: " << activation_; } diff --git a/mace/ops/fully_connected.cc b/mace/ops/fully_connected.cc new file mode 100644 index 00000000..9b733794 --- /dev/null +++ b/mace/ops/fully_connected.cc @@ -0,0 +1,29 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/ops/fully_connected.h" + +namespace mace { + +void Register_FullyConnected(OperatorRegistry *op_registry) { + REGISTER_OPERATOR(op_registry, OpKeyBuilder("FC") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + FullyConnectedOp); + + REGISTER_OPERATOR(op_registry, OpKeyBuilder("FC") + .Device(DeviceType::OPENCL) + .TypeConstraint("T") + .Build(), + FullyConnectedOp); + + REGISTER_OPERATOR(op_registry, OpKeyBuilder("FC") + .Device(DeviceType::OPENCL) + .TypeConstraint("T") + .Build(), + FullyConnectedOp); +} + +} // namespace mace diff --git a/mace/ops/fully_connected.h b/mace/ops/fully_connected.h new file mode 100644 index 00000000..0ee90e2b --- /dev/null +++ b/mace/ops/fully_connected.h @@ -0,0 +1,49 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_OPS_FULLY_CONNECTED_H_ +#define MACE_OPS_FULLY_CONNECTED_H_ + +#include "mace/core/operator.h" +#include "mace/kernels/fully_connected.h" + +namespace mace { + +template +class FullyConnectedOp : public Operator { + public: + FullyConnectedOp(const OperatorDef &operator_def, Workspace *ws) + : Operator(operator_def, ws), + functor_( + kernels::StringToActivationType( + OperatorBase::GetSingleArgument("activation", + "NOOP")), + OperatorBase::GetSingleArgument("max_limit", 0.0f), + OperatorBase::GetSingleArgument("alpha", 0.0f)) {} + + bool Run(StatsFuture *future) override { + const Tensor *input = this->Input(INPUT); + const Tensor *weight = this->Input(WEIGHT); + const Tensor *bias = this->InputSize() >= 3 ? this->Input(BIAS) : nullptr; + Tensor *output = this->Output(OUTPUT); + + const index_t input_size = input->dim(1) * input->dim(2) * input->dim(3); + MACE_CHECK(input_size == weight->dim(1) && weight->dim(0) == bias->dim(0)) + << "The size of Input, Weight and Bias don't match."; + + functor_(input, weight, bias, output, future); + return true; + } + + private: + kernels::FullyConnectedFunctor functor_; + + protected: + OP_INPUT_TAGS(INPUT, WEIGHT, BIAS); + OP_OUTPUT_TAGS(OUTPUT); +}; + +} // namespace mace + +#endif // MACE_OPS_FULLY_CONNECTED_H_ diff --git a/mace/ops/fully_connected_benchmark.cc b/mace/ops/fully_connected_benchmark.cc new file mode 100644 index 00000000..04776899 --- /dev/null +++ b/mace/ops/fully_connected_benchmark.cc @@ -0,0 +1,77 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include +#include "mace/core/operator.h" +#include "mace/core/testing/test_benchmark.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +template +static void FCBenchmark( + int iters, int batch, int height, int width, int channel, int out_channel) { + mace::testing::StopTiming(); + + OpsTestNet net; + + // Add input data + net.AddRandomInput("Input", {batch, height, width, channel}); + net.AddRandomInput("Weight", {out_channel, height * width * channel}); + net.AddRandomInput("Bias", {out_channel}); + + if (D == DeviceType::OPENCL) { + BufferToImage(net, "Input", "InputImage", + kernels::BufferType::IN_OUT_CHANNEL); + BufferToImage(net, "Weight", "WeightImage", + kernels::BufferType::WEIGHT_HEIGHT); + BufferToImage(net, "Bias", "BiasImage", + kernels::BufferType::ARGUMENT); + + OpDefBuilder("FC", "FullyConnectedTest") + .Input("InputImage") + .Input("WeightImage") + .Input("BiasImage") + .Output("OutputImage") + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Finalize(net.NewOperatorDef()); + } else { + OpDefBuilder("FC", "FullyConnectedTest") + .Input("Input") + .Input("Weight") + .Input("Bias") + .Output("Output") + .Finalize(net.NewOperatorDef()); + } + + // Warm-up + for (int i = 0; i < 5; ++i) { + net.RunOp(D); + } + net.Sync(); + + mace::testing::StartTiming(); + while (iters--) { + net.RunOp(D); + } + net.Sync(); +} + +#define BM_FC_MACRO(N, H, W, C, OC, TYPE, DEVICE) \ + static void BM_FC_##N##_##H##_##W##_##C##_##OC##_##TYPE##_##DEVICE(int iters) { \ + const int64_t macc = static_cast(iters) * N * C * H * W * OC + OC; \ + const int64_t tot = static_cast(iters) * (N + OC) * C * H * W + OC; \ + mace::testing::MaccProcessed(macc); \ + mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ + FCBenchmark(iters, N, H, W, C, OC); \ + } \ + BENCHMARK(BM_FC_##N##_##H##_##W##_##C##_##OC##_##TYPE##_##DEVICE) + +#define BM_FC(N, H, W, C, OC) \ + BM_FC_MACRO(N, H, W, C, OC, float, CPU); \ + BM_FC_MACRO(N, H, W, C, OC, float, OPENCL); \ + BM_FC_MACRO(N, H, W, C, OC, half, OPENCL); + +BM_FC(1, 16, 16, 32, 32); +BM_FC(1, 8, 8, 32, 1000); +} // namespace mace diff --git a/mace/ops/fully_connected_test.cc b/mace/ops/fully_connected_test.cc new file mode 100644 index 00000000..a945f41a --- /dev/null +++ b/mace/ops/fully_connected_test.cc @@ -0,0 +1,220 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include +#include "mace/core/operator.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { + +class FullyConnectedOpTest : public OpsTestBase {}; + +template +void Simple(const std::vector &input_shape, + const std::vector &input_value, + const std::vector &weight_shape, + const std::vector &weight_value, + const std::vector &bias_shape, + const std::vector &bias_value, + const std::vector &output_shape, + const std::vector &output_value) { + OpsTestNet net; + + // Add input data + net.AddInputFromArray("Input", input_shape, input_value); + net.AddInputFromArray("Weight", weight_shape, weight_value); + net.AddInputFromArray("Bias", bias_shape, bias_value); + + if (D == DeviceType::OPENCL) { + BufferToImage(net, "Input", "InputImage", + kernels::BufferType::IN_OUT_CHANNEL); + BufferToImage(net, "Weight", "WeightImage", + kernels::BufferType::WEIGHT_HEIGHT); + BufferToImage(net, "Bias", "BiasImage", + kernels::BufferType::ARGUMENT); + + OpDefBuilder("FC", "FullyConnectedTest") + .Input("InputImage") + .Input("WeightImage") + .Input("BiasImage") + .Output("OutputImage") + .Finalize(net.NewOperatorDef()); + // Run + net.RunOp(D); + + // Transfer output + ImageToBuffer(net, "OutputImage", "Output", + kernels::BufferType::IN_OUT_CHANNEL); + } else { + OpDefBuilder("FC", "FullyConnectedTest") + .Input("Input") + .Input("Weight") + .Input("Bias") + .Output("Output") + .Finalize(net.NewOperatorDef()); + // Run + net.RunOp(D); + } + + // Check + auto expected = + CreateTensor(output_shape, output_value); + + ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5); +} + +TEST_F(FullyConnectedOpTest, SimpleCPU) { + Simple({1, 2, 2, 2}, + {1, 2, 3, 4, 5, 6, 7, 8}, + {1, 8}, + {1, 2, 3, 4, 5, 6, 7, 8}, + {1}, {2}, + {1, 1, 1, 1}, {206}); + Simple({1, 1, 2, 5}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + {2, 10}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 10, 20, 30, 40, 50, 60, 70, 80, 90, 100}, + {2}, {2, 3}, + {1, 1, 1, 2}, {387, 3853}); + Simple({1, 1, 2, 3}, + {1, 2, 3, 4, 5, 6}, + {5, 6}, + {1, 2, 3, 4, 5, 6, + 10, 20, 30, 40, 50, 60, + 1, 2, 3, 4, 5, 6, + 10, 20, 30, 40, 50, 60, + 1, 2, 3, 4, 5, 6}, + {5}, {1, 2, 3, 4, 5}, + {1, 1, 1, 5}, {92, 912, 94, 914, 96}); +} + +TEST_F(FullyConnectedOpTest, SimpleCPUWithBatch) { + Simple({2, 1, 2, 2}, + {1, 2, 3, 4, 5, 6, 7, 8}, + {1, 4}, + {1, 2, 3, 4}, + {1}, {2}, + {2, 1, 1, 1}, {32, 72}); +} + +TEST_F(FullyConnectedOpTest, SimpleOPENCL) { + Simple({1, 2, 2, 2}, + {1, 2, 3, 4, 5, 6, 7, 8}, + {1, 8}, + {1, 2, 3, 4, 5, 6, 7, 8}, + {1}, {2}, + {1, 1, 1, 1}, {206}); + Simple({1, 1, 2, 5}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + {2, 10}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 10, 20, 30, 40, 50, 60, 70, 80, 90, 100}, + {2}, {2, 3}, + {1, 1, 1, 2}, {387, 3853}); + Simple({1, 1, 2, 3}, + {1, 2, 3, 4, 5, 6}, + {5, 6}, + {1, 2, 3, 4, 5, 6, + 10, 20, 30, 40, 50, 60, + 1, 2, 3, 4, 5, 6, + 10, 20, 30, 40, 50, 60, + 1, 2, 3, 4, 5, 6}, + {5}, {1, 2, 3, 4, 5}, + {1, 1, 1, 5}, {92, 912, 94, 914, 96}); +} + +TEST_F(FullyConnectedOpTest, SimpleGPUWithBatch) { + Simple({2, 1, 2, 2}, + {1, 2, 3, 4, 5, 6, 7, 8}, + {1, 4}, + {1, 2, 3, 4}, + {1}, {2}, + {2, 1, 1, 1}, {32, 72}); +} + +template +void Complex(const index_t batch, + const index_t height, + const index_t width, + const index_t channels, + const index_t out_channel) { + srand(time(NULL)); + + // Construct graph + OpsTestNet net; + OpDefBuilder("FC", "FullyConnectedTest") + .Input("Input") + .Input("Weight") + .Input("Bias") + .Output("Output") + .Finalize(net.NewOperatorDef()); + + // Add input data + net.AddRandomInput( + "Input", {batch, height, width, channels}); + net.AddRandomInput( + "Weight", {out_channel, height * width * channels}); + net.AddRandomInput( + "Bias", {out_channel}); + + // run cpu + net.RunOp(); + + // Check + Tensor expected; + expected.Copy(*net.GetOutput("Output")); + + // Run on opencl + BufferToImage(net, "Input", "InputImage", + kernels::BufferType::IN_OUT_CHANNEL); + BufferToImage(net, "Weight", "WeightImage", + kernels::BufferType::WEIGHT_HEIGHT); + BufferToImage(net, "Bias", "BiasImage", + kernels::BufferType::ARGUMENT); + + OpDefBuilder("FC", "FullyConnectedTest") + .Input("InputImage") + .Input("WeightImage") + .Input("BiasImage") + .Output("OutputImage") + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Finalize(net.NewOperatorDef()); + + // Run on opencl + net.RunOp(DeviceType::OPENCL); + + ImageToBuffer(net, "OutputImage", "OPENCLOutput", + kernels::BufferType::IN_OUT_CHANNEL); + if (DataTypeToEnum::value == DataType::DT_HALF) { + ExpectTensorNear(expected, *net.GetOutput("OPENCLOutput"), 1); + } else { + ExpectTensorNear(expected, *net.GetOutput("OPENCLOutput"), 1e-3); + } +} + +TEST_F(FullyConnectedOpTest, OPENCLAlignedWithoutBatch) { + Complex(1, 16, 16, 32, 16); + Complex(1, 16, 32, 32, 32); +} +TEST_F(FullyConnectedOpTest, OPENCLUnAlignedWithoutBatch) { + Complex(1, 13, 11, 11, 17); + Complex(1, 23, 29, 23, 113); +} +TEST_F(FullyConnectedOpTest, OPENCLUnAlignedWithBatch) { + Complex(16, 11, 13, 23, 17); + Complex(31, 13, 11, 29, 113); +} +TEST_F(FullyConnectedOpTest, OPENCLHalfAlignedWithoutBatch) { + Complex(1, 16, 16, 32, 16); + Complex(1, 16, 32, 32, 32); +} +TEST_F(FullyConnectedOpTest, OPENCLHalfUnAlignedWithBatch) { + Complex(2, 11, 13, 61, 17); + Complex(16, 13, 12, 31, 113); + Complex(31, 21, 11, 23, 103); +} + +} + -- GitLab