提交 4d9d7e6d 编写于 作者: V Vadim Pisarevsky

Merge pull request #3160 from akarsakov:ocl_dft_double_support

......@@ -1802,11 +1802,14 @@ private:
String buildOptions;
int thread_count;
int dft_size;
int dft_depth;
bool status;
public:
OCL_FftPlan(int _size) : dft_size(_size), status(true)
OCL_FftPlan(int _size, int _depth) : dft_size(_size), dft_depth(_depth), status(true)
{
CV_Assert( dft_depth == CV_32F || dft_depth == CV_64F );
int min_radix;
std::vector<int> radixes, blocks;
ocl_getRadixes(dft_size, radixes, blocks, min_radix);
......@@ -1832,31 +1835,15 @@ public:
n *= radix;
}
twiddles.create(1, twiddle_size, CV_32FC2);
Mat tw = twiddles.getMat(ACCESS_WRITE);
float* ptr = tw.ptr<float>();
int ptr_index = 0;
n = 1;
for (size_t i=0; i<radixes.size(); i++)
{
int radix = radixes[i];
n *= radix;
for (int j=1; j<radix; j++)
{
double theta = -CV_2PI*j/n;
for (int k=0; k<(n/radix); k++)
{
ptr[ptr_index++] = (float) cos(k*theta);
ptr[ptr_index++] = (float) sin(k*theta);
}
}
}
twiddles.create(1, twiddle_size, CV_MAKE_TYPE(dft_depth, 2));
if (dft_depth == CV_32F)
fillRadixTable<float>(twiddles, radixes);
else
fillRadixTable<double>(twiddles, radixes);
buildOptions = format("-D LOCAL_SIZE=%d -D kercn=%d -D RADIX_PROCESS=%s",
dft_size, min_radix, radix_processing.c_str());
buildOptions = format("-D LOCAL_SIZE=%d -D kercn=%d -D FT=%s -D CT=%s%s -D RADIX_PROCESS=%s",
dft_size, min_radix, ocl::typeToStr(dft_depth), ocl::typeToStr(CV_MAKE_TYPE(dft_depth, 2)),
dft_depth == CV_64F ? " -D DOUBLE_SUPPORT" : "", radix_processing.c_str());
}
bool enqueueTransform(InputArray _src, OutputArray _dst, int num_dfts, int flags, int fftType, bool rows = true) const
......@@ -1913,7 +1900,7 @@ public:
if (k.empty())
return false;
k.args(ocl::KernelArg::ReadOnly(src), ocl::KernelArg::WriteOnly(dst), ocl::KernelArg::PtrReadOnly(twiddles), thread_count, num_dfts);
k.args(ocl::KernelArg::ReadOnly(src), ocl::KernelArg::WriteOnly(dst), ocl::KernelArg::ReadOnlyNoSize(twiddles), thread_count, num_dfts);
return k.run(2, globalsize, localsize, false);
}
......@@ -1986,6 +1973,32 @@ private:
min_radix = min(min_radix, block*radix);
}
}
template <typename T>
static void fillRadixTable(UMat twiddles, const std::vector<int>& radixes)
{
Mat tw = twiddles.getMat(ACCESS_WRITE);
T* ptr = tw.ptr<T>();
int ptr_index = 0;
int n = 1;
for (size_t i=0; i<radixes.size(); i++)
{
int radix = radixes[i];
n *= radix;
for (int j=1; j<radix; j++)
{
double theta = -CV_2PI*j/n;
for (int k=0; k<(n/radix); k++)
{
ptr[ptr_index++] = (T) cos(k*theta);
ptr[ptr_index++] = (T) sin(k*theta);
}
}
}
}
};
class OCL_FftPlanCache
......@@ -1997,17 +2010,18 @@ public:
return planCache;
}
Ptr<OCL_FftPlan> getFftPlan(int dft_size)
Ptr<OCL_FftPlan> getFftPlan(int dft_size, int depth)
{
std::map<int, Ptr<OCL_FftPlan> >::iterator f = planStorage.find(dft_size);
int key = (dft_size << 16) | (depth & 0xFFFF);
std::map<int, Ptr<OCL_FftPlan> >::iterator f = planStorage.find(key);
if (f != planStorage.end())
{
return f->second;
}
else
{
Ptr<OCL_FftPlan> newPlan = Ptr<OCL_FftPlan>(new OCL_FftPlan(dft_size));
planStorage[dft_size] = newPlan;
Ptr<OCL_FftPlan> newPlan = Ptr<OCL_FftPlan>(new OCL_FftPlan(dft_size, depth));
planStorage[key] = newPlan;
return newPlan;
}
}
......@@ -2027,21 +2041,25 @@ protected:
static bool ocl_dft_rows(InputArray _src, OutputArray _dst, int nonzero_rows, int flags, int fftType)
{
Ptr<OCL_FftPlan> plan = OCL_FftPlanCache::getInstance().getFftPlan(_src.cols());
int type = _src.type(), depth = CV_MAT_DEPTH(type);
Ptr<OCL_FftPlan> plan = OCL_FftPlanCache::getInstance().getFftPlan(_src.cols(), depth);
return plan->enqueueTransform(_src, _dst, nonzero_rows, flags, fftType, true);
}
static bool ocl_dft_cols(InputArray _src, OutputArray _dst, int nonzero_cols, int flags, int fftType)
{
Ptr<OCL_FftPlan> plan = OCL_FftPlanCache::getInstance().getFftPlan(_src.rows());
int type = _src.type(), depth = CV_MAT_DEPTH(type);
Ptr<OCL_FftPlan> plan = OCL_FftPlanCache::getInstance().getFftPlan(_src.rows(), depth);
return plan->enqueueTransform(_src, _dst, nonzero_cols, flags, fftType, false);
}
static bool ocl_dft(InputArray _src, OutputArray _dst, int flags, int nonzero_rows)
{
int type = _src.type(), cn = CV_MAT_CN(type);
int type = _src.type(), cn = CV_MAT_CN(type), depth = CV_MAT_DEPTH(type);
Size ssize = _src.size();
if ( !(type == CV_32FC1 || type == CV_32FC2) )
bool doubleSupport = ocl::Device::getDefault().doubleFPConfig() > 0;
if ( !((cn == 1 || cn == 2) && (depth == CV_32F || (depth == CV_64F && doubleSupport))) )
return false;
// if is not a multiplication of prime numbers { 2, 3, 5 }
......@@ -2082,7 +2100,7 @@ static bool ocl_dft(InputArray _src, OutputArray _dst, int flags, int nonzero_ro
if (fftType == C2C || fftType == R2C)
{
// complex output
_dst.create(src.size(), CV_32FC2);
_dst.create(src.size(), CV_MAKETYPE(depth, 2));
output = _dst.getUMat();
}
else
......@@ -2090,13 +2108,13 @@ static bool ocl_dft(InputArray _src, OutputArray _dst, int flags, int nonzero_ro
// real output
if (is1d)
{
_dst.create(src.size(), CV_32FC1);
_dst.create(src.size(), CV_MAKETYPE(depth, 1));
output = _dst.getUMat();
}
else
{
_dst.create(src.size(), CV_32FC1);
output.create(src.size(), CV_32FC2);
_dst.create(src.size(), CV_MAKETYPE(depth, 1));
output.create(src.size(), CV_MAKETYPE(depth, 2));
}
}
......
此差异已折叠。
......@@ -62,7 +62,7 @@ namespace ocl {
////////////////////////////////////////////////////////////////////////////
// Dft
PARAM_TEST_CASE(Dft, cv::Size, OCL_FFT_TYPE, bool, bool, bool, bool)
PARAM_TEST_CASE(Dft, cv::Size, OCL_FFT_TYPE, MatDepth, bool, bool, bool, bool)
{
cv::Size dft_size;
int dft_flags, depth, cn, dft_type;
......@@ -76,7 +76,7 @@ PARAM_TEST_CASE(Dft, cv::Size, OCL_FFT_TYPE, bool, bool, bool, bool)
{
dft_size = GET_PARAM(0);
dft_type = GET_PARAM(1);
depth = CV_32F;
depth = GET_PARAM(2);
dft_flags = 0;
switch (dft_type)
......@@ -87,13 +87,13 @@ PARAM_TEST_CASE(Dft, cv::Size, OCL_FFT_TYPE, bool, bool, bool, bool)
case C2C: dft_flags |= cv::DFT_COMPLEX_OUTPUT; cn = 2; break;
}
if (GET_PARAM(2))
dft_flags |= cv::DFT_INVERSE;
if (GET_PARAM(3))
dft_flags |= cv::DFT_ROWS;
dft_flags |= cv::DFT_INVERSE;
if (GET_PARAM(4))
dft_flags |= cv::DFT_ROWS;
if (GET_PARAM(5))
dft_flags |= cv::DFT_SCALE;
hint = GET_PARAM(5);
hint = GET_PARAM(6);
is1d = (dft_flags & DFT_ROWS) != 0 || dft_size.height == 1;
}
......@@ -177,6 +177,7 @@ OCL_INSTANTIATE_TEST_CASE_P(OCL_ImgProc, MulSpectrums, testing::Combine(Bool(),
OCL_INSTANTIATE_TEST_CASE_P(Core, Dft, Combine(Values(cv::Size(45, 72), cv::Size(36, 36), cv::Size(512, 1), cv::Size(1280, 768)),
Values((OCL_FFT_TYPE) R2C, (OCL_FFT_TYPE) C2C, (OCL_FFT_TYPE) R2R, (OCL_FFT_TYPE) C2R),
Values(CV_32F, CV_64F),
Bool(), // DFT_INVERSE
Bool(), // DFT_ROWS
Bool(), // DFT_SCALE
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册