提交 d8425d88 编写于 作者: M mrquorr

finished for one sample

Finished with several samples support, need regression testing

Gave a more relevant name to function (getVotes)

Finished implicit implementation

Removed printf, finished regresion testing

Fixed conversion warning

Finished test for Rtrees

Fixed documentation

Initialized variable

Added doxygen documentation

Added parameter name
上级 ec47a0a6
......@@ -1164,6 +1164,17 @@ public:
*/
CV_WRAP virtual Mat getVarImportance() const = 0;
/** Returns the result of each individual tree in the forest.
In case the model is a regression problem, the method will return each of the trees'
results for each of the sample cases. If the model is a classifier, it will return
a Mat with samples + 1 rows, where the first row gives the class number and the
following rows return the votes each class had for each sample.
@param samples Array containg the samples for which votes will be calculated.
@param results Array where the result of the calculation will be written.
@param flags Flags for defining the type of RTrees.
*/
CV_WRAP void getVotes(InputArray samples, OutputArray results, int flags) const;
/** Creates the empty model.
Use StatModel::train to train the model, StatModel::train to create and train the model,
Algorithm::load to load the pre-trained model.
......
......@@ -349,6 +349,60 @@ public:
}
}
void getVotes( InputArray input, OutputArray output, int flags ) const
{
CV_Assert( !roots.empty() );
int nclasses = (int)classLabels.size(), ntrees = (int)roots.size();
Mat samples = input.getMat(), results;
int i, j, nsamples = samples.rows;
int predictType = flags & PREDICT_MASK;
if( predictType == PREDICT_AUTO )
{
predictType = !_isClassifier || (classLabels.size() == 2 && (flags & RAW_OUTPUT) != 0) ?
PREDICT_SUM : PREDICT_MAX_VOTE;
}
if( predictType == PREDICT_SUM )
{
output.create(nsamples, ntrees, CV_32F);
results = output.getMat();
for( i = 0; i < nsamples; i++ )
{
for( j = 0; j < ntrees; j++ )
{
float val = predictTrees( Range(j, j+1), samples.row(i), flags);
results.at<float> (i, j) = val;
}
}
} else
{
vector<int> votes;
output.create(nsamples+1, nclasses, CV_32S);
results = output.getMat();
for ( j = 0; j < nclasses; j++)
{
results.at<int> (0, j) = classLabels[j];
}
for( i = 0; i < nsamples; i++ )
{
votes.clear();
for( j = 0; j < ntrees; j++ )
{
int val = (int)predictTrees( Range(j, j+1), samples.row(i), flags);
votes.push_back(val);
}
for ( j = 0; j < nclasses; j++)
{
results.at<int> (i+1, j) = (int)std::count(votes.begin(), votes.end(), classLabels[j]);
}
}
}
}
RTreeParams rparams;
double oobError;
vector<float> varImportance;
......@@ -401,6 +455,11 @@ public:
impl.read(fn);
}
void getVotes_( InputArray samples, OutputArray results, int flags ) const
{
impl.getVotes(samples, results, flags);
}
Mat getVarImportance() const { return Mat_<float>(impl.varImportance, true); }
int getVarCount() const { return impl.getVarCount(); }
......@@ -427,6 +486,14 @@ Ptr<RTrees> RTrees::load(const String& filepath, const String& nodeName)
return Algorithm::load<RTrees>(filepath, nodeName);
}
void RTrees::getVotes(InputArray input, OutputArray output, int flags) const
{
const RTreesImpl* this_ = dynamic_cast<const RTreesImpl*>(this);
if(!this_)
CV_Error(Error::StsNotImplemented, "the class is not RTreesImpl");
return this_->getVotes_(input, output, flags);
}
}}
// End of file.
......@@ -172,4 +172,49 @@ TEST(ML_NBAYES, regression_5911)
EXPECT_EQ(sum(P1 == P3)[0], 255 * P3.total());
}
TEST(ML_RTrees, getVotes)
{
int n = 12;
int count, i;
int label_size = 3;
int predicted_class = 0;
int max_votes = -1;
int val;
// RTrees for classification
Ptr<ml::RTrees> rt = cv::ml::RTrees::create();
//data
Mat data(n, 4, CV_32F);
randu(data, 0, 10);
//labels
Mat labels = (Mat_<int>(n,1) << 0,0,0,0, 1,1,1,1, 2,2,2,2);
rt->train(data, ml::ROW_SAMPLE, labels);
//run function
Mat test(1, 4, CV_32F);
Mat result;
randu(test, 0, 10);
rt->getVotes(test, result, 0);
//count vote amount and find highest vote
count = 0;
const int* result_row = result.ptr<int>(1);
for( i = 0; i < label_size; i++ )
{
val = result_row[i];
//predicted_class = max_votes < val? i;
if( max_votes < val )
{
max_votes = val;
predicted_class = i;
}
count += val;
}
EXPECT_EQ(count, (int)rt->getRoots().size());
EXPECT_EQ(result.at<float>(0, predicted_class), rt->predict(test));
}
/* End of file. */
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册