提交 177dec94 编写于 作者: M Megvii Engine Team

feat(mgb/opr): add bgr2gray mode for cvtcolor opr

GitOrigin-RevId: d50415b236080a13d31c43158f9d03e2aef48d59
上级 000517c6
...@@ -748,11 +748,16 @@ void cvt_BT601_yuv_transform(const Mat8u& src, Mat8u& dst) { ...@@ -748,11 +748,16 @@ void cvt_BT601_yuv_transform(const Mat8u& src, Mat8u& dst) {
} // namespace } // namespace
template<bool rgb = true>
void cvt_rgb2gray_32f_neon(const Mat32f& src, Mat32f& dst) { void cvt_rgb2gray_32f_neon(const Mat32f& src, Mat32f& dst) {
static const float coef[] = {0.299f, 0.587f, 0.114f}; static const float coef[] = {0.299f, 0.587f, 0.114f};
// load coef into neon types // load coef into neon types
const float32x4_t v_cr(vdupq_n_f32(coef[0])), v_cg(vdupq_n_f32(coef[1])), float coef_c0 = rgb ? coef[0] : coef[2];
v_cb(vdupq_n_f32(coef[2])); float coef_c1 = coef[1];
float coef_c2 = rgb ? coef[2] : coef[0];
const float32x4_t v_cr(vdupq_n_f32(coef_c0)), v_cg(vdupq_n_f32(coef_c1)),
v_cb(vdupq_n_f32(coef_c2));
#define EXPAND(offset) \ #define EXPAND(offset) \
v_src = vld3q_f32(psrc + offset * 3); \ v_src = vld3q_f32(psrc + offset * 3); \
...@@ -796,7 +801,7 @@ void cvt_rgb2gray_32f_neon(const Mat32f& src, Mat32f& dst) { ...@@ -796,7 +801,7 @@ void cvt_rgb2gray_32f_neon(const Mat32f& src, Mat32f& dst) {
} }
// loop over left pixels // loop over left pixels
for (; psrc < pend; psrc += 3, pdst += 1) { for (; psrc < pend; psrc += 3, pdst += 1) {
*pdst = psrc[0] * coef[0] + psrc[1] * coef[1] + psrc[2] * coef[2]; *pdst = psrc[0] * coef_c0 + psrc[1] * coef_c1 + psrc[2] * coef_c2;
} }
} }
#undef EXPAND #undef EXPAND
...@@ -1187,7 +1192,7 @@ void cvt_rgb2gray<float>(const Mat32f& src, Mat32f& dst) { ...@@ -1187,7 +1192,7 @@ void cvt_rgb2gray<float>(const Mat32f& src, Mat32f& dst) {
megdnn_assert(src.rows() == dst.rows()); megdnn_assert(src.rows() == dst.rows());
megdnn_assert(src.cols() == dst.cols()); megdnn_assert(src.cols() == dst.cols());
return cvt_rgb2gray_32f_neon(src, dst); return cvt_rgb2gray_32f_neon<true>(src, dst);
} }
// gray2rgb // gray2rgb
...@@ -1381,6 +1386,16 @@ void cvt_bgr2gray<uchar>(const Mat8u& src, Mat8u& dst) { ...@@ -1381,6 +1386,16 @@ void cvt_bgr2gray<uchar>(const Mat8u& src, Mat8u& dst) {
} }
} }
template <>
void cvt_bgr2gray<float>(const Mat32f& src, Mat32f& dst) {
megdnn_assert(src.channels() == 3);
megdnn_assert(dst.channels() == 1);
megdnn_assert(src.rows() == dst.rows());
megdnn_assert(src.cols() == dst.cols());
return cvt_rgb2gray_32f_neon<false>(src, dst);
}
template <> template <>
void cvt_bgr2rgb<uchar>(const Mat8u& src, Mat8u& dst) { void cvt_bgr2rgb<uchar>(const Mat8u& src, Mat8u& dst) {
return cvt_rgb2bgr<uchar>(src, dst); return cvt_rgb2bgr<uchar>(src, dst);
......
...@@ -45,7 +45,6 @@ ...@@ -45,7 +45,6 @@
_cb(cvt_rgba2bgr, float) \ _cb(cvt_rgba2bgr, float) \
_cb(cvt_rgba2gray, float) \ _cb(cvt_rgba2gray, float) \
_cb(cvt_rgb2bgr, float) \ _cb(cvt_rgb2bgr, float) \
_cb(cvt_bgr2gray, float) \
_cb(cvt_bgr2rgb, float) \ _cb(cvt_bgr2rgb, float) \
_cb(cvt_yuv2gray_nv21, float) \ _cb(cvt_yuv2gray_nv21, float) \
_cb(cvt_yuv2rgb_nv21, float) \ _cb(cvt_yuv2rgb_nv21, float) \
......
...@@ -145,6 +145,73 @@ __global__ void cvt_rgb2gray_32f_kernel(const float* src, float* dst, ...@@ -145,6 +145,73 @@ __global__ void cvt_rgb2gray_32f_kernel(const float* src, float* dst,
} }
} }
__global__ void cvt_bgr2gray_8u_kernel(const uchar* src, uchar* dst,
const size_t rows, const size_t cols,
const size_t src_step,
const size_t dst_step) {
size_t t = blockIdx.x * blockDim.x + threadIdx.x;
if (t < (rows * cols) / U8_PROCESS_PER_THREADS_X) {
size_t offset = t * U8_PROCESS_PER_THREADS_X;
src += 3 * offset;
dst += 1 * offset;
uchar temp_des[4];
uchar temp_src[12];
*((uint3*)temp_src) = *((uint3*)src);
temp_des[0] = (temp_src[0] * 1868 + temp_src[1] * 9617 +
temp_src[2] * 4899 + (1 << 13)) >>
14;
temp_des[1] = (temp_src[3] * 1868 + temp_src[4] * 9617 +
temp_src[5] * 4899 + (1 << 13)) >>
14;
temp_des[2] = (temp_src[6] * 1868 + temp_src[7] * 9617 +
temp_src[8] * 4899 + (1 << 13)) >>
14;
temp_des[3] = (temp_src[9] * 1868 + temp_src[10] * 9617 +
temp_src[11] * 4899 + (1 << 13)) >>
14;
*((uint32_t*)dst) = *((uint32_t*)temp_des);
} else if (t == (rows * cols) / U8_PROCESS_PER_THREADS_X) {
size_t rest = (rows * cols) % U8_PROCESS_PER_THREADS_X;
if (rest != 0) {
size_t offset = t * U8_PROCESS_PER_THREADS_X;
src += 3 * offset;
dst += 1 * offset;
for (int i = 0; i < rest; i++, src += 3, dst += 1)
dst[0] = (src[0] * 1868 + src[1] * 9617 + src[2] * 4899 +
(1 << 13)) >>
14;
}
}
}
__global__ void cvt_bgr2gray_32f_kernel(const float* src, float* dst,
const size_t rows, const size_t cols,
const size_t src_step,
const size_t dst_step) {
size_t t = blockIdx.x * blockDim.x + threadIdx.x;
if (t < rows * cols) {
size_t offset = t;
src += offset * 3;
dst += offset * 1;
float temp_src[3], temp_dst;
*((float3*)temp_src) = *((float3*)src);
temp_dst = temp_src[0] * 0.114f + temp_src[1] * 0.587f +
temp_src[2] * 0.299f;
dst[0] = temp_dst;
}
}
__global__ void cvt_gray2rgb_8u_kernel(const uchar* src, uchar* dst, __global__ void cvt_gray2rgb_8u_kernel(const uchar* src, uchar* dst,
const size_t rows, const size_t cols, const size_t rows, const size_t cols,
const size_t src_step, const size_t src_step,
...@@ -683,6 +750,9 @@ void cvt_color_8u_proxy(const uchar* src, uchar* dst, const size_t src_rows, ...@@ -683,6 +750,9 @@ void cvt_color_8u_proxy(const uchar* src, uchar* dst, const size_t src_rows,
case CvtColor::Mode::RGB2GRAY: case CvtColor::Mode::RGB2GRAY:
CALL_CVT_OPR_8U_KERNEL(rgb2gray) CALL_CVT_OPR_8U_KERNEL(rgb2gray)
break; break;
case CvtColor::Mode::BGR2GRAY:
CALL_CVT_OPR_8U_KERNEL(bgr2gray)
break;
case CvtColor::Mode::RGB2YUV: case CvtColor::Mode::RGB2YUV:
CALL_CVT_OPR_8U_KERNEL(rgb2yuv) CALL_CVT_OPR_8U_KERNEL(rgb2yuv)
break; break;
...@@ -739,6 +809,9 @@ void cvt_color_32f_proxy(const float* src, float* dst, const size_t src_rows, ...@@ -739,6 +809,9 @@ void cvt_color_32f_proxy(const float* src, float* dst, const size_t src_rows,
case CvtColor::Mode::RGB2GRAY: case CvtColor::Mode::RGB2GRAY:
CALL_CVT_OPR_32F_KERNEL(rgb2gray) CALL_CVT_OPR_32F_KERNEL(rgb2gray)
break; break;
case CvtColor::Mode::BGR2GRAY:
CALL_CVT_OPR_32F_KERNEL(bgr2gray)
break;
case CvtColor::Mode::RGB2YUV: case CvtColor::Mode::RGB2YUV:
CALL_CVT_OPR_32F_KERNEL(rgb2yuv) CALL_CVT_OPR_32F_KERNEL(rgb2yuv)
break; break;
......
...@@ -684,6 +684,26 @@ void cvt_bgr2gray<uchar>(const Mat8u& src, Mat8u& dst) { ...@@ -684,6 +684,26 @@ void cvt_bgr2gray<uchar>(const Mat8u& src, Mat8u& dst) {
} }
} }
template <>
void cvt_bgr2gray<float>(const Mat32f& src, Mat32f& dst) {
megdnn_assert(src.channels() == 3);
megdnn_assert(dst.channels() == 1);
megdnn_assert(src.rows() == dst.rows());
megdnn_assert(src.cols() == dst.cols());
const float coef_r = 0.299f, coef_g = 0.587f, coef_b = 0.114f;
for (size_t r = 0; r < src.rows(); ++r) {
for (size_t c = 0; c < src.cols(); ++c) {
float B = src.at(r, c, 0);
float G = src.at(r, c, 1);
float R = src.at(r, c, 2);
float& Y = dst.at(r, c, 0);
Y = R * coef_r + G * coef_g + B * coef_b;
}
}
}
template <> template <>
void cvt_bgr2rgb<uchar>(const Mat8u& src, Mat8u& dst) { void cvt_bgr2rgb<uchar>(const Mat8u& src, Mat8u& dst) {
return cvt_rgb2bgr<uchar>(src, dst); return cvt_rgb2bgr<uchar>(src, dst);
......
...@@ -1311,6 +1311,41 @@ void cvt_rgb2gray_32f_SSE_4_2(const Mat32f& src, Mat32f& dst) { ...@@ -1311,6 +1311,41 @@ void cvt_rgb2gray_32f_SSE_4_2(const Mat32f& src, Mat32f& dst) {
} }
} }
MEGDNN_ATTRIBUTE_TARGET("sse4.2")
void cvt_bgr2gray_32f_SSE_4_2(const Mat32f& src, Mat32f& dst) {
const float coef_r = 0.299f, coef_g = 0.587f, coef_b = 0.114f;
__m128 v_coef_r = _mm_set1_ps(coef_r);
__m128 v_coef_g = _mm_set1_ps(coef_g);
__m128 v_coef_b = _mm_set1_ps(coef_b);
for (size_t r = 0; r < src.rows(); ++r) {
const float* psrc = src.ptr(r);
float* pdst = dst.ptr(r);
const float* const pend = psrc + src.cols() * 3;
__m128 v_r, v_g, v_b, ans;
for (; psrc <= pend - 4 * 3; psrc += 4 * 3, pdst += 4) {
v_b = _mm_set_ps(psrc[9], psrc[6], psrc[3], psrc[0]);
v_b = _mm_mul_ps(v_b, v_coef_b);
v_g = _mm_set_ps(psrc[10], psrc[7], psrc[4], psrc[1]);
v_g = _mm_mul_ps(v_g, v_coef_g);
v_r = _mm_set_ps(psrc[11], psrc[8], psrc[5], psrc[2]);
v_r = _mm_mul_ps(v_r, v_coef_r);
ans = _mm_add_ps(v_r, _mm_add_ps(v_g, v_b));
_mm_storeu_ps(pdst, ans);
}
for (; psrc < pend; psrc += 3, pdst += 1) {
pdst[0] = psrc[1] * coef_g + psrc[0] * coef_b + psrc[2] * coef_r;
}
}
}
MEGDNN_ATTRIBUTE_TARGET("sse4.2") MEGDNN_ATTRIBUTE_TARGET("sse4.2")
void cvt_rgba2rgb_8u_SSE_4_2(const Mat8u& src, Mat8u& dst) { void cvt_rgba2rgb_8u_SSE_4_2(const Mat8u& src, Mat8u& dst) {
__m128i dst_data0, dst_data1, dst_data2; __m128i dst_data0, dst_data1, dst_data2;
...@@ -1705,6 +1740,16 @@ void cvt_bgr2gray<uchar>(const Mat8u& src, Mat8u& dst) { ...@@ -1705,6 +1740,16 @@ void cvt_bgr2gray<uchar>(const Mat8u& src, Mat8u& dst) {
} }
} }
template <>
void cvt_bgr2gray<float>(const Mat32f& src, Mat32f& dst) {
megdnn_assert(src.channels() == 3);
megdnn_assert(dst.channels() == 1);
megdnn_assert(src.rows() == dst.rows());
megdnn_assert(src.cols() == dst.cols());
return cvt_bgr2gray_32f_SSE_4_2(src, dst);
}
template <> template <>
void cvt_bgr2rgb<uchar>(const Mat8u& src, Mat8u& dst) { void cvt_bgr2rgb<uchar>(const Mat8u& src, Mat8u& dst) {
return cvt_rgb2bgr<uchar>(src, dst); return cvt_rgb2bgr<uchar>(src, dst);
......
...@@ -133,6 +133,9 @@ inline std::vector<TestArg> get_cuda_args() { ...@@ -133,6 +133,9 @@ inline std::vector<TestArg> get_cuda_args() {
for (size_t i = 2; i <= 10; ++i) { for (size_t i = 2; i <= 10; ++i) {
for (size_t j = 2; j <= 10; ++j) { for (size_t j = 2; j <= 10; ++j) {
cur_param.mode = Mode::RGB2GRAY; cur_param.mode = Mode::RGB2GRAY;
args.emplace_back(cur_param, TensorShape{1, i, j, 3},
dtype::Uint8());
cur_param.mode = Mode::BGR2GRAY;
args.emplace_back(cur_param, TensorShape{1, i, j, 3}, args.emplace_back(cur_param, TensorShape{1, i, j, 3},
dtype::Uint8()); dtype::Uint8());
cur_param.mode = Mode::RGB2YUV; cur_param.mode = Mode::RGB2YUV;
...@@ -146,6 +149,9 @@ inline std::vector<TestArg> get_cuda_args() { ...@@ -146,6 +149,9 @@ inline std::vector<TestArg> get_cuda_args() {
dtype::Uint8()); dtype::Uint8());
// float32 test // float32 test
cur_param.mode = Mode::RGB2GRAY; cur_param.mode = Mode::RGB2GRAY;
args.emplace_back(cur_param, TensorShape{1, i, j, 3},
dtype::Float32());
cur_param.mode = Mode::BGR2GRAY;
args.emplace_back(cur_param, TensorShape{1, i, j, 3}, args.emplace_back(cur_param, TensorShape{1, i, j, 3},
dtype::Float32()); dtype::Float32());
cur_param.mode = Mode::RGB2YUV; cur_param.mode = Mode::RGB2YUV;
......
...@@ -1057,12 +1057,19 @@ def test_cvt_color(): ...@@ -1057,12 +1057,19 @@ def test_cvt_color():
def rgb2gray(rgb): def rgb2gray(rgb):
return np.dot(rgb[..., :3], [0.299, 0.587, 0.114]) return np.dot(rgb[..., :3], [0.299, 0.587, 0.114])
def bgr2gray(bgr):
return np.dot(bgr[..., :3], [0.114, 0.587, 0.299])
inp = np.random.randn(3, 3, 3, 3).astype(np.float32) inp = np.random.randn(3, 3, 3, 3).astype(np.float32)
out = np.expand_dims(rgb2gray(inp), 3).astype(np.float32) out = np.expand_dims(rgb2gray(inp), 3).astype(np.float32)
x = tensor(inp) x = tensor(inp)
y = F.vision.cvt_color(x, mode="RGB2GRAY") y = F.vision.cvt_color(x, mode="RGB2GRAY")
np.testing.assert_allclose(y.numpy(), out, atol=1e-5) np.testing.assert_allclose(y.numpy(), out, atol=1e-5)
out1 = np.expand_dims(bgr2gray(inp), 3).astype(np.float32)
y1 = F.vision.cvt_color(x, mode="BGR2GRAY")
np.testing.assert_allclose(y1.numpy(), out1, atol=1e-5)
@pytest.mark.parametrize("val", [2, [2,], [2, 3]]) @pytest.mark.parametrize("val", [2, [2,], [2, 3]])
def test_ones(val): def test_ones(val):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册