提交 97092e31 编写于 作者: C Cosmin Boaca

Added result_probabilities parameter to CvNormalBayesClassifier::predict method. Issue #3401

上级 fff5a6c0
......@@ -52,12 +52,12 @@ CvNormalBayesClassifier::predict
--------------------------------
Predicts the response for sample(s).
.. ocv:function:: float CvNormalBayesClassifier::predict( const Mat& samples, Mat* results=0 ) const
.. ocv:function:: float CvNormalBayesClassifier::predict( const Mat& samples, Mat* results=0, Mat* results_prob=0 ) const
.. ocv:function:: float CvNormalBayesClassifier::predict( const CvMat* samples, CvMat* results=0 ) const
.. ocv:function:: float CvNormalBayesClassifier::predict( const CvMat* samples, CvMat* results=0, CvMat* results_prob=0 ) const
.. ocv:pyfunction:: cv2.NormalBayesClassifier.predict(samples) -> retval, results
The method estimates the most probable classes for input vectors. Input vectors (one or more) are stored as rows of the matrix ``samples``. In case of multiple input vectors, there should be one output vector ``results``. The predicted class for a single input vector is returned by the method.
The method estimates the most probable classes for input vectors. Input vectors (one or more) are stored as rows of the matrix ``samples``. In case of multiple input vectors, there should be one output vector ``results``. The predicted class for a single input vector is returned by the method. The vector ``results_prob`` contains the output probabilities coresponding to each element of ``result``.
The function is parallelized with the TBB library.
......@@ -201,7 +201,7 @@ public:
virtual bool train( const CvMat* trainData, const CvMat* responses,
const CvMat* varIdx = 0, const CvMat* sampleIdx=0, bool update=false );
virtual float predict( const CvMat* samples, CV_OUT CvMat* results=0 ) const;
virtual float predict( const CvMat* samples, CV_OUT CvMat* results=0, CV_OUT CvMat* results_prob=0 ) const;
CV_WRAP virtual void clear();
CV_WRAP CvNormalBayesClassifier( const cv::Mat& trainData, const cv::Mat& responses,
......@@ -209,7 +209,7 @@ public:
CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,
const cv::Mat& varIdx = cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
bool update=false );
CV_WRAP virtual float predict( const cv::Mat& samples, CV_OUT cv::Mat* results=0 ) const;
CV_WRAP virtual float predict( const cv::Mat& samples, CV_OUT cv::Mat* results=0, CV_OUT cv::Mat* results_prob=0 ) const;
virtual void write( CvFileStorage* storage, const char* name ) const;
virtual void read( CvFileStorage* storage, CvFileNode* node );
......
......@@ -282,7 +282,7 @@ bool CvNormalBayesClassifier::train( const CvMat* _train_data, const CvMat* _res
struct predict_body : cv::ParallelLoopBody {
predict_body(CvMat* _c, CvMat** _cov_rotate_mats, CvMat** _inv_eigen_values, CvMat** _avg,
const CvMat* _samples, const int* _vidx, CvMat* _cls_labels,
CvMat* _results, float* _value, int _var_count1
CvMat* _results, float* _value, int _var_count1, CvMat* _results_prob
)
{
c = _c;
......@@ -295,6 +295,7 @@ struct predict_body : cv::ParallelLoopBody {
results = _results;
value = _value;
var_count1 = _var_count1;
results_prob = _results_prob;
}
CvMat* c;
......@@ -305,6 +306,7 @@ struct predict_body : cv::ParallelLoopBody {
const int* vidx;
CvMat* cls_labels;
CvMat* results_prob;
CvMat* results;
float* value;
int var_count1;
......@@ -313,15 +315,21 @@ struct predict_body : cv::ParallelLoopBody {
{
int cls = -1;
int rtype = 0, rstep = 0;
int rtype = 0, rstep = 0, rptype = 0, rpstep = 0;
int nclasses = cls_labels->cols;
int _var_count = avg[0]->cols;
double probability = 0;
if (results)
{
rtype = CV_MAT_TYPE(results->type);
rstep = CV_IS_MAT_CONT(results->type) ? 1 : results->step/CV_ELEM_SIZE(rtype);
}
if (results_prob)
{
rptype = CV_MAT_TYPE(results_prob->type);
rpstep = CV_IS_MAT_CONT(results_prob->type) ? 1 : results_prob->step/CV_ELEM_SIZE(rptype);
}
// allocate memory and initializing headers for calculating
cv::AutoBuffer<double> buffer(nclasses + var_count1);
CvMat diff = cvMat( 1, var_count1, CV_64FC1, &buffer[0] );
......@@ -333,7 +341,6 @@ struct predict_body : cv::ParallelLoopBody {
for(int i = 0; i < nclasses; i++ )
{
double cur = c->data.db[i];
CvMat* u = cov_rotate_mats[i];
CvMat* w = inv_eigen_values[i];
......@@ -358,6 +365,7 @@ struct predict_body : cv::ParallelLoopBody {
opt = cur;
}
/* probability = exp( -0.5 * cur ) */
probability = exp( -0.5 * cur );
}
ival = cls_labels->data.i[cls];
......@@ -368,6 +376,13 @@ struct predict_body : cv::ParallelLoopBody {
else
results->data.fl[k*rstep] = (float)ival;
}
if ( results_prob )
{
if ( rptype == CV_32FC1 )
results_prob->data.fl[k*rpstep] = (float)probability;
else
results_prob->data.db[k*rpstep] = probability;
}
if( k == 0 )
*value = (float)ival;
}
......@@ -375,7 +390,7 @@ struct predict_body : cv::ParallelLoopBody {
};
float CvNormalBayesClassifier::predict( const CvMat* samples, CvMat* results ) const
float CvNormalBayesClassifier::predict( const CvMat* samples, CvMat* results, CvMat* results_prob ) const
{
float value = 0;
......@@ -397,11 +412,21 @@ float CvNormalBayesClassifier::predict( const CvMat* samples, CvMat* results ) c
"with the number of elements = number of rows in the input matrix" );
}
if( results_prob )
{
if( !CV_IS_MAT(results_prob) || (CV_MAT_TYPE(results_prob->type) != CV_32FC1 &&
CV_MAT_TYPE(results_prob->type) != CV_64FC1) ||
(results_prob->cols != 1 && results_prob->rows != 1) ||
results_prob->cols + results_prob->rows - 1 != samples->rows )
CV_Error( CV_StsBadArg, "The output array must be double or float vector "
"with the number of elements = number of rows in the input matrix" );
}
const int* vidx = var_idx ? var_idx->data.i : 0;
cv::parallel_for_(cv::Range(0, samples->rows),
predict_body(c, cov_rotate_mats, inv_eigen_values, avg, samples,
vidx, cls_labels, results, &value, var_count));
vidx, cls_labels, results, &value, var_count, results_prob));
return value;
}
......@@ -608,9 +633,9 @@ bool CvNormalBayesClassifier::train( const Mat& _train_data, const Mat& _respons
sidx.data.ptr ? &sidx : 0, update);
}
float CvNormalBayesClassifier::predict( const Mat& _samples, Mat* _results ) const
float CvNormalBayesClassifier::predict( const Mat& _samples, Mat* _results, Mat* _results_prob ) const
{
CvMat samples = _samples, results, *presults = 0;
CvMat samples = _samples, results, *presults = 0, results_prob, *presults_prob = 0;
if( _results )
{
......@@ -621,7 +646,16 @@ float CvNormalBayesClassifier::predict( const Mat& _samples, Mat* _results ) con
presults = &(results = *_results);
}
return predict(&samples, presults);
if( _results_prob )
{
if( !(_results_prob->data && _results_prob->type() == CV_64F &&
(_results_prob->cols == 1 || _results_prob->rows == 1) &&
_results_prob->cols + _results_prob->rows - 1 == _samples.rows) )
_results_prob->create(_samples.rows, 1, CV_64F);
presults_prob = &(results_prob = *_results_prob);
}
return predict(&samples, presults, presults_prob);
}
/* End of file. */
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册