提交 97092e31 编写于 作者: C Cosmin Boaca

Added result_probabilities parameter to CvNormalBayesClassifier::predict method. Issue #3401

上级 fff5a6c0
...@@ -52,12 +52,12 @@ CvNormalBayesClassifier::predict ...@@ -52,12 +52,12 @@ CvNormalBayesClassifier::predict
-------------------------------- --------------------------------
Predicts the response for sample(s). Predicts the response for sample(s).
.. ocv:function:: float CvNormalBayesClassifier::predict( const Mat& samples, Mat* results=0 ) const .. ocv:function:: float CvNormalBayesClassifier::predict( const Mat& samples, Mat* results=0, Mat* results_prob=0 ) const
.. ocv:function:: float CvNormalBayesClassifier::predict( const CvMat* samples, CvMat* results=0 ) const .. ocv:function:: float CvNormalBayesClassifier::predict( const CvMat* samples, CvMat* results=0, CvMat* results_prob=0 ) const
.. ocv:pyfunction:: cv2.NormalBayesClassifier.predict(samples) -> retval, results .. ocv:pyfunction:: cv2.NormalBayesClassifier.predict(samples) -> retval, results
The method estimates the most probable classes for input vectors. Input vectors (one or more) are stored as rows of the matrix ``samples``. In case of multiple input vectors, there should be one output vector ``results``. The predicted class for a single input vector is returned by the method. The method estimates the most probable classes for input vectors. Input vectors (one or more) are stored as rows of the matrix ``samples``. In case of multiple input vectors, there should be one output vector ``results``. The predicted class for a single input vector is returned by the method. The vector ``results_prob`` contains the output probabilities coresponding to each element of ``result``.
The function is parallelized with the TBB library. The function is parallelized with the TBB library.
...@@ -201,7 +201,7 @@ public: ...@@ -201,7 +201,7 @@ public:
virtual bool train( const CvMat* trainData, const CvMat* responses, virtual bool train( const CvMat* trainData, const CvMat* responses,
const CvMat* varIdx = 0, const CvMat* sampleIdx=0, bool update=false ); const CvMat* varIdx = 0, const CvMat* sampleIdx=0, bool update=false );
virtual float predict( const CvMat* samples, CV_OUT CvMat* results=0 ) const; virtual float predict( const CvMat* samples, CV_OUT CvMat* results=0, CV_OUT CvMat* results_prob=0 ) const;
CV_WRAP virtual void clear(); CV_WRAP virtual void clear();
CV_WRAP CvNormalBayesClassifier( const cv::Mat& trainData, const cv::Mat& responses, CV_WRAP CvNormalBayesClassifier( const cv::Mat& trainData, const cv::Mat& responses,
...@@ -209,7 +209,7 @@ public: ...@@ -209,7 +209,7 @@ public:
CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses, CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,
const cv::Mat& varIdx = cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varIdx = cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
bool update=false ); bool update=false );
CV_WRAP virtual float predict( const cv::Mat& samples, CV_OUT cv::Mat* results=0 ) const; CV_WRAP virtual float predict( const cv::Mat& samples, CV_OUT cv::Mat* results=0, CV_OUT cv::Mat* results_prob=0 ) const;
virtual void write( CvFileStorage* storage, const char* name ) const; virtual void write( CvFileStorage* storage, const char* name ) const;
virtual void read( CvFileStorage* storage, CvFileNode* node ); virtual void read( CvFileStorage* storage, CvFileNode* node );
......
...@@ -282,7 +282,7 @@ bool CvNormalBayesClassifier::train( const CvMat* _train_data, const CvMat* _res ...@@ -282,7 +282,7 @@ bool CvNormalBayesClassifier::train( const CvMat* _train_data, const CvMat* _res
struct predict_body : cv::ParallelLoopBody { struct predict_body : cv::ParallelLoopBody {
predict_body(CvMat* _c, CvMat** _cov_rotate_mats, CvMat** _inv_eigen_values, CvMat** _avg, predict_body(CvMat* _c, CvMat** _cov_rotate_mats, CvMat** _inv_eigen_values, CvMat** _avg,
const CvMat* _samples, const int* _vidx, CvMat* _cls_labels, const CvMat* _samples, const int* _vidx, CvMat* _cls_labels,
CvMat* _results, float* _value, int _var_count1 CvMat* _results, float* _value, int _var_count1, CvMat* _results_prob
) )
{ {
c = _c; c = _c;
...@@ -295,6 +295,7 @@ struct predict_body : cv::ParallelLoopBody { ...@@ -295,6 +295,7 @@ struct predict_body : cv::ParallelLoopBody {
results = _results; results = _results;
value = _value; value = _value;
var_count1 = _var_count1; var_count1 = _var_count1;
results_prob = _results_prob;
} }
CvMat* c; CvMat* c;
...@@ -305,6 +306,7 @@ struct predict_body : cv::ParallelLoopBody { ...@@ -305,6 +306,7 @@ struct predict_body : cv::ParallelLoopBody {
const int* vidx; const int* vidx;
CvMat* cls_labels; CvMat* cls_labels;
CvMat* results_prob;
CvMat* results; CvMat* results;
float* value; float* value;
int var_count1; int var_count1;
...@@ -313,15 +315,21 @@ struct predict_body : cv::ParallelLoopBody { ...@@ -313,15 +315,21 @@ struct predict_body : cv::ParallelLoopBody {
{ {
int cls = -1; int cls = -1;
int rtype = 0, rstep = 0; int rtype = 0, rstep = 0, rptype = 0, rpstep = 0;
int nclasses = cls_labels->cols; int nclasses = cls_labels->cols;
int _var_count = avg[0]->cols; int _var_count = avg[0]->cols;
double probability = 0;
if (results) if (results)
{ {
rtype = CV_MAT_TYPE(results->type); rtype = CV_MAT_TYPE(results->type);
rstep = CV_IS_MAT_CONT(results->type) ? 1 : results->step/CV_ELEM_SIZE(rtype); rstep = CV_IS_MAT_CONT(results->type) ? 1 : results->step/CV_ELEM_SIZE(rtype);
} }
if (results_prob)
{
rptype = CV_MAT_TYPE(results_prob->type);
rpstep = CV_IS_MAT_CONT(results_prob->type) ? 1 : results_prob->step/CV_ELEM_SIZE(rptype);
}
// allocate memory and initializing headers for calculating // allocate memory and initializing headers for calculating
cv::AutoBuffer<double> buffer(nclasses + var_count1); cv::AutoBuffer<double> buffer(nclasses + var_count1);
CvMat diff = cvMat( 1, var_count1, CV_64FC1, &buffer[0] ); CvMat diff = cvMat( 1, var_count1, CV_64FC1, &buffer[0] );
...@@ -333,7 +341,6 @@ struct predict_body : cv::ParallelLoopBody { ...@@ -333,7 +341,6 @@ struct predict_body : cv::ParallelLoopBody {
for(int i = 0; i < nclasses; i++ ) for(int i = 0; i < nclasses; i++ )
{ {
double cur = c->data.db[i]; double cur = c->data.db[i];
CvMat* u = cov_rotate_mats[i]; CvMat* u = cov_rotate_mats[i];
CvMat* w = inv_eigen_values[i]; CvMat* w = inv_eigen_values[i];
...@@ -358,6 +365,7 @@ struct predict_body : cv::ParallelLoopBody { ...@@ -358,6 +365,7 @@ struct predict_body : cv::ParallelLoopBody {
opt = cur; opt = cur;
} }
/* probability = exp( -0.5 * cur ) */ /* probability = exp( -0.5 * cur ) */
probability = exp( -0.5 * cur );
} }
ival = cls_labels->data.i[cls]; ival = cls_labels->data.i[cls];
...@@ -368,6 +376,13 @@ struct predict_body : cv::ParallelLoopBody { ...@@ -368,6 +376,13 @@ struct predict_body : cv::ParallelLoopBody {
else else
results->data.fl[k*rstep] = (float)ival; results->data.fl[k*rstep] = (float)ival;
} }
if ( results_prob )
{
if ( rptype == CV_32FC1 )
results_prob->data.fl[k*rpstep] = (float)probability;
else
results_prob->data.db[k*rpstep] = probability;
}
if( k == 0 ) if( k == 0 )
*value = (float)ival; *value = (float)ival;
} }
...@@ -375,7 +390,7 @@ struct predict_body : cv::ParallelLoopBody { ...@@ -375,7 +390,7 @@ struct predict_body : cv::ParallelLoopBody {
}; };
float CvNormalBayesClassifier::predict( const CvMat* samples, CvMat* results ) const float CvNormalBayesClassifier::predict( const CvMat* samples, CvMat* results, CvMat* results_prob ) const
{ {
float value = 0; float value = 0;
...@@ -397,11 +412,21 @@ float CvNormalBayesClassifier::predict( const CvMat* samples, CvMat* results ) c ...@@ -397,11 +412,21 @@ float CvNormalBayesClassifier::predict( const CvMat* samples, CvMat* results ) c
"with the number of elements = number of rows in the input matrix" ); "with the number of elements = number of rows in the input matrix" );
} }
if( results_prob )
{
if( !CV_IS_MAT(results_prob) || (CV_MAT_TYPE(results_prob->type) != CV_32FC1 &&
CV_MAT_TYPE(results_prob->type) != CV_64FC1) ||
(results_prob->cols != 1 && results_prob->rows != 1) ||
results_prob->cols + results_prob->rows - 1 != samples->rows )
CV_Error( CV_StsBadArg, "The output array must be double or float vector "
"with the number of elements = number of rows in the input matrix" );
}
const int* vidx = var_idx ? var_idx->data.i : 0; const int* vidx = var_idx ? var_idx->data.i : 0;
cv::parallel_for_(cv::Range(0, samples->rows), cv::parallel_for_(cv::Range(0, samples->rows),
predict_body(c, cov_rotate_mats, inv_eigen_values, avg, samples, predict_body(c, cov_rotate_mats, inv_eigen_values, avg, samples,
vidx, cls_labels, results, &value, var_count)); vidx, cls_labels, results, &value, var_count, results_prob));
return value; return value;
} }
...@@ -608,9 +633,9 @@ bool CvNormalBayesClassifier::train( const Mat& _train_data, const Mat& _respons ...@@ -608,9 +633,9 @@ bool CvNormalBayesClassifier::train( const Mat& _train_data, const Mat& _respons
sidx.data.ptr ? &sidx : 0, update); sidx.data.ptr ? &sidx : 0, update);
} }
float CvNormalBayesClassifier::predict( const Mat& _samples, Mat* _results ) const float CvNormalBayesClassifier::predict( const Mat& _samples, Mat* _results, Mat* _results_prob ) const
{ {
CvMat samples = _samples, results, *presults = 0; CvMat samples = _samples, results, *presults = 0, results_prob, *presults_prob = 0;
if( _results ) if( _results )
{ {
...@@ -621,7 +646,16 @@ float CvNormalBayesClassifier::predict( const Mat& _samples, Mat* _results ) con ...@@ -621,7 +646,16 @@ float CvNormalBayesClassifier::predict( const Mat& _samples, Mat* _results ) con
presults = &(results = *_results); presults = &(results = *_results);
} }
return predict(&samples, presults); if( _results_prob )
{
if( !(_results_prob->data && _results_prob->type() == CV_64F &&
(_results_prob->cols == 1 || _results_prob->rows == 1) &&
_results_prob->cols + _results_prob->rows - 1 == _samples.rows) )
_results_prob->create(_samples.rows, 1, CV_64F);
presults_prob = &(results_prob = *_results_prob);
}
return predict(&samples, presults, presults_prob);
} }
/* End of file. */ /* End of file. */
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册