提交 4d9d7e6d 编写于 作者: V Vadim Pisarevsky

Merge pull request #3160 from akarsakov:ocl_dft_double_support

...@@ -1802,11 +1802,14 @@ private: ...@@ -1802,11 +1802,14 @@ private:
String buildOptions; String buildOptions;
int thread_count; int thread_count;
int dft_size; int dft_size;
int dft_depth;
bool status; bool status;
public: public:
OCL_FftPlan(int _size) : dft_size(_size), status(true) OCL_FftPlan(int _size, int _depth) : dft_size(_size), dft_depth(_depth), status(true)
{ {
CV_Assert( dft_depth == CV_32F || dft_depth == CV_64F );
int min_radix; int min_radix;
std::vector<int> radixes, blocks; std::vector<int> radixes, blocks;
ocl_getRadixes(dft_size, radixes, blocks, min_radix); ocl_getRadixes(dft_size, radixes, blocks, min_radix);
...@@ -1832,31 +1835,15 @@ public: ...@@ -1832,31 +1835,15 @@ public:
n *= radix; n *= radix;
} }
twiddles.create(1, twiddle_size, CV_32FC2); twiddles.create(1, twiddle_size, CV_MAKE_TYPE(dft_depth, 2));
Mat tw = twiddles.getMat(ACCESS_WRITE); if (dft_depth == CV_32F)
float* ptr = tw.ptr<float>(); fillRadixTable<float>(twiddles, radixes);
int ptr_index = 0; else
fillRadixTable<double>(twiddles, radixes);
n = 1;
for (size_t i=0; i<radixes.size(); i++)
{
int radix = radixes[i];
n *= radix;
for (int j=1; j<radix; j++)
{
double theta = -CV_2PI*j/n;
for (int k=0; k<(n/radix); k++)
{
ptr[ptr_index++] = (float) cos(k*theta);
ptr[ptr_index++] = (float) sin(k*theta);
}
}
}
buildOptions = format("-D LOCAL_SIZE=%d -D kercn=%d -D RADIX_PROCESS=%s", buildOptions = format("-D LOCAL_SIZE=%d -D kercn=%d -D FT=%s -D CT=%s%s -D RADIX_PROCESS=%s",
dft_size, min_radix, radix_processing.c_str()); dft_size, min_radix, ocl::typeToStr(dft_depth), ocl::typeToStr(CV_MAKE_TYPE(dft_depth, 2)),
dft_depth == CV_64F ? " -D DOUBLE_SUPPORT" : "", radix_processing.c_str());
} }
bool enqueueTransform(InputArray _src, OutputArray _dst, int num_dfts, int flags, int fftType, bool rows = true) const bool enqueueTransform(InputArray _src, OutputArray _dst, int num_dfts, int flags, int fftType, bool rows = true) const
...@@ -1913,7 +1900,7 @@ public: ...@@ -1913,7 +1900,7 @@ public:
if (k.empty()) if (k.empty())
return false; return false;
k.args(ocl::KernelArg::ReadOnly(src), ocl::KernelArg::WriteOnly(dst), ocl::KernelArg::PtrReadOnly(twiddles), thread_count, num_dfts); k.args(ocl::KernelArg::ReadOnly(src), ocl::KernelArg::WriteOnly(dst), ocl::KernelArg::ReadOnlyNoSize(twiddles), thread_count, num_dfts);
return k.run(2, globalsize, localsize, false); return k.run(2, globalsize, localsize, false);
} }
...@@ -1986,6 +1973,32 @@ private: ...@@ -1986,6 +1973,32 @@ private:
min_radix = min(min_radix, block*radix); min_radix = min(min_radix, block*radix);
} }
} }
template <typename T>
static void fillRadixTable(UMat twiddles, const std::vector<int>& radixes)
{
Mat tw = twiddles.getMat(ACCESS_WRITE);
T* ptr = tw.ptr<T>();
int ptr_index = 0;
int n = 1;
for (size_t i=0; i<radixes.size(); i++)
{
int radix = radixes[i];
n *= radix;
for (int j=1; j<radix; j++)
{
double theta = -CV_2PI*j/n;
for (int k=0; k<(n/radix); k++)
{
ptr[ptr_index++] = (T) cos(k*theta);
ptr[ptr_index++] = (T) sin(k*theta);
}
}
}
}
}; };
class OCL_FftPlanCache class OCL_FftPlanCache
...@@ -1997,17 +2010,18 @@ public: ...@@ -1997,17 +2010,18 @@ public:
return planCache; return planCache;
} }
Ptr<OCL_FftPlan> getFftPlan(int dft_size) Ptr<OCL_FftPlan> getFftPlan(int dft_size, int depth)
{ {
std::map<int, Ptr<OCL_FftPlan> >::iterator f = planStorage.find(dft_size); int key = (dft_size << 16) | (depth & 0xFFFF);
std::map<int, Ptr<OCL_FftPlan> >::iterator f = planStorage.find(key);
if (f != planStorage.end()) if (f != planStorage.end())
{ {
return f->second; return f->second;
} }
else else
{ {
Ptr<OCL_FftPlan> newPlan = Ptr<OCL_FftPlan>(new OCL_FftPlan(dft_size)); Ptr<OCL_FftPlan> newPlan = Ptr<OCL_FftPlan>(new OCL_FftPlan(dft_size, depth));
planStorage[dft_size] = newPlan; planStorage[key] = newPlan;
return newPlan; return newPlan;
} }
} }
...@@ -2027,21 +2041,25 @@ protected: ...@@ -2027,21 +2041,25 @@ protected:
static bool ocl_dft_rows(InputArray _src, OutputArray _dst, int nonzero_rows, int flags, int fftType) static bool ocl_dft_rows(InputArray _src, OutputArray _dst, int nonzero_rows, int flags, int fftType)
{ {
Ptr<OCL_FftPlan> plan = OCL_FftPlanCache::getInstance().getFftPlan(_src.cols()); int type = _src.type(), depth = CV_MAT_DEPTH(type);
Ptr<OCL_FftPlan> plan = OCL_FftPlanCache::getInstance().getFftPlan(_src.cols(), depth);
return plan->enqueueTransform(_src, _dst, nonzero_rows, flags, fftType, true); return plan->enqueueTransform(_src, _dst, nonzero_rows, flags, fftType, true);
} }
static bool ocl_dft_cols(InputArray _src, OutputArray _dst, int nonzero_cols, int flags, int fftType) static bool ocl_dft_cols(InputArray _src, OutputArray _dst, int nonzero_cols, int flags, int fftType)
{ {
Ptr<OCL_FftPlan> plan = OCL_FftPlanCache::getInstance().getFftPlan(_src.rows()); int type = _src.type(), depth = CV_MAT_DEPTH(type);
Ptr<OCL_FftPlan> plan = OCL_FftPlanCache::getInstance().getFftPlan(_src.rows(), depth);
return plan->enqueueTransform(_src, _dst, nonzero_cols, flags, fftType, false); return plan->enqueueTransform(_src, _dst, nonzero_cols, flags, fftType, false);
} }
static bool ocl_dft(InputArray _src, OutputArray _dst, int flags, int nonzero_rows) static bool ocl_dft(InputArray _src, OutputArray _dst, int flags, int nonzero_rows)
{ {
int type = _src.type(), cn = CV_MAT_CN(type); int type = _src.type(), cn = CV_MAT_CN(type), depth = CV_MAT_DEPTH(type);
Size ssize = _src.size(); Size ssize = _src.size();
if ( !(type == CV_32FC1 || type == CV_32FC2) ) bool doubleSupport = ocl::Device::getDefault().doubleFPConfig() > 0;
if ( !((cn == 1 || cn == 2) && (depth == CV_32F || (depth == CV_64F && doubleSupport))) )
return false; return false;
// if is not a multiplication of prime numbers { 2, 3, 5 } // if is not a multiplication of prime numbers { 2, 3, 5 }
...@@ -2082,7 +2100,7 @@ static bool ocl_dft(InputArray _src, OutputArray _dst, int flags, int nonzero_ro ...@@ -2082,7 +2100,7 @@ static bool ocl_dft(InputArray _src, OutputArray _dst, int flags, int nonzero_ro
if (fftType == C2C || fftType == R2C) if (fftType == C2C || fftType == R2C)
{ {
// complex output // complex output
_dst.create(src.size(), CV_32FC2); _dst.create(src.size(), CV_MAKETYPE(depth, 2));
output = _dst.getUMat(); output = _dst.getUMat();
} }
else else
...@@ -2090,13 +2108,13 @@ static bool ocl_dft(InputArray _src, OutputArray _dst, int flags, int nonzero_ro ...@@ -2090,13 +2108,13 @@ static bool ocl_dft(InputArray _src, OutputArray _dst, int flags, int nonzero_ro
// real output // real output
if (is1d) if (is1d)
{ {
_dst.create(src.size(), CV_32FC1); _dst.create(src.size(), CV_MAKETYPE(depth, 1));
output = _dst.getUMat(); output = _dst.getUMat();
} }
else else
{ {
_dst.create(src.size(), CV_32FC1); _dst.create(src.size(), CV_MAKETYPE(depth, 1));
output.create(src.size(), CV_32FC2); output.create(src.size(), CV_MAKETYPE(depth, 2));
} }
} }
......
此差异已折叠。
...@@ -62,7 +62,7 @@ namespace ocl { ...@@ -62,7 +62,7 @@ namespace ocl {
//////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////
// Dft // Dft
PARAM_TEST_CASE(Dft, cv::Size, OCL_FFT_TYPE, bool, bool, bool, bool) PARAM_TEST_CASE(Dft, cv::Size, OCL_FFT_TYPE, MatDepth, bool, bool, bool, bool)
{ {
cv::Size dft_size; cv::Size dft_size;
int dft_flags, depth, cn, dft_type; int dft_flags, depth, cn, dft_type;
...@@ -76,7 +76,7 @@ PARAM_TEST_CASE(Dft, cv::Size, OCL_FFT_TYPE, bool, bool, bool, bool) ...@@ -76,7 +76,7 @@ PARAM_TEST_CASE(Dft, cv::Size, OCL_FFT_TYPE, bool, bool, bool, bool)
{ {
dft_size = GET_PARAM(0); dft_size = GET_PARAM(0);
dft_type = GET_PARAM(1); dft_type = GET_PARAM(1);
depth = CV_32F; depth = GET_PARAM(2);
dft_flags = 0; dft_flags = 0;
switch (dft_type) switch (dft_type)
...@@ -87,13 +87,13 @@ PARAM_TEST_CASE(Dft, cv::Size, OCL_FFT_TYPE, bool, bool, bool, bool) ...@@ -87,13 +87,13 @@ PARAM_TEST_CASE(Dft, cv::Size, OCL_FFT_TYPE, bool, bool, bool, bool)
case C2C: dft_flags |= cv::DFT_COMPLEX_OUTPUT; cn = 2; break; case C2C: dft_flags |= cv::DFT_COMPLEX_OUTPUT; cn = 2; break;
} }
if (GET_PARAM(2))
dft_flags |= cv::DFT_INVERSE;
if (GET_PARAM(3)) if (GET_PARAM(3))
dft_flags |= cv::DFT_ROWS; dft_flags |= cv::DFT_INVERSE;
if (GET_PARAM(4)) if (GET_PARAM(4))
dft_flags |= cv::DFT_ROWS;
if (GET_PARAM(5))
dft_flags |= cv::DFT_SCALE; dft_flags |= cv::DFT_SCALE;
hint = GET_PARAM(5); hint = GET_PARAM(6);
is1d = (dft_flags & DFT_ROWS) != 0 || dft_size.height == 1; is1d = (dft_flags & DFT_ROWS) != 0 || dft_size.height == 1;
} }
...@@ -177,6 +177,7 @@ OCL_INSTANTIATE_TEST_CASE_P(OCL_ImgProc, MulSpectrums, testing::Combine(Bool(), ...@@ -177,6 +177,7 @@ OCL_INSTANTIATE_TEST_CASE_P(OCL_ImgProc, MulSpectrums, testing::Combine(Bool(),
OCL_INSTANTIATE_TEST_CASE_P(Core, Dft, Combine(Values(cv::Size(45, 72), cv::Size(36, 36), cv::Size(512, 1), cv::Size(1280, 768)), OCL_INSTANTIATE_TEST_CASE_P(Core, Dft, Combine(Values(cv::Size(45, 72), cv::Size(36, 36), cv::Size(512, 1), cv::Size(1280, 768)),
Values((OCL_FFT_TYPE) R2C, (OCL_FFT_TYPE) C2C, (OCL_FFT_TYPE) R2R, (OCL_FFT_TYPE) C2R), Values((OCL_FFT_TYPE) R2C, (OCL_FFT_TYPE) C2C, (OCL_FFT_TYPE) R2R, (OCL_FFT_TYPE) C2R),
Values(CV_32F, CV_64F),
Bool(), // DFT_INVERSE Bool(), // DFT_INVERSE
Bool(), // DFT_ROWS Bool(), // DFT_ROWS
Bool(), // DFT_SCALE Bool(), // DFT_SCALE
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册