diff --git a/mace/kernels/fully_connected.h b/mace/kernels/fully_connected.h index 740faaccb9a7e1cd41a699ada5c2aed39ed79a02..b8a740215f3c3d23a85cc4d55184ab0b65e4c13e 100644 --- a/mace/kernels/fully_connected.h +++ b/mace/kernels/fully_connected.h @@ -9,24 +9,30 @@ #include "mace/core/runtime/opencl/cl2_header.h" #include "mace/core/tensor.h" #include "mace/kernels/activation.h" +#include "mace/kernels/opencl/helper.h" namespace mace { namespace kernels { struct FullyConnectedBase { - FullyConnectedBase(const ActivationType activation, + FullyConnectedBase(const BufferType weight_type, + const ActivationType activation, const float relux_max_limit) - : activation_(activation), relux_max_limit_(relux_max_limit) {} + : weight_type_(weight_type), + activation_(activation), + relux_max_limit_(relux_max_limit) {} + const int weight_type_; const ActivationType activation_; const float relux_max_limit_; }; template struct FullyConnectedFunctor : FullyConnectedBase { - FullyConnectedFunctor(const ActivationType activation, + FullyConnectedFunctor(const BufferType weight_type, + const ActivationType activation, const float relux_max_limit) - : FullyConnectedBase(activation, relux_max_limit) {} + : FullyConnectedBase(weight_type, activation, relux_max_limit) {} void operator()(const Tensor *input, const Tensor *weight, @@ -70,9 +76,10 @@ struct FullyConnectedFunctor : FullyConnectedBase { template struct FullyConnectedFunctor : FullyConnectedBase { - FullyConnectedFunctor(const ActivationType activation, + FullyConnectedFunctor(const BufferType weight_type, + const ActivationType activation, const float relux_max_limit) - : FullyConnectedBase(activation, relux_max_limit) {} + : FullyConnectedBase(weight_type, activation, relux_max_limit) {} void operator()(const Tensor *input, const Tensor *weight, @@ -81,6 +88,8 @@ struct FullyConnectedFunctor : FullyConnectedBase { StatsFuture *future); cl::Kernel kernel_; + std::vector gws_; + std::vector lws_; }; } // namespace kernels diff --git a/mace/kernels/opencl/buffer_to_image.cc b/mace/kernels/opencl/buffer_to_image.cc index 7b48446461cf65bc0604384ae5f302cacaf67bc2..b0fa30a5cf146fd0da2ccd0ea9bc9ea419349f32 100644 --- a/mace/kernels/opencl/buffer_to_image.cc +++ b/mace/kernels/opencl/buffer_to_image.cc @@ -49,6 +49,7 @@ void BufferToImageFunctor::operator()( : "in_out_height_buffer_to_image"; break; case IN_OUT_WIDTH: + case WEIGHT_WIDTH: MACE_CHECK(!i2b_) << "IN_OUT_WIDTH only support buffer to image now"; kernel_name = "in_out_width_buffer_to_image"; break; @@ -88,7 +89,7 @@ void BufferToImageFunctor::operator()( } if (type == ARGUMENT) { b2f_kernel.setArg(idx++, static_cast(buffer->dim(0))); - } else if (type == WEIGHT_HEIGHT) { + } else if (type == WEIGHT_HEIGHT || type == WEIGHT_WIDTH) { b2f_kernel.setArg(idx++, static_cast(buffer->dim(0))); b2f_kernel.setArg(idx++, static_cast(buffer->dim(1))); b2f_kernel.setArg(idx++, 1); diff --git a/mace/kernels/opencl/cl/buffer_to_image.cl b/mace/kernels/opencl/cl/buffer_to_image.cl index 42ab6617db1f6294513ae47b5a770072d8a59a0f..781d21e363a2a9320999ab4dcc933ffab5fcc0fa 100644 --- a/mace/kernels/opencl/cl/buffer_to_image.cl +++ b/mace/kernels/opencl/cl/buffer_to_image.cl @@ -318,10 +318,11 @@ __kernel void in_out_width_buffer_to_image(__global const DATA_TYPE *input, /* n __write_only image2d_t output) { int w = get_global_id(0); int h = get_global_id(1); + const int width_blks = (width + 3) / 4; const int batch_idx = h / height; const int height_idx = h % height; - const int width_idx = (w % width) << 2; - const int channel_idx = w / width; + const int width_idx = (w % width_blks) << 2; + const int channel_idx = w / width_blks; const int offset = input_offset + ((batch_idx * height + height_idx) * width + width_idx) * channels + channel_idx; diff --git a/mace/kernels/opencl/cl/fully_connected.cl b/mace/kernels/opencl/cl/fully_connected.cl index 89264d82b890ca0effd9de53ea935a8821506f19..217224db93b5b7a9bf57e46890e3570b1cb62ed2 100644 --- a/mace/kernels/opencl/cl/fully_connected.cl +++ b/mace/kernels/opencl/cl/fully_connected.cl @@ -4,7 +4,7 @@ __kernel void fully_connected(__read_only image2d_t input, __read_only image2d_t weight, #ifdef BIAS - __read_only image2d_t bias, + __read_only image2d_t bias, #endif __write_only image2d_t output, __private const int input_height, @@ -55,3 +55,76 @@ __kernel void fully_connected(__read_only image2d_t input, #endif WRITE_IMAGET(output, (int2)(out_blk_idx, batch_idx), result); } + +// output = weight * input + bias +__kernel void fully_connected_width(__read_only image2d_t input, + __read_only image2d_t weight, +#ifdef BIAS + __read_only image2d_t bias, +#endif + __write_only image2d_t output, + __local float *intermediate_output, + __private const int input_height, + __private const int input_width, + __private const short in_chan_blks, + __private const float relux_max_limit) { + const int inter_out_idx = get_global_id(0); + const int width_blk_idx = get_global_id(1); + const int width_blk_count = get_global_size(1); + const int out_blk_idx = get_global_id(2); + + const short in_outer_size = mul24(input_width, in_chan_blks); + const short weight_y = mad24(out_blk_idx, 4, inter_out_idx); + + int2 input_coord, weight_coord; + DATA_TYPE4 in, w; + DATA_TYPE sum = 0.0; + + input_coord = (int2)(0, 0); + + for (short h_idx = 0; h_idx < input_height; ++h_idx) { + short weight_x_base = mul24(h_idx, in_outer_size); + for (short w_idx = (short)width_blk_idx; w_idx < input_width; w_idx += width_blk_count) { + short weight_x = mad24(w_idx, in_chan_blks, weight_x_base); + weight_coord = (int2)(weight_x, weight_y); + input_coord.x = w_idx; +#pragma unroll + for (short chan_idx = 0; chan_idx < in_chan_blks; ++chan_idx) { + in = READ_IMAGET(input, SAMPLER, input_coord); + + w = READ_IMAGET(weight, SAMPLER, weight_coord); + + sum += dot(in, w); + + input_coord.x += input_width; + weight_coord.x += 1; + } + } + input_coord.y++; + } + + const short inter_out_offset = mad24(get_local_id(1), 4, get_local_id(0)); + const short local_width_blk_size = (short)get_local_size(1); + const short local_size = mul24((short)get_local_size(0), + local_width_blk_size); + short inter_idx = mad24((short)get_local_id(2), local_size, inter_out_offset); + intermediate_output[inter_idx] = sum; + + if (inter_out_offset == 0) { +#ifdef BIAS + DATA_TYPE4 result = READ_IMAGET(bias, SAMPLER, (int2)(out_blk_idx, 0)); +#else + DATA_TYPE4 result = (DATA_TYPE4)(0, 0, 0, 0); +#endif + + for(short i = 0; i < local_width_blk_size; ++i) { + result += vload4(0, intermediate_output+inter_idx); + inter_idx += 4; + } + +#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) + result = do_activation(result, relux_max_limit); +#endif + WRITE_IMAGET(output, (int2)(out_blk_idx, 0), result); + } +} diff --git a/mace/kernels/opencl/fully_connected_opencl.cc b/mace/kernels/opencl/fully_connected_opencl.cc index 0e208cf4706a3f28e1aaa86feae886f1c3969ab1..ca07b98934986cf32692357fde529561393b2380 100644 --- a/mace/kernels/opencl/fully_connected_opencl.cc +++ b/mace/kernels/opencl/fully_connected_opencl.cc @@ -3,31 +3,105 @@ // #include "mace/kernels/fully_connected.h" -#include "mace/core/runtime/opencl/opencl_runtime.h" -#include "mace/kernels/opencl/helper.h" #include "mace/utils/tuner.h" namespace mace { namespace kernels { template -void FullyConnectedFunctor::operator()( - const Tensor *input, - const Tensor *weight, - const Tensor *bias, - Tensor *output, - StatsFuture *future) { - std::vector output_shape = {input->dim(0), 1, 1, weight->dim(0)}; - std::vector output_image_shape; - CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, output_image_shape); - output->ResizeImage(output_shape, output_image_shape); +void FCWXKernel(cl::Kernel *kernel, + const Tensor *input, + const Tensor *weight, + const Tensor *bias, + Tensor *output, + const ActivationType activation, + std::vector &gws, + std::vector &lws, + const float relux_max_limit, + StatsFuture *future) { + MACE_CHECK(input->dim(3) % 4 == 0) + << "FC width kernel only support input with 4x channel."; + auto runtime = OpenCLRuntime::Global(); + + if (kernel->get() == nullptr) { + std::set built_options; + auto dt = DataTypeToEnum::value; + std::string kernel_name = MACE_OBFUSCATE_SYMBOL("fully_connected"); + kernel_name = MACE_OBFUSCATE_SYMBOL("fully_connected_width"); + built_options.emplace("-Dfully_connected_width=" + kernel_name); + built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt)); + built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt)); + if (bias != nullptr) { + built_options.emplace("-DBIAS"); + } + switch (activation) { + case NOOP:break; + case RELU:built_options.emplace("-DUSE_RELU"); + break; + case RELUX:built_options.emplace("-DUSE_RELUX"); + break; + case TANH:built_options.emplace("-DUSE_TANH"); + break; + case SIGMOID:built_options.emplace("-DUSE_SIGMOID"); + break; + default:LOG(FATAL) << "Unknown activation type: " << activation; + } + + *kernel = + runtime->BuildKernel("fully_connected", kernel_name, built_options); + + const index_t batch = output->dim(0); + const index_t output_size = output->dim(3); + const index_t output_blocks = RoundUpDiv4(output_size); - const index_t batch = output->dim(0); - const index_t output_size = output->dim(3); + gws = {4, 8, static_cast(output_blocks)}; - const index_t output_blocks = RoundUpDiv4(output_size); + const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(*kernel); + const uint32_t inter_local_blks = kwg_size / (gws[0] * gws[1]); + lws = {gws[0], gws[1], inter_local_blks}; - if (kernel_.get() == nullptr) { + uint32_t idx = 0; + kernel->setArg(idx++, *(input->opencl_image())); + kernel->setArg(idx++, *(weight->opencl_image())); + if (bias != nullptr) { + kernel->setArg(idx++, *(bias->opencl_image())); + } + kernel->setArg(idx++, *(output->opencl_image())); + kernel->setArg(idx++, (lws[0] * lws[1] * lws[2] * sizeof(float)), nullptr); + kernel->setArg(idx++, static_cast(input->dim(1))); + kernel->setArg(idx++, static_cast(input->dim(2))); + kernel->setArg(idx++, static_cast(RoundUpDiv4(input->dim(3)))); + kernel->setArg(idx++, relux_max_limit); + } + cl::Event event; + cl_int error = runtime->command_queue().enqueueNDRangeKernel( + *kernel, cl::NullRange, cl::NDRange(gws[0], gws[1], gws[2]), + cl::NDRange(lws[0], lws[1], lws[2]), nullptr, &event); + MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error; + + if (future != nullptr) { + future->wait_fn = [runtime, event](CallStats *stats) { + event.wait(); + if (stats != nullptr) { + runtime->GetCallStats(event, stats); + } + }; + } + +} + +template +void FCWTXKernel(cl::Kernel *kernel, + const Tensor *input, + const Tensor *weight, + const Tensor *bias, + Tensor *output, + const ActivationType activation, + std::vector &gws, + std::vector &lws, + const float relux_max_limit, + StatsFuture *future) { + if (kernel->get() == nullptr) { auto runtime = OpenCLRuntime::Global(); std::set built_options; auto dt = DataTypeToEnum::value; @@ -38,7 +112,7 @@ void FullyConnectedFunctor::operator()( if (bias != nullptr) { built_options.emplace("-DBIAS"); } - switch (activation_) { + switch (activation) { case NOOP: break; case RELU: @@ -54,33 +128,61 @@ void FullyConnectedFunctor::operator()( built_options.emplace("-DUSE_SIGMOID"); break; default: - LOG(FATAL) << "Unknown activation type: " << activation_; + LOG(FATAL) << "Unknown activation type: " << activation; } - kernel_ = + *kernel = runtime->BuildKernel("fully_connected", kernel_name, built_options); uint32_t idx = 0; - kernel_.setArg(idx++, *(input->opencl_image())); - kernel_.setArg(idx++, *(weight->opencl_image())); + kernel->setArg(idx++, *(input->opencl_image())); + kernel->setArg(idx++, *(weight->opencl_image())); if (bias != nullptr) { - kernel_.setArg(idx++, *(bias->opencl_image())); + kernel->setArg(idx++, *(bias->opencl_image())); } - kernel_.setArg(idx++, *(output->opencl_image())); - kernel_.setArg(idx++, static_cast(input->dim(1))); - kernel_.setArg(idx++, static_cast(input->dim(2))); - kernel_.setArg(idx++, static_cast(input->dim(3))); + kernel->setArg(idx++, *(output->opencl_image())); + kernel->setArg(idx++, static_cast(input->dim(1))); + kernel->setArg(idx++, static_cast(input->dim(2))); + kernel->setArg(idx++, static_cast(input->dim(3))); // FIXME handle flexable data type: half not supported - kernel_.setArg(idx++, relux_max_limit_); + kernel->setArg(idx++, relux_max_limit); + + const index_t batch = output->dim(0); + const index_t output_size = output->dim(3); + + const index_t output_blocks = RoundUpDiv4(output_size); + + gws = { + static_cast(batch), static_cast(output_blocks), + }; + lws = {16, 64, 1}; } - const uint32_t gws[2] = { - static_cast(batch), static_cast(output_blocks), - }; - const std::vector lws = {16, 64, 1}; std::stringstream ss; ss << "fc_opencl_kernel_" << output->dim(0) << "_" << output->dim(1) << "_" << output->dim(2) << "_" << output->dim(3); - TuningOrRun2DKernel(kernel_, ss.str(), gws, lws, future); + TuningOrRun2DKernel(*kernel, ss.str(), gws.data(), lws, future); + +} + +template +void FullyConnectedFunctor::operator()( + const Tensor *input, + const Tensor *weight, + const Tensor *bias, + Tensor *output, + StatsFuture *future) { + std::vector output_shape = {input->dim(0), 1, 1, weight->dim(0)}; + std::vector output_image_shape; + CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, output_image_shape); + output->ResizeImage(output_shape, output_image_shape); + + if (weight_type_ == BufferType::WEIGHT_HEIGHT) { + FCWTXKernel(&kernel_, input, weight, bias, output, + activation_, gws_, lws_, relux_max_limit_, future); + } else { + FCWXKernel(&kernel_, input, weight, bias, output, + activation_, gws_, lws_, relux_max_limit_, future); + } }; template struct FullyConnectedFunctor; diff --git a/mace/kernels/opencl/helper.cc b/mace/kernels/opencl/helper.cc index 791db16742408f69520f128be435417ab6056be4..3f41966299f2b8ec4b61e65d1191eaef1d94b533 100644 --- a/mace/kernels/opencl/helper.cc +++ b/mace/kernels/opencl/helper.cc @@ -84,6 +84,15 @@ void CalWeightHeightImageShape(const std::vector &shape, /* HW */ image_shape[1] = RoundUpDiv4(shape[0]); } +// [(W + 3) / 4, H] +void CalWeightWidthImageShape(const std::vector &shape, /* HW */ + std::vector &image_shape) { + MACE_CHECK(shape.size() == 2); + image_shape.resize(2); + image_shape[0] = RoundUpDiv4(shape[1]); + image_shape[1] = shape[0]; +} + void CalImage2DShape(const std::vector &shape, /* NHWC */ const BufferType type, std::vector &image_shape) { @@ -112,6 +121,9 @@ void CalImage2DShape(const std::vector &shape, /* NHWC */ case WEIGHT_HEIGHT: CalWeightHeightImageShape(shape, image_shape); break; + case WEIGHT_WIDTH: + CalWeightWidthImageShape(shape, image_shape); + break; default: LOG(FATAL) << "Mace not supported yet."; } diff --git a/mace/kernels/opencl/helper.h b/mace/kernels/opencl/helper.h index 19cc6ff3b9cab9620533e577e9ab38d9579cd554..6513415a02a00574dbfc1b22c1c909e94e6bfd49 100644 --- a/mace/kernels/opencl/helper.h +++ b/mace/kernels/opencl/helper.h @@ -25,6 +25,7 @@ enum BufferType { WINOGRAD_FILTER = 5, DW_CONV2D_FILTER = 6, WEIGHT_HEIGHT = 7, + WEIGHT_WIDTH = 8, }; void CalImage2DShape(const std::vector &shape, /* NHWC */ diff --git a/mace/ops/fully_connected.h b/mace/ops/fully_connected.h index 2f915149084f6729a2cc7517239528869e49cf38..282804a5793a4b8c5bb6bab26b2389e41f0f5170 100644 --- a/mace/ops/fully_connected.h +++ b/mace/ops/fully_connected.h @@ -15,7 +15,11 @@ class FullyConnectedOp : public Operator { public: FullyConnectedOp(const OperatorDef &operator_def, Workspace *ws) : Operator(operator_def, ws), - functor_(kernels::StringToActivationType( + functor_(static_cast( + OperatorBase::GetSingleArgument( + "weight_type", static_cast( + kernels::WEIGHT_WIDTH))), + kernels::StringToActivationType( OperatorBase::GetSingleArgument("activation", "NOOP")), OperatorBase::GetSingleArgument("max_limit", 0.0f)) {} diff --git a/mace/ops/fully_connected_benchmark.cc b/mace/ops/fully_connected_benchmark.cc index 9ada2c5420d3a1571db9398c34fc5086fb1a2c59..c7f3dee445a127f7eebef354c2a39835a1ab6a26 100644 --- a/mace/ops/fully_connected_benchmark.cc +++ b/mace/ops/fully_connected_benchmark.cc @@ -22,10 +22,18 @@ static void FCBenchmark( net.AddRandomInput("Bias", {out_channel}); if (D == DeviceType::OPENCL) { + const int width_size = height * width * channel; + kernels::BufferType weight_type = kernels::BufferType::WEIGHT_HEIGHT; +// if (width_size > 16384) { + BufferToImage(net, "Weight", "WeightImage", + kernels::BufferType::WEIGHT_WIDTH); + weight_type = kernels::BufferType::WEIGHT_WIDTH; +// } else { +// BufferToImage(net, "Weight", "WeightImage", +// kernels::BufferType::WEIGHT_HEIGHT); +// } BufferToImage(net, "Input", "InputImage", kernels::BufferType::IN_OUT_CHANNEL); - BufferToImage(net, "Weight", "WeightImage", - kernels::BufferType::WEIGHT_HEIGHT); BufferToImage(net, "Bias", "BiasImage", kernels::BufferType::ARGUMENT); @@ -34,6 +42,7 @@ static void FCBenchmark( .Input("WeightImage") .Input("BiasImage") .Output("OutputImage") + .AddIntArg("weight_type", static_cast(weight_type)) .AddIntArg("T", static_cast(DataTypeToEnum::value)) .Finalize(net.NewOperatorDef()); } else { @@ -78,4 +87,6 @@ static void FCBenchmark( BM_FC(1, 16, 16, 32, 32); BM_FC(1, 8, 8, 32, 1000); +BM_FC(1, 2, 2, 512, 2); +BM_FC(1, 7, 7, 512, 4096); } // namespace mace diff --git a/mace/ops/fully_connected_test.cc b/mace/ops/fully_connected_test.cc index 3a41dd87f766ac92a6f12ed32d801f59b943835b..cfbc679644404a37befcb0c611189855bb6f9e41 100644 --- a/mace/ops/fully_connected_test.cc +++ b/mace/ops/fully_connected_test.cc @@ -39,6 +39,7 @@ void Simple(const std::vector &input_shape, .Input("WeightImage") .Input("BiasImage") .Output("OutputImage") + .AddIntArg("weight_type", kernels::BufferType::WEIGHT_HEIGHT) .Finalize(net.NewOperatorDef()); // Run net.RunOp(D); @@ -147,6 +148,7 @@ void Complex(const index_t batch, .Input("WeightImage") .Input("BiasImage") .Output("OutputImage") + .AddIntArg("weight_type", kernels::BufferType::WEIGHT_HEIGHT) .AddIntArg("T", static_cast(DataTypeToEnum::value)) .Finalize(net.NewOperatorDef()); @@ -183,4 +185,75 @@ TEST_F(FullyConnectedOpTest, OPENCLHalfUnAlignedWithBatch) { Complex(16, 13, 12, 31, 113); Complex(31, 21, 11, 23, 103); } + +template +void TestWeightWidthFormat(const index_t batch, + const index_t height, + const index_t width, + const index_t channels, + const index_t out_channel) { + srand(time(NULL)); + + // Construct graph + OpsTestNet net; + OpDefBuilder("FC", "FullyConnectedTest") + .Input("Input") + .Input("Weight") + .Input("Bias") + .Output("Output") + .Finalize(net.NewOperatorDef()); + + // Add input data + net.AddRandomInput( + "Input", {batch, height, width, channels}); + net.AddRandomInput( + "Weight", {out_channel, height * width * channels}); + net.AddRandomInput("Bias", {out_channel}); + + // run cpu + net.RunOp(); + + // Check + Tensor expected; + expected.Copy(*net.GetOutput("Output")); + + // Run on opencl + BufferToImage(net, "Input", "InputImage", + kernels::BufferType::IN_OUT_CHANNEL); + BufferToImage(net, "Weight", "WeightImage", + kernels::BufferType::WEIGHT_WIDTH); + BufferToImage(net, "Bias", "BiasImage", + kernels::BufferType::ARGUMENT); + + OpDefBuilder("FC", "FullyConnectedTest") + .Input("InputImage") + .Input("WeightImage") + .Input("BiasImage") + .Output("OutputImage") + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Finalize(net.NewOperatorDef()); + + // Run on opencl + net.RunOp(DeviceType::OPENCL); + + ImageToBuffer(net, "OutputImage", "OPENCLOutput", + kernels::BufferType::IN_OUT_CHANNEL); + if (DataTypeToEnum::value == DataType::DT_HALF) { + ExpectTensorNear(expected, *net.GetOutput("OPENCLOutput"), 1); + } else { + ExpectTensorNear(expected, *net.GetOutput("OPENCLOutput"), 1e-2); + } +} + +TEST_F(FullyConnectedOpTest, OPENCLWidthFormatAligned) { + TestWeightWidthFormat(1, 7, 7, 32, 16); + TestWeightWidthFormat(1, 7, 7, 512, 128); + TestWeightWidthFormat(1, 1, 1, 2048, 1024); +} +TEST_F(FullyConnectedOpTest, OPENCLHalfWidthFormatAligned) { + TestWeightWidthFormat(1, 2, 2, 512, 2); + TestWeightWidthFormat(1, 11, 11, 32, 16); + TestWeightWidthFormat(1, 16, 32, 32, 32); +} + }