提交 88bd1f1d 编写于 作者: I Ilya Lysenkov

Added plot data generation for visual descriptors comparison in the evaluation framework

上级 f6f634ba
......@@ -1503,6 +1503,25 @@ struct CV_EXPORTS L2
}
};
/****************************************************************************************\
* DescriptorMatching *
\****************************************************************************************/
/*
* Struct for matching: match index and distance between descriptors
*/
struct DescriptorMatching
{
int index;
float distance;
//less is better
bool operator<( const DescriptorMatching &m) const
{
return distance < m.distance;
}
};
/****************************************************************************************\
* DescriptorMatcher *
\****************************************************************************************/
......@@ -1545,6 +1564,28 @@ public:
void match( const Mat& query, const Mat& mask,
vector<int>& matches ) const;
/*
* Find the best match for each descriptor from a query set
*
* query The query set of descriptors
* matchings Matchings of the closest matches from the training set
*/
void match( const Mat& query, vector<DescriptorMatching>& matchings ) const;
/*
* Find the best matches between two descriptor sets, with constraints
* on which pairs of descriptors can be matched.
*
* The mask describes which descriptors can be matched. descriptors_1[i]
* can be matched with descriptors_2[j] only if mask.at<char>(i,j) is non-zero.
*
* query The query set of descriptors
* mask Mask specifying permissible matches.
* matchings Matchings of the closest matches from the training set
*/
void match( const Mat& query, const Mat& mask,
vector<DescriptorMatching>& matchings ) const;
/*
* Find the best keypoint matches for small view changes.
*
......@@ -1574,6 +1615,13 @@ protected:
virtual void matchImpl( const Mat& descriptors_1, const Mat& descriptors_2,
const Mat& mask, vector<int>& matches ) const = 0;
/*
* Find matches; match() calls this. Must be implemented by the subclass.
* The mask may be empty.
*/
virtual void matchImpl( const Mat& descriptors_1, const Mat& descriptors_2,
const Mat& mask, vector<DescriptorMatching>& matches ) const = 0;
static bool possibleMatch( const Mat& mask, int index_1, int index_2 )
{
return mask.empty() || mask.at<char>(index_1, index_2);
......@@ -1609,6 +1657,18 @@ inline void DescriptorMatcher::match( const Mat& query, const Mat& mask,
matchImpl( query, train, mask, matches );
}
inline void DescriptorMatcher::match( const Mat& query, vector<DescriptorMatching>& matches ) const
{
matchImpl( query, train, Mat(), matches );
}
inline void DescriptorMatcher::match( const Mat& query, const Mat& mask,
vector<DescriptorMatching>& matches ) const
{
matchImpl( query, train, mask, matches );
}
inline void DescriptorMatcher::clear()
{
train.release();
......@@ -1633,12 +1693,28 @@ protected:
virtual void matchImpl( const Mat& descriptors_1, const Mat& descriptors_2,
const Mat& mask, vector<int>& matches ) const;
virtual void matchImpl( const Mat& descriptors_1, const Mat& descriptors_2,
const Mat& mask, vector<DescriptorMatching>& matches ) const;
Distance distance;
};
template<class Distance>
void BruteForceMatcher<Distance>::matchImpl( const Mat& descriptors_1, const Mat& descriptors_2,
const Mat& mask, vector<int>& matches ) const
{
vector<DescriptorMatching> matchings;
matchImpl( descriptors_1, descriptors_2, mask, matchings);
matches.resize( matchings.size() );
for( size_t i=0;i<matchings.size();i++)
{
matches[i] = matchings[i].index;
}
}
template<class Distance>
void BruteForceMatcher<Distance>::matchImpl( const Mat& descriptors_1, const Mat& descriptors_2,
const Mat& mask, vector<DescriptorMatching>& matches ) const
{
typedef typename Distance::ValueType ValueType;
typedef typename Distance::ResultType DistanceType;
......@@ -1650,8 +1726,7 @@ void BruteForceMatcher<Distance>::matchImpl( const Mat& descriptors_1, const Mat
assert( DataType<ValueType>::type == descriptors_2.type() || descriptors_2.empty() );
int dimension = descriptors_1.cols;
matches.clear();
matches.reserve(descriptors_1.rows);
matches.resize(descriptors_1.rows);
for( int i = 0; i < descriptors_1.rows; i++ )
{
......@@ -1674,7 +1749,12 @@ void BruteForceMatcher<Distance>::matchImpl( const Mat& descriptors_1, const Mat
}
if( matchIndex != -1 )
matches.push_back( matchIndex );
{
DescriptorMatching matching;
matching.index = matchIndex;
matching.distance = matchDistance;
matches[i] = matching;
}
}
}
......@@ -1742,6 +1822,12 @@ public:
// indices A vector to be filled with keypoint class indices
virtual void match( const Mat& image, vector<KeyPoint>& points, vector<int>& indices ) = 0;
// Matches test keypoints to the training set
// image The source image
// points Test keypoints from the source image
// matchings A vector to be filled with keypoint matchings
virtual void match( const Mat& image, vector<KeyPoint>& points, vector<DescriptorMatching>& matchings ) {};
// Clears keypoints storing in collection
virtual void clear();
......@@ -1816,6 +1902,8 @@ public:
// loaded with DescriptorOneWay::Initialize, kd tree is used for finding minimum distances.
virtual void match( const Mat& image, vector<KeyPoint>& points, vector<int>& indices );
virtual void match( const Mat& image, vector<KeyPoint>& points, vector<DescriptorMatching>& matchings );
// Classify a set of keypoints. The same as match, but returns point classes rather than indices
virtual void classify( const Mat& image, vector<KeyPoint>& points );
......@@ -1944,6 +2032,8 @@ public:
virtual void match( const Mat& image, vector<KeyPoint>& keypoints, vector<int>& indices );
virtual void match( const Mat& image, vector<KeyPoint>& points, vector<DescriptorMatching>& matchings );
virtual void classify( const Mat& image, vector<KeyPoint>& keypoints );
virtual void clear ();
......@@ -2000,6 +2090,14 @@ public:
matcher.match( descriptors, keypointIndices );
};
virtual void match( const Mat& image, vector<KeyPoint>& points, vector<DescriptorMatching>& matchings )
{
Mat descriptors;
extractor.compute( image, points, descriptors );
matcher.match( descriptors, matchings );
}
virtual void clear()
{
GenericDescriptorMatch::clear();
......
......@@ -44,6 +44,8 @@
using namespace std;
using namespace cv;
//#define _KDTREE
/****************************************************************************************\
* DescriptorExtractor *
\****************************************************************************************/
......@@ -332,15 +334,27 @@ void OneWayDescriptorMatch::add( KeyPointCollection& keypoints )
void OneWayDescriptorMatch::match( const Mat& image, vector<KeyPoint>& points, vector<int>& indices)
{
vector<DescriptorMatching> matchings( points.size() );
indices.resize(points.size());
match( image, points, matchings );
for( size_t i = 0; i < points.size(); i++ )
indices[i] = matchings[i].index;
}
void OneWayDescriptorMatch::match( const Mat& image, vector<KeyPoint>& points, vector<DescriptorMatching>& matchings )
{
matchings.resize( points.size() );
IplImage _image = image;
for( size_t i = 0; i < points.size(); i++ )
{
int descIdx = -1;
int poseIdx = -1;
float distance;
base->FindDescriptor( &_image, points[i].pt, descIdx, poseIdx, distance );
indices[i] = descIdx;
DescriptorMatching matching;
matching.index = -1;
base->FindDescriptor( &_image, points[i].pt, matching.index, poseIdx, matching.distance );
matchings[i] = matching;
}
}
......@@ -631,6 +645,21 @@ void FernDescriptorMatch::match( const Mat& image, vector<KeyPoint>& keypoints,
}
}
void FernDescriptorMatch::match( const Mat& image, vector<KeyPoint>& keypoints, vector<DescriptorMatching>& matchings )
{
trainFernClassifier();
matchings.resize( keypoints.size() );
vector<float> signature( (size_t)classifier->getClassCount() );
for( size_t pi = 0; pi < keypoints.size(); pi++ )
{
calcBestProbAndMatchIdx( image, keypoints[pi].pt, matchings[pi].distance, matchings[pi].index, signature );
//matching[pi].distance is log of probability so we need to transform it
matchings[pi].distance = -matchings[pi].distance;
}
}
void FernDescriptorMatch::classify( const Mat& image, vector<KeyPoint>& keypoints )
{
trainFernClassifier();
......
......@@ -549,9 +549,9 @@ inline float precision( int correctMatchCount, int falseMatchCount )
}
void evaluateDescriptors( const vector<EllipticKeyPoint>& keypoints1, const vector<EllipticKeyPoint>& keypoints2,
const vector<int>& matches1to2,
vector< pair<DescriptorMatching, int> >& matches1to2,
const Mat& img1, const Mat& img2, const Mat& H1to2,
int& correctMatchCount, int& falseMatchCount, int& correspondenceCount )
int &correctMatchCount, int &falseMatchCount, vector<int> &matchStatuses, int& correspondenceCount )
{
assert( !keypoints1.empty() && !keypoints2.empty() && !matches1to2.empty() );
assert( keypoints1.size() == matches1to2.size() );
......@@ -564,17 +564,27 @@ void evaluateDescriptors( const vector<EllipticKeyPoint>& keypoints1, const vect
repeatability, correspCount,
&thresholdedOverlapMask );
correspondenceCount = thresholdedOverlapMask.nzcount();
correctMatchCount = falseMatchCount = 0;
matchStatuses.resize( matches1to2.size() );
correctMatchCount = 0;
falseMatchCount = 0;
//the nearest descriptors should be examined first
std::sort( matches1to2.begin(), matches1to2.end() );
for( size_t i1 = 0; i1 < matches1to2.size(); i1++ )
{
int i2 = matches1to2[i1];
int i2 = matches1to2[i1].first.index;
if( i2 > 0 )
{
if( thresholdedOverlapMask(i1, i2) )
matchStatuses[i2] = thresholdedOverlapMask(matches1to2[i1].second, i2);
if( matchStatuses[i2] )
correctMatchCount++;
else
falseMatchCount++;
}
else
matchStatuses[i2] = -1;
}
}
......@@ -615,11 +625,16 @@ class BaseQualityTest : public CvTest
{
public:
BaseQualityTest( const char* _algName, const char* _testName, const char* _testFuncs ) :
CvTest( _testName, _testFuncs ), algName(_algName) {}
CvTest( _testName, _testFuncs ), algName(_algName)
{
//TODO: change this
isWriteGraphicsData = true;
}
protected:
virtual string getRunParamsFilename() const = 0;
virtual string getResultsFilename() const = 0;
virtual string getPlotPath() const = 0;
virtual void validQualityClear( int datasetIdx ) = 0;
virtual void calcQualityClear( int datasetIdx ) = 0;
......@@ -650,9 +665,11 @@ protected:
virtual void processResults();
virtual int processResults( int datasetIdx, int caseIdx ) = 0;
void writeAllPlotData() const;
virtual void writePlotData( const string &filename, int datasetIdx ) const {};
string algName;
bool isWriteParams, isWriteResults;
bool isWriteParams, isWriteResults, isWriteGraphicsData;
};
void BaseQualityTest::readAllDatasetsRunParams()
......@@ -811,6 +828,8 @@ void BaseQualityTest::processResults()
{
if( isWriteParams )
writeAllDatasetsRunParams();
if( isWriteGraphicsData )
writeAllPlotData();
int res = CvTS::OK;
if( isWriteResults )
......@@ -838,6 +857,18 @@ void BaseQualityTest::processResults()
ts->set_failed_test_info( res );
}
void BaseQualityTest::writeAllPlotData() const
{
for( int di = 0; di < DATASETS_COUNT; di++ )
{
stringstream stream;
stream << getPlotPath() << algName << "_" << DATASET_NAMES[di] << ".csv";
string filename;
stream >> filename;
writePlotData( filename, di );
}
}
void BaseQualityTest::run ( int )
{
readAlgorithm ();
......@@ -904,6 +935,7 @@ protected:
virtual string getRunParamsFilename() const;
virtual string getResultsFilename() const;
virtual string getPlotPath() const;
virtual void validQualityClear( int datasetIdx );
virtual void calcQualityClear( int datasetIdx );
......@@ -961,6 +993,11 @@ string DetectorQualityTest::getResultsFilename() const
return string(ts->get_data_path()) + DETECTORS_DIR + algName + RES_POSTFIX;
}
string DetectorQualityTest::getPlotPath() const
{
return string(ts->get_data_path()) + DETECTORS_DIR + "plots/";
}
void DetectorQualityTest::validQualityClear( int datasetIdx )
{
validQuality[datasetIdx].clear();
......@@ -1253,6 +1290,7 @@ public:
{
validQuality.resize(DATASETS_COUNT);
calcQuality.resize(DATASETS_COUNT);
calcDatasetQuality.resize(DATASETS_COUNT);
commRunParams.resize(DATASETS_COUNT);
commRunParamsDefault.projectKeypointsFrom1Image = true;
......@@ -1267,6 +1305,7 @@ protected:
virtual string getRunParamsFilename() const;
virtual string getResultsFilename() const;
virtual string getPlotPath() const;
virtual void validQualityClear( int datasetIdx );
virtual void calcQualityClear( int datasetIdx );
......@@ -1289,6 +1328,8 @@ protected:
virtual int processResults( int datasetIdx, int caseIdx );
virtual void writePlotData( const string &filename, int di ) const;
struct Quality
{
float recall;
......@@ -1296,6 +1337,7 @@ protected:
};
vector<vector<Quality> > validQuality;
vector<vector<Quality> > calcQuality;
vector<vector<Quality> > calcDatasetQuality;
struct CommonRunParams
{
......@@ -1322,6 +1364,11 @@ string DescriptorQualityTest::getResultsFilename() const
return string(ts->get_data_path()) + DESCRIPTORS_DIR + algName + RES_POSTFIX;
}
string DescriptorQualityTest::getPlotPath() const
{
return string(ts->get_data_path()) + DESCRIPTORS_DIR + "plots/";
}
void DescriptorQualityTest::validQualityClear( int datasetIdx )
{
validQuality[datasetIdx].clear();
......@@ -1408,6 +1455,16 @@ void DescriptorQualityTest::setDefaultDatasetRunParams( int datasetIdx )
commRunParams[datasetIdx].keypontsFilename = "surf_" + DATASET_NAMES[datasetIdx] + ".xml.gz";
}
void DescriptorQualityTest::writePlotData( const string &filename, int di ) const
{
FILE *file = fopen (filename.c_str(),"w");
size_t size = calcDatasetQuality[di].size();
for (size_t i=0;i<size;i++)
{
fprintf( file, "%f, %f\n", 1 - calcDatasetQuality[di][i].precision, calcDatasetQuality[di][i].recall);
}
fclose( file );
}
void DescriptorQualityTest::readAlgorithm( )
{
......@@ -1478,6 +1535,10 @@ void DescriptorQualityTest::runDatasetTest (const vector<Mat> &imgs, const vecto
transformToEllipticKeyPoints( keypoints1, ekeypoints1 );
int progressCount = DATASETS_COUNT*TEST_CASE_COUNT;
vector< pair<DescriptorMatching, int> > allMatchings;
vector<int> allMatchStatuses;
size_t matchingIndex = 0;
int allCorrespCount = 0;
for( int ci = 0; ci < TEST_CASE_COUNT; ci++ )
{
progress = update_progress( progress, di*TEST_CASE_COUNT + ci, progressCount, 0 );
......@@ -1494,16 +1555,50 @@ void DescriptorQualityTest::runDatasetTest (const vector<Mat> &imgs, const vecto
readKeypoints( keypontsFS, keypoints2, ci+1 );
transformToEllipticKeyPoints( keypoints2, ekeypoints2 );
descMatch->add( imgs[ci+1], keypoints2 );
vector<int> matches1to2;
descMatch->match( imgs[0], keypoints1, matches1to2 );
vector<DescriptorMatching> matchings1to2;
descMatch->match( imgs[0], keypoints1, matchings1to2 );
vector< pair<DescriptorMatching, int> > matchings (matchings1to2.size());
for( size_t i=0;i<matchings1to2.size();i++ )
matchings[i] = pair<DescriptorMatching, int>( matchings1to2[i], i);
// TODO if( commRunParams[di].matchFilter )
int correctMatchCount, falseMatchCount, correspCount;
evaluateDescriptors( ekeypoints1, ekeypoints2, matches1to2, imgs[0], imgs[ci+1], Hs[ci],
correctMatchCount, falseMatchCount, correspCount );
int correspCount;
int correctMatchCount = 0, falseMatchCount = 0;
vector<int> matchStatuses;
evaluateDescriptors( ekeypoints1, ekeypoints2, matchings, imgs[0], imgs[ci+1], Hs[ci],
correctMatchCount, falseMatchCount, matchStatuses, correspCount );
for( size_t i=0;i<matchings.size();i++ )
matchings[i].second += matchingIndex;
matchingIndex += matchings.size();
allCorrespCount += correspCount;
//TODO: use merge
std::copy( matchings.begin(), matchings.end(), std::back_inserter( allMatchings ) );
std::copy( matchStatuses.begin(), matchStatuses.end(), std::back_inserter( allMatchStatuses ) );
printf ("%d %d %d \n", correctMatchCount, falseMatchCount, correspCount );
calcQuality[di][ci].recall = recall( correctMatchCount, correspCount );
calcQuality[di][ci].precision = precision( correctMatchCount, falseMatchCount );
descMatch->clear ();
}
std::sort( allMatchings.begin(), allMatchings.end() );
calcDatasetQuality[di].resize( allMatchings.size() );
int correctMatchCount = 0, falseMatchCount = 0;
for( size_t i=0;i<allMatchings.size();i++)
{
if( allMatchStatuses[ allMatchings[i].second ] )
correctMatchCount++;
else
falseMatchCount++;
calcDatasetQuality[di][i].recall = recall( correctMatchCount, allCorrespCount );
calcDatasetQuality[di][i].precision = precision( correctMatchCount, falseMatchCount );
}
}
int DescriptorQualityTest::processResults( int datasetIdx, int caseIdx )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册