提交 8023095b 编写于 作者: L liuqi

Optimize FC op using local memory.

上级 9ba7bce9
......@@ -9,24 +9,30 @@
#include "mace/core/runtime/opencl/cl2_header.h"
#include "mace/core/tensor.h"
#include "mace/kernels/activation.h"
#include "mace/kernels/opencl/helper.h"
namespace mace {
namespace kernels {
struct FullyConnectedBase {
FullyConnectedBase(const ActivationType activation,
FullyConnectedBase(const BufferType weight_type,
const ActivationType activation,
const float relux_max_limit)
: activation_(activation), relux_max_limit_(relux_max_limit) {}
: weight_type_(weight_type),
activation_(activation),
relux_max_limit_(relux_max_limit) {}
const int weight_type_;
const ActivationType activation_;
const float relux_max_limit_;
};
template <DeviceType D, typename T>
struct FullyConnectedFunctor : FullyConnectedBase {
FullyConnectedFunctor(const ActivationType activation,
FullyConnectedFunctor(const BufferType weight_type,
const ActivationType activation,
const float relux_max_limit)
: FullyConnectedBase(activation, relux_max_limit) {}
: FullyConnectedBase(weight_type, activation, relux_max_limit) {}
void operator()(const Tensor *input,
const Tensor *weight,
......@@ -70,9 +76,10 @@ struct FullyConnectedFunctor : FullyConnectedBase {
template <typename T>
struct FullyConnectedFunctor<DeviceType::OPENCL, T> : FullyConnectedBase {
FullyConnectedFunctor(const ActivationType activation,
FullyConnectedFunctor(const BufferType weight_type,
const ActivationType activation,
const float relux_max_limit)
: FullyConnectedBase(activation, relux_max_limit) {}
: FullyConnectedBase(weight_type, activation, relux_max_limit) {}
void operator()(const Tensor *input,
const Tensor *weight,
......@@ -81,6 +88,8 @@ struct FullyConnectedFunctor<DeviceType::OPENCL, T> : FullyConnectedBase {
StatsFuture *future);
cl::Kernel kernel_;
std::vector<uint32_t> gws_;
std::vector<uint32_t> lws_;
};
} // namespace kernels
......
......@@ -49,6 +49,7 @@ void BufferToImageFunctor<DeviceType::OPENCL, T>::operator()(
: "in_out_height_buffer_to_image";
break;
case IN_OUT_WIDTH:
case WEIGHT_WIDTH:
MACE_CHECK(!i2b_) << "IN_OUT_WIDTH only support buffer to image now";
kernel_name = "in_out_width_buffer_to_image";
break;
......@@ -88,7 +89,7 @@ void BufferToImageFunctor<DeviceType::OPENCL, T>::operator()(
}
if (type == ARGUMENT) {
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(0)));
} else if (type == WEIGHT_HEIGHT) {
} else if (type == WEIGHT_HEIGHT || type == WEIGHT_WIDTH) {
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(0)));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(1)));
b2f_kernel.setArg(idx++, 1);
......
......@@ -318,10 +318,11 @@ __kernel void in_out_width_buffer_to_image(__global const DATA_TYPE *input, /* n
__write_only image2d_t output) {
int w = get_global_id(0);
int h = get_global_id(1);
const int width_blks = (width + 3) / 4;
const int batch_idx = h / height;
const int height_idx = h % height;
const int width_idx = (w % width) << 2;
const int channel_idx = w / width;
const int width_idx = (w % width_blks) << 2;
const int channel_idx = w / width_blks;
const int offset = input_offset + ((batch_idx * height + height_idx) * width + width_idx) * channels
+ channel_idx;
......
......@@ -4,7 +4,7 @@
__kernel void fully_connected(__read_only image2d_t input,
__read_only image2d_t weight,
#ifdef BIAS
__read_only image2d_t bias,
__read_only image2d_t bias,
#endif
__write_only image2d_t output,
__private const int input_height,
......@@ -55,3 +55,76 @@ __kernel void fully_connected(__read_only image2d_t input,
#endif
WRITE_IMAGET(output, (int2)(out_blk_idx, batch_idx), result);
}
// output = weight * input + bias
__kernel void fully_connected_width(__read_only image2d_t input,
__read_only image2d_t weight,
#ifdef BIAS
__read_only image2d_t bias,
#endif
__write_only image2d_t output,
__local float *intermediate_output,
__private const int input_height,
__private const int input_width,
__private const short in_chan_blks,
__private const float relux_max_limit) {
const int inter_out_idx = get_global_id(0);
const int width_blk_idx = get_global_id(1);
const int width_blk_count = get_global_size(1);
const int out_blk_idx = get_global_id(2);
const short in_outer_size = mul24(input_width, in_chan_blks);
const short weight_y = mad24(out_blk_idx, 4, inter_out_idx);
int2 input_coord, weight_coord;
DATA_TYPE4 in, w;
DATA_TYPE sum = 0.0;
input_coord = (int2)(0, 0);
for (short h_idx = 0; h_idx < input_height; ++h_idx) {
short weight_x_base = mul24(h_idx, in_outer_size);
for (short w_idx = (short)width_blk_idx; w_idx < input_width; w_idx += width_blk_count) {
short weight_x = mad24(w_idx, in_chan_blks, weight_x_base);
weight_coord = (int2)(weight_x, weight_y);
input_coord.x = w_idx;
#pragma unroll
for (short chan_idx = 0; chan_idx < in_chan_blks; ++chan_idx) {
in = READ_IMAGET(input, SAMPLER, input_coord);
w = READ_IMAGET(weight, SAMPLER, weight_coord);
sum += dot(in, w);
input_coord.x += input_width;
weight_coord.x += 1;
}
}
input_coord.y++;
}
const short inter_out_offset = mad24(get_local_id(1), 4, get_local_id(0));
const short local_width_blk_size = (short)get_local_size(1);
const short local_size = mul24((short)get_local_size(0),
local_width_blk_size);
short inter_idx = mad24((short)get_local_id(2), local_size, inter_out_offset);
intermediate_output[inter_idx] = sum;
if (inter_out_offset == 0) {
#ifdef BIAS
DATA_TYPE4 result = READ_IMAGET(bias, SAMPLER, (int2)(out_blk_idx, 0));
#else
DATA_TYPE4 result = (DATA_TYPE4)(0, 0, 0, 0);
#endif
for(short i = 0; i < local_width_blk_size; ++i) {
result += vload4(0, intermediate_output+inter_idx);
inter_idx += 4;
}
#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID)
result = do_activation(result, relux_max_limit);
#endif
WRITE_IMAGET(output, (int2)(out_blk_idx, 0), result);
}
}
......@@ -3,31 +3,105 @@
//
#include "mace/kernels/fully_connected.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/kernels/opencl/helper.h"
#include "mace/utils/tuner.h"
namespace mace {
namespace kernels {
template <typename T>
void FullyConnectedFunctor<DeviceType::OPENCL, T>::operator()(
const Tensor *input,
const Tensor *weight,
const Tensor *bias,
Tensor *output,
StatsFuture *future) {
std::vector<index_t> output_shape = {input->dim(0), 1, 1, weight->dim(0)};
std::vector<size_t> output_image_shape;
CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, output_image_shape);
output->ResizeImage(output_shape, output_image_shape);
void FCWXKernel(cl::Kernel *kernel,
const Tensor *input,
const Tensor *weight,
const Tensor *bias,
Tensor *output,
const ActivationType activation,
std::vector<uint32_t> &gws,
std::vector<uint32_t> &lws,
const float relux_max_limit,
StatsFuture *future) {
MACE_CHECK(input->dim(3) % 4 == 0)
<< "FC width kernel only support input with 4x channel.";
auto runtime = OpenCLRuntime::Global();
if (kernel->get() == nullptr) {
std::set<std::string> built_options;
auto dt = DataTypeToEnum<T>::value;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("fully_connected");
kernel_name = MACE_OBFUSCATE_SYMBOL("fully_connected_width");
built_options.emplace("-Dfully_connected_width=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
if (bias != nullptr) {
built_options.emplace("-DBIAS");
}
switch (activation) {
case NOOP:break;
case RELU:built_options.emplace("-DUSE_RELU");
break;
case RELUX:built_options.emplace("-DUSE_RELUX");
break;
case TANH:built_options.emplace("-DUSE_TANH");
break;
case SIGMOID:built_options.emplace("-DUSE_SIGMOID");
break;
default:LOG(FATAL) << "Unknown activation type: " << activation;
}
*kernel =
runtime->BuildKernel("fully_connected", kernel_name, built_options);
const index_t batch = output->dim(0);
const index_t output_size = output->dim(3);
const index_t output_blocks = RoundUpDiv4(output_size);
const index_t batch = output->dim(0);
const index_t output_size = output->dim(3);
gws = {4, 8, static_cast<uint32_t>(output_blocks)};
const index_t output_blocks = RoundUpDiv4(output_size);
const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(*kernel);
const uint32_t inter_local_blks = kwg_size / (gws[0] * gws[1]);
lws = {gws[0], gws[1], inter_local_blks};
if (kernel_.get() == nullptr) {
uint32_t idx = 0;
kernel->setArg(idx++, *(input->opencl_image()));
kernel->setArg(idx++, *(weight->opencl_image()));
if (bias != nullptr) {
kernel->setArg(idx++, *(bias->opencl_image()));
}
kernel->setArg(idx++, *(output->opencl_image()));
kernel->setArg(idx++, (lws[0] * lws[1] * lws[2] * sizeof(float)), nullptr);
kernel->setArg(idx++, static_cast<int>(input->dim(1)));
kernel->setArg(idx++, static_cast<int>(input->dim(2)));
kernel->setArg(idx++, static_cast<short>(RoundUpDiv4(input->dim(3))));
kernel->setArg(idx++, relux_max_limit);
}
cl::Event event;
cl_int error = runtime->command_queue().enqueueNDRangeKernel(
*kernel, cl::NullRange, cl::NDRange(gws[0], gws[1], gws[2]),
cl::NDRange(lws[0], lws[1], lws[2]), nullptr, &event);
MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error;
if (future != nullptr) {
future->wait_fn = [runtime, event](CallStats *stats) {
event.wait();
if (stats != nullptr) {
runtime->GetCallStats(event, stats);
}
};
}
}
template <typename T>
void FCWTXKernel(cl::Kernel *kernel,
const Tensor *input,
const Tensor *weight,
const Tensor *bias,
Tensor *output,
const ActivationType activation,
std::vector<uint32_t> &gws,
std::vector<uint32_t> &lws,
const float relux_max_limit,
StatsFuture *future) {
if (kernel->get() == nullptr) {
auto runtime = OpenCLRuntime::Global();
std::set<std::string> built_options;
auto dt = DataTypeToEnum<T>::value;
......@@ -38,7 +112,7 @@ void FullyConnectedFunctor<DeviceType::OPENCL, T>::operator()(
if (bias != nullptr) {
built_options.emplace("-DBIAS");
}
switch (activation_) {
switch (activation) {
case NOOP:
break;
case RELU:
......@@ -54,33 +128,61 @@ void FullyConnectedFunctor<DeviceType::OPENCL, T>::operator()(
built_options.emplace("-DUSE_SIGMOID");
break;
default:
LOG(FATAL) << "Unknown activation type: " << activation_;
LOG(FATAL) << "Unknown activation type: " << activation;
}
kernel_ =
*kernel =
runtime->BuildKernel("fully_connected", kernel_name, built_options);
uint32_t idx = 0;
kernel_.setArg(idx++, *(input->opencl_image()));
kernel_.setArg(idx++, *(weight->opencl_image()));
kernel->setArg(idx++, *(input->opencl_image()));
kernel->setArg(idx++, *(weight->opencl_image()));
if (bias != nullptr) {
kernel_.setArg(idx++, *(bias->opencl_image()));
kernel->setArg(idx++, *(bias->opencl_image()));
}
kernel_.setArg(idx++, *(output->opencl_image()));
kernel_.setArg(idx++, static_cast<int>(input->dim(1)));
kernel_.setArg(idx++, static_cast<int>(input->dim(2)));
kernel_.setArg(idx++, static_cast<int>(input->dim(3)));
kernel->setArg(idx++, *(output->opencl_image()));
kernel->setArg(idx++, static_cast<int>(input->dim(1)));
kernel->setArg(idx++, static_cast<int>(input->dim(2)));
kernel->setArg(idx++, static_cast<int>(input->dim(3)));
// FIXME handle flexable data type: half not supported
kernel_.setArg(idx++, relux_max_limit_);
kernel->setArg(idx++, relux_max_limit);
const index_t batch = output->dim(0);
const index_t output_size = output->dim(3);
const index_t output_blocks = RoundUpDiv4(output_size);
gws = {
static_cast<uint32_t>(batch), static_cast<uint32_t>(output_blocks),
};
lws = {16, 64, 1};
}
const uint32_t gws[2] = {
static_cast<uint32_t>(batch), static_cast<uint32_t>(output_blocks),
};
const std::vector<uint32_t> lws = {16, 64, 1};
std::stringstream ss;
ss << "fc_opencl_kernel_" << output->dim(0) << "_" << output->dim(1) << "_"
<< output->dim(2) << "_" << output->dim(3);
TuningOrRun2DKernel(kernel_, ss.str(), gws, lws, future);
TuningOrRun2DKernel(*kernel, ss.str(), gws.data(), lws, future);
}
template <typename T>
void FullyConnectedFunctor<DeviceType::OPENCL, T>::operator()(
const Tensor *input,
const Tensor *weight,
const Tensor *bias,
Tensor *output,
StatsFuture *future) {
std::vector<index_t> output_shape = {input->dim(0), 1, 1, weight->dim(0)};
std::vector<size_t> output_image_shape;
CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, output_image_shape);
output->ResizeImage(output_shape, output_image_shape);
if (weight_type_ == BufferType::WEIGHT_HEIGHT) {
FCWTXKernel<T>(&kernel_, input, weight, bias, output,
activation_, gws_, lws_, relux_max_limit_, future);
} else {
FCWXKernel<T>(&kernel_, input, weight, bias, output,
activation_, gws_, lws_, relux_max_limit_, future);
}
};
template struct FullyConnectedFunctor<DeviceType::OPENCL, float>;
......
......@@ -84,6 +84,15 @@ void CalWeightHeightImageShape(const std::vector<index_t> &shape, /* HW */
image_shape[1] = RoundUpDiv4(shape[0]);
}
// [(W + 3) / 4, H]
void CalWeightWidthImageShape(const std::vector<index_t> &shape, /* HW */
std::vector<size_t> &image_shape) {
MACE_CHECK(shape.size() == 2);
image_shape.resize(2);
image_shape[0] = RoundUpDiv4(shape[1]);
image_shape[1] = shape[0];
}
void CalImage2DShape(const std::vector<index_t> &shape, /* NHWC */
const BufferType type,
std::vector<size_t> &image_shape) {
......@@ -112,6 +121,9 @@ void CalImage2DShape(const std::vector<index_t> &shape, /* NHWC */
case WEIGHT_HEIGHT:
CalWeightHeightImageShape(shape, image_shape);
break;
case WEIGHT_WIDTH:
CalWeightWidthImageShape(shape, image_shape);
break;
default:
LOG(FATAL) << "Mace not supported yet.";
}
......
......@@ -25,6 +25,7 @@ enum BufferType {
WINOGRAD_FILTER = 5,
DW_CONV2D_FILTER = 6,
WEIGHT_HEIGHT = 7,
WEIGHT_WIDTH = 8,
};
void CalImage2DShape(const std::vector<index_t> &shape, /* NHWC */
......
......@@ -15,7 +15,11 @@ class FullyConnectedOp : public Operator<D, T> {
public:
FullyConnectedOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws),
functor_(kernels::StringToActivationType(
functor_(static_cast<kernels::BufferType>(
OperatorBase::GetSingleArgument<int>(
"weight_type", static_cast<int>(
kernels::WEIGHT_WIDTH))),
kernels::StringToActivationType(
OperatorBase::GetSingleArgument<std::string>("activation",
"NOOP")),
OperatorBase::GetSingleArgument<float>("max_limit", 0.0f)) {}
......
......@@ -22,10 +22,18 @@ static void FCBenchmark(
net.AddRandomInput<D, float>("Bias", {out_channel});
if (D == DeviceType::OPENCL) {
const int width_size = height * width * channel;
kernels::BufferType weight_type = kernels::BufferType::WEIGHT_HEIGHT;
// if (width_size > 16384) {
BufferToImage<D, T>(net, "Weight", "WeightImage",
kernels::BufferType::WEIGHT_WIDTH);
weight_type = kernels::BufferType::WEIGHT_WIDTH;
// } else {
// BufferToImage<D, T>(net, "Weight", "WeightImage",
// kernels::BufferType::WEIGHT_HEIGHT);
// }
BufferToImage<D, T>(net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, T>(net, "Weight", "WeightImage",
kernels::BufferType::WEIGHT_HEIGHT);
BufferToImage<D, T>(net, "Bias", "BiasImage",
kernels::BufferType::ARGUMENT);
......@@ -34,6 +42,7 @@ static void FCBenchmark(
.Input("WeightImage")
.Input("BiasImage")
.Output("OutputImage")
.AddIntArg("weight_type", static_cast<int>(weight_type))
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
} else {
......@@ -78,4 +87,6 @@ static void FCBenchmark(
BM_FC(1, 16, 16, 32, 32);
BM_FC(1, 8, 8, 32, 1000);
BM_FC(1, 2, 2, 512, 2);
BM_FC(1, 7, 7, 512, 4096);
} // namespace mace
......@@ -39,6 +39,7 @@ void Simple(const std::vector<index_t> &input_shape,
.Input("WeightImage")
.Input("BiasImage")
.Output("OutputImage")
.AddIntArg("weight_type", kernels::BufferType::WEIGHT_HEIGHT)
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
......@@ -147,6 +148,7 @@ void Complex(const index_t batch,
.Input("WeightImage")
.Input("BiasImage")
.Output("OutputImage")
.AddIntArg("weight_type", kernels::BufferType::WEIGHT_HEIGHT)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
......@@ -183,4 +185,75 @@ TEST_F(FullyConnectedOpTest, OPENCLHalfUnAlignedWithBatch) {
Complex<half>(16, 13, 12, 31, 113);
Complex<half>(31, 21, 11, 23, 103);
}
template <typename T>
void TestWeightWidthFormat(const index_t batch,
const index_t height,
const index_t width,
const index_t channels,
const index_t out_channel) {
srand(time(NULL));
// Construct graph
OpsTestNet net;
OpDefBuilder("FC", "FullyConnectedTest")
.Input("Input")
.Input("Weight")
.Input("Bias")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<DeviceType::OPENCL, float>(
"Input", {batch, height, width, channels});
net.AddRandomInput<DeviceType::OPENCL, float>(
"Weight", {out_channel, height * width * channels});
net.AddRandomInput<DeviceType::OPENCL, float>("Bias", {out_channel});
// run cpu
net.RunOp();
// Check
Tensor expected;
expected.Copy(*net.GetOutput("Output"));
// Run on opencl
BufferToImage<DeviceType::OPENCL, T>(net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<DeviceType::OPENCL, T>(net, "Weight", "WeightImage",
kernels::BufferType::WEIGHT_WIDTH);
BufferToImage<DeviceType::OPENCL, float>(net, "Bias", "BiasImage",
kernels::BufferType::ARGUMENT);
OpDefBuilder("FC", "FullyConnectedTest")
.Input("InputImage")
.Input("WeightImage")
.Input("BiasImage")
.Output("OutputImage")
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
// Run on opencl
net.RunOp(DeviceType::OPENCL);
ImageToBuffer<DeviceType::OPENCL, float>(net, "OutputImage", "OPENCLOutput",
kernels::BufferType::IN_OUT_CHANNEL);
if (DataTypeToEnum<T>::value == DataType::DT_HALF) {
ExpectTensorNear<float>(expected, *net.GetOutput("OPENCLOutput"), 1);
} else {
ExpectTensorNear<float>(expected, *net.GetOutput("OPENCLOutput"), 1e-2);
}
}
TEST_F(FullyConnectedOpTest, OPENCLWidthFormatAligned) {
TestWeightWidthFormat<float>(1, 7, 7, 32, 16);
TestWeightWidthFormat<float>(1, 7, 7, 512, 128);
TestWeightWidthFormat<float>(1, 1, 1, 2048, 1024);
}
TEST_F(FullyConnectedOpTest, OPENCLHalfWidthFormatAligned) {
TestWeightWidthFormat<float>(1, 2, 2, 512, 2);
TestWeightWidthFormat<half>(1, 11, 11, 32, 16);
TestWeightWidthFormat<half>(1, 16, 32, 32, 32);
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册