svm.cpp 79.7 KB
Newer Older
1 2 3 4 5 6 7 8 9
/*M///////////////////////////////////////////////////////////////////////////////////////
//
//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
//
//  By downloading, copying, installing or using the software you agree to this license.
//  If you do not agree to this license, do not download, install,
//  copy or use the software.
//
//
10 11
//                           License Agreement
//                For Open Source Computer Vision Library
12 13
//
// Copyright (C) 2000, Intel Corporation, all rights reserved.
14
// Copyright (C) 2014, Itseez Inc, all rights reserved.
15 16 17 18 19 20 21 22 23 24 25 26
// Third party copyrights are property of their respective owners.
//
// Redistribution and use in source and binary forms, with or without modification,
// are permitted provided that the following conditions are met:
//
//   * Redistribution's of source code must retain the above copyright notice,
//     this list of conditions and the following disclaimer.
//
//   * Redistribution's in binary form must reproduce the above copyright notice,
//     this list of conditions and the following disclaimer in the documentation
//     and/or other materials provided with the distribution.
//
27
//   * The name of the copyright holders may not be used to endorse or promote products
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
//     derived from this software without specific prior written permission.
//
// This software is provided by the copyright holders and contributors "as is" and
// any express or implied warranties, including, but not limited to, the implied
// warranties of merchantability and fitness for a particular purpose are disclaimed.
// In no event shall the Intel Corporation or contributors be liable for any direct,
// indirect, incidental, special, exemplary, or consequential damages
// (including, but not limited to, procurement of substitute goods or services;
// loss of use, data, or profits; or business interruption) however caused
// and on any theory of liability, whether in contract, strict liability,
// or tort (including negligence or otherwise) arising in any way out of
// the use of this software, even if advised of the possibility of such damage.
//
//M*/

#include "precomp.hpp"

45 46 47
#include <stdarg.h>
#include <ctype.h>

48 49 50 51 52 53 54
/****************************************************************************************\
                                COPYRIGHT NOTICE
                                ----------------

  The code has been derived from libsvm library (version 2.6)
  (http://www.csie.ntu.edu.tw/~cjlin/libsvm).

L
luz.paz 已提交
55
  Here is the original copyright:
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
------------------------------------------------------------------------------------------
    Copyright (c) 2000-2003 Chih-Chung Chang and Chih-Jen Lin
    All rights reserved.

    Redistribution and use in source and binary forms, with or without
    modification, are permitted provided that the following conditions
    are met:

    1. Redistributions of source code must retain the above copyright
    notice, this list of conditions and the following disclaimer.

    2. Redistributions in binary form must reproduce the above copyright
    notice, this list of conditions and the following disclaimer in the
    documentation and/or other materials provided with the distribution.

    3. Neither name of copyright holders nor the names of its contributors
    may be used to endorse or promote products derived from this software
    without specific prior written permission.


    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE REGENTS OR
    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
\****************************************************************************************/

89
namespace cv { namespace ml {
90 91

typedef float Qfloat;
92
const int QFLOAT_TYPE = DataDepth<Qfloat>::value;
93 94

// Param Grid
95
static void checkParamGrid(const ParamGrid& pg)
96
{
97 98 99 100 101 102
    if( pg.minVal > pg.maxVal )
        CV_Error( CV_StsBadArg, "Lower bound of the grid must be less then the upper one" );
    if( pg.minVal < DBL_EPSILON )
        CV_Error( CV_StsBadArg, "Lower bound of the grid must be positive" );
    if( pg.logStep < 1. + FLT_EPSILON )
        CV_Error( CV_StsBadArg, "Grid step must greater then 1" );
103 104 105
}

// SVM training parameters
106
struct SvmParams
107
{
108 109 110 111 112 113 114 115 116 117
    int         svmType;
    int         kernelType;
    double      gamma;
    double      coef0;
    double      degree;
    double      C;
    double      nu;
    double      p;
    Mat         classWeights;
    TermCriteria termCrit;
118

119 120 121 122 123 124 125 126 127 128 129 130
    SvmParams()
    {
        svmType = SVM::C_SVC;
        kernelType = SVM::RBF;
        degree = 0;
        gamma = 1;
        coef0 = 0;
        C = 1;
        nu = 0;
        p = 0;
        termCrit = TermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 1000, FLT_EPSILON );
    }
131

132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
    SvmParams( int _svmType, int _kernelType,
            double _degree, double _gamma, double _coef0,
            double _Con, double _nu, double _p,
            const Mat& _classWeights, TermCriteria _termCrit )
    {
        svmType = _svmType;
        kernelType = _kernelType;
        degree = _degree;
        gamma = _gamma;
        coef0 = _coef0;
        C = _Con;
        nu = _nu;
        p = _p;
        classWeights = _classWeights;
        termCrit = _termCrit;
    }

};
150 151

/////////////////////////////////////// SVM kernel ///////////////////////////////////////
152
class SVMKernelImpl CV_FINAL : public SVM::Kernel
153
{
154
public:
155
    SVMKernelImpl( const SvmParams& _params = SvmParams() )
156
    {
157
        params = _params;
158 159
    }

160
    int getType() const CV_OVERRIDE
161 162 163
    {
        return params.kernelType;
    }
164

165 166 167 168 169 170
    void calc_non_rbf_base( int vcount, int var_count, const float* vecs,
                            const float* another, Qfloat* results,
                            double alpha, double beta )
    {
        int j, k;
        for( j = 0; j < vcount; j++ )
171
        {
172 173 174 175 176 177 178 179
            const float* sample = &vecs[j*var_count];
            double s = 0;
            for( k = 0; k <= var_count - 4; k += 4 )
                s += sample[k]*another[k] + sample[k+1]*another[k+1] +
                sample[k+2]*another[k+2] + sample[k+3]*another[k+3];
            for( ; k < var_count; k++ )
                s += sample[k]*another[k];
            results[j] = (Qfloat)(s*alpha + beta);
180 181 182
        }
    }

183 184
    void calc_linear( int vcount, int var_count, const float* vecs,
                      const float* another, Qfloat* results )
185
    {
186
        calc_non_rbf_base( vcount, var_count, vecs, another, results, 1, 0 );
A
Andrey Kamaev 已提交
187
    }
188

189 190
    void calc_poly( int vcount, int var_count, const float* vecs,
                    const float* another, Qfloat* results )
191
    {
192 193 194 195 196 197 198 199
        Mat R( 1, vcount, QFLOAT_TYPE, results );
        calc_non_rbf_base( vcount, var_count, vecs, another, results, params.gamma, params.coef0 );
        if( vcount > 0 )
            pow( R, params.degree, R );
    }

    void calc_sigmoid( int vcount, int var_count, const float* vecs,
                       const float* another, Qfloat* results )
A
Andrey Kamaev 已提交
200
    {
201 202 203 204 205
        int j;
        calc_non_rbf_base( vcount, var_count, vecs, another, results,
                          -2*params.gamma, -2*params.coef0 );
        // TODO: speedup this
        for( j = 0; j < vcount; j++ )
A
Andrey Kamaev 已提交
206
        {
207 208 209 210 211 212
            Qfloat t = results[j];
            Qfloat e = std::exp(-std::abs(t));
            if( t > 0 )
                results[j] = (Qfloat)((1. - e)/(1. + e));
            else
                results[j] = (Qfloat)((e - 1.)/(e + 1.));
A
Andrey Kamaev 已提交
213 214 215 216
        }
    }


217 218
    void calc_rbf( int vcount, int var_count, const float* vecs,
                   const float* another, Qfloat* results )
219
    {
220 221
        double gamma = -params.gamma;
        int j, k;
222

223 224 225 226
        for( j = 0; j < vcount; j++ )
        {
            const float* sample = &vecs[j*var_count];
            double s = 0;
227

228 229 230 231
            for( k = 0; k <= var_count - 4; k += 4 )
            {
                double t0 = sample[k] - another[k];
                double t1 = sample[k+1] - another[k+1];
232

233
                s += t0*t0 + t1*t1;
234

235 236
                t0 = sample[k+2] - another[k+2];
                t1 = sample[k+3] - another[k+3];
237

238 239
                s += t0*t0 + t1*t1;
            }
240

241 242 243 244 245 246 247
            for( ; k < var_count; k++ )
            {
                double t0 = sample[k] - another[k];
                s += t0*t0;
            }
            results[j] = (Qfloat)(s*gamma);
        }
248

249 250 251 252 253
        if( vcount > 0 )
        {
            Mat R( 1, vcount, QFLOAT_TYPE, results );
            exp( R, R );
        }
254
    }
255 256 257 258

    /// Histogram intersection kernel
    void calc_intersec( int vcount, int var_count, const float* vecs,
                        const float* another, Qfloat* results )
259
    {
260 261 262 263 264 265 266 267 268 269 270 271
        int j, k;
        for( j = 0; j < vcount; j++ )
        {
            const float* sample = &vecs[j*var_count];
            double s = 0;
            for( k = 0; k <= var_count - 4; k += 4 )
                s += std::min(sample[k],another[k]) + std::min(sample[k+1],another[k+1]) +
                std::min(sample[k+2],another[k+2]) + std::min(sample[k+3],another[k+3]);
            for( ; k < var_count; k++ )
                s += std::min(sample[k],another[k]);
            results[j] = (Qfloat)(s);
        }
272 273
    }

274 275 276
    /// Exponential chi2 kernel
    void calc_chi2( int vcount, int var_count, const float* vecs,
                    const float* another, Qfloat* results )
277
    {
278 279 280 281 282 283 284 285 286 287 288 289
        Mat R( 1, vcount, QFLOAT_TYPE, results );
        double gamma = -params.gamma;
        int j, k;
        for( j = 0; j < vcount; j++ )
        {
            const float* sample = &vecs[j*var_count];
            double chi2 = 0;
            for(k = 0 ; k < var_count; k++ )
            {
                double d = sample[k]-another[k];
                double devisor = sample[k]+another[k];
                /// if devisor == 0, the Chi2 distance would be zero,
L
luz.paz 已提交
290
                // but calculation would rise an error because of dividing by zero
291 292 293 294 295 296 297 298 299
                if (devisor != 0)
                {
                    chi2 += d*d/devisor;
                }
            }
            results[j] = (Qfloat) (gamma*chi2);
        }
        if( vcount > 0 )
            exp( R, R );
300
    }
301

302
    void calc( int vcount, int var_count, const float* vecs,
303
               const float* another, Qfloat* results ) CV_OVERRIDE
304
    {
305
        switch( params.kernelType )
306
        {
307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326
        case SVM::LINEAR:
            calc_linear(vcount, var_count, vecs, another, results);
            break;
        case SVM::RBF:
            calc_rbf(vcount, var_count, vecs, another, results);
            break;
        case SVM::POLY:
            calc_poly(vcount, var_count, vecs, another, results);
            break;
        case SVM::SIGMOID:
            calc_sigmoid(vcount, var_count, vecs, another, results);
            break;
        case SVM::CHI2:
            calc_chi2(vcount, var_count, vecs, another, results);
            break;
        case SVM::INTER:
            calc_intersec(vcount, var_count, vecs, another, results);
            break;
        default:
            CV_Error(CV_StsBadArg, "Unknown kernel type");
327
        }
328 329
        const Qfloat max_val = (Qfloat)(FLT_MAX*1e-3);
        for( int j = 0; j < vcount; j++ )
330
        {
331 332
            if( results[j] > max_val )
                results[j] = max_val;
333 334 335
        }
    }

336
    SvmParams params;
337
};
338 339 340



341
/////////////////////////////////////////////////////////////////////////
342

343 344
static void sortSamplesByClasses( const Mat& _samples, const Mat& _responses,
                           vector<int>& sidx_all, vector<int>& class_ranges )
345
{
346 347
    int i, nsamples = _samples.rows;
    CV_Assert( _responses.isContinuous() && _responses.checkVector(1, CV_32S) == nsamples );
348

349
    setRangeVector(sidx_all, nsamples);
350

351 352 353 354
    const int* rptr = _responses.ptr<int>();
    std::sort(sidx_all.begin(), sidx_all.end(), cmp_lt_idx<int>(rptr));
    class_ranges.clear();
    class_ranges.push_back(0);
355

356 357 358 359 360 361
    for( i = 0; i < nsamples; i++ )
    {
        if( i == nsamples-1 || rptr[sidx_all[i]] != rptr[sidx_all[i+1]] )
            class_ranges.push_back(i+1);
    }
}
362

363
//////////////////////// SVM implementation //////////////////////////////
364

365 366 367 368 369 370
Ptr<ParamGrid> SVM::getDefaultGridPtr( int param_id)
{
  ParamGrid grid = getDefaultGrid(param_id); // this is not a nice solution..
  return makePtr<ParamGrid>(grid.minVal, grid.maxVal, grid.logStep);
}

371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415
ParamGrid SVM::getDefaultGrid( int param_id )
{
    ParamGrid grid;
    if( param_id == SVM::C )
    {
        grid.minVal = 0.1;
        grid.maxVal = 500;
        grid.logStep = 5; // total iterations = 5
    }
    else if( param_id == SVM::GAMMA )
    {
        grid.minVal = 1e-5;
        grid.maxVal = 0.6;
        grid.logStep = 15; // total iterations = 4
    }
    else if( param_id == SVM::P )
    {
        grid.minVal = 0.01;
        grid.maxVal = 100;
        grid.logStep = 7; // total iterations = 4
    }
    else if( param_id == SVM::NU )
    {
        grid.minVal = 0.01;
        grid.maxVal = 0.2;
        grid.logStep = 3; // total iterations = 3
    }
    else if( param_id == SVM::COEF )
    {
        grid.minVal = 0.1;
        grid.maxVal = 300;
        grid.logStep = 14; // total iterations = 3
    }
    else if( param_id == SVM::DEGREE )
    {
        grid.minVal = 0.01;
        grid.maxVal = 4;
        grid.logStep = 7; // total iterations = 3
    }
    else
        cvError( CV_StsBadArg, "SVM::getDefaultGrid", "Invalid type of parameter "
                "(use one of SVM::C, SVM::GAMMA et al.)", __FILE__, __LINE__ );
    return grid;
}

416

417
class SVMImpl CV_FINAL : public SVM
418
{
419 420
public:
    struct DecisionFunc
421
    {
422 423 424 425 426
        DecisionFunc(double _rho, int _ofs) : rho(_rho), ofs(_ofs) {}
        DecisionFunc() : rho(0.), ofs(0) {}
        double rho;
        int ofs;
    };
427

428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480
    // Generalized SMO+SVMlight algorithm
    // Solves:
    //
    //  min [0.5(\alpha^T Q \alpha) + b^T \alpha]
    //
    //      y^T \alpha = \delta
    //      y_i = +1 or -1
    //      0 <= alpha_i <= Cp for y_i = 1
    //      0 <= alpha_i <= Cn for y_i = -1
    //
    // Given:
    //
    //  Q, b, y, Cp, Cn, and an initial feasible point \alpha
    //  l is the size of vectors and matrices
    //  eps is the stopping criterion
    //
    // solution will be put in \alpha, objective value will be put in obj
    //
    class Solver
    {
    public:
        enum { MIN_CACHE_SIZE = (40 << 20) /* 40Mb */, MAX_CACHE_SIZE = (500 << 20) /* 500Mb */ };

        typedef bool (Solver::*SelectWorkingSet)( int& i, int& j );
        typedef Qfloat* (Solver::*GetRow)( int i, Qfloat* row, Qfloat* dst, bool existed );
        typedef void (Solver::*CalcRho)( double& rho, double& r );

        struct KernelRow
        {
            KernelRow() { idx = -1; prev = next = 0; }
            KernelRow(int _idx, int _prev, int _next) : idx(_idx), prev(_prev), next(_next) {}
            int idx;
            int prev;
            int next;
        };

        struct SolutionInfo
        {
            SolutionInfo() { obj = rho = upper_bound_p = upper_bound_n = r = 0; }
            double obj;
            double rho;
            double upper_bound_p;
            double upper_bound_n;
            double r;   // for Solver_NU
        };

        void clear()
        {
            alpha_vec = 0;
            select_working_set_func = 0;
            calc_rho_func = 0;
            get_row_func = 0;
            lru_cache.clear();
481 482
        }

483 484 485 486 487 488
        Solver( const Mat& _samples, const vector<schar>& _y,
                vector<double>& _alpha, const vector<double>& _b,
                double _Cp, double _Cn,
                const Ptr<SVM::Kernel>& _kernel, GetRow _get_row,
                SelectWorkingSet _select_working_set, CalcRho _calc_rho,
                TermCriteria _termCrit )
489
        {
490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533
            clear();

            samples = _samples;
            sample_count = samples.rows;
            var_count = samples.cols;

            y_vec = _y;
            alpha_vec = &_alpha;
            alpha_count = (int)alpha_vec->size();
            b_vec = _b;
            kernel = _kernel;

            C[0] = _Cn;
            C[1] = _Cp;
            eps = _termCrit.epsilon;
            max_iter = _termCrit.maxCount;

            G_vec.resize(alpha_count);
            alpha_status_vec.resize(alpha_count);
            buf[0].resize(sample_count*2);
            buf[1].resize(sample_count*2);

            select_working_set_func = _select_working_set;
            CV_Assert(select_working_set_func != 0);

            calc_rho_func = _calc_rho;
            CV_Assert(calc_rho_func != 0);

            get_row_func = _get_row;
            CV_Assert(get_row_func != 0);

            // assume that for large training sets ~25% of Q matrix is used
            int64 csize = (int64)sample_count*sample_count/4;
            csize = std::max(csize, (int64)(MIN_CACHE_SIZE/sizeof(Qfloat)) );
            csize = std::min(csize, (int64)(MAX_CACHE_SIZE/sizeof(Qfloat)) );
            max_cache_size = (int)((csize + sample_count-1)/sample_count);
            max_cache_size = std::min(std::max(max_cache_size, 1), sample_count);
            cache_size = 0;

            lru_cache.clear();
            lru_cache.resize(sample_count+1, KernelRow(-1, 0, 0));
            lru_first = lru_last = 0;
            lru_cache_data.create(max_cache_size, sample_count, QFLOAT_TYPE);
        }
534

535 536 537 538 539 540 541
        Qfloat* get_row_base( int i, bool* _existed )
        {
            int i1 = i < sample_count ? i : i - sample_count;
            KernelRow& kr = lru_cache[i1+1];
            if( _existed )
                *_existed = kr.idx >= 0;
            if( kr.idx < 0 )
542
            {
543 544 545 546
                if( cache_size < max_cache_size )
                {
                    kr.idx = cache_size;
                    cache_size++;
547 548
                    if (!lru_last)
                        lru_last = i1+1;
549 550 551 552 553 554 555 556
                }
                else
                {
                    KernelRow& last = lru_cache[lru_last];
                    kr.idx = last.idx;
                    last.idx = -1;
                    lru_cache[last.prev].next = 0;
                    lru_last = last.prev;
557 558
                    last.prev = 0;
                    last.next = 0;
559 560 561
                }
                kernel->calc( sample_count, var_count, samples.ptr<float>(),
                              samples.ptr<float>(i1), lru_cache_data.ptr<Qfloat>(kr.idx) );
562
            }
563
            else
564
            {
565 566 567 568 569 570 571 572
                if( kr.next )
                    lru_cache[kr.next].prev = kr.prev;
                else
                    lru_last = kr.prev;
                if( kr.prev )
                    lru_cache[kr.prev].next = kr.next;
                else
                    lru_first = kr.next;
573
            }
574 575
            if (lru_first)
                lru_cache[lru_first].prev = i1+1;
576 577 578
            kr.next = lru_first;
            kr.prev = 0;
            lru_first = i1+1;
579

580
            return lru_cache_data.ptr<Qfloat>(kr.idx);
581 582
        }

583 584 585
        Qfloat* get_row_svc( int i, Qfloat* row, Qfloat*, bool existed )
        {
            if( !existed )
586
            {
587 588
                const schar* _y = &y_vec[0];
                int j, len = sample_count;
589

590 591 592 593 594 595 596 597 598 599
                if( _y[i] > 0 )
                {
                    for( j = 0; j < len; j++ )
                        row[j] = _y[j]*row[j];
                }
                else
                {
                    for( j = 0; j < len; j++ )
                        row[j] = -_y[j]*row[j];
                }
600
            }
601
            return row;
602 603
        }

604 605 606 607
        Qfloat* get_row_one_class( int, Qfloat* row, Qfloat*, bool )
        {
            return row;
        }
608

609 610 611 612 613 614 615
        Qfloat* get_row_svr( int i, Qfloat* row, Qfloat* dst, bool )
        {
            int j, len = sample_count;
            Qfloat* dst_pos = dst;
            Qfloat* dst_neg = dst + len;
            if( i >= len )
                std::swap(dst_pos, dst_neg);
616

617 618 619 620 621 622 623 624
            for( j = 0; j < len; j++ )
            {
                Qfloat t = row[j];
                dst_pos[j] = t;
                dst_neg[j] = -t;
            }
            return dst;
        }
625

626 627 628 629 630 631
        Qfloat* get_row( int i, float* dst )
        {
            bool existed = false;
            float* row = get_row_base( i, &existed );
            return (this->*get_row_func)( i, row, dst, existed );
        }
632

633 634
        #undef is_upper_bound
        #define is_upper_bound(i) (alpha_status[i] > 0)
635

636 637
        #undef is_lower_bound
        #define is_lower_bound(i) (alpha_status[i] < 0)
638

639 640
        #undef is_free
        #define is_free(i) (alpha_status[i] == 0)
641

642 643
        #undef get_C
        #define get_C(i) (C[y[i]>0])
644

645 646 647
        #undef update_alpha_status
        #define update_alpha_status(i) \
            alpha_status[i] = (schar)(alpha[i] >= get_C(i) ? 1 : alpha[i] <= 0 ? -1 : 0)
648

649 650
        #undef reconstruct_gradient
        #define reconstruct_gradient() /* empty for now */
651

652 653 654 655 656 657 658
        bool solve_generic( SolutionInfo& si )
        {
            const schar* y = &y_vec[0];
            double* alpha = &alpha_vec->at(0);
            schar* alpha_status = &alpha_status_vec[0];
            double* G = &G_vec[0];
            double* b = &b_vec[0];
659

660 661
            int iter = 0;
            int i, j, k;
662

663 664
            // 1. initialize gradient and alpha status
            for( i = 0; i < alpha_count; i++ )
665
            {
666 667 668 669
                update_alpha_status(i);
                G[i] = b[i];
                if( fabs(G[i]) > 1e200 )
                    return false;
670
            }
671 672

            for( i = 0; i < alpha_count; i++ )
673
            {
674 675 676 677 678 679 680 681
                if( !is_lower_bound(i) )
                {
                    const Qfloat *Q_i = get_row( i, &buf[0][0] );
                    double alpha_i = alpha[i];

                    for( j = 0; j < alpha_count; j++ )
                        G[j] += alpha_i*Q_i[j];
                }
682
            }
683 684 685

            // 2. optimization loop
            for(;;)
686
            {
687 688 689 690
                const Qfloat *Q_i, *Q_j;
                double C_i, C_j;
                double old_alpha_i, old_alpha_j, alpha_i, alpha_j;
                double delta_alpha_i, delta_alpha_j;
691

692 693 694 695 696
        #ifdef _DEBUG
                for( i = 0; i < alpha_count; i++ )
                {
                    if( fabs(G[i]) > 1e+300 )
                        return false;
697

698 699 700 701
                    if( fabs(alpha[i]) > 1e16 )
                        return false;
                }
        #endif
702

703 704
                if( (this->*select_working_set_func)( i, j ) != 0 || iter++ >= max_iter )
                    break;
705

706 707
                Q_i = get_row( i, &buf[0][0] );
                Q_j = get_row( j, &buf[1][0] );
708

709 710
                C_i = get_C(i);
                C_j = get_C(j);
711

712 713
                alpha_i = old_alpha_i = alpha[i];
                alpha_j = old_alpha_j = alpha[j];
714

715 716 717 718 719 720 721
                if( y[i] != y[j] )
                {
                    double denom = Q_i[i]+Q_j[j]+2*Q_i[j];
                    double delta = (-G[i]-G[j])/MAX(fabs(denom),FLT_EPSILON);
                    double diff = alpha_i - alpha_j;
                    alpha_i += delta;
                    alpha_j += delta;
722

723 724 725 726 727 728 729 730 731 732
                    if( diff > 0 && alpha_j < 0 )
                    {
                        alpha_j = 0;
                        alpha_i = diff;
                    }
                    else if( diff <= 0 && alpha_i < 0 )
                    {
                        alpha_i = 0;
                        alpha_j = -diff;
                    }
733

734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751
                    if( diff > C_i - C_j && alpha_i > C_i )
                    {
                        alpha_i = C_i;
                        alpha_j = C_i - diff;
                    }
                    else if( diff <= C_i - C_j && alpha_j > C_j )
                    {
                        alpha_j = C_j;
                        alpha_i = C_j + diff;
                    }
                }
                else
                {
                    double denom = Q_i[i]+Q_j[j]-2*Q_i[j];
                    double delta = (G[i]-G[j])/MAX(fabs(denom),FLT_EPSILON);
                    double sum = alpha_i + alpha_j;
                    alpha_i -= delta;
                    alpha_j += delta;
752

753 754 755 756 757 758 759 760 761 762
                    if( sum > C_i && alpha_i > C_i )
                    {
                        alpha_i = C_i;
                        alpha_j = sum - C_i;
                    }
                    else if( sum <= C_i && alpha_j < 0)
                    {
                        alpha_j = 0;
                        alpha_i = sum;
                    }
763

764 765 766 767 768 769 770 771 772 773 774
                    if( sum > C_j && alpha_j > C_j )
                    {
                        alpha_j = C_j;
                        alpha_i = sum - C_j;
                    }
                    else if( sum <= C_j && alpha_i < 0 )
                    {
                        alpha_i = 0;
                        alpha_j = sum;
                    }
                }
775

776 777 778 779 780
                // update alpha
                alpha[i] = alpha_i;
                alpha[j] = alpha_j;
                update_alpha_status(i);
                update_alpha_status(j);
781

782 783 784
                // update G
                delta_alpha_i = alpha_i - old_alpha_i;
                delta_alpha_j = alpha_j - old_alpha_j;
785

786 787
                for( k = 0; k < alpha_count; k++ )
                    G[k] += Q_i[k]*delta_alpha_i + Q_j[k]*delta_alpha_j;
788 789
            }

790 791
            // calculate rho
            (this->*calc_rho_func)( si.rho, si.r );
792

793 794 795
            // calculate objective value
            for( i = 0, si.obj = 0; i < alpha_count; i++ )
                si.obj += alpha[i] * (G[i] + b[i]);
796

797
            si.obj *= 0.5;
798

799 800
            si.upper_bound_p = C[1];
            si.upper_bound_n = C[0];
801

802
            return true;
803 804
        }

805 806
        // return 1 if already optimal, return 0 otherwise
        bool select_working_set( int& out_i, int& out_j )
807
        {
808 809 810 811 812
            // return i,j which maximize -grad(f)^T d , under constraint
            // if alpha_i == C, d != +1
            // if alpha_i == 0, d != -1
            double Gmax1 = -DBL_MAX;        // max { -grad(f)_i * d | y_i*d = +1 }
            int Gmax1_idx = -1;
813

814 815
            double Gmax2 = -DBL_MAX;        // max { -grad(f)_i * d | y_i*d = -1 }
            int Gmax2_idx = -1;
816

817 818 819
            const schar* y = &y_vec[0];
            const schar* alpha_status = &alpha_status_vec[0];
            const double* G = &G_vec[0];
820

821
            for( int i = 0; i < alpha_count; i++ )
822
            {
823
                double t;
824

825
                if( y[i] > 0 )    // y = +1
826
                {
827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845
                    if( !is_upper_bound(i) && (t = -G[i]) > Gmax1 )  // d = +1
                    {
                        Gmax1 = t;
                        Gmax1_idx = i;
                    }
                    if( !is_lower_bound(i) && (t = G[i]) > Gmax2 )  // d = -1
                    {
                        Gmax2 = t;
                        Gmax2_idx = i;
                    }
                }
                else        // y = -1
                {
                    if( !is_upper_bound(i) && (t = -G[i]) > Gmax2 )  // d = +1
                    {
                        Gmax2 = t;
                        Gmax2_idx = i;
                    }
                    if( !is_lower_bound(i) && (t = G[i]) > Gmax1 )  // d = -1
846
                    {
847 848
                        Gmax1 = t;
                        Gmax1_idx = i;
849 850 851
                    }
                }
            }
852 853 854 855 856

            out_i = Gmax1_idx;
            out_j = Gmax2_idx;

            return Gmax1 + Gmax2 < eps;
857 858
        }

859
        void calc_rho( double& rho, double& r )
860
        {
861 862 863 864 865 866 867
            int nr_free = 0;
            double ub = DBL_MAX, lb = -DBL_MAX, sum_free = 0;
            const schar* y = &y_vec[0];
            const schar* alpha_status = &alpha_status_vec[0];
            const double* G = &G_vec[0];

            for( int i = 0; i < alpha_count; i++ )
868
            {
869
                double yG = y[i]*G[i];
870

871
                if( is_lower_bound(i) )
872
                {
873 874 875 876
                    if( y[i] > 0 )
                        ub = MIN(ub,yG);
                    else
                        lb = MAX(lb,yG);
877
                }
878
                else if( is_upper_bound(i) )
879
                {
880 881 882 883
                    if( y[i] < 0)
                        ub = MIN(ub,yG);
                    else
                        lb = MAX(lb,yG);
884
                }
885
                else
886
                {
887 888
                    ++nr_free;
                    sum_free += yG;
889
                }
890 891 892 893 894 895 896 897 898 899 900 901 902
            }

            rho = nr_free > 0 ? sum_free/nr_free : (ub + lb)*0.5;
            r = 0;
        }

        bool select_working_set_nu_svm( int& out_i, int& out_j )
        {
            // return i,j which maximize -grad(f)^T d , under constraint
            // if alpha_i == C, d != +1
            // if alpha_i == 0, d != -1
            double Gmax1 = -DBL_MAX;    // max { -grad(f)_i * d | y_i = +1, d = +1 }
            int Gmax1_idx = -1;
903

904 905
            double Gmax2 = -DBL_MAX;    // max { -grad(f)_i * d | y_i = +1, d = -1 }
            int Gmax2_idx = -1;
906

907 908
            double Gmax3 = -DBL_MAX;    // max { -grad(f)_i * d | y_i = -1, d = +1 }
            int Gmax3_idx = -1;
909

910 911
            double Gmax4 = -DBL_MAX;    // max { -grad(f)_i * d | y_i = -1, d = -1 }
            int Gmax4_idx = -1;
912

913 914 915
            const schar* y = &y_vec[0];
            const schar* alpha_status = &alpha_status_vec[0];
            const double* G = &G_vec[0];
916

917 918 919 920 921
            for( int i = 0; i < alpha_count; i++ )
            {
                double t;

                if( y[i] > 0 )    // y == +1
922
                {
923 924 925 926 927 928
                    if( !is_upper_bound(i) && (t = -G[i]) > Gmax1 )  // d = +1
                    {
                        Gmax1 = t;
                        Gmax1_idx = i;
                    }
                    if( !is_lower_bound(i) && (t = G[i]) > Gmax2 )  // d = -1
929
                    {
930 931
                        Gmax2 = t;
                        Gmax2_idx = i;
932 933
                    }
                }
934
                else        // y == -1
935
                {
936 937 938 939 940 941
                    if( !is_upper_bound(i) && (t = -G[i]) > Gmax3 )  // d = +1
                    {
                        Gmax3 = t;
                        Gmax3_idx = i;
                    }
                    if( !is_lower_bound(i) && (t = G[i]) > Gmax4 )  // d = -1
942
                    {
943 944
                        Gmax4 = t;
                        Gmax4_idx = i;
945 946 947 948
                    }
                }
            }

949 950
            if( MAX(Gmax1 + Gmax2, Gmax3 + Gmax4) < eps )
                return 1;
951

952
            if( Gmax1 + Gmax2 > Gmax3 + Gmax4 )
953
            {
954 955
                out_i = Gmax1_idx;
                out_j = Gmax2_idx;
956
            }
957 958 959 960 961 962
            else
            {
                out_i = Gmax3_idx;
                out_j = Gmax4_idx;
            }
            return 0;
963 964
        }

965
        void calc_rho_nu_svm( double& rho, double& r )
966
        {
967 968 969 970 971 972 973 974 975 976
            int nr_free1 = 0, nr_free2 = 0;
            double ub1 = DBL_MAX, ub2 = DBL_MAX;
            double lb1 = -DBL_MAX, lb2 = -DBL_MAX;
            double sum_free1 = 0, sum_free2 = 0;

            const schar* y = &y_vec[0];
            const schar* alpha_status = &alpha_status_vec[0];
            const double* G = &G_vec[0];

            for( int i = 0; i < alpha_count; i++ )
977
            {
978 979
                double G_i = G[i];
                if( y[i] > 0 )
980
                {
981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001
                    if( is_lower_bound(i) )
                        ub1 = MIN( ub1, G_i );
                    else if( is_upper_bound(i) )
                        lb1 = MAX( lb1, G_i );
                    else
                    {
                        ++nr_free1;
                        sum_free1 += G_i;
                    }
                }
                else
                {
                    if( is_lower_bound(i) )
                        ub2 = MIN( ub2, G_i );
                    else if( is_upper_bound(i) )
                        lb2 = MAX( lb2, G_i );
                    else
                    {
                        ++nr_free2;
                        sum_free2 += G_i;
                    }
1002 1003
                }
            }
1004

1005 1006
            double r1 = nr_free1 > 0 ? sum_free1/nr_free1 : (ub1 + lb1)*0.5;
            double r2 = nr_free2 > 0 ? sum_free2/nr_free2 : (ub2 + lb2)*0.5;
1007

1008 1009
            rho = (r1 - r2)*0.5;
            r = (r1 + r2)*0.5;
1010
        }
1011

1012 1013 1014 1015 1016 1017 1018 1019
        /*
        ///////////////////////// construct and solve various formulations ///////////////////////
        */
        static bool solve_c_svc( const Mat& _samples, const vector<schar>& _y,
                                 double _Cp, double _Cn, const Ptr<SVM::Kernel>& _kernel,
                                 vector<double>& _alpha, SolutionInfo& _si, TermCriteria termCrit )
        {
            int sample_count = _samples.rows;
1020

1021 1022
            _alpha.assign(sample_count, 0.);
            vector<double> _b(sample_count, -1.);
1023

1024 1025 1026 1027 1028
            Solver solver( _samples, _y, _alpha, _b, _Cp, _Cn, _kernel,
                           &Solver::get_row_svc,
                           &Solver::select_working_set,
                           &Solver::calc_rho,
                           termCrit );
1029

1030 1031
            if( !solver.solve_generic( _si ))
                return false;
1032

1033 1034
            for( int i = 0; i < sample_count; i++ )
                _alpha[i] *= _y[i];
1035

1036 1037
            return true;
        }
1038 1039


1040 1041 1042 1043 1044 1045
        static bool solve_nu_svc( const Mat& _samples, const vector<schar>& _y,
                                  double nu, const Ptr<SVM::Kernel>& _kernel,
                                  vector<double>& _alpha, SolutionInfo& _si,
                                  TermCriteria termCrit )
        {
            int sample_count = _samples.rows;
1046

1047 1048
            _alpha.resize(sample_count);
            vector<double> _b(sample_count, 0.);
1049

1050 1051
            double sum_pos = nu * sample_count * 0.5;
            double sum_neg = nu * sample_count * 0.5;
1052

1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067
            for( int i = 0; i < sample_count; i++ )
            {
                double a;
                if( _y[i] > 0 )
                {
                    a = std::min(1.0, sum_pos);
                    sum_pos -= a;
                }
                else
                {
                    a = std::min(1.0, sum_neg);
                    sum_neg -= a;
                }
                _alpha[i] = a;
            }
1068

1069 1070 1071 1072 1073
            Solver solver( _samples, _y, _alpha, _b, 1., 1., _kernel,
                           &Solver::get_row_svc,
                           &Solver::select_working_set_nu_svm,
                           &Solver::calc_rho_nu_svm,
                           termCrit );
1074

1075 1076
            if( !solver.solve_generic( _si ))
                return false;
1077

1078
            double inv_r = 1./_si.r;
1079

1080 1081
            for( int i = 0; i < sample_count; i++ )
                _alpha[i] *= _y[i]*inv_r;
1082

1083 1084 1085 1086
            _si.rho *= inv_r;
            _si.obj *= (inv_r*inv_r);
            _si.upper_bound_p = inv_r;
            _si.upper_bound_n = inv_r;
1087

1088 1089
            return true;
        }
1090

1091 1092 1093 1094 1095 1096 1097 1098
        static bool solve_one_class( const Mat& _samples, double nu,
                                     const Ptr<SVM::Kernel>& _kernel,
                                     vector<double>& _alpha, SolutionInfo& _si,
                                     TermCriteria termCrit )
        {
            int sample_count = _samples.rows;
            vector<schar> _y(sample_count, 1);
            vector<double> _b(sample_count, 0.);
1099

1100
            int i, n = cvRound( nu*sample_count );
1101

1102 1103 1104
            _alpha.resize(sample_count);
            for( i = 0; i < sample_count; i++ )
                _alpha[i] = i < n ? 1 : 0;
1105

1106 1107 1108 1109
            if( n < sample_count )
                _alpha[n] = nu * sample_count - n;
            else
                _alpha[n-1] = nu * sample_count - (n-1);
1110

1111 1112 1113 1114 1115
            Solver solver( _samples, _y, _alpha, _b, 1., 1., _kernel,
                           &Solver::get_row_one_class,
                           &Solver::select_working_set,
                           &Solver::calc_rho,
                           termCrit );
1116

1117 1118
            return solver.solve_generic(_si);
        }
1119

1120 1121 1122 1123 1124 1125 1126
        static bool solve_eps_svr( const Mat& _samples, const vector<float>& _yf,
                                   double p, double C, const Ptr<SVM::Kernel>& _kernel,
                                   vector<double>& _alpha, SolutionInfo& _si,
                                   TermCriteria termCrit )
        {
            int sample_count = _samples.rows;
            int alpha_count = sample_count*2;
1127

1128
            CV_Assert( (int)_yf.size() == sample_count );
1129

1130 1131 1132
            _alpha.assign(alpha_count, 0.);
            vector<schar> _y(alpha_count);
            vector<double> _b(alpha_count);
1133

1134
            for( int i = 0; i < sample_count; i++ )
1135
            {
1136 1137
                _b[i] = p - _yf[i];
                _y[i] = 1;
1138

1139 1140 1141
                _b[i+sample_count] = p + _yf[i];
                _y[i+sample_count] = -1;
            }
1142

1143 1144 1145 1146 1147
            Solver solver( _samples, _y, _alpha, _b, C, C, _kernel,
                           &Solver::get_row_svr,
                           &Solver::select_working_set,
                           &Solver::calc_rho,
                           termCrit );
1148

1149 1150
            if( !solver.solve_generic( _si ))
                return false;
1151

1152 1153
            for( int i = 0; i < sample_count; i++ )
                _alpha[i] -= _alpha[i+sample_count];
1154

1155
            return true;
1156 1157 1158
        }


1159 1160 1161 1162 1163 1164 1165 1166
        static bool solve_nu_svr( const Mat& _samples, const vector<float>& _yf,
                                  double nu, double C, const Ptr<SVM::Kernel>& _kernel,
                                  vector<double>& _alpha, SolutionInfo& _si,
                                  TermCriteria termCrit )
        {
            int sample_count = _samples.rows;
            int alpha_count = sample_count*2;
            double sum = C * nu * sample_count * 0.5;
1167

1168
            CV_Assert( (int)_yf.size() == sample_count );
1169

1170 1171 1172
            _alpha.resize(alpha_count);
            vector<schar> _y(alpha_count);
            vector<double> _b(alpha_count);
1173

1174 1175 1176 1177
            for( int i = 0; i < sample_count; i++ )
            {
                _alpha[i] = _alpha[i + sample_count] = std::min(sum, C);
                sum -= _alpha[i];
1178

1179 1180
                _b[i] = -_yf[i];
                _y[i] = 1;
1181

1182 1183 1184
                _b[i + sample_count] = _yf[i];
                _y[i + sample_count] = -1;
            }
1185

1186 1187 1188 1189 1190
            Solver solver( _samples, _y, _alpha, _b, 1., 1., _kernel,
                           &Solver::get_row_svr,
                           &Solver::select_working_set_nu_svm,
                           &Solver::calc_rho_nu_svm,
                           termCrit );
1191

1192 1193
            if( !solver.solve_generic( _si ))
                return false;
1194

1195 1196
            for( int i = 0; i < sample_count; i++ )
                _alpha[i] -= _alpha[i+sample_count];
1197

1198 1199
            return true;
        }
1200

1201 1202 1203 1204 1205
        int sample_count;
        int var_count;
        int cache_size;
        int max_cache_size;
        Mat samples;
1206
        SvmParams params;
1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225
        vector<KernelRow> lru_cache;
        int lru_first;
        int lru_last;
        Mat lru_cache_data;

        int alpha_count;

        vector<double> G_vec;
        vector<double>* alpha_vec;
        vector<schar> y_vec;
        // -1 - lower bound, 0 - free, 1 - upper bound
        vector<schar> alpha_status_vec;
        vector<double> b_vec;

        vector<Qfloat> buf[2];
        double eps;
        int max_iter;
        double C[2];  // C[0] == Cn, C[1] == Cp
        Ptr<SVM::Kernel> kernel;
1226

1227 1228 1229 1230 1231 1232 1233
        SelectWorkingSet select_working_set_func;
        CalcRho calc_rho_func;
        GetRow get_row_func;
    };

    //////////////////////////////////////////////////////////////////////////////////////////
    SVMImpl()
1234
    {
1235
        clear();
1236
        checkParams();
1237
    }
1238 1239

    ~SVMImpl()
1240
    {
1241 1242
        clear();
    }
1243

1244
    void clear() CV_OVERRIDE
1245 1246 1247 1248 1249
    {
        decision_func.clear();
        df_alpha.clear();
        df_index.clear();
        sv.release();
1250 1251 1252 1253 1254 1255
        uncompressed_sv.release();
    }

    Mat getUncompressedSupportVectors_() const
    {
        return uncompressed_sv;
1256
    }
1257

1258
    Mat getSupportVectors() const CV_OVERRIDE
1259 1260 1261
    {
        return sv;
    }
1262

1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283
    inline int getType() const CV_OVERRIDE { return params.svmType; }
    inline void setType(int val) CV_OVERRIDE { params.svmType = val; }
    inline double getGamma() const CV_OVERRIDE { return params.gamma; }
    inline void setGamma(double val) CV_OVERRIDE { params.gamma = val; }
    inline double getCoef0() const CV_OVERRIDE { return params.coef0; }
    inline void setCoef0(double val) CV_OVERRIDE { params.coef0 = val; }
    inline double getDegree() const CV_OVERRIDE { return params.degree; }
    inline void setDegree(double val) CV_OVERRIDE { params.degree = val; }
    inline double getC() const CV_OVERRIDE { return params.C; }
    inline void setC(double val) CV_OVERRIDE { params.C = val; }
    inline double getNu() const CV_OVERRIDE { return params.nu; }
    inline void setNu(double val) CV_OVERRIDE { params.nu = val; }
    inline double getP() const CV_OVERRIDE { return params.p; }
    inline void setP(double val) CV_OVERRIDE { params.p = val; }
    inline cv::Mat getClassWeights() const CV_OVERRIDE { return params.classWeights; }
    inline void setClassWeights(const cv::Mat& val) CV_OVERRIDE { params.classWeights = val; }
    inline cv::TermCriteria getTermCriteria() const CV_OVERRIDE { return params.termCrit; }
    inline void setTermCriteria(const cv::TermCriteria& val) CV_OVERRIDE { params.termCrit = val; }

    int getKernelType() const CV_OVERRIDE { return params.kernelType; }
    void setKernel(int kernelType) CV_OVERRIDE
1284 1285 1286 1287 1288
    {
        params.kernelType = kernelType;
        if (kernelType != CUSTOM)
            kernel = makePtr<SVMKernelImpl>(params);
    }
1289

1290
    void setCustomKernel(const Ptr<Kernel> &_kernel) CV_OVERRIDE
1291 1292 1293 1294
    {
        params.kernelType = CUSTOM;
        kernel = _kernel;
    }
1295

1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319
    void checkParams()
    {
        int kernelType = params.kernelType;
        if (kernelType != CUSTOM)
        {
            if( kernelType != LINEAR && kernelType != POLY &&
                kernelType != SIGMOID && kernelType != RBF &&
                kernelType != INTER && kernelType != CHI2)
                CV_Error( CV_StsBadArg, "Unknown/unsupported kernel type" );

            if( kernelType == LINEAR )
                params.gamma = 1;
            else if( params.gamma <= 0 )
                CV_Error( CV_StsOutOfRange, "gamma parameter of the kernel must be positive" );

            if( kernelType != SIGMOID && kernelType != POLY )
                params.coef0 = 0;
            else if( params.coef0 < 0 )
                CV_Error( CV_StsOutOfRange, "The kernel parameter <coef0> must be positive or zero" );

            if( kernelType != POLY )
                params.degree = 0;
            else if( params.degree <= 0 )
                CV_Error( CV_StsOutOfRange, "The kernel parameter <degree> must be positive" );
1320

1321 1322 1323 1324 1325 1326 1327
            kernel = makePtr<SVMKernelImpl>(params);
        }
        else
        {
            if (!kernel)
                CV_Error( CV_StsBadArg, "Custom kernel is not set" );
        }
1328

1329
        int svmType = params.svmType;
1330

1331 1332 1333 1334
        if( svmType != C_SVC && svmType != NU_SVC &&
            svmType != ONE_CLASS && svmType != EPS_SVR &&
            svmType != NU_SVR )
            CV_Error( CV_StsBadArg, "Unknown/unsupported SVM type" );
1335

1336 1337 1338 1339
        if( svmType == ONE_CLASS || svmType == NU_SVC )
            params.C = 0;
        else if( params.C <= 0 )
            CV_Error( CV_StsOutOfRange, "The parameter C must be positive" );
1340

1341 1342 1343 1344
        if( svmType == C_SVC || svmType == EPS_SVR )
            params.nu = 0;
        else if( params.nu <= 0 || params.nu >= 1 )
            CV_Error( CV_StsOutOfRange, "The parameter nu must be between 0 and 1" );
1345

1346 1347 1348 1349
        if( svmType != EPS_SVR )
            params.p = 0;
        else if( params.p <= 0 )
            CV_Error( CV_StsOutOfRange, "The parameter p must be positive" );
1350

1351 1352
        if( svmType != C_SVC )
            params.classWeights.release();
1353

1354 1355 1356 1357 1358 1359
        if( !(params.termCrit.type & TermCriteria::EPS) )
            params.termCrit.epsilon = DBL_EPSILON;
        params.termCrit.epsilon = std::max(params.termCrit.epsilon, DBL_EPSILON);
        if( !(params.termCrit.type & TermCriteria::COUNT) )
            params.termCrit.maxCount = INT_MAX;
        params.termCrit.maxCount = std::max(params.termCrit.maxCount, 1);
1360
    }
1361

1362
    void setParams( const SvmParams& _params)
1363
    {
1364 1365
        params = _params;
        checkParams();
1366
    }
1367 1368 1369 1370 1371

    int getSVCount(int i) const
    {
        return (i < (int)(decision_func.size()-1) ? decision_func[i+1].ofs :
                (int)df_index.size()) - decision_func[i].ofs;
1372 1373
    }

1374 1375 1376 1377 1378 1379
    bool do_train( const Mat& _samples, const Mat& _responses )
    {
        int svmType = params.svmType;
        int i, j, k, sample_count = _samples.rows;
        vector<double> _alpha;
        Solver::SolutionInfo sinfo;
1380

1381 1382
        CV_Assert( _samples.type() == CV_32F );
        var_count = _samples.cols;
1383

1384 1385 1386 1387
        if( svmType == ONE_CLASS || svmType == EPS_SVR || svmType == NU_SVR )
        {
            int sv_count = 0;
            decision_func.clear();
V
Vadim Pisarevsky 已提交
1388

1389 1390 1391
            vector<float> _yf;
            if( !_responses.empty() )
                _responses.convertTo(_yf, CV_32F);
1392

1393
            bool ok =
1394 1395 1396
            svmType == ONE_CLASS ? Solver::solve_one_class( _samples, params.nu, kernel, _alpha, sinfo, params.termCrit ) :
            svmType == EPS_SVR ? Solver::solve_eps_svr( _samples, _yf, params.p, params.C, kernel, _alpha, sinfo, params.termCrit ) :
            svmType == NU_SVR ? Solver::solve_nu_svr( _samples, _yf, params.nu, params.C, kernel, _alpha, sinfo, params.termCrit ) : false;
1397

1398 1399
            if( !ok )
                return false;
1400

1401 1402
            for( i = 0; i < sample_count; i++ )
                sv_count += fabs(_alpha[i]) > 0;
1403

1404
            CV_Assert(sv_count != 0);
1405

1406 1407 1408
            sv.create(sv_count, _samples.cols, CV_32F);
            df_alpha.resize(sv_count);
            df_index.resize(sv_count);
1409

1410 1411 1412 1413 1414 1415 1416 1417 1418 1419
            for( i = k = 0; i < sample_count; i++ )
            {
                if( std::abs(_alpha[i]) > 0 )
                {
                    _samples.row(i).copyTo(sv.row(k));
                    df_alpha[k] = _alpha[i];
                    df_index[k] = k;
                    k++;
                }
            }
1420

1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435
            decision_func.push_back(DecisionFunc(sinfo.rho, 0));
        }
        else
        {
            int class_count = (int)class_labels.total();
            vector<int> svidx, sidx, sidx_all, sv_tab(sample_count, 0);
            Mat temp_samples, class_weights;
            vector<int> class_ranges;
            vector<schar> temp_y;
            double nu = params.nu;
            CV_Assert( svmType == C_SVC || svmType == NU_SVC );

            if( svmType == C_SVC && !params.classWeights.empty() )
            {
                const Mat cw = params.classWeights;
1436

1437 1438 1439 1440 1441
                if( (cw.cols != 1 && cw.rows != 1) ||
                    (int)cw.total() != class_count ||
                    (cw.type() != CV_32F && cw.type() != CV_64F) )
                    CV_Error( CV_StsBadArg, "params.class_weights must be 1d floating-point vector "
                        "containing as many elements as the number of classes" );
1442

1443 1444 1445
                cw.convertTo(class_weights, CV_64F, params.C);
                //normalize(cw, class_weights, params.C, 0, NORM_L1, CV_64F);
            }
1446

1447 1448 1449
            decision_func.clear();
            df_alpha.clear();
            df_index.clear();
1450

1451
            sortSamplesByClasses( _samples, _responses, sidx_all, class_ranges );
1452

1453 1454 1455
            //check that while cross-validation there were the samples from all the classes
            if( class_ranges[class_count] <= 0 )
                CV_Error( CV_StsBadArg, "While cross-validation one or more of the classes have "
I
Ishank gulati 已提交
1456
                "been fell out of the sample. Try to reduce <Params::k_fold>" );
1457

1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472
            if( svmType == NU_SVC )
            {
                // check if nu is feasible
                for( i = 0; i < class_count; i++ )
                {
                    int ci = class_ranges[i+1] - class_ranges[i];
                    for( j = i+1; j< class_count; j++ )
                    {
                        int cj = class_ranges[j+1] - class_ranges[j];
                        if( nu*(ci + cj)*0.5 > std::min( ci, cj ) )
                            // TODO: add some diagnostic
                            return false;
                    }
                }
            }
1473

1474
            size_t samplesize = _samples.cols*_samples.elemSize();
1475

1476 1477 1478 1479 1480 1481 1482 1483
            // train n*(n-1)/2 classifiers
            for( i = 0; i < class_count; i++ )
            {
                for( j = i+1; j < class_count; j++ )
                {
                    int si = class_ranges[i], ci = class_ranges[i+1] - si;
                    int sj = class_ranges[j], cj = class_ranges[j+1] - sj;
                    double Cp = params.C, Cn = Cp;
1484

1485 1486 1487
                    temp_samples.create(ci + cj, _samples.cols, _samples.type());
                    sidx.resize(ci + cj);
                    temp_y.resize(ci + cj);
1488

1489 1490 1491 1492 1493 1494 1495 1496
                    // form input for the binary classification problem
                    for( k = 0; k < ci+cj; k++ )
                    {
                        int idx = k < ci ? si+k : sj+k-ci;
                        memcpy(temp_samples.ptr(k), _samples.ptr(sidx_all[idx]), samplesize);
                        sidx[k] = sidx_all[idx];
                        temp_y[k] = k < ci ? 1 : -1;
                    }
1497

1498 1499 1500 1501 1502
                    if( !class_weights.empty() )
                    {
                        Cp = class_weights.at<double>(i);
                        Cn = class_weights.at<double>(j);
                    }
1503

1504 1505 1506
                    DecisionFunc df;
                    bool ok = params.svmType == C_SVC ?
                                Solver::solve_c_svc( temp_samples, temp_y, Cp, Cn,
1507
                                                     kernel, _alpha, sinfo, params.termCrit ) :
1508 1509
                              params.svmType == NU_SVC ?
                                Solver::solve_nu_svc( temp_samples, temp_y, params.nu,
1510
                                                      kernel, _alpha, sinfo, params.termCrit ) :
1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529
                              false;
                    if( !ok )
                        return false;
                    df.rho = sinfo.rho;
                    df.ofs = (int)df_index.size();
                    decision_func.push_back(df);

                    for( k = 0; k < ci + cj; k++ )
                    {
                        if( std::abs(_alpha[k]) > 0 )
                        {
                            int idx = k < ci ? si+k : sj+k-ci;
                            sv_tab[sidx_all[idx]] = 1;
                            df_index.push_back(sidx_all[idx]);
                            df_alpha.push_back(_alpha[k]);
                        }
                    }
                }
            }
1530

1531 1532 1533 1534 1535 1536
            // allocate support vectors and initialize sv_tab
            for( i = 0, k = 0; i < sample_count; i++ )
            {
                if( sv_tab[i] )
                    sv_tab[i] = ++k;
            }
1537

1538 1539
            int sv_total = k;
            sv.create(sv_total, _samples.cols, _samples.type());
1540

1541 1542 1543 1544 1545 1546
            for( i = 0; i < sample_count; i++ )
            {
                if( !sv_tab[i] )
                    continue;
                memcpy(sv.ptr(sv_tab[i]-1), _samples.ptr(i), samplesize);
            }
1547

1548 1549 1550 1551 1552 1553 1554 1555
            // set sv pointers
            int n = (int)df_index.size();
            for( i = 0; i < n; i++ )
            {
                CV_Assert( sv_tab[df_index[i]] > 0 );
                df_index[i] = sv_tab[df_index[i]] - 1;
            }
        }
1556

1557
        optimize_linear_svm();
1558

1559 1560
        return true;
    }
1561

1562 1563 1564 1565 1566
    void optimize_linear_svm()
    {
        // we optimize only linear SVM: compress all the support vectors into one.
        if( params.kernelType != LINEAR )
            return;
1567

1568
        int i, df_count = (int)decision_func.size();
1569

1570 1571 1572 1573 1574
        for( i = 0; i < df_count; i++ )
        {
            if( getSVCount(i) != 1 )
                break;
        }
1575

1576 1577 1578 1579
        // if every decision functions uses a single support vector;
        // it's already compressed. skip it then.
        if( i == df_count )
            return;
1580

1581
        AutoBuffer<double> vbuf(var_count);
1582
        double* v = vbuf.data();
1583
        Mat new_sv(df_count, var_count, CV_32F);
1584

1585
        vector<DecisionFunc> new_df;
1586

1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605
        for( i = 0; i < df_count; i++ )
        {
            float* dst = new_sv.ptr<float>(i);
            memset(v, 0, var_count*sizeof(v[0]));
            int j, k, sv_count = getSVCount(i);
            const DecisionFunc& df = decision_func[i];
            const int* sv_index = &df_index[df.ofs];
            const double* sv_alpha = &df_alpha[df.ofs];
            for( j = 0; j < sv_count; j++ )
            {
                const float* src = sv.ptr<float>(sv_index[j]);
                double a = sv_alpha[j];
                for( k = 0; k < var_count; k++ )
                    v[k] += src[k]*a;
            }
            for( k = 0; k < var_count; k++ )
                dst[k] = (float)v[k];
            new_df.push_back(DecisionFunc(df.rho, i));
        }
1606

1607 1608
        setRangeVector(df_index, df_count);
        df_alpha.assign(df_count, 1.);
1609
        sv.copyTo(uncompressed_sv);
1610 1611
        std::swap(sv, new_sv);
        std::swap(decision_func, new_df);
1612 1613
    }

1614
    bool train( const Ptr<TrainData>& data, int ) CV_OVERRIDE
1615
    {
1616
        clear();
1617

1618 1619
        checkParams();

1620 1621 1622
        int svmType = params.svmType;
        Mat samples = data->getTrainSamples();
        Mat responses;
1623

1624
        if( svmType == C_SVC || svmType == NU_SVC )
1625
        {
1626
            responses = data->getTrainNormCatResponses();
1627 1628 1629
            if( responses.empty() )
                CV_Error(CV_StsBadArg, "in the case of classification problem the responses must be categorical; "
                                       "either specify varType when creating TrainData, or pass integer responses");
1630
            class_labels = data->getClassLabels();
1631 1632
        }
        else
1633
            responses = data->getTrainResponses();
1634

1635 1636 1637 1638 1639
        if( !do_train( samples, responses ))
        {
            clear();
            return false;
        }
1640

1641 1642
        return true;
    }
1643

1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658
    class TrainAutoBody : public ParallelLoopBody
    {
    public:
        TrainAutoBody(const vector<SvmParams>& _parameters,
                      const cv::Mat& _samples,
                      const cv::Mat& _responses,
                      const cv::Mat& _labels,
                      const vector<int>& _sidx,
                      bool _is_classification,
                      int _k_fold,
                      std::vector<double>& _result) :
        parameters(_parameters), samples(_samples), responses(_responses), labels(_labels),
        sidx(_sidx), is_classification(_is_classification), k_fold(_k_fold), result(_result)
        {}

1659
        void operator()( const cv::Range& range ) const CV_OVERRIDE
1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736
        {
            int sample_count = samples.rows;
            int var_count_ = samples.cols;
            size_t sample_size = var_count_*samples.elemSize();

            int test_sample_count = (sample_count + k_fold/2)/k_fold;
            int train_sample_count = sample_count - test_sample_count;

            // Use a local instance
            cv::Ptr<SVMImpl> svm = makePtr<SVMImpl>();
            svm->class_labels = labels;

            int rtype = responses.type();

            Mat temp_train_samples(train_sample_count, var_count_, CV_32F);
            Mat temp_test_samples(test_sample_count, var_count_, CV_32F);
            Mat temp_train_responses(train_sample_count, 1, rtype);
            Mat temp_test_responses;

            for( int p = range.start; p < range.end; p++ )
            {
                svm->setParams(parameters[p]);

                double error = 0;
                for( int k = 0; k < k_fold; k++ )
                {
                    int start = (k*sample_count + k_fold/2)/k_fold;
                    for( int i = 0; i < train_sample_count; i++ )
                    {
                        int j = sidx[(i+start)%sample_count];
                        memcpy(temp_train_samples.ptr(i), samples.ptr(j), sample_size);
                        if( is_classification )
                            temp_train_responses.at<int>(i) = responses.at<int>(j);
                        else if( !responses.empty() )
                            temp_train_responses.at<float>(i) = responses.at<float>(j);
                    }

                    // Train SVM on <train_size> samples
                    if( !svm->do_train( temp_train_samples, temp_train_responses ))
                        continue;

                    for( int i = 0; i < test_sample_count; i++ )
                    {
                        int j = sidx[(i+start+train_sample_count) % sample_count];
                        memcpy(temp_test_samples.ptr(i), samples.ptr(j), sample_size);
                    }

                    svm->predict(temp_test_samples, temp_test_responses, 0);
                    for( int i = 0; i < test_sample_count; i++ )
                    {
                        float val = temp_test_responses.at<float>(i);
                        int j = sidx[(i+start+train_sample_count) % sample_count];
                        if( is_classification )
                            error += (float)(val != responses.at<int>(j));
                        else
                        {
                            val -= responses.at<float>(j);
                            error += val*val;
                        }
                    }
                }

                result[p] = error;
            }
        }

    private:
        const vector<SvmParams>& parameters;
        const cv::Mat& samples;
        const cv::Mat& responses;
        const cv::Mat& labels;
        const vector<int>& sidx;
        bool is_classification;
        int k_fold;
        std::vector<double>& result;
    };

1737 1738 1739
    bool trainAuto( const Ptr<TrainData>& data, int k_fold,
                    ParamGrid C_grid, ParamGrid gamma_grid, ParamGrid p_grid,
                    ParamGrid nu_grid, ParamGrid coef_grid, ParamGrid degree_grid,
1740
                    bool balanced ) CV_OVERRIDE
1741
    {
1742 1743
        checkParams();

1744
        int svmType = params.svmType;
1745
        RNG rng((uint64)-1);
1746

1747 1748 1749
        if( svmType == ONE_CLASS )
            // current implementation of "auto" svm does not support the 1-class case.
            return train( data, 0 );
1750

1751
        clear();
1752

1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789
        CV_Assert( k_fold >= 2 );

        // All the parameters except, possibly, <coef0> are positive.
        // <coef0> is nonnegative
        #define CHECK_GRID(grid, param) \
        if( grid.logStep <= 1 ) \
        { \
            grid.minVal = grid.maxVal = params.param; \
            grid.logStep = 10; \
        } \
        else \
            checkParamGrid(grid)

        CHECK_GRID(C_grid, C);
        CHECK_GRID(gamma_grid, gamma);
        CHECK_GRID(p_grid, p);
        CHECK_GRID(nu_grid, nu);
        CHECK_GRID(coef_grid, coef0);
        CHECK_GRID(degree_grid, degree);

        // these parameters are not used:
        if( params.kernelType != POLY )
            degree_grid.minVal = degree_grid.maxVal = params.degree;
        if( params.kernelType == LINEAR )
            gamma_grid.minVal = gamma_grid.maxVal = params.gamma;
        if( params.kernelType != POLY && params.kernelType != SIGMOID )
            coef_grid.minVal = coef_grid.maxVal = params.coef0;
        if( svmType == NU_SVC || svmType == ONE_CLASS )
            C_grid.minVal = C_grid.maxVal = params.C;
        if( svmType == C_SVC || svmType == EPS_SVR )
            nu_grid.minVal = nu_grid.maxVal = params.nu;
        if( svmType != EPS_SVR )
            p_grid.minVal = p_grid.maxVal = params.p;

        Mat samples = data->getTrainSamples();
        Mat responses;
        bool is_classification = false;
J
Jie Yang 已提交
1790
        Mat class_labels0;
1791 1792 1793 1794 1795 1796
        int class_count = (int)class_labels.total();

        if( svmType == C_SVC || svmType == NU_SVC )
        {
            responses = data->getTrainNormCatResponses();
            class_labels = data->getClassLabels();
S
Sancho McCann 已提交
1797
            class_count = (int)class_labels.total();
1798
            is_classification = true;
1799

1800 1801
            vector<int> temp_class_labels;
            setRangeVector(temp_class_labels, class_count);
1802

1803
            // temporarily replace class labels with 0, 1, ..., NCLASSES-1
J
Jie Yang 已提交
1804 1805
            class_labels0 = class_labels;
            class_labels = Mat(temp_class_labels).clone();
1806 1807 1808
        }
        else
            responses = data->getTrainResponses();
1809

1810
        CV_Assert(samples.type() == CV_32F);
1811

1812 1813
        int sample_count = samples.rows;
        var_count = samples.cols;
1814

1815 1816
        vector<int> sidx;
        setRangeVector(sidx, sample_count);
1817

1818
        // randomly permute training samples
1819
        for( int i = 0; i < sample_count; i++ )
1820 1821 1822 1823 1824
        {
            int i1 = rng.uniform(0, sample_count);
            int i2 = rng.uniform(0, sample_count);
            std::swap(sidx[i1], sidx[i2]);
        }
1825

1826 1827 1828 1829 1830 1831
        if( is_classification && class_count == 2 && balanced )
        {
            // reshuffle the training set in such a way that
            // instances of each class are divided more or less evenly
            // between the k_fold parts.
            vector<int> sidx0, sidx1;
1832

1833
            for( int i = 0; i < sample_count; i++ )
1834 1835 1836 1837 1838 1839
            {
                if( responses.at<int>(sidx[i]) == 0 )
                    sidx0.push_back(sidx[i]);
                else
                    sidx1.push_back(sidx[i]);
            }
1840

1841 1842 1843
            int n0 = (int)sidx0.size(), n1 = (int)sidx1.size();
            int a0 = 0, a1 = 0;
            sidx.clear();
1844
            for( int k = 0; k < k_fold; k++ )
1845 1846 1847
            {
                int b0 = ((k+1)*n0 + k_fold/2)/k_fold, b1 = ((k+1)*n1 + k_fold/2)/k_fold;
                int a = (int)sidx.size(), b = a + (b0 - a0) + (b1 - a1);
1848
                for( int i = a0; i < b0; i++ )
1849
                    sidx.push_back(sidx0[i]);
1850
                for( int i = a1; i < b1; i++ )
1851
                    sidx.push_back(sidx1[i]);
1852
                for( int i = 0; i < (b - a); i++ )
1853 1854 1855 1856 1857 1858 1859 1860
                {
                    int i1 = rng.uniform(a, b);
                    int i2 = rng.uniform(a, b);
                    std::swap(sidx[i1], sidx[i2]);
                }
                a0 = b0; a1 = b1;
            }
        }
1861

S
Sancho McCann 已提交
1862
        // If grid.minVal == grid.maxVal, this will allow one and only one pass through the loop with params.var = grid.minVal.
1863
        #define FOR_IN_GRID(var, grid) \
S
Sancho McCann 已提交
1864
            for( params.var = grid.minVal; params.var == grid.minVal || params.var < grid.maxVal; params.var = (grid.minVal == grid.maxVal) ? grid.maxVal + 1 : params.var * grid.logStep )
1865

1866 1867
        // Create the list of parameters to test
        std::vector<SvmParams> parameters;
1868 1869 1870 1871 1872 1873 1874
        FOR_IN_GRID(C, C_grid)
        FOR_IN_GRID(gamma, gamma_grid)
        FOR_IN_GRID(p, p_grid)
        FOR_IN_GRID(nu, nu_grid)
        FOR_IN_GRID(coef0, coef_grid)
        FOR_IN_GRID(degree, degree_grid)
        {
1875 1876
            parameters.push_back(params);
        }
1877

1878 1879 1880 1881
        std::vector<double> result(parameters.size());
        TrainAutoBody invoker(parameters, samples, responses, class_labels, sidx,
                              is_classification, k_fold, result);
        parallel_for_(cv::Range(0,(int)parameters.size()), invoker);
1882

1883 1884 1885 1886 1887 1888
        // Extract the best parameters
        SvmParams best_params = params;
        double min_error = FLT_MAX;
        for( int i = 0; i < (int)result.size(); i++ )
        {
            if( result[i] < min_error )
1889
            {
1890 1891
                min_error   = result[i];
                best_params = parameters[i];
1892 1893
            }
        }
1894

J
Jie Yang 已提交
1895
        class_labels = class_labels0;
1896
        setParams(best_params);
1897 1898
        return do_train( samples, responses );
    }
1899

1900 1901 1902 1903 1904 1905 1906 1907 1908
    struct PredictBody : ParallelLoopBody
    {
        PredictBody( const SVMImpl* _svm, const Mat& _samples, Mat& _results, bool _returnDFVal )
        {
            svm = _svm;
            results = &_results;
            samples = &_samples;
            returnDFVal = _returnDFVal;
        }
M
Maria Dimashova 已提交
1909

1910
        void operator()(const Range& range) const CV_OVERRIDE
1911 1912 1913
        {
            int svmType = svm->params.svmType;
            int sv_total = svm->sv.rows;
1914
            int class_count = !svm->class_labels.empty() ? (int)svm->class_labels.total() : svmType == ONE_CLASS ? 1 : 0;
1915

1916
            AutoBuffer<float> _buffer(sv_total + (class_count+1)*2);
1917
            float* buffer = _buffer.data();
1918

1919
            int i, j, dfi, k, si;
1920

1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934 1935 1936 1937 1938
            if( svmType == EPS_SVR || svmType == NU_SVR || svmType == ONE_CLASS )
            {
                for( si = range.start; si < range.end; si++ )
                {
                    const float* row_sample = samples->ptr<float>(si);
                    svm->kernel->calc( sv_total, svm->var_count, svm->sv.ptr<float>(), row_sample, buffer );

                    const SVMImpl::DecisionFunc* df = &svm->decision_func[0];
                    double sum = -df->rho;
                    for( i = 0; i < sv_total; i++ )
                        sum += buffer[i]*svm->df_alpha[i];
                    float result = svm->params.svmType == ONE_CLASS && !returnDFVal ? (float)(sum > 0) : (float)sum;
                    results->at<float>(si) = result;
                }
            }
            else if( svmType == C_SVC || svmType == NU_SVC )
            {
                int* vote = (int*)(buffer + sv_total);
1939

1940 1941 1942 1943 1944
                for( si = range.start; si < range.end; si++ )
                {
                    svm->kernel->calc( sv_total, svm->var_count, svm->sv.ptr<float>(),
                                       samples->ptr<float>(si), buffer );
                    double sum = 0.;
1945

1946
                    memset( vote, 0, class_count*sizeof(vote[0]));
1947

1948 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 1961 1962
                    for( i = dfi = 0; i < class_count; i++ )
                    {
                        for( j = i+1; j < class_count; j++, dfi++ )
                        {
                            const DecisionFunc& df = svm->decision_func[dfi];
                            sum = -df.rho;
                            int sv_count = svm->getSVCount(dfi);
                            const double* alpha = &svm->df_alpha[df.ofs];
                            const int* sv_index = &svm->df_index[df.ofs];
                            for( k = 0; k < sv_count; k++ )
                                sum += alpha[k]*buffer[sv_index[k]];

                            vote[sum > 0 ? i : j]++;
                        }
                    }
1963

1964 1965 1966 1967 1968 1969 1970 1971 1972 1973 1974 1975 1976 1977
                    for( i = 1, k = 0; i < class_count; i++ )
                    {
                        if( vote[i] > vote[k] )
                            k = i;
                    }
                    float result = returnDFVal && class_count == 2 ?
                        (float)sum : (float)(svm->class_labels.at<int>(k));
                    results->at<float>(si) = result;
                }
            }
            else
                CV_Error( CV_StsBadArg, "INTERNAL ERROR: Unknown SVM type, "
                         "the SVM structure is probably corrupted" );
        }
1978

1979 1980 1981 1982 1983
        const SVMImpl* svm;
        const Mat* samples;
        Mat* results;
        bool returnDFVal;
    };
1984

1985 1986 1987 1988 1989 1990 1991 1992 1993 1994 1995 1996 1997 1998 1999 2000 2001 2002
    bool trainAuto_(InputArray samples, int layout,
            InputArray responses, int kfold, Ptr<ParamGrid> Cgrid,
            Ptr<ParamGrid> gammaGrid, Ptr<ParamGrid> pGrid, Ptr<ParamGrid> nuGrid,
            Ptr<ParamGrid> coeffGrid, Ptr<ParamGrid> degreeGrid, bool balanced)
    {
        Ptr<TrainData> data = TrainData::create(samples, layout, responses);
        return this->trainAuto(
                data, kfold,
                *Cgrid.get(),
                *gammaGrid.get(),
                *pGrid.get(),
                *nuGrid.get(),
                *coeffGrid.get(),
                *degreeGrid.get(),
                balanced);
    }


2003
    float predict( InputArray _samples, OutputArray _results, int flags ) const CV_OVERRIDE
2004 2005 2006 2007 2008
    {
        float result = 0;
        Mat samples = _samples.getMat(), results;
        int nsamples = samples.rows;
        bool returnDFVal = (flags & RAW_OUTPUT) != 0;
2009

2010
        CV_Assert( samples.cols == var_count && samples.type() == CV_32F );
2011

2012 2013 2014 2015 2016 2017
        if( _results.needed() )
        {
            _results.create( nsamples, 1, samples.type() );
            results = _results.getMat();
        }
        else
2018
        {
2019 2020
            CV_Assert( nsamples == 1 );
            results = Mat(1, 1, CV_32F, &result);
2021
        }
2022 2023 2024 2025

        PredictBody invoker(this, samples, results, returnDFVal);
        if( nsamples < 10 )
            invoker(Range(0, nsamples));
2026
        else
2027 2028 2029
            parallel_for_(Range(0, nsamples), invoker);
        return result;
    }
2030

2031
    double getDecisionFunction(int i, OutputArray _alpha, OutputArray _svidx ) const CV_OVERRIDE
2032 2033 2034 2035 2036 2037 2038
    {
        CV_Assert( 0 <= i && i < (int)decision_func.size());
        const DecisionFunc& df = decision_func[i];
        int count = getSVCount(i);
        Mat(1, count, CV_64F, (double*)&df_alpha[df.ofs]).copyTo(_alpha);
        Mat(1, count, CV_32S, (int*)&df_index[df.ofs]).copyTo(_svidx);
        return df.rho;
2039 2040
    }

2041 2042 2043 2044
    void write_params( FileStorage& fs ) const
    {
        int svmType = params.svmType;
        int kernelType = params.kernelType;
2045

2046 2047 2048 2049 2050
        String svm_type_str =
            svmType == C_SVC ? "C_SVC" :
            svmType == NU_SVC ? "NU_SVC" :
            svmType == ONE_CLASS ? "ONE_CLASS" :
            svmType == EPS_SVR ? "EPS_SVR" :
L
luz.paz 已提交
2051
            svmType == NU_SVR ? "NU_SVR" : format("Unknown_%d", svmType);
2052 2053 2054 2055
        String kernel_type_str =
            kernelType == LINEAR ? "LINEAR" :
            kernelType == POLY ? "POLY" :
            kernelType == RBF ? "RBF" :
2056 2057 2058
            kernelType == SIGMOID ? "SIGMOID" :
            kernelType == CHI2 ? "CHI2" :
            kernelType == INTER ? "INTER" : format("Unknown_%d", kernelType);
2059

2060
        fs << "svmType" << svm_type_str;
2061

2062 2063
        // save kernel
        fs << "kernel" << "{" << "type" << kernel_type_str;
2064

2065 2066
        if( kernelType == POLY )
            fs << "degree" << params.degree;
2067

2068 2069
        if( kernelType != LINEAR )
            fs << "gamma" << params.gamma;
2070

2071 2072
        if( kernelType == POLY || kernelType == SIGMOID )
            fs << "coef0" << params.coef0;
2073

2074
        fs << "}";
2075

2076 2077
        if( svmType == C_SVC || svmType == EPS_SVR || svmType == NU_SVR )
            fs << "C" << params.C;
2078

2079 2080
        if( svmType == NU_SVC || svmType == ONE_CLASS || svmType == NU_SVR )
            fs << "nu" << params.nu;
2081

2082 2083
        if( svmType == EPS_SVR )
            fs << "p" << params.p;
2084

2085 2086 2087 2088 2089 2090 2091
        fs << "term_criteria" << "{:";
        if( params.termCrit.type & TermCriteria::EPS )
            fs << "epsilon" << params.termCrit.epsilon;
        if( params.termCrit.type & TermCriteria::COUNT )
            fs << "iterations" << params.termCrit.maxCount;
        fs << "}";
    }
2092

2093
    bool isTrained() const CV_OVERRIDE
2094
    {
2095
        return !sv.empty();
2096 2097
    }

2098
    bool isClassifier() const CV_OVERRIDE
2099 2100 2101
    {
        return params.svmType == C_SVC || params.svmType == NU_SVC || params.svmType == ONE_CLASS;
    }
2102

2103
    int getVarCount() const CV_OVERRIDE
2104 2105 2106
    {
        return var_count;
    }
2107

2108
    String getDefaultName() const CV_OVERRIDE
2109
    {
2110
        return "opencv_ml_svm";
2111 2112
    }

2113
    void write( FileStorage& fs ) const CV_OVERRIDE
2114 2115 2116 2117 2118
    {
        int class_count = !class_labels.empty() ? (int)class_labels.total() :
                          params.svmType == ONE_CLASS ? 1 : 0;
        if( !isTrained() )
            CV_Error( CV_StsParseError, "SVM model data is invalid, check sv_count, var_* and class_count tags" );
2119

2120
        writeFormat(fs);
2121
        write_params( fs );
2122

2123
        fs << "var_count" << var_count;
2124

2125 2126 2127
        if( class_count > 0 )
        {
            fs << "class_count" << class_count;
2128

2129 2130
            if( !class_labels.empty() )
                fs << "class_labels" << class_labels;
2131

2132 2133
            if( !params.classWeights.empty() )
                fs << "class_weights" << params.classWeights;
2134 2135
        }

2136 2137 2138 2139 2140
        // write the joint collection of support vectors
        int i, sv_total = sv.rows;
        fs << "sv_total" << sv_total;
        fs << "support_vectors" << "[";
        for( i = 0; i < sv_total; i++ )
2141
        {
2142 2143 2144
            fs << "[:";
            fs.writeRaw("f", sv.ptr(i), sv.cols*sv.elemSize());
            fs << "]";
2145
        }
2146
        fs << "]";
2147

2148 2149 2150 2151 2152 2153 2154 2155 2156 2157 2158 2159 2160 2161 2162
        if ( !uncompressed_sv.empty() )
        {
            // write the joint collection of uncompressed support vectors
            int uncompressed_sv_total = uncompressed_sv.rows;
            fs << "uncompressed_sv_total" << uncompressed_sv_total;
            fs << "uncompressed_support_vectors" << "[";
            for( i = 0; i < uncompressed_sv_total; i++ )
            {
                fs << "[:";
                fs.writeRaw("f", uncompressed_sv.ptr(i), uncompressed_sv.cols*uncompressed_sv.elemSize());
                fs << "]";
            }
            fs << "]";
        }

2163 2164
        // write decision functions
        int df_count = (int)decision_func.size();
2165

2166 2167
        fs << "decision_functions" << "[";
        for( i = 0; i < df_count; i++ )
2168
        {
2169 2170 2171 2172 2173 2174 2175
            const DecisionFunc& df = decision_func[i];
            int sv_count = getSVCount(i);
            fs << "{" << "sv_count" << sv_count
               << "rho" << df.rho
               << "alpha" << "[:";
            fs.writeRaw("d", (const uchar*)&df_alpha[df.ofs], sv_count*sizeof(df_alpha[0]));
            fs << "]";
2176
            if( class_count >= 2 )
2177 2178 2179 2180 2181 2182 2183 2184
            {
                fs << "index" << "[:";
                fs.writeRaw("i", (const uchar*)&df_index[df.ofs], sv_count*sizeof(df_index[0]));
                fs << "]";
            }
            else
                CV_Assert( sv_count == sv_total );
            fs << "}";
2185
        }
2186
        fs << "]";
2187 2188
    }

2189
    void read_params( const FileNode& fn )
2190
    {
2191
        SvmParams _params;
2192

2193 2194
        // check for old naming
        String svm_type_str = (String)(fn["svm_type"].empty() ? fn["svmType"] : fn["svm_type"]);
2195 2196 2197 2198 2199 2200 2201 2202
        int svmType =
            svm_type_str == "C_SVC" ? C_SVC :
            svm_type_str == "NU_SVC" ? NU_SVC :
            svm_type_str == "ONE_CLASS" ? ONE_CLASS :
            svm_type_str == "EPS_SVR" ? EPS_SVR :
            svm_type_str == "NU_SVR" ? NU_SVR : -1;

        if( svmType < 0 )
2203
            CV_Error( CV_StsParseError, "Missing or invalid SVM type" );
2204 2205 2206 2207 2208 2209 2210 2211 2212 2213

        FileNode kernel_node = fn["kernel"];
        if( kernel_node.empty() )
            CV_Error( CV_StsParseError, "SVM kernel tag is not found" );

        String kernel_type_str = (String)kernel_node["type"];
        int kernelType =
            kernel_type_str == "LINEAR" ? LINEAR :
            kernel_type_str == "POLY" ? POLY :
            kernel_type_str == "RBF" ? RBF :
2214 2215 2216
            kernel_type_str == "SIGMOID" ? SIGMOID :
            kernel_type_str == "CHI2" ? CHI2 :
            kernel_type_str == "INTER" ? INTER : CUSTOM;
2217

2218 2219
        if( kernelType == CUSTOM )
            CV_Error( CV_StsParseError, "Invalid SVM kernel type (or custom kernel)" );
2220 2221 2222 2223 2224 2225 2226 2227 2228 2229 2230 2231 2232 2233

        _params.svmType = svmType;
        _params.kernelType = kernelType;
        _params.degree = (double)kernel_node["degree"];
        _params.gamma = (double)kernel_node["gamma"];
        _params.coef0 = (double)kernel_node["coef0"];

        _params.C = (double)fn["C"];
        _params.nu = (double)fn["nu"];
        _params.p = (double)fn["p"];
        _params.classWeights = Mat();

        FileNode tcnode = fn["term_criteria"];
        if( !tcnode.empty() )
2234
        {
2235 2236 2237 2238
            _params.termCrit.epsilon = (double)tcnode["epsilon"];
            _params.termCrit.maxCount = (int)tcnode["iterations"];
            _params.termCrit.type = (_params.termCrit.epsilon > 0 ? TermCriteria::EPS : 0) +
                                   (_params.termCrit.maxCount > 0 ? TermCriteria::COUNT : 0);
2239
        }
2240 2241 2242
        else
            _params.termCrit = TermCriteria( TermCriteria::EPS + TermCriteria::COUNT, 1000, FLT_EPSILON );

2243
        setParams( _params );
2244 2245
    }

2246
    void read( const FileNode& fn ) CV_OVERRIDE
2247 2248
    {
        clear();
2249

2250 2251
        // read SVM parameters
        read_params( fn );
2252

2253 2254 2255 2256
        // and top-level data
        int i, sv_total = (int)fn["sv_total"];
        var_count = (int)fn["var_count"];
        int class_count = (int)fn["class_count"];
2257

2258 2259
        if( sv_total <= 0 || var_count <= 0 )
            CV_Error( CV_StsParseError, "SVM model data is invalid, check sv_count, var_* and class_count tags" );
2260

2261 2262 2263 2264 2265 2266
        FileNode m = fn["class_labels"];
        if( !m.empty() )
            m >> class_labels;
        m = fn["class_weights"];
        if( !m.empty() )
            m >> params.classWeights;
2267

2268 2269
        if( class_count > 1 && (class_labels.empty() || (int)class_labels.total() != class_count))
            CV_Error( CV_StsParseError, "Array of class labels is missing or invalid" );
2270

2271 2272
        // read support vectors
        FileNode sv_node = fn["support_vectors"];
2273

2274
        CV_Assert((int)sv_node.size() == sv_total);
2275

2276
        sv.create(sv_total, var_count, CV_32F);
2277 2278
        FileNodeIterator sv_it = sv_node.begin();
        for( i = 0; i < sv_total; i++, ++sv_it )
2279
        {
2280
            (*sv_it).readRaw("f", sv.ptr(i), var_count*sv.elemSize());
2281 2282
        }

2283 2284 2285 2286 2287 2288 2289 2290 2291 2292 2293 2294 2295 2296 2297 2298 2299
        int uncompressed_sv_total = (int)fn["uncompressed_sv_total"];

        if( uncompressed_sv_total > 0 )
        {
            // read uncompressed support vectors
            FileNode uncompressed_sv_node = fn["uncompressed_support_vectors"];

            CV_Assert((int)uncompressed_sv_node.size() == uncompressed_sv_total);
            uncompressed_sv.create(uncompressed_sv_total, var_count, CV_32F);

            FileNodeIterator uncompressed_sv_it = uncompressed_sv_node.begin();
            for( i = 0; i < uncompressed_sv_total; i++, ++uncompressed_sv_it )
            {
                (*uncompressed_sv_it).readRaw("f", uncompressed_sv.ptr(i), var_count*uncompressed_sv.elemSize());
            }
        }

2300 2301 2302 2303 2304
        // read decision functions
        int df_count = class_count > 1 ? class_count*(class_count-1)/2 : 1;
        FileNode df_node = fn["decision_functions"];

        CV_Assert((int)df_node.size() == df_count);
2305

2306 2307 2308 2309 2310 2311 2312 2313 2314 2315 2316 2317
        FileNodeIterator df_it = df_node.begin();
        for( i = 0; i < df_count; i++, ++df_it )
        {
            FileNode dfi = *df_it;
            DecisionFunc df;
            int sv_count = (int)dfi["sv_count"];
            int ofs = (int)df_index.size();
            df.rho = (double)dfi["rho"];
            df.ofs = ofs;
            df_index.resize(ofs + sv_count);
            df_alpha.resize(ofs + sv_count);
            dfi["alpha"].readRaw("d", (uchar*)&df_alpha[ofs], sv_count*sizeof(df_alpha[0]));
2318
            if( class_count >= 2 )
2319 2320 2321
                dfi["index"].readRaw("i", (uchar*)&df_index[ofs], sv_count*sizeof(df_index[0]));
            decision_func.push_back(df);
        }
2322
        if( class_count < 2 )
2323 2324 2325 2326 2327
            setRangeVector(df_index, sv_total);
        if( (int)fn["optimize_linear"] != 0 )
            optimize_linear_svm();
    }

2328
    SvmParams params;
2329 2330
    Mat class_labels;
    int var_count;
2331
    Mat sv, uncompressed_sv;
2332 2333 2334 2335 2336 2337
    vector<DecisionFunc> decision_func;
    vector<double> df_alpha;
    vector<int> df_index;

    Ptr<Kernel> kernel;
};
2338 2339


2340
Ptr<SVM> SVM::create()
2341
{
2342
    return makePtr<SVMImpl>();
2343 2344
}

2345
Ptr<SVM> SVM::load(const String& filepath)
2346 2347
{
    FileStorage fs;
2348
    fs.open(filepath, FileStorage::READ);
2349 2350 2351

    Ptr<SVM> svm = makePtr<SVMImpl>();

2352
    ((SVMImpl*)svm.get())->read(fs.getFirstTopLevelNode());
2353 2354 2355
    return svm;
}

2356 2357 2358 2359 2360 2361 2362 2363
Mat SVM::getUncompressedSupportVectors() const
{
    const SVMImpl* this_ = dynamic_cast<const SVMImpl*>(this);
    if(!this_)
        CV_Error(Error::StsNotImplemented, "the class is not SVMImpl");
    return this_->getUncompressedSupportVectors_();
}

2364 2365 2366 2367 2368 2369 2370 2371 2372 2373 2374 2375 2376
bool SVM::trainAuto(InputArray samples, int layout,
            InputArray responses, int kfold, Ptr<ParamGrid> Cgrid,
            Ptr<ParamGrid> gammaGrid, Ptr<ParamGrid> pGrid, Ptr<ParamGrid> nuGrid,
            Ptr<ParamGrid> coeffGrid, Ptr<ParamGrid> degreeGrid, bool balanced)
{
  SVMImpl* this_ = dynamic_cast<SVMImpl*>(this);
  if (!this_) {
    CV_Error(Error::StsNotImplemented, "the class is not SVMImpl");
  }
  return this_->trainAuto_(samples, layout, responses,
    kfold, Cgrid, gammaGrid, pGrid, nuGrid, coeffGrid, degreeGrid, balanced);
}

2377 2378
}
}
2379 2380

/* End of file. */