提交 e19514a9 编写于 作者: L liuqi

Support general conv for opencl.

上级 5d5d06c2
......@@ -139,6 +139,7 @@ const std::map<std::string, std::string>
OpenCLRuntime::program_map_ = {
{"addn", "addn.cl"},
{"batch_norm", "batch_norm.cl"},
{"conv_2d", "conv_2d.cl"},
{"conv_2d_1x1", "conv_2d_1x1.cl"},
{"conv_2d_3x3", "conv_2d_3x3.cl"},
{"depthwise_conv_3x3", "depthwise_conv_3x3.cl"},
......
#include <common.h>
__kernel void conv_2d(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */
__read_only image2d_t filter, /* cout%4 * cin * kw * kh, cout/4 */
#ifdef BIAS
__read_only image2d_t bias, /* cout%4 * cout/4 */
#endif
__write_only image2d_t output,
__private const int in_height,
__private const int in_width,
__private const int in_ch_blks,
__private const int out_height,
__private const int out_width,
__private const int filter_height,
__private const int filter_width,
__private const int padding_top,
__private const int padding_left) {
const int out_ch_blk = get_global_id(0);
const int out_w_blk = get_global_id(1);
const int out_w_blks = get_global_size(1);
const int out_hb = get_global_id(2);
const int rounded_in_ch = in_ch_blks * 4;
const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
#ifdef BIAS
DATA_TYPE4 out0 =
READ_IMAGET(bias, sampler, (int2)(out_ch_blk, 0));
DATA_TYPE4 out1 = out0;
DATA_TYPE4 out2 = out0;
DATA_TYPE4 out3 = out0;
#else
DATA_TYPE4 out0 = 0;
DATA_TYPE4 out1 = 0;
DATA_TYPE4 out2 = 0;
DATA_TYPE4 out3 = 0;
#endif
#if STRIDE == 1
int in_width0 = out_w_blk - padding_left;
int in_width1 = in_width0 + out_w_blks;
int in_width2 = in_width1 + out_w_blks;
int in_width3 = in_width2 + out_w_blks;
const int height_idx = (out_hb % out_height) - padding_top;
#else
int in_width0 = out_w_blk * 2 - padding_left;
int in_width1 = (out_w_blk + out_w_blks) * 2 - padding_left;
int in_width2 = (out_w_blk + 2 * out_w_blks) * 2 - padding_left;
int in_width3 = (out_w_blk + 3 * out_w_blks) * 2 - padding_left;
const int height_idx = (out_hb % out_height) * 2 - padding_top;
#endif
const int batch_idx = (out_hb / out_height) * in_height;
DATA_TYPE4 in0, in1, in2, in3;
DATA_TYPE4 weights0, weights1, weights2, weights3;
int in_idx, in_width_idx;
// Unrolling this loop hurt perfmance
for (short in_ch_blk = 0; in_ch_blk < in_ch_blks; ++in_ch_blk) {
for (short hb_idx = 0; hb_idx < filter_height; ++hb_idx) {
for (short width_idx = 0; width_idx < filter_width; ++width_idx) {
in_idx = in_ch_blk * in_width;
int in_hb_value = height_idx + hb_idx;
in_hb_value = select(in_hb_value + batch_idx,
-1,
(in_hb_value < 0 || in_hb_value >= in_height));
int in_width_value;
#define READ_INPUT(i) \
in_width_value = in_width##i + width_idx; \
in_width_value = select(in_idx + in_width_value, \
-1, \
(in_width_value < 0 || in_width_value >= in_width)); \
in##i = READ_IMAGET(input, sampler, (int2)(in_width_value, in_hb_value));
READ_INPUT(0);
READ_INPUT(1);
READ_INPUT(2);
READ_INPUT(3);
#undef READ_INPUT
int filter_idx = (in_ch_blk << 2) + (hb_idx * filter_width + width_idx) * rounded_in_ch;
weights0 = READ_IMAGET(filter, sampler, (int2)(filter_idx + 0, out_ch_blk));
weights1 = READ_IMAGET(filter, sampler, (int2)(filter_idx + 1, out_ch_blk));
weights2 = READ_IMAGET(filter, sampler, (int2)(filter_idx + 2, out_ch_blk));
weights3 = READ_IMAGET(filter, sampler, (int2)(filter_idx + 3, out_ch_blk));
// Will prefetch L2 improve performance? How to pretch image data?
// Interleaving load and mul does not improve performance as expected
out0 += in0.x * weights0;
out0 += in0.y * weights1;
out0 += in0.z * weights2;
out0 += in0.w * weights3;
out1 += in1.x * weights0;
out1 += in1.y * weights1;
out1 += in1.z * weights2;
out1 += in1.w * weights3;
out2 += in2.x * weights0;
out2 += in2.y * weights1;
out2 += in2.z * weights2;
out2 += in2.w * weights3;
out3 += in3.x * weights0;
out3 += in3.y * weights1;
out3 += in3.z * weights2;
out3 += in3.w * weights3;
}
}
}
#ifdef FUSED_RELU
// TODO relux
out0 = fmax(out0, 0);
out1 = fmax(out1, 0);
out2 = fmax(out2, 0);
out3 = fmax(out3, 0);
#endif
const int out_x_base = out_ch_blk * out_width;
int w = out_w_blk;
WRITE_IMAGET(output,
(int2)(out_x_base + w, out_hb),
out0);
w += out_w_blks;
if (w >= out_width) return;
WRITE_IMAGET(output,
(int2)(out_x_base + w, out_hb),
out1);
w += out_w_blks;
if (w >= out_width) return;
WRITE_IMAGET(output,
(int2)(out_x_base + w, out_hb),
out2);
w += out_w_blks;
if (w >= out_width) return;
WRITE_IMAGET(output,
(int2)(out_x_base + w, out_hb),
out3);
}
......@@ -28,6 +28,11 @@ extern void Conv2dOpenclK3x3S2(const Tensor *input, const Tensor *filter,
const int *padding, const DataType dt,
Tensor *output);
extern void Conv2dOpencl(const Tensor *input, const Tensor *filter,
const Tensor *bias, const bool fused_relu,
const uint32_t stride, const int *padding,
const DataType dt, Tensor *output);
template<typename T>
void Conv2dFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
const Tensor *filter,
......@@ -47,17 +52,13 @@ void Conv2dFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
index_t kernel_h = filter->dim(0);
index_t kernel_w = filter->dim(1);
if (kernel_h != kernel_w || kernel_h > 5 || strides_[0] != strides_[1] ||
strides_[0] > 2 || dilations_[0] != 1 || dilations_[1] != 1 ||
selector[kernel_h - 1][strides_[0] - 1] == nullptr) {
if (!input->is_image() || strides_[0] != strides_[1] ||
strides_[0] > 2 || dilations_[0] != 1 || dilations_[1] != 1) {
LOG(WARNING) << "OpenCL conv2d kernel with "
<< "filter" << kernel_h << "x" << kernel_w << ","
<< " stride " << strides_[0] << "x" << strides_[1]
<< " is not implemented yet, using slow version";
// TODO(heliangliang) The CPU/NEON kernel should map the buffer
Conv2dFunctor<DeviceType::CPU, T>(strides_, paddings_, dilations_)(
input, filter, bias, output);
return;
MACE_NOT_IMPLEMENTED;
}
std::vector<index_t> output_shape(4);
......@@ -66,16 +67,18 @@ void Conv2dFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
input->shape().data(), filter->shape().data(), dilations_,
strides_, paddings_, output_shape.data(), paddings.data());
if (input->is_image()) {
std::vector<size_t> output_image_shape;
CalImage2DShape(output_shape, BufferType::IN_OUT, output_image_shape);
output->ResizeImage(output_shape, output_image_shape);
} else {
output->Resize(output_shape);
}
if (kernel_h == kernel_w && kernel_h <= 5 &&
selector[kernel_h - 1][strides_[0] - 1] != nullptr) {
auto conv2d_func = selector[kernel_h - 1][strides_[0] - 1];
conv2d_func(input, filter, bias, false, paddings.data(), DataTypeToEnum<T>::value, output);
} else {
Conv2dOpencl(input, filter, bias, false, strides_[0], paddings.data(), DataTypeToEnum<T>::value, output);
}
}
template
......
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/core/common.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/kernels/conv_2d.h"
#include "mace/kernels/opencl/helper.h"
#include "mace/utils/utils.h"
namespace mace {
namespace kernels {
void Conv2dOpencl(const Tensor *input, const Tensor *filter,
const Tensor *bias, const bool fused_relu,
const uint32_t stride, const int *padding,
const DataType dt, Tensor *output) {
const index_t batch = output->dim(0);
const index_t height = output->dim(1);
const index_t width = output->dim(2);
const index_t channels = output->dim(3);
const index_t input_channels = input->dim(3);
const index_t channel_blocks = RoundUpDiv4(channels);
const index_t input_channel_blocks = RoundUpDiv4(input_channels);
const index_t width_blocks = RoundUpDiv4(width);
std::set<std::string> built_options;
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
built_options.emplace(bias != nullptr ? "-DBIAS" : "");
built_options.emplace("-DSTRIDE=" + ToString(stride));
if (fused_relu) {
built_options.emplace("-DFUSED_RELU");
}
auto runtime = OpenCLRuntime::Get();
auto program = runtime->program();
auto conv_2d_kernel = runtime->BuildKernel("conv_2d", "conv_2d", built_options);
const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(conv_2d_kernel);
uint32_t idx = 0;
conv_2d_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(input->buffer())));
conv_2d_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(filter->buffer())));
if (bias != nullptr) {
conv_2d_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(bias->buffer())));
}
conv_2d_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(output->buffer())));
conv_2d_kernel.setArg(idx++, static_cast<int>(input->dim(1)));
conv_2d_kernel.setArg(idx++, static_cast<int>(input->dim(2)));
conv_2d_kernel.setArg(idx++, static_cast<int>(input_channel_blocks));
conv_2d_kernel.setArg(idx++, static_cast<int>(height));
conv_2d_kernel.setArg(idx++, static_cast<int>(width));
conv_2d_kernel.setArg(idx++, static_cast<int>(filter->dim(0)));
conv_2d_kernel.setArg(idx++, static_cast<int>(filter->dim(1)));
conv_2d_kernel.setArg(idx++, padding[0] / 2);
conv_2d_kernel.setArg(idx++, padding[1] / 2);
auto command_queue = runtime->command_queue();
cl_int error;
error = command_queue.enqueueNDRangeKernel(
conv_2d_kernel, cl::NullRange,
cl::NDRange(static_cast<uint32_t>(channel_blocks), static_cast<uint32_t>(width_blocks),
static_cast<uint32_t>(height * batch)),
cl::NDRange(16, 16, 4),
NULL, OpenCLRuntime::Get()->GetDefaultEvent());
MACE_CHECK(error == CL_SUCCESS, error);
}
} // namespace kernels
} // namespace mace
......@@ -118,14 +118,13 @@ void TestDiffTypeBidirectionTransform(const int type, const std::vector<index_t>
.Input("B2IOutput")
.Output("I2BOutput")
.AddIntArg("buffer_type", type)
.AddIntArg("T", DataTypeToEnum<T>::value)
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
// Check
ExpectTensorNear<float, T>(*net.GetOutput("Input"), *net.GetOutput("I2BOutput"), 1e-2);
ExpectTensorNear<float>(*net.GetOutput("Input"), *net.GetOutput("I2BOutput"), 1e-3);
}
TEST(BufferToImageTest, ArgFloatToHalfSmall) {
......
......@@ -558,18 +558,20 @@ TEST_F(Conv2dOpTest, OPENCLUnalignedConvNxNS12) {
}
template<DeviceType D>
static void TestHalfComplexConvNxNS12(const std::vector<index_t> &shape) {
static void TestHalfComplexConvNxNS12(const std::vector<index_t> &input_shape,
const std::vector<index_t> &filter_shape) {
testing::internal::LogToStderr();
auto func = [&](int kernel_h, int kernel_w, int stride_h, int stride_w,
Padding type) {
srand(time(NULL));
auto func = [&](int stride_h, int stride_w, Padding padding) {
// generate random input
index_t batch = 3 + (rand() % 10);
index_t height = shape[0];
index_t width = shape[1];
index_t input_channels = shape[2] + (rand() % 10);
index_t output_channels = shape[3] + (rand() % 10);
index_t height = input_shape[0];
index_t width = input_shape[1];
index_t kernel_h = filter_shape[0];
index_t kernel_w = filter_shape[1];
index_t input_channels = filter_shape[2] + (rand() % 10);
index_t output_channels = filter_shape[3] + (rand() % 10);
// Construct graph
OpsTestNet net;
OpDefBuilder("Conv2D", "Conv2dTest")
......@@ -578,7 +580,7 @@ static void TestHalfComplexConvNxNS12(const std::vector<index_t> &shape) {
.Input("Bias")
.Output("Output")
.AddIntsArg("strides", {stride_h, stride_w})
.AddIntArg("padding", type)
.AddIntArg("padding", padding)
.AddIntsArg("dilations", {1, 1})
.Finalize(net.NewOperatorDef());
......@@ -611,7 +613,7 @@ static void TestHalfComplexConvNxNS12(const std::vector<index_t> &shape) {
.Input("BiasImage")
.Output("OutputImage")
.AddIntsArg("strides", {stride_h, stride_w})
.AddIntArg("padding", type)
.AddIntArg("padding", padding)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataType::DT_HALF))
.Finalize(net.NewOperatorDef());
......@@ -620,20 +622,46 @@ static void TestHalfComplexConvNxNS12(const std::vector<index_t> &shape) {
ImageToBuffer<D, float>(net, "OutputImage", "OPENCLOutput", kernels::BufferType::IN_OUT);
ExpectTensorNear<float>(expected, *net.GetOutput("OPENCLOutput"), 0.2);
ExpectTensorNear<float>(expected, *net.GetOutput("OPENCLOutput"), 0.5);
};
for (int kernel_size : {1, 3}) {
for (int stride : {1, 2}) {
func(kernel_size, kernel_size, stride, stride, VALID);
}
func(stride, stride, VALID);
func(stride, stride, SAME);
}
}
TEST_F(Conv2dOpTest, OPENCLHalfAlignedConvNxNS12) {
TestHalfComplexConvNxNS12<DeviceType::OPENCL>({32, 32, 32, 64});
TEST_F(Conv2dOpTest, OPENCLHalfAlignedConv1x1S12) {
TestHalfComplexConvNxNS12<DeviceType::OPENCL>({32, 32},
{1, 1, 32, 64});
}
TEST_F(Conv2dOpTest, OPENCLHalfAlignedConv3x3S12) {
TestHalfComplexConvNxNS12<DeviceType::OPENCL>({32, 32},
{3, 3, 32, 64});
}
TEST_F(Conv2dOpTest, OPENCLHalfAlignedConv15x1S12) {
TestHalfComplexConvNxNS12<DeviceType::OPENCL>({32, 32},
{15, 1, 256, 2});
}
TEST_F(Conv2dOpTest, OPENCLHalfAlignedConv1x15S12) {
TestHalfComplexConvNxNS12<DeviceType::OPENCL>({32, 32},
{1, 15, 256, 2});
}
TEST_F(Conv2dOpTest, OPENCLHalfAlignedConv7x75S12) {
TestHalfComplexConvNxNS12<DeviceType::OPENCL>({32, 32},
{7, 7, 3, 64});
}
TEST_F(Conv2dOpTest, OPENCLHalfUnalignedConv1x1S12) {
TestHalfComplexConvNxNS12<DeviceType::OPENCL>({107, 113},
{1, 1, 5, 7});
}
TEST_F(Conv2dOpTest, OPENCLHalfUnalignedConvNxNS12) {
TestHalfComplexConvNxNS12<DeviceType::OPENCL>({107, 113, 5, 7});
TEST_F(Conv2dOpTest, OPENCLHalfUnalignedConv3x3S12) {
TestHalfComplexConvNxNS12<DeviceType::OPENCL>({107, 113},
{3, 3, 5, 7});
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册