未验证 提交 c31164bf 编写于 作者: D Danny 提交者: GitHub

Merge pull request #18126 from danielenricocahall:add-oob-error-sample-weighting

Account for sample weights in calculating OOB Error

* account for sample weights in oob error calculation

* redefine oob error functions

* fix ABI compatibility
上级 3835ab39
......@@ -1294,6 +1294,15 @@ public:
*/
CV_WRAP void getVotes(InputArray samples, OutputArray results, int flags) const;
/** Returns the OOB error value, computed at the training stage when calcOOBError is set to true.
* If this flag was set to false, 0 is returned. The OOB error is also scaled by sample weighting.
*/
#if CV_VERSION_MAJOR == 3
CV_WRAP double getOOBError() const;
#else
/*CV_WRAP*/ virtual double getOOBError() const = 0;
#endif
/** Creates the empty model.
Use StatModel::train to train the model, StatModel::train to create and train the model,
Algorithm::load to load the pre-trained model.
......
......@@ -216,13 +216,14 @@ public:
sample = Mat( nallvars, 1, CV_32F, psamples + sstep0*w->sidx[j], sstep1*sizeof(psamples[0]) );
double val = predictTrees(Range(treeidx, treeidx+1), sample, predictFlags);
double sample_weight = w->sample_weights[w->sidx[j]];
if( !_isClassifier )
{
oobres[j] += val;
oobcount[j]++;
double true_val = w->ord_responses[w->sidx[j]];
double a = oobres[j]/oobcount[j] - true_val;
oobError += a*a;
oobError += sample_weight * a*a;
val = (val - true_val)/max_response;
ncorrect_responses += std::exp( -val*val );
}
......@@ -237,7 +238,7 @@ public:
if( votes[best_class] < votes[k] )
best_class = k;
int diff = best_class != w->cat_responses[w->sidx[j]];
oobError += diff;
oobError += sample_weight * diff;
ncorrect_responses += diff == 0;
}
}
......@@ -421,6 +422,10 @@ public:
}
}
double getOOBError() const {
return oobError;
}
RTreeParams rparams;
double oobError;
vector<float> varImportance;
......@@ -505,6 +510,12 @@ public:
const vector<Node>& getNodes() const CV_OVERRIDE { return impl.getNodes(); }
const vector<Split>& getSplits() const CV_OVERRIDE { return impl.getSplits(); }
const vector<int>& getSubsets() const CV_OVERRIDE { return impl.getSubsets(); }
#if CV_VERSION_MAJOR == 3
double getOOBError_() const { return impl.getOOBError(); }
#else
double getOOBError() const CV_OVERRIDE { return impl.getOOBError(); }
#endif
DTreesImplForRTrees impl;
};
......@@ -532,6 +543,17 @@ void RTrees::getVotes(InputArray input, OutputArray output, int flags) const
return this_->getVotes_(input, output, flags);
}
#if CV_VERSION_MAJOR == 3
double RTrees::getOOBError() const
{
CV_TRACE_FUNCTION();
const RTreesImpl* this_ = dynamic_cast<const RTreesImpl*>(this);
if(!this_)
CV_Error(Error::StsNotImplemented, "the class is not RTreesImpl");
return this_->getOOBError_();
}
#endif
}}
// End of file.
......@@ -51,4 +51,50 @@ TEST(ML_RTrees, getVotes)
EXPECT_EQ(result.at<float>(0, predicted_class), rt->predict(test));
}
TEST(ML_RTrees, 11142_sample_weights_regression)
{
int n = 3;
// RTrees for regression
Ptr<ml::RTrees> rt = cv::ml::RTrees::create();
//simple regression problem of x -> 2x
Mat data = (Mat_<float>(n,1) << 1, 2, 3);
Mat values = (Mat_<float>(n,1) << 2, 4, 6);
Mat weights = (Mat_<float>(n, 1) << 10, 10, 10);
Ptr<TrainData> trainData = TrainData::create(data, ml::ROW_SAMPLE, values);
rt->train(trainData);
double error_without_weights = round(rt->getOOBError());
rt->clear();
Ptr<TrainData> trainDataWithWeights = TrainData::create(data, ml::ROW_SAMPLE, values, Mat(), Mat(), weights );
rt->train(trainDataWithWeights);
double error_with_weights = round(rt->getOOBError());
// error with weights should be larger than error without weights
EXPECT_GE(error_with_weights, error_without_weights);
}
TEST(ML_RTrees, 11142_sample_weights_classification)
{
int n = 12;
// RTrees for classification
Ptr<ml::RTrees> rt = cv::ml::RTrees::create();
Mat data(n, 4, CV_32F);
randu(data, 0, 10);
Mat labels = (Mat_<int>(n,1) << 0,0,0,0, 1,1,1,1, 2,2,2,2);
Mat weights = (Mat_<float>(n, 1) << 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10);
rt->train(data, ml::ROW_SAMPLE, labels);
rt->clear();
double error_without_weights = round(rt->getOOBError());
Ptr<TrainData> trainDataWithWeights = TrainData::create(data, ml::ROW_SAMPLE, labels, Mat(), Mat(), weights );
rt->train(data, ml::ROW_SAMPLE, labels);
double error_with_weights = round(rt->getOOBError());
std::cout << error_without_weights << std::endl;
std::cout << error_with_weights << std::endl;
// error with weights should be larger than error without weights
EXPECT_GE(error_with_weights, error_without_weights);
}
}} // namespace
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册