diff --git a/modules/dnn/src/layers/mvn_layer.cpp b/modules/dnn/src/layers/mvn_layer.cpp index a74bc0e14e280c6919460df7028bbac294c85365..c911b741b47b1acc2f4817996077d99e16f6049f 100644 --- a/modules/dnn/src/layers/mvn_layer.cpp +++ b/modules/dnn/src/layers/mvn_layer.cpp @@ -93,6 +93,67 @@ public: } #ifdef HAVE_OPENCL + bool fast_forward_ocl(std::vector &inputs, std::vector &outputs) + { + if( fuse_batch_norm && scale.empty()) + { + bnorm->getScaleShift(scale, shift); + bnorm_weight = scale.getUMat(ACCESS_READ); + bnorm_bias = shift.getUMat(ACCESS_READ); + } + + int splitDim = (acrossChannels) ? 1 : 2; + for (size_t inpIdx = 0; inpIdx < inputs.size(); inpIdx++) + { + UMat &inpMat = inputs[inpIdx]; + UMat &outMat = outputs[inpIdx]; + int newRows = total(shape(inpMat), 0, splitDim); + + MatShape s = shape(newRows, inpMat.total() / newRows); + UMat oneMat = UMat::ones(s[1], 1, CV_32F); + UMat meanMat = UMat(s[0], 1, CV_32F); + UMat tmpMat = UMat(s[0], s[1], CV_32F); + float alpha = 1.0f / s[1]; + + String buildopt = "-DNUM=4"; + ocl::Kernel k("mean_fuse4", ocl::dnn::mvn_oclsrc, buildopt); + size_t localsize[] = { 128 }; + size_t globalsize[] = { (size_t)s[0] / 4 * localsize[0] }; + + int argId = 0; + k.set(argId++, ocl::KernelArg::PtrReadOnly(inpMat)); + k.set(argId++, (int)s[1]); + k.set(argId++, alpha); + k.set(argId++, ocl::KernelArg::PtrWriteOnly(meanMat)); + k.set(argId++, ocl::KernelArg::PtrWriteOnly(tmpMat)); + k.set(argId++, NULL, localsize[0] * sizeof(cl_float4)); + bool ret = k.run(1, globalsize, localsize, false); + if (!ret) + return false; + + buildopt += format(" %s %s", (fuse_batch_norm) ? "-DFUSE_BATCH_NORM" : "", + (fuse_relu) ? "-DFUSE_RELU" : ""); + + ocl::Kernel k1("mvn_fuse4", ocl::dnn::mvn_oclsrc, buildopt); + argId = 0; + k1.set(argId++, ocl::KernelArg::PtrReadOnly(tmpMat)); + k1.set(argId++, ocl::KernelArg::PtrReadOnly(inpMat)); + k1.set(argId++, ocl::KernelArg::PtrReadOnly(meanMat)); + k1.set(argId++, (int)s[1]); + k1.set(argId++, (float)alpha); + k1.set(argId++, (float)eps); + k1.set(argId++, (float)relu_slope); + k1.set(argId++, ocl::KernelArg::PtrReadOnly(bnorm_weight)); + k1.set(argId++, ocl::KernelArg::PtrReadOnly(bnorm_bias)); + k1.set(argId++, ocl::KernelArg::PtrWriteOnly(outMat)); + k1.set(argId++, NULL, localsize[0] * sizeof(cl_float4)); + ret = k1.run(1, globalsize, localsize, false); + if (!ret) + return false; + } + return true; + } + bool forward_ocl(InputArrayOfArrays inputs_, OutputArrayOfArrays outputs_, OutputArrayOfArrays internals_) { std::vector inputs; @@ -101,6 +162,15 @@ public: inputs_.getUMatVector(inputs); outputs_.getUMatVector(outputs); + int splitDim = (acrossChannels) ? 1 : 2; + int row_size = total(shape(inputs[0]), 0, splitDim); + int plane_size = total(shape(inputs[0]), splitDim); + if (normVariance && (row_size % 4 == 0) && (plane_size % 4 == 0)) + { + bool ret = fast_forward_ocl(inputs, outputs); + return ret; + } + if( fuse_batch_norm && scale.empty()) { bnorm->getScaleShift(scale, shift); @@ -112,11 +182,7 @@ public: { UMat &inpMat = inputs[inpIdx]; UMat &outMat = outputs[inpIdx]; - - int splitDim = (acrossChannels) ? 1 : 2; - int i, newRows = 1; - for( i = 0; i < splitDim; i++ ) - newRows *= inpMat.size[i]; + int newRows = total(shape(inpMat), 0, splitDim); MatShape s = shape(newRows, inpMat.total() / newRows); UMat oneMat = UMat::ones(s[1], 1, CV_32F); diff --git a/modules/dnn/src/opencl/mvn.cl b/modules/dnn/src/opencl/mvn.cl index cc059eeb1ad2970ef66c723e6348d4b571a37789..9f8ab574cad6abe0f3553a9677c5b84f68d8aa88 100644 --- a/modules/dnn/src/opencl/mvn.cl +++ b/modules/dnn/src/opencl/mvn.cl @@ -50,18 +50,24 @@ #define vec_type Dtype8 #define CALC_MEAN calc_mean8 #define MVN mvn8 + #define MEAN_FUSE mean_fuse8 + #define MVN_FUSE mvn_fuse8 #elif NUM == 4 #define load(src, index) vload4(0, src + index) #define store(vec, dst, index) vstore4(vec, 0, dst + index) #define vec_type Dtype4 #define CALC_MEAN calc_mean4 #define MVN mvn4 + #define MEAN_FUSE mean_fuse4 + #define MVN_FUSE mvn_fuse4 #elif NUM == 1 #define load(src, index) src[index] #define store(vec, dst, index) dst[index] = vec #define vec_type Dtype #define CALC_MEAN calc_mean1 #define MVN mvn1 + #define MEAN_FUSE mean_fuse1 + #define MVN_FUSE mvn_fuse1 #endif __kernel void CALC_MEAN(__global const Dtype* src, @@ -128,3 +134,177 @@ __kernel void MVN(__global const Dtype* src, store(dst_vec, dst, index); } + +__kernel void MEAN_FUSE(__global const Dtype * A, + unsigned int A_col_size, + float alpha, + __global Dtype4 * result, + __global Dtype * B, + __local Dtype4 * work) +{ + unsigned int row_gid = get_group_id(0); + unsigned int lid = get_local_id(0); + const __global Dtype *src0_read = A + row_gid * 4 * A_col_size; + __global Dtype *dst0_read = B + row_gid * 4 * A_col_size; + Dtype4 dot0, dot1, dot2, dot3; + dot0 = dot1 = dot2 = dot3 = (Dtype4)(0.f); + + unsigned int i = lid; + const Dtype4 b0 = (Dtype4)1.f; + while( i < A_col_size / 4) + { + const Dtype4 a0 = vload4(i, src0_read); + const Dtype4 a1 = vload4(i, src0_read + A_col_size); + const Dtype4 a2 = vload4(i, src0_read + 2 * A_col_size); + const Dtype4 a3 = vload4(i, src0_read + 3 * A_col_size); + + dot0 += a0; + dot1 += a1; + dot2 += a2; + dot3 += a3; + + i += get_local_size(0); + } + + work[lid].s0 = dot(dot0, b0); + work[lid].s1 = dot(dot1, b0); + work[lid].s2 = dot(dot2, b0); + work[lid].s3 = dot(dot3, b0); + + for(unsigned int stride=get_local_size(0)/2 ; stride>0 ; stride>>=1) + { + barrier(CLK_LOCAL_MEM_FENCE); + if(lid < stride) + work[lid] += work[lid+stride]; + } + barrier(CLK_LOCAL_MEM_FENCE); + + if(lid == 0) + { + result[row_gid] = alpha * work[0]; + } + + Dtype4 sum = work[0] * alpha; + i = lid; + while( i < A_col_size / 4) + { + const Dtype4 a0 = vload4(i, src0_read); + const Dtype4 a1 = vload4(i, src0_read + A_col_size); + const Dtype4 a2 = vload4(i, src0_read + 2 * A_col_size); + const Dtype4 a3 = vload4(i, src0_read + 3 * A_col_size); + + dot0 = native_powr(a0 - (Dtype4)sum.x, 2); + dot1 = native_powr(a1 - (Dtype4)sum.y, 2); + dot2 = native_powr(a2 - (Dtype4)sum.z, 2); + dot3 = native_powr(a3 - (Dtype4)sum.w, 2); + + vstore4(dot0, i, dst0_read); + vstore4(dot1, i, dst0_read + A_col_size); + vstore4(dot2, i, dst0_read + 2 * A_col_size); + vstore4(dot3, i, dst0_read + 3 * A_col_size); + + i += get_local_size(0); + } +} + +__kernel void MVN_FUSE(__global const Dtype * tmp, + __global const Dtype * A, + __global const Dtype4 * mean, + unsigned int A_col_size, + const float alpha_val, + const float eps, + const float relu_slope, + __global const Dtype4 * bnorm_weight, + __global const Dtype4 * bnorm_bias, + __global Dtype * B, + __local Dtype4 * work) +{ + unsigned int row_gid = get_group_id(0); + unsigned int lid = get_local_id(0); + const __global Dtype *src0_read = tmp + row_gid * 4 * A_col_size; + const __global Dtype *src1_read = A + row_gid * 4 * A_col_size; + __global Dtype *dst0_read = B + row_gid * 4 * A_col_size; + Dtype4 dot0, dot1, dot2, dot3; + dot0 = dot1 = dot2 = dot3 = (Dtype4)(0.f); + + unsigned int i = lid; + const Dtype4 b0 = (Dtype4)1.f; + while( i < A_col_size / 4) + { + const Dtype4 a0 = vload4(i, src0_read); + const Dtype4 a1 = vload4(i, src0_read + A_col_size); + const Dtype4 a2 = vload4(i, src0_read + 2 * A_col_size); + const Dtype4 a3 = vload4(i, src0_read + 3 * A_col_size); + + dot0 += a0; + dot1 += a1; + dot2 += a2; + dot3 += a3; + + i += get_local_size(0); + } + + work[lid].s0 = dot(dot0, b0); + work[lid].s1 = dot(dot1, b0); + work[lid].s2 = dot(dot2, b0); + work[lid].s3 = dot(dot3, b0); + + for(unsigned int stride=get_local_size(0)/2 ; stride>0 ; stride>>=1) + { + barrier(CLK_LOCAL_MEM_FENCE); + if(lid < stride) + work[lid] += work[lid+stride]; + } + barrier(CLK_LOCAL_MEM_FENCE); + + Dtype4 mean_val = mean[row_gid]; + Dtype4 dev_val = sqrt(work[0] * alpha_val) + (Dtype4)eps; + Dtype4 alpha = (Dtype4)1.f / dev_val; + + Dtype4 w = (Dtype4)1.f; + Dtype4 b = (Dtype4)0.f; +#ifdef FUSE_BATCH_NORM + w = bnorm_weight[row_gid]; + b = bnorm_bias[row_gid]; +#endif + + i = lid; + while( i < A_col_size / 4) + { + const Dtype4 a0 = vload4(i, src1_read); + const Dtype4 a1 = vload4(i, src1_read + A_col_size); + const Dtype4 a2 = vload4(i, src1_read + 2 * A_col_size); + const Dtype4 a3 = vload4(i, src1_read + 3 * A_col_size); + + dot0 = (a0 - (Dtype4)mean_val.x) * alpha.x; + dot1 = (a1 - (Dtype4)mean_val.y) * alpha.y; + dot2 = (a2 - (Dtype4)mean_val.z) * alpha.z; + dot3 = (a3 - (Dtype4)mean_val.w) * alpha.w; + + dot0 = dot0 * w.x + (Dtype4)b.x; + dot1 = dot1 * w.y + (Dtype4)b.y; + dot2 = dot2 * w.z + (Dtype4)b.z; + dot3 = dot3 * w.w + (Dtype4)b.w; + +#ifdef FUSE_RELU + Dtype4 new0 = dot0 * relu_slope; + dot0 = select(new0, dot0, dot0 > (Dtype4)0.f); + + Dtype4 new1 = dot1 * relu_slope; + dot1 = select(new1, dot1, dot1 > (Dtype4)0.f); + + Dtype4 new2 = dot2 * relu_slope; + dot2 = select(new2, dot2, dot2 > (Dtype4)0.f); + + Dtype4 new3 = dot3 * relu_slope; + dot3 = select(new3, dot3, dot3 > (Dtype4)0.f); +#endif + + vstore4(dot0, i, dst0_read); + vstore4(dot1, i, dst0_read + A_col_size); + vstore4(dot2, i, dst0_read + 2 * A_col_size); + vstore4(dot3, i, dst0_read + 3 * A_col_size); + + i += get_local_size(0); + } +}