提交 eab6744a 编写于 作者: A Alexander Alekhin

dnn(ocl): use compile-time LOCAL_SIZE parameter

instead of get_local_size(0) and dynamic local memory allocation
上级 f46cd9db
......@@ -138,9 +138,12 @@ public:
UMat& bnorm_weight = umat_scale;
UMat& bnorm_bias = umat_shift;
const unsigned LOCAL_SIZE = 128;
bool use_half = (inputs[0].depth() == CV_16S);
String opts = format(" -DT=%s -DT4=%s -Dconvert_T=%s", use_half ? "half" : "float",
use_half ? "half4" : "float4", use_half ? "convert_half4" : "convert_float4");
String opts = format(" -DT=%s -DT4=%s -Dconvert_T=%s -DLOCAL_SIZE=%u", use_half ? "half" : "float",
use_half ? "half4" : "float4", use_half ? "convert_half4" : "convert_float4",
LOCAL_SIZE
);
int splitDim = (acrossChannels) ? 1 : 2;
for (size_t inpIdx = 0; inpIdx < inputs.size(); inpIdx++)
......@@ -155,8 +158,8 @@ public:
float alpha = 1.0f / s[1];
String buildopt = "-DNUM=4" + opts;
ocl::Kernel k("mean_fuse4", ocl::dnn::mvn_oclsrc, buildopt);
size_t localsize[] = { 128 };
ocl::Kernel k("mean_fuse4", ocl::dnn::mvn_oclsrc, buildopt + " -DKERNEL_MEAN_FUSE");
size_t localsize[] = { LOCAL_SIZE };
size_t globalsize[] = { (size_t)s[0] / 4 * localsize[0] };
int argId = 0;
......@@ -165,7 +168,6 @@ public:
k.set(argId++, alpha);
k.set(argId++, ocl::KernelArg::PtrWriteOnly(meanMat));
k.set(argId++, ocl::KernelArg::PtrWriteOnly(tmpMat));
k.set(argId++, NULL, localsize[0] * sizeof(cl_float4));
bool ret = k.run(1, globalsize, localsize, false);
if (!ret)
return false;
......@@ -173,7 +175,7 @@ public:
buildopt += format(" %s %s", (fuse_batch_norm) ? "-DFUSE_BATCH_NORM" : "",
(fuse_relu) ? "-DFUSE_RELU" : "");
ocl::Kernel k1("mvn_fuse4", ocl::dnn::mvn_oclsrc, buildopt);
ocl::Kernel k1("mvn_fuse4", ocl::dnn::mvn_oclsrc, buildopt + " -DKERNEL_MVN_FUSE");
argId = 0;
k1.set(argId++, ocl::KernelArg::PtrReadOnly(tmpMat));
k1.set(argId++, ocl::KernelArg::PtrReadOnly(inpMat));
......@@ -185,7 +187,6 @@ public:
k1.set(argId++, ocl::KernelArg::PtrReadOnly(bnorm_weight));
k1.set(argId++, ocl::KernelArg::PtrReadOnly(bnorm_bias));
k1.set(argId++, ocl::KernelArg::PtrWriteOnly(outMat));
k1.set(argId++, NULL, localsize[0] * sizeof(cl_float4));
ret = k1.run(1, globalsize, localsize, false);
if (!ret)
return false;
......@@ -243,7 +244,7 @@ public:
if (normVariance)
{
String kname = format("calc_mean%d", number);
ocl::Kernel kernel(kname.c_str(), ocl::dnn::mvn_oclsrc, buildopt);
ocl::Kernel kernel(kname.c_str(), ocl::dnn::mvn_oclsrc, buildopt + " -DKERNEL_MEAN");
if (kernel.empty())
return false;
......@@ -263,7 +264,7 @@ public:
}
String kname = format("mvn%d", number);
buildopt += format("%s%s%s", (normVariance) ? " -DNORM_VARIANCE" : "",
buildopt += format("%s%s%s -DKERNEL_MVN", (normVariance) ? " -DNORM_VARIANCE" : "",
(fuse_batch_norm) ? " -DFUSE_BATCH_NORM" : "",
(fuse_relu) ? " -DFUSE_RELU" : "");
ocl::Kernel kernel1(kname.c_str(), ocl::dnn::mvn_oclsrc, buildopt);
......
......@@ -74,6 +74,8 @@
#define MVN_FUSE mvn_fuse1
#endif
#ifdef KERNEL_MEAN
__kernel void CALC_MEAN(__global const Dtype* src,
const int rows,
const int cols,
......@@ -94,6 +96,8 @@ __kernel void CALC_MEAN(__global const Dtype* src,
store(dst_vec, dst, index);
}
#elif defined KERNEL_MVN
__kernel void MVN(__global const Dtype* src,
const int rows,
const int cols,
......@@ -140,12 +144,13 @@ __kernel void MVN(__global const Dtype* src,
store(dst_vec, dst, index);
}
#elif defined KERNEL_MEAN_FUSE
__kernel void MEAN_FUSE(__global const T * A,
unsigned int A_col_size,
float alpha,
__global T4 * mean,
__global Dtype * tmp,
__local Dtype4 * work)
__global Dtype * tmp)
{
unsigned int row_gid = get_group_id(0);
unsigned int lid = get_local_id(0);
......@@ -168,15 +173,16 @@ __kernel void MEAN_FUSE(__global const T * A,
dot2 += convert_float4(a2);
dot3 += convert_float4(a3);
i += get_local_size(0);
i += LOCAL_SIZE;
}
__local Dtype4 work[LOCAL_SIZE];
work[lid].s0 = dot(dot0, b0);
work[lid].s1 = dot(dot1, b0);
work[lid].s2 = dot(dot2, b0);
work[lid].s3 = dot(dot3, b0);
for(unsigned int stride=get_local_size(0)/2 ; stride>0 ; stride>>=1)
for(unsigned int stride=LOCAL_SIZE/2 ; stride>0 ; stride>>=1)
{
barrier(CLK_LOCAL_MEM_FENCE);
if(lid < stride)
......@@ -212,10 +218,12 @@ __kernel void MEAN_FUSE(__global const T * A,
vstore4(dot2, i, dst0_read + 2 * A_col_size);
vstore4(dot3, i, dst0_read + 3 * A_col_size);
i += get_local_size(0);
i += LOCAL_SIZE;
}
}
#elif defined KERNEL_MVN_FUSE
__kernel void MVN_FUSE(__global const Dtype * tmp,
__global const T * A,
__global const T4 * mean,
......@@ -225,8 +233,7 @@ __kernel void MVN_FUSE(__global const Dtype * tmp,
const float relu_slope,
__global const Dtype4 * bnorm_weight,
__global const Dtype4 * bnorm_bias,
__global T * B,
__local Dtype4 * work)
__global T * B)
{
unsigned int row_gid = get_group_id(0);
unsigned int lid = get_local_id(0);
......@@ -250,15 +257,16 @@ __kernel void MVN_FUSE(__global const Dtype * tmp,
dot2 += a2;
dot3 += a3;
i += get_local_size(0);
i += LOCAL_SIZE;
}
__local Dtype4 work[LOCAL_SIZE];
work[lid].s0 = dot(dot0, b0);
work[lid].s1 = dot(dot1, b0);
work[lid].s2 = dot(dot2, b0);
work[lid].s3 = dot(dot3, b0);
for(unsigned int stride=get_local_size(0)/2 ; stride>0 ; stride>>=1)
for(unsigned int stride=LOCAL_SIZE/2 ; stride>0 ; stride>>=1)
{
barrier(CLK_LOCAL_MEM_FENCE);
if(lid < stride)
......@@ -314,6 +322,10 @@ __kernel void MVN_FUSE(__global const Dtype * tmp,
vstore4(convert_T(dot2), i, dst0_read + 2 * A_col_size);
vstore4(convert_T(dot3), i, dst0_read + 3 * A_col_size);
i += get_local_size(0);
i += LOCAL_SIZE;
}
}
#else
#error "Configuration error!"
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册