未验证 提交 475187cc 编写于 作者: Y Yuan Shuai 提交者: GitHub

[LITE][OPENCL] Change fp32 fc to fp16's. test=develop (#3173)

* [LITE][OPENCL] Change fp32 fc to fp16's. test=develop

* fix act in conv3x3opt opencl kernel. test=develop
上级 27e90303
......@@ -255,7 +255,7 @@ void fc_gemv_1x4(__global const CL_DTYPE* a,
const int col = get_global_id(0) << 2; // gws[0]: [0, N >> 2) height of B == N
if (col + 3 < N) {
CL_DTYPE4 c0 = 0.0f;
half4 c0 = 0.0f;
if (bias) {
c0.x = bias[col];
c0.y = bias[col+1];
......@@ -266,11 +266,12 @@ void fc_gemv_1x4(__global const CL_DTYPE* a,
// main loop of K
int p = 0;
for (; p < K - 3; p += 4) {
CL_DTYPE4 a0 = vload4(0, a + p);
CL_DTYPE4 b0 = vload4(0, b + p * N + col);
CL_DTYPE4 b1 = vload4(0, b + (p+1) * N + col);
CL_DTYPE4 b2 = vload4(0, b + (p+2) * N + col);
CL_DTYPE4 b3 = vload4(0, b + (p+3) * N + col);
half4 a0 = convert_half4(vload4(0, a + p));
half4 b0 = convert_half4(vload4(0, b + p * N + col));
half4 b1 = convert_half4(vload4(0, b + (p+1) * N + col));
half4 b2 = convert_half4(vload4(0, b + (p+2) * N + col));
half4 b3 = convert_half4(vload4(0, b + (p+3) * N + col));
c0 += a0.x * b0;
c0 += a0.y * b1;
......@@ -279,21 +280,21 @@ void fc_gemv_1x4(__global const CL_DTYPE* a,
}
// compute left K
CL_DTYPE4 b2 = 0.0f,
b1 = 0.0f,
b0 = 0.0f,
a0 = 0.0f;
half4 b2 = 0.0f,
b1 = 0.0f,
b0 = 0.0f,
a0 = 0.0f;
switch (K - p) {
case 3: {
b2 = vload4(0, b + (p+2) * N + col);
b2 = convert_half4(vload4(0, b + (p+2) * N + col));
a0.z = a[p + 2];
}
case 2: {
b1 = vload4(0, b + (p+1) * N + col);
b1 = convert_half4(vload4(0, b + (p+1) * N + col));
a0.y = a[p + 1];
}
case 1: {
b0 = vload4(0, b + (p) * N + col);
b0 = convert_half4(vload4(0, b + (p) * N + col));
a0.x = a[p];
}
}
......@@ -317,7 +318,7 @@ void fc_gemv_1x4(__global const CL_DTYPE* a,
}
#else
if (col % 4 == 0) {
vstore4(c0, 0, c + col);
vstore4(convert_float4(c0), 0, c + col);
} else {
switch (col % 4) {
case 3:
......@@ -332,10 +333,10 @@ void fc_gemv_1x4(__global const CL_DTYPE* a,
} else {
const int left_col = N - col;
for (int col_offset = 0; col_offset < left_col; ++col_offset) {
CL_DTYPE c0 = bias ? bias[col] : 0;
half c0 = bias ? bias[col] : 0;
for (int p = 0; p < K; ++p) {
CL_DTYPE b0 = *(b + p * N + col + col_offset);
CL_DTYPE a0 = *(a + p);
half b0 = *(b + p * N + col + col_offset);
half a0 = *(a + p);
c0 += a0 * b0;
}
#ifdef RELU
......@@ -362,18 +363,18 @@ void fc_gemm_4x4(__global const CL_DTYPE* a,
const int col = get_global_id(1) << 2; // id: [0, N>>2) width of out == N
if (row+3 < M && col+3 < N) {
CL_DTYPE bias0 = bias ? bias[col] : 0,
bias1 = bias ? bias[col+1] : 0,
bias2 = bias ? bias[col+2] : 0,
bias3 = bias ? bias[col+3] : 0;
CL_COMPUTE_DTYPE bias0 = bias ? bias[col] : 0,
bias1 = bias ? bias[col+1] : 0,
bias2 = bias ? bias[col+2] : 0,
bias3 = bias ? bias[col+3] : 0;
CL_DTYPE c00 = bias0, c01 = bias1, c02 = bias2, c03 = bias3,
c10 = bias0, c11 = bias1, c12 = bias2, c13 = bias3,
c20 = bias0, c21 = bias1, c22 = bias2, c23 = bias3,
c30 = bias0, c31 = bias1, c32 = bias2, c33 = bias3;
CL_COMPUTE_DTYPE c00 = bias0, c01 = bias1, c02 = bias2, c03 = bias3,
c10 = bias0, c11 = bias1, c12 = bias2, c13 = bias3,
c20 = bias0, c21 = bias1, c22 = bias2, c23 = bias3,
c30 = bias0, c31 = bias1, c32 = bias2, c33 = bias3;
for (int p = 0; p < K; ++p) {
CL_DTYPE
CL_COMPUTE_DTYPE
a00 = *(a + row * K + p),
a10 = *(a + (row + 1) * K + p),
a20 = *(a + (row + 2) * K + p),
......@@ -403,7 +404,7 @@ void fc_gemm_4x4(__global const CL_DTYPE* a,
} else {
for (int cidx = col; cidx < N; ++cidx) {
for (int ridx = row; ridx < M; ++ridx) {
CL_DTYPE a0, b0, c0 = bias ? bias[cidx] : 0;
CL_COMPUTE_DTYPE a0, b0, c0 = bias ? bias[cidx] : 0;
for (int p = 0; p < K; ++p) {
a0 = *(a + ridx * K + p);
b0 = *(b + p * N + cidx),
......
......@@ -188,13 +188,12 @@ __kernel void conv2d_3x3(__private const int item_ch,
}
}
#ifdef RELU
output[0] = activation_type4(output[0]);
output[1] = activation_type4(output[1]);
output[2] = activation_type4(output[2]);
output[3] = activation_type4(output[3]);
output[4] = activation_type4(output[4]);
#endif
WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, (int2)(out_w_base_id + out_w_id0, item_h_id),
output[0]);
if (out_w_id1 < out_w) {
......
......@@ -57,6 +57,7 @@ class FcCompute
global_work_size_ = cl::NDRange{static_cast<size_t>((m_ + 3) / 4),
static_cast<size_t>((n_ + 3) / 4)};
}
VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
if (param.activation_type == "relu") {
build_options_ += "-DRELU";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册