提交 eab6744a 编写于 作者: A Alexander Alekhin

dnn(ocl): use compile-time LOCAL_SIZE parameter

instead of get_local_size(0) and dynamic local memory allocation
上级 f46cd9db
...@@ -138,9 +138,12 @@ public: ...@@ -138,9 +138,12 @@ public:
UMat& bnorm_weight = umat_scale; UMat& bnorm_weight = umat_scale;
UMat& bnorm_bias = umat_shift; UMat& bnorm_bias = umat_shift;
const unsigned LOCAL_SIZE = 128;
bool use_half = (inputs[0].depth() == CV_16S); bool use_half = (inputs[0].depth() == CV_16S);
String opts = format(" -DT=%s -DT4=%s -Dconvert_T=%s", use_half ? "half" : "float", String opts = format(" -DT=%s -DT4=%s -Dconvert_T=%s -DLOCAL_SIZE=%u", use_half ? "half" : "float",
use_half ? "half4" : "float4", use_half ? "convert_half4" : "convert_float4"); use_half ? "half4" : "float4", use_half ? "convert_half4" : "convert_float4",
LOCAL_SIZE
);
int splitDim = (acrossChannels) ? 1 : 2; int splitDim = (acrossChannels) ? 1 : 2;
for (size_t inpIdx = 0; inpIdx < inputs.size(); inpIdx++) for (size_t inpIdx = 0; inpIdx < inputs.size(); inpIdx++)
...@@ -155,8 +158,8 @@ public: ...@@ -155,8 +158,8 @@ public:
float alpha = 1.0f / s[1]; float alpha = 1.0f / s[1];
String buildopt = "-DNUM=4" + opts; String buildopt = "-DNUM=4" + opts;
ocl::Kernel k("mean_fuse4", ocl::dnn::mvn_oclsrc, buildopt); ocl::Kernel k("mean_fuse4", ocl::dnn::mvn_oclsrc, buildopt + " -DKERNEL_MEAN_FUSE");
size_t localsize[] = { 128 }; size_t localsize[] = { LOCAL_SIZE };
size_t globalsize[] = { (size_t)s[0] / 4 * localsize[0] }; size_t globalsize[] = { (size_t)s[0] / 4 * localsize[0] };
int argId = 0; int argId = 0;
...@@ -165,7 +168,6 @@ public: ...@@ -165,7 +168,6 @@ public:
k.set(argId++, alpha); k.set(argId++, alpha);
k.set(argId++, ocl::KernelArg::PtrWriteOnly(meanMat)); k.set(argId++, ocl::KernelArg::PtrWriteOnly(meanMat));
k.set(argId++, ocl::KernelArg::PtrWriteOnly(tmpMat)); k.set(argId++, ocl::KernelArg::PtrWriteOnly(tmpMat));
k.set(argId++, NULL, localsize[0] * sizeof(cl_float4));
bool ret = k.run(1, globalsize, localsize, false); bool ret = k.run(1, globalsize, localsize, false);
if (!ret) if (!ret)
return false; return false;
...@@ -173,7 +175,7 @@ public: ...@@ -173,7 +175,7 @@ public:
buildopt += format(" %s %s", (fuse_batch_norm) ? "-DFUSE_BATCH_NORM" : "", buildopt += format(" %s %s", (fuse_batch_norm) ? "-DFUSE_BATCH_NORM" : "",
(fuse_relu) ? "-DFUSE_RELU" : ""); (fuse_relu) ? "-DFUSE_RELU" : "");
ocl::Kernel k1("mvn_fuse4", ocl::dnn::mvn_oclsrc, buildopt); ocl::Kernel k1("mvn_fuse4", ocl::dnn::mvn_oclsrc, buildopt + " -DKERNEL_MVN_FUSE");
argId = 0; argId = 0;
k1.set(argId++, ocl::KernelArg::PtrReadOnly(tmpMat)); k1.set(argId++, ocl::KernelArg::PtrReadOnly(tmpMat));
k1.set(argId++, ocl::KernelArg::PtrReadOnly(inpMat)); k1.set(argId++, ocl::KernelArg::PtrReadOnly(inpMat));
...@@ -185,7 +187,6 @@ public: ...@@ -185,7 +187,6 @@ public:
k1.set(argId++, ocl::KernelArg::PtrReadOnly(bnorm_weight)); k1.set(argId++, ocl::KernelArg::PtrReadOnly(bnorm_weight));
k1.set(argId++, ocl::KernelArg::PtrReadOnly(bnorm_bias)); k1.set(argId++, ocl::KernelArg::PtrReadOnly(bnorm_bias));
k1.set(argId++, ocl::KernelArg::PtrWriteOnly(outMat)); k1.set(argId++, ocl::KernelArg::PtrWriteOnly(outMat));
k1.set(argId++, NULL, localsize[0] * sizeof(cl_float4));
ret = k1.run(1, globalsize, localsize, false); ret = k1.run(1, globalsize, localsize, false);
if (!ret) if (!ret)
return false; return false;
...@@ -243,7 +244,7 @@ public: ...@@ -243,7 +244,7 @@ public:
if (normVariance) if (normVariance)
{ {
String kname = format("calc_mean%d", number); String kname = format("calc_mean%d", number);
ocl::Kernel kernel(kname.c_str(), ocl::dnn::mvn_oclsrc, buildopt); ocl::Kernel kernel(kname.c_str(), ocl::dnn::mvn_oclsrc, buildopt + " -DKERNEL_MEAN");
if (kernel.empty()) if (kernel.empty())
return false; return false;
...@@ -263,7 +264,7 @@ public: ...@@ -263,7 +264,7 @@ public:
} }
String kname = format("mvn%d", number); String kname = format("mvn%d", number);
buildopt += format("%s%s%s", (normVariance) ? " -DNORM_VARIANCE" : "", buildopt += format("%s%s%s -DKERNEL_MVN", (normVariance) ? " -DNORM_VARIANCE" : "",
(fuse_batch_norm) ? " -DFUSE_BATCH_NORM" : "", (fuse_batch_norm) ? " -DFUSE_BATCH_NORM" : "",
(fuse_relu) ? " -DFUSE_RELU" : ""); (fuse_relu) ? " -DFUSE_RELU" : "");
ocl::Kernel kernel1(kname.c_str(), ocl::dnn::mvn_oclsrc, buildopt); ocl::Kernel kernel1(kname.c_str(), ocl::dnn::mvn_oclsrc, buildopt);
......
...@@ -74,6 +74,8 @@ ...@@ -74,6 +74,8 @@
#define MVN_FUSE mvn_fuse1 #define MVN_FUSE mvn_fuse1
#endif #endif
#ifdef KERNEL_MEAN
__kernel void CALC_MEAN(__global const Dtype* src, __kernel void CALC_MEAN(__global const Dtype* src,
const int rows, const int rows,
const int cols, const int cols,
...@@ -94,6 +96,8 @@ __kernel void CALC_MEAN(__global const Dtype* src, ...@@ -94,6 +96,8 @@ __kernel void CALC_MEAN(__global const Dtype* src,
store(dst_vec, dst, index); store(dst_vec, dst, index);
} }
#elif defined KERNEL_MVN
__kernel void MVN(__global const Dtype* src, __kernel void MVN(__global const Dtype* src,
const int rows, const int rows,
const int cols, const int cols,
...@@ -140,12 +144,13 @@ __kernel void MVN(__global const Dtype* src, ...@@ -140,12 +144,13 @@ __kernel void MVN(__global const Dtype* src,
store(dst_vec, dst, index); store(dst_vec, dst, index);
} }
#elif defined KERNEL_MEAN_FUSE
__kernel void MEAN_FUSE(__global const T * A, __kernel void MEAN_FUSE(__global const T * A,
unsigned int A_col_size, unsigned int A_col_size,
float alpha, float alpha,
__global T4 * mean, __global T4 * mean,
__global Dtype * tmp, __global Dtype * tmp)
__local Dtype4 * work)
{ {
unsigned int row_gid = get_group_id(0); unsigned int row_gid = get_group_id(0);
unsigned int lid = get_local_id(0); unsigned int lid = get_local_id(0);
...@@ -168,15 +173,16 @@ __kernel void MEAN_FUSE(__global const T * A, ...@@ -168,15 +173,16 @@ __kernel void MEAN_FUSE(__global const T * A,
dot2 += convert_float4(a2); dot2 += convert_float4(a2);
dot3 += convert_float4(a3); dot3 += convert_float4(a3);
i += get_local_size(0); i += LOCAL_SIZE;
} }
__local Dtype4 work[LOCAL_SIZE];
work[lid].s0 = dot(dot0, b0); work[lid].s0 = dot(dot0, b0);
work[lid].s1 = dot(dot1, b0); work[lid].s1 = dot(dot1, b0);
work[lid].s2 = dot(dot2, b0); work[lid].s2 = dot(dot2, b0);
work[lid].s3 = dot(dot3, b0); work[lid].s3 = dot(dot3, b0);
for(unsigned int stride=get_local_size(0)/2 ; stride>0 ; stride>>=1) for(unsigned int stride=LOCAL_SIZE/2 ; stride>0 ; stride>>=1)
{ {
barrier(CLK_LOCAL_MEM_FENCE); barrier(CLK_LOCAL_MEM_FENCE);
if(lid < stride) if(lid < stride)
...@@ -212,10 +218,12 @@ __kernel void MEAN_FUSE(__global const T * A, ...@@ -212,10 +218,12 @@ __kernel void MEAN_FUSE(__global const T * A,
vstore4(dot2, i, dst0_read + 2 * A_col_size); vstore4(dot2, i, dst0_read + 2 * A_col_size);
vstore4(dot3, i, dst0_read + 3 * A_col_size); vstore4(dot3, i, dst0_read + 3 * A_col_size);
i += get_local_size(0); i += LOCAL_SIZE;
} }
} }
#elif defined KERNEL_MVN_FUSE
__kernel void MVN_FUSE(__global const Dtype * tmp, __kernel void MVN_FUSE(__global const Dtype * tmp,
__global const T * A, __global const T * A,
__global const T4 * mean, __global const T4 * mean,
...@@ -225,8 +233,7 @@ __kernel void MVN_FUSE(__global const Dtype * tmp, ...@@ -225,8 +233,7 @@ __kernel void MVN_FUSE(__global const Dtype * tmp,
const float relu_slope, const float relu_slope,
__global const Dtype4 * bnorm_weight, __global const Dtype4 * bnorm_weight,
__global const Dtype4 * bnorm_bias, __global const Dtype4 * bnorm_bias,
__global T * B, __global T * B)
__local Dtype4 * work)
{ {
unsigned int row_gid = get_group_id(0); unsigned int row_gid = get_group_id(0);
unsigned int lid = get_local_id(0); unsigned int lid = get_local_id(0);
...@@ -250,15 +257,16 @@ __kernel void MVN_FUSE(__global const Dtype * tmp, ...@@ -250,15 +257,16 @@ __kernel void MVN_FUSE(__global const Dtype * tmp,
dot2 += a2; dot2 += a2;
dot3 += a3; dot3 += a3;
i += get_local_size(0); i += LOCAL_SIZE;
} }
__local Dtype4 work[LOCAL_SIZE];
work[lid].s0 = dot(dot0, b0); work[lid].s0 = dot(dot0, b0);
work[lid].s1 = dot(dot1, b0); work[lid].s1 = dot(dot1, b0);
work[lid].s2 = dot(dot2, b0); work[lid].s2 = dot(dot2, b0);
work[lid].s3 = dot(dot3, b0); work[lid].s3 = dot(dot3, b0);
for(unsigned int stride=get_local_size(0)/2 ; stride>0 ; stride>>=1) for(unsigned int stride=LOCAL_SIZE/2 ; stride>0 ; stride>>=1)
{ {
barrier(CLK_LOCAL_MEM_FENCE); barrier(CLK_LOCAL_MEM_FENCE);
if(lid < stride) if(lid < stride)
...@@ -314,6 +322,10 @@ __kernel void MVN_FUSE(__global const Dtype * tmp, ...@@ -314,6 +322,10 @@ __kernel void MVN_FUSE(__global const Dtype * tmp,
vstore4(convert_T(dot2), i, dst0_read + 2 * A_col_size); vstore4(convert_T(dot2), i, dst0_read + 2 * A_col_size);
vstore4(convert_T(dot3), i, dst0_read + 3 * A_col_size); vstore4(convert_T(dot3), i, dst0_read + 3 * A_col_size);
i += get_local_size(0); i += LOCAL_SIZE;
} }
} }
#else
#error "Configuration error!"
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册