From 177dec94c51b4cfcc92e1f5e1fdf5c6eb03172f0 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 20 Sep 2021 19:43:59 +0800 Subject: [PATCH] feat(mgb/opr): add bgr2gray mode for cvtcolor opr GitOrigin-RevId: d50415b236080a13d31c43158f9d03e2aef48d59 --- dnn/src/arm_common/cvt_color/opr_impl.cpp | 23 +++++- dnn/src/common/cv/cvt_color.h | 1 - dnn/src/cuda/cvt_color/cvt_color.cu | 73 +++++++++++++++++++ dnn/src/naive/cvt_color/opr_impl.cpp | 20 +++++ dnn/src/x86/cvt_color/opr_impl.cpp | 45 ++++++++++++ dnn/test/common/cvt_color.h | 6 ++ .../test/unit/functional/test_functional.py | 7 ++ 7 files changed, 170 insertions(+), 5 deletions(-) diff --git a/dnn/src/arm_common/cvt_color/opr_impl.cpp b/dnn/src/arm_common/cvt_color/opr_impl.cpp index 99cd4020..8ec548bb 100644 --- a/dnn/src/arm_common/cvt_color/opr_impl.cpp +++ b/dnn/src/arm_common/cvt_color/opr_impl.cpp @@ -748,11 +748,16 @@ void cvt_BT601_yuv_transform(const Mat8u& src, Mat8u& dst) { } // namespace +template void cvt_rgb2gray_32f_neon(const Mat32f& src, Mat32f& dst) { static const float coef[] = {0.299f, 0.587f, 0.114f}; // load coef into neon types - const float32x4_t v_cr(vdupq_n_f32(coef[0])), v_cg(vdupq_n_f32(coef[1])), - v_cb(vdupq_n_f32(coef[2])); + float coef_c0 = rgb ? coef[0] : coef[2]; + float coef_c1 = coef[1]; + float coef_c2 = rgb ? coef[2] : coef[0]; + + const float32x4_t v_cr(vdupq_n_f32(coef_c0)), v_cg(vdupq_n_f32(coef_c1)), + v_cb(vdupq_n_f32(coef_c2)); #define EXPAND(offset) \ v_src = vld3q_f32(psrc + offset * 3); \ @@ -796,7 +801,7 @@ void cvt_rgb2gray_32f_neon(const Mat32f& src, Mat32f& dst) { } // loop over left pixels for (; psrc < pend; psrc += 3, pdst += 1) { - *pdst = psrc[0] * coef[0] + psrc[1] * coef[1] + psrc[2] * coef[2]; + *pdst = psrc[0] * coef_c0 + psrc[1] * coef_c1 + psrc[2] * coef_c2; } } #undef EXPAND @@ -1187,7 +1192,7 @@ void cvt_rgb2gray(const Mat32f& src, Mat32f& dst) { megdnn_assert(src.rows() == dst.rows()); megdnn_assert(src.cols() == dst.cols()); - return cvt_rgb2gray_32f_neon(src, dst); + return cvt_rgb2gray_32f_neon(src, dst); } // gray2rgb @@ -1381,6 +1386,16 @@ void cvt_bgr2gray(const Mat8u& src, Mat8u& dst) { } } +template <> +void cvt_bgr2gray(const Mat32f& src, Mat32f& dst) { + megdnn_assert(src.channels() == 3); + megdnn_assert(dst.channels() == 1); + megdnn_assert(src.rows() == dst.rows()); + megdnn_assert(src.cols() == dst.cols()); + + return cvt_rgb2gray_32f_neon(src, dst); +} + template <> void cvt_bgr2rgb(const Mat8u& src, Mat8u& dst) { return cvt_rgb2bgr(src, dst); diff --git a/dnn/src/common/cv/cvt_color.h b/dnn/src/common/cv/cvt_color.h index d6d88ba8..377478bb 100644 --- a/dnn/src/common/cv/cvt_color.h +++ b/dnn/src/common/cv/cvt_color.h @@ -45,7 +45,6 @@ _cb(cvt_rgba2bgr, float) \ _cb(cvt_rgba2gray, float) \ _cb(cvt_rgb2bgr, float) \ - _cb(cvt_bgr2gray, float) \ _cb(cvt_bgr2rgb, float) \ _cb(cvt_yuv2gray_nv21, float) \ _cb(cvt_yuv2rgb_nv21, float) \ diff --git a/dnn/src/cuda/cvt_color/cvt_color.cu b/dnn/src/cuda/cvt_color/cvt_color.cu index 9bb1de7e..012c13ce 100644 --- a/dnn/src/cuda/cvt_color/cvt_color.cu +++ b/dnn/src/cuda/cvt_color/cvt_color.cu @@ -145,6 +145,73 @@ __global__ void cvt_rgb2gray_32f_kernel(const float* src, float* dst, } } + +__global__ void cvt_bgr2gray_8u_kernel(const uchar* src, uchar* dst, + const size_t rows, const size_t cols, + const size_t src_step, + const size_t dst_step) { + size_t t = blockIdx.x * blockDim.x + threadIdx.x; + + if (t < (rows * cols) / U8_PROCESS_PER_THREADS_X) { + size_t offset = t * U8_PROCESS_PER_THREADS_X; + src += 3 * offset; + dst += 1 * offset; + + uchar temp_des[4]; + uchar temp_src[12]; + *((uint3*)temp_src) = *((uint3*)src); + + temp_des[0] = (temp_src[0] * 1868 + temp_src[1] * 9617 + + temp_src[2] * 4899 + (1 << 13)) >> + 14; + temp_des[1] = (temp_src[3] * 1868 + temp_src[4] * 9617 + + temp_src[5] * 4899 + (1 << 13)) >> + 14; + temp_des[2] = (temp_src[6] * 1868 + temp_src[7] * 9617 + + temp_src[8] * 4899 + (1 << 13)) >> + 14; + temp_des[3] = (temp_src[9] * 1868 + temp_src[10] * 9617 + + temp_src[11] * 4899 + (1 << 13)) >> + 14; + + *((uint32_t*)dst) = *((uint32_t*)temp_des); + } else if (t == (rows * cols) / U8_PROCESS_PER_THREADS_X) { + size_t rest = (rows * cols) % U8_PROCESS_PER_THREADS_X; + if (rest != 0) { + size_t offset = t * U8_PROCESS_PER_THREADS_X; + src += 3 * offset; + dst += 1 * offset; + + for (int i = 0; i < rest; i++, src += 3, dst += 1) + dst[0] = (src[0] * 1868 + src[1] * 9617 + src[2] * 4899 + + (1 << 13)) >> + 14; + } + } +} + +__global__ void cvt_bgr2gray_32f_kernel(const float* src, float* dst, + const size_t rows, const size_t cols, + const size_t src_step, + const size_t dst_step) { + size_t t = blockIdx.x * blockDim.x + threadIdx.x; + + if (t < rows * cols) { + size_t offset = t; + src += offset * 3; + dst += offset * 1; + + float temp_src[3], temp_dst; + *((float3*)temp_src) = *((float3*)src); + + temp_dst = temp_src[0] * 0.114f + temp_src[1] * 0.587f + + temp_src[2] * 0.299f; + + dst[0] = temp_dst; + } +} + + __global__ void cvt_gray2rgb_8u_kernel(const uchar* src, uchar* dst, const size_t rows, const size_t cols, const size_t src_step, @@ -683,6 +750,9 @@ void cvt_color_8u_proxy(const uchar* src, uchar* dst, const size_t src_rows, case CvtColor::Mode::RGB2GRAY: CALL_CVT_OPR_8U_KERNEL(rgb2gray) break; + case CvtColor::Mode::BGR2GRAY: + CALL_CVT_OPR_8U_KERNEL(bgr2gray) + break; case CvtColor::Mode::RGB2YUV: CALL_CVT_OPR_8U_KERNEL(rgb2yuv) break; @@ -739,6 +809,9 @@ void cvt_color_32f_proxy(const float* src, float* dst, const size_t src_rows, case CvtColor::Mode::RGB2GRAY: CALL_CVT_OPR_32F_KERNEL(rgb2gray) break; + case CvtColor::Mode::BGR2GRAY: + CALL_CVT_OPR_32F_KERNEL(bgr2gray) + break; case CvtColor::Mode::RGB2YUV: CALL_CVT_OPR_32F_KERNEL(rgb2yuv) break; diff --git a/dnn/src/naive/cvt_color/opr_impl.cpp b/dnn/src/naive/cvt_color/opr_impl.cpp index c676ca51..2f6392c7 100644 --- a/dnn/src/naive/cvt_color/opr_impl.cpp +++ b/dnn/src/naive/cvt_color/opr_impl.cpp @@ -684,6 +684,26 @@ void cvt_bgr2gray(const Mat8u& src, Mat8u& dst) { } } +template <> +void cvt_bgr2gray(const Mat32f& src, Mat32f& dst) { + megdnn_assert(src.channels() == 3); + megdnn_assert(dst.channels() == 1); + megdnn_assert(src.rows() == dst.rows()); + megdnn_assert(src.cols() == dst.cols()); + + const float coef_r = 0.299f, coef_g = 0.587f, coef_b = 0.114f; + for (size_t r = 0; r < src.rows(); ++r) { + for (size_t c = 0; c < src.cols(); ++c) { + float B = src.at(r, c, 0); + float G = src.at(r, c, 1); + float R = src.at(r, c, 2); + float& Y = dst.at(r, c, 0); + Y = R * coef_r + G * coef_g + B * coef_b; + } + } +} + + template <> void cvt_bgr2rgb(const Mat8u& src, Mat8u& dst) { return cvt_rgb2bgr(src, dst); diff --git a/dnn/src/x86/cvt_color/opr_impl.cpp b/dnn/src/x86/cvt_color/opr_impl.cpp index 36ffba84..b25cdfa1 100644 --- a/dnn/src/x86/cvt_color/opr_impl.cpp +++ b/dnn/src/x86/cvt_color/opr_impl.cpp @@ -1311,6 +1311,41 @@ void cvt_rgb2gray_32f_SSE_4_2(const Mat32f& src, Mat32f& dst) { } } +MEGDNN_ATTRIBUTE_TARGET("sse4.2") +void cvt_bgr2gray_32f_SSE_4_2(const Mat32f& src, Mat32f& dst) { + const float coef_r = 0.299f, coef_g = 0.587f, coef_b = 0.114f; + __m128 v_coef_r = _mm_set1_ps(coef_r); + __m128 v_coef_g = _mm_set1_ps(coef_g); + __m128 v_coef_b = _mm_set1_ps(coef_b); + + for (size_t r = 0; r < src.rows(); ++r) { + const float* psrc = src.ptr(r); + float* pdst = dst.ptr(r); + const float* const pend = psrc + src.cols() * 3; + __m128 v_r, v_g, v_b, ans; + for (; psrc <= pend - 4 * 3; psrc += 4 * 3, pdst += 4) { + v_b = _mm_set_ps(psrc[9], psrc[6], psrc[3], psrc[0]); + v_b = _mm_mul_ps(v_b, v_coef_b); + + v_g = _mm_set_ps(psrc[10], psrc[7], psrc[4], psrc[1]); + v_g = _mm_mul_ps(v_g, v_coef_g); + + v_r = _mm_set_ps(psrc[11], psrc[8], psrc[5], psrc[2]); + v_r = _mm_mul_ps(v_r, v_coef_r); + + ans = _mm_add_ps(v_r, _mm_add_ps(v_g, v_b)); + + _mm_storeu_ps(pdst, ans); + } + + for (; psrc < pend; psrc += 3, pdst += 1) { + pdst[0] = psrc[1] * coef_g + psrc[0] * coef_b + psrc[2] * coef_r; + } + } +} + + + MEGDNN_ATTRIBUTE_TARGET("sse4.2") void cvt_rgba2rgb_8u_SSE_4_2(const Mat8u& src, Mat8u& dst) { __m128i dst_data0, dst_data1, dst_data2; @@ -1705,6 +1740,16 @@ void cvt_bgr2gray(const Mat8u& src, Mat8u& dst) { } } +template <> +void cvt_bgr2gray(const Mat32f& src, Mat32f& dst) { + megdnn_assert(src.channels() == 3); + megdnn_assert(dst.channels() == 1); + megdnn_assert(src.rows() == dst.rows()); + megdnn_assert(src.cols() == dst.cols()); + + return cvt_bgr2gray_32f_SSE_4_2(src, dst); +} + template <> void cvt_bgr2rgb(const Mat8u& src, Mat8u& dst) { return cvt_rgb2bgr(src, dst); diff --git a/dnn/test/common/cvt_color.h b/dnn/test/common/cvt_color.h index 70b41342..eb07e23a 100644 --- a/dnn/test/common/cvt_color.h +++ b/dnn/test/common/cvt_color.h @@ -133,6 +133,9 @@ inline std::vector get_cuda_args() { for (size_t i = 2; i <= 10; ++i) { for (size_t j = 2; j <= 10; ++j) { cur_param.mode = Mode::RGB2GRAY; + args.emplace_back(cur_param, TensorShape{1, i, j, 3}, + dtype::Uint8()); + cur_param.mode = Mode::BGR2GRAY; args.emplace_back(cur_param, TensorShape{1, i, j, 3}, dtype::Uint8()); cur_param.mode = Mode::RGB2YUV; @@ -146,6 +149,9 @@ inline std::vector get_cuda_args() { dtype::Uint8()); // float32 test cur_param.mode = Mode::RGB2GRAY; + args.emplace_back(cur_param, TensorShape{1, i, j, 3}, + dtype::Float32()); + cur_param.mode = Mode::BGR2GRAY; args.emplace_back(cur_param, TensorShape{1, i, j, 3}, dtype::Float32()); cur_param.mode = Mode::RGB2YUV; diff --git a/imperative/python/test/unit/functional/test_functional.py b/imperative/python/test/unit/functional/test_functional.py index 4771562e..364cffd2 100644 --- a/imperative/python/test/unit/functional/test_functional.py +++ b/imperative/python/test/unit/functional/test_functional.py @@ -1057,12 +1057,19 @@ def test_cvt_color(): def rgb2gray(rgb): return np.dot(rgb[..., :3], [0.299, 0.587, 0.114]) + def bgr2gray(bgr): + return np.dot(bgr[..., :3], [0.114, 0.587, 0.299]) + inp = np.random.randn(3, 3, 3, 3).astype(np.float32) out = np.expand_dims(rgb2gray(inp), 3).astype(np.float32) x = tensor(inp) y = F.vision.cvt_color(x, mode="RGB2GRAY") np.testing.assert_allclose(y.numpy(), out, atol=1e-5) + out1 = np.expand_dims(bgr2gray(inp), 3).astype(np.float32) + y1 = F.vision.cvt_color(x, mode="BGR2GRAY") + np.testing.assert_allclose(y1.numpy(), out1, atol=1e-5) + @pytest.mark.parametrize("val", [2, [2,], [2, 3]]) def test_ones(val): -- GitLab