op_param.h 18.1 KB
Newer Older
W
wangliu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
朔-望's avatar
朔-望 已提交
14

15
#pragma once
朔-望's avatar
朔-望 已提交
16

E
eclipsess 已提交
17
#include <string>
W
wangliu 已提交
18
#include <vector>
L
liuruilong 已提交
19
#include "common/log.h"
朔-望's avatar
朔-望 已提交
20 21 22 23 24 25 26
#include "common/type_define.h"
#include "framework/lod_tensor.h"
#include "framework/scope.h"
#include "framework/tensor.h"
#include "framework/variable.h"

namespace paddle_mobile {
朔-望's avatar
朔-望 已提交
27 28
namespace operators {

W
wangliu 已提交
29 30 31 32 33 34 35
using framework::Attribute;
using framework::AttributeMap;
using framework::LoDTensor;
using framework::Scope;
using framework::Tensor;
using std::string;
using std::vector;
朔-望's avatar
朔-望 已提交
36 37

class OpParam : PaddleMobileObject {
朔-望's avatar
朔-望 已提交
38
 protected:
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
  template <typename T>
  static T *InputFrom(const VariableNameMap &inputs, const Scope &scope) {
    return GetVarValue<T>("Input", inputs, scope);
  }

  template <typename T>
  static T *InputXFrom(const VariableNameMap &inputs, const Scope &scope) {
    return GetVarValue<T>("X", inputs, scope);
  }

  template <typename T>
  static T *InputYFrom(const VariableNameMap &inputs, const Scope &scope) {
    return GetVarValue<T>("Y", inputs, scope);
  }

  template <typename T>
  static T *InputBiasFrom(const VariableNameMap &inputs, const Scope &scope) {
    return GetVarValue<T>("Bias", inputs, scope);
  }
  template <typename T>
  static T *InputVarianceFrom(const VariableNameMap &inputs,
                              const Scope &scope) {
    return GetVarValue<T>("Variance", inputs, scope);
  }
  template <typename T>
  static T *InputMeanFrom(const VariableNameMap &inputs, const Scope &scope) {
    return GetVarValue<T>("Mean", inputs, scope);
  }
  template <typename T>
  static T *InputScaleFrom(const VariableNameMap &inputs, const Scope &scope) {
    return GetVarValue<T>("Scale", inputs, scope);
  }
E
eclipsess 已提交
71 72 73 74
  template <typename T>
  static T *InputImageFrom(const VariableNameMap &inputs, const Scope &scope) {
    return GetVarValue<T>("Image", inputs, scope);
  }
E
eclipsess 已提交
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
  template <typename T>
  static T *InputPriorBoxFrom(const VariableNameMap &inputs,
                              const Scope &scope) {
    return GetVarValue<T>("PriorBox", inputs, scope);
  }
  template <typename T>
  static T *InputPriorBoxVarFrom(const VariableNameMap &inputs,
                                 const Scope &scope) {
    return GetVarValue<T>("PriorBoxVar", inputs, scope);
  }
  // LoDTensor but now use Tensor
  template <typename T>
  static T *InputTargetBoxFrom(const VariableNameMap &inputs,
                               const Scope &scope) {
    return GetVarValue<T>("TargetBox", inputs, scope);
  }
91

E
eclipsess 已提交
92 93 94 95 96 97 98 99 100 101
  template <typename T>
  static T *InputBBoxesFrom(const VariableNameMap &inputs, const Scope &scope) {
    return GetVarValue<T>("BBoxes", inputs, scope);
  }

  template <typename T>
  static T *InputScoresFrom(const VariableNameMap &inputs, const Scope &scope) {
    return GetVarValue<T>("Scores", inputs, scope);
  }

102
  template <typename T>
W
wangliu 已提交
103 104
  static vector<T *> InputMultiFrom(const VariableNameMap &inputs,
                                    const Scope &scope) {
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
    return GetMultiVarValue<T>("X", inputs, scope);
  }

  template <typename T>
  static T *OutputFrom(const VariableNameMap &outputs, const Scope &scope) {
    return GetVarValue<T>("Output", outputs, scope);
  }

  template <typename T>
  static T *OutFrom(const VariableNameMap &outputs, const Scope &scope) {
    return GetVarValue<T>("Out", outputs, scope);
  }

  template <typename T>
  static T *OutputYFrom(const VariableNameMap &outputs, const Scope &scope) {
    return GetVarValue<T>("Y", outputs, scope);
  }

E
eclipsess 已提交
123 124 125 126 127 128
  template <typename T>
  static T *OutputBoxesFrom(const VariableNameMap &outputs,
                            const Scope &scope) {
    return GetVarValue<T>("Boxes", outputs, scope);
  }

E
eclipsess 已提交
129 130 131 132 133
  template <typename T>
  static T *OutputBoxFrom(const VariableNameMap &outputs, const Scope &scope) {
    return GetVarValue<T>("OutputBox", outputs, scope);
  }

E
eclipsess 已提交
134 135 136 137 138 139
  template <typename T>
  static T *OutputVariancesFrom(const VariableNameMap &outputs,
                                const Scope &scope) {
    return GetVarValue<T>("Variances", outputs, scope);
  }

140 141 142 143 144 145 146 147 148 149 150
  template <typename T>
  static T *MidOutFrom(const VariableNameMap &outputs, const Scope &scope) {
    return GetVarValue<T>("MidOut", outputs, scope);
  }

  template <typename T>
  static T *FilterFrom(const VariableNameMap &inputs, const Scope &scope) {
    return GetVarValue<T>("Filter", inputs, scope);
  }

  template <typename T>
W
wangliu 已提交
151
  static const T GetAttr(const string &key, const AttributeMap &map) {
152 153 154 155
    return ((Attribute)map.at(key)).Get<T>();
  }

  template <typename T>
W
wangliu 已提交
156
  static T *GetVarValue(const string &key, const VariableNameMap &var_map,
157 158 159 160 161 162 163 164 165
                        const Scope &scope) {
    auto var_vec = var_map.at(key);
    if (!var_vec.empty()) {
      //      std::cout << " get var value -- " << var_vec[0] <<
      //      std::endl;
      auto var = scope.FindVar(var_vec[0]);
      return var->GetMutable<T>();
    } else {
      return nullptr;
朔-望's avatar
朔-望 已提交
166
    }
167
  }
朔-望's avatar
朔-望 已提交
168

169
  template <typename T>
W
wangliu 已提交
170 171 172
  static vector<T *> GetMultiVarValue(const string &key,
                                      const VariableNameMap &var_map,
                                      const Scope &scope) {
173 174
    auto var_vecs = var_map.at(key);
    assert(var_vecs.size() > 1);
W
wangliu 已提交
175
    vector<T *> var_res;
176 177 178
    for (auto &var_vec : var_vecs) {
      auto var = scope.FindVar(var_vec);
      var_res.push_back(var->GetMutable<T>());
朔-望's avatar
朔-望 已提交
179
    }
180 181
    return var_res;
  }
朔-望's avatar
朔-望 已提交
182 183 184
};

class ConvParam : OpParam {
朔-望's avatar
朔-望 已提交
185
 public:
186 187 188
  ConvParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
            const framework::AttributeMap &attrs,
            const framework::Scope &scope) {
W
wangliu 已提交
189 190 191 192 193 194
    filter_ = FilterFrom<LoDTensor>(inputs, scope);
    input_ = InputFrom<Tensor>(inputs, scope);
    output_ = OutputFrom<Tensor>(outputs, scope);
    strides_ = GetAttr<vector<int>>("strides", attrs);
    paddings_ = GetAttr<vector<int>>("paddings", attrs);
    dilations_ = GetAttr<vector<int>>("dilations", attrs);
195 196
    groups = GetAttr<int>("groups", attrs);
  }
朔-望's avatar
朔-望 已提交
197

198
  const Tensor *Input() const { return input_; }
朔-望's avatar
朔-望 已提交
199

200
  const LoDTensor *Filter() const { return filter_; }
朔-望's avatar
朔-望 已提交
201

202
  Tensor *Output() const { return output_; }
朔-望's avatar
朔-望 已提交
203

W
wangliu 已提交
204
  const vector<int> &Strides() const { return strides_; }
朔-望's avatar
朔-望 已提交
205

W
wangliu 已提交
206
  const vector<int> &Paddings() const { return paddings_; }
朔-望's avatar
朔-望 已提交
207

W
wangliu 已提交
208
  const vector<int> &Dilations() const { return dilations_; }
朔-望's avatar
朔-望 已提交
209

210
  const int &Groups() const { return groups; }
朔-望's avatar
朔-望 已提交
211

朔-望's avatar
朔-望 已提交
212
 private:
213 214 215
  Tensor *input_;
  Tensor *output_;
  LoDTensor *filter_;
W
wangliu 已提交
216 217 218
  vector<int> strides_;
  vector<int> paddings_;
  vector<int> dilations_;
219
  int groups;
朔-望's avatar
朔-望 已提交
220 221 222 223 224
};

Print &operator<<(Print &printer, const ConvParam &conv_param);

class ElementwiseAddParam : OpParam {
朔-望's avatar
朔-望 已提交
225
 public:
226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243
  ElementwiseAddParam(const VariableNameMap &inputs,
                      const VariableNameMap &outputs,
                      const framework::AttributeMap &attrs,
                      const framework::Scope &scope) {
    input_x_ = InputXFrom<framework::Tensor>(inputs, scope);
    input_y_ = InputYFrom<framework::Tensor>(inputs, scope);
    out_ = OutFrom<framework::Tensor>(outputs, scope);
    axis_ = GetAttr<int>("axis", attrs);
  }

  const Tensor *InputX() const { return input_x_; }

  const Tensor *InputY() const { return input_y_; }

  Tensor *Out() const { return out_; }

  const int &Axis() const { return axis_; }

朔-望's avatar
朔-望 已提交
244
 private:
245 246 247 248
  Tensor *input_x_;
  Tensor *input_y_;
  Tensor *out_;
  int axis_;
朔-望's avatar
朔-望 已提交
249 250 251
};

class MulParam : OpParam {
朔-望's avatar
朔-望 已提交
252
 public:
253 254 255 256 257 258 259 260 261
  MulParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
           const framework::AttributeMap &attrs,
           const framework::Scope &scope) {
    input_x_ = InputXFrom<framework::Tensor>(inputs, scope);
    input_y_ = InputYFrom<framework::Tensor>(inputs, scope);
    out_ = OutFrom<framework::Tensor>(outputs, scope);
    x_num_col_dims_ = GetAttr<int>("x_num_col_dims", attrs);
    y_num_col_dims_ = GetAttr<int>("y_num_col_dims", attrs);
  }
朔-望's avatar
朔-望 已提交
262

263
  const Tensor *InputX() const { return input_x_; }
朔-望's avatar
朔-望 已提交
264

265
  const Tensor *InputY() const { return input_y_; }
朔-望's avatar
朔-望 已提交
266

267
  Tensor *Out() const { return out_; }
朔-望's avatar
朔-望 已提交
268

269
  const int &XNumColDims() const { return x_num_col_dims_; }
朔-望's avatar
朔-望 已提交
270

271
  const int &YNumColDims() const { return y_num_col_dims_; }
朔-望's avatar
朔-望 已提交
272

朔-望's avatar
朔-望 已提交
273
 private:
274 275 276 277 278
  Tensor *input_x_;
  Tensor *input_y_;
  Tensor *out_;
  int x_num_col_dims_;
  int y_num_col_dims_;
朔-望's avatar
朔-望 已提交
279 280 281
};

class ConcatParam : public OpParam {
朔-望's avatar
朔-望 已提交
282
 public:
283 284 285 286 287 288 289
  ConcatParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
              const framework::AttributeMap &attrs,
              const framework::Scope &scope) {
    inputs_ = InputMultiFrom<framework::Tensor>(inputs, scope);
    out_ = OutFrom<framework::Tensor>(outputs, scope);
    axis_ = GetAttr<int>("axis", attrs);
  }
朔-望's avatar
朔-望 已提交
290

W
wangliu 已提交
291
  vector<Tensor *> Inputs() const { return inputs_; }
朔-望's avatar
朔-望 已提交
292

293
  Tensor *Out() const { return out_; }
朔-望's avatar
朔-望 已提交
294

295
  const int &Axis() const { return axis_; }
朔-望's avatar
朔-望 已提交
296

朔-望's avatar
朔-望 已提交
297
 private:
W
wangliu 已提交
298
  vector<Tensor *> inputs_;
299 300
  Tensor *out_;
  int axis_;
朔-望's avatar
朔-望 已提交
301 302
};

E
eclipsess 已提交
303
class LrnParam : public OpParam {
朔-望's avatar
朔-望 已提交
304
 public:
305 306 307 308 309 310 311 312 313 314
  LrnParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
           const framework::AttributeMap &attrs,
           const framework::Scope &scope) {
    input_x_ = InputXFrom<framework::Tensor>(inputs, scope);
    out_ = OutFrom<framework::Tensor>(outputs, scope);
    mid_out_ = MidOutFrom<framework::Tensor>(outputs, scope);
    n_ = GetAttr<int>("n", attrs);
    alpha_ = GetAttr<float>("alpha", attrs);
    beta_ = GetAttr<float>("beta", attrs);
    k_ = GetAttr<float>("k", attrs);
W
wangliu 已提交
315
    data_format_ = GetAttr<string>("data_format", attrs);
316
  }
E
eclipsess 已提交
317

318
  const Tensor *InputX() const { return input_x_; }
E
eclipsess 已提交
319

320
  Tensor *Out() const { return out_; }
E
eclipsess 已提交
321

322
  Tensor *MidOut() const { return mid_out_; }
E
eclipsess 已提交
323

324
  const int &N() const { return n_; }
E
eclipsess 已提交
325

326
  const float &Alpha() const { return alpha_; }
E
eclipsess 已提交
327

328
  const float &Beta() const { return beta_; }
E
eclipsess 已提交
329

330
  const float &K() const { return k_; }
E
eclipsess 已提交
331

W
wangliu 已提交
332
  const string &DataFormat() const { return data_format_; }
E
eclipsess 已提交
333

朔-望's avatar
朔-望 已提交
334
 private:
335 336 337 338 339 340 341
  Tensor *input_x_;
  Tensor *out_;
  Tensor *mid_out_;
  int n_;
  float alpha_;
  float beta_;
  float k_;
W
wangliu 已提交
342
  string data_format_;
E
eclipsess 已提交
343
};
E
eclipsess 已提交
344
class BatchNormParam : OpParam {
朔-望's avatar
朔-望 已提交
345
 public:
346 347 348 349 350 351 352 353 354 355 356 357 358
  BatchNormParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
                 const framework::AttributeMap &attrs,
                 const framework::Scope &scope) {
    input_x_ = InputXFrom<framework::Tensor>(inputs, scope);
    output_y_ = OutputYFrom<framework::Tensor>(outputs, scope);
    input_bias_ = InputBiasFrom<framework::Tensor>(inputs, scope);
    input_mean_ = InputMeanFrom<framework::Tensor>(inputs, scope);
    input_scale_ = InputScaleFrom<framework::Tensor>(inputs, scope);
    input_variance_ = InputVarianceFrom<framework::Tensor>(inputs, scope);
    epsilon_ = GetAttr<float>("epsilon", attrs);
    momentum_ = GetAttr<float>("momentum", attrs);
    is_test_ = GetAttr<bool>("is_test", attrs);
  }
E
eclipsess 已提交
359

360
  const Tensor *InputX() const { return input_x_; }
E
eclipsess 已提交
361

362
  Tensor *OutputY() const { return output_y_; }
E
eclipsess 已提交
363

364
  const Tensor *InputBias() const { return input_bias_; }
E
eclipsess 已提交
365

366
  const Tensor *InputMean() const { return input_mean_; }
E
eclipsess 已提交
367

368
  const Tensor *InputScale() const { return input_scale_; }
E
eclipsess 已提交
369

370
  const Tensor *InputVariance() const { return input_variance_; }
E
eclipsess 已提交
371

372
  const float &Epsilon() const { return epsilon_; }
E
eclipsess 已提交
373

374
  const float &Momentum() const { return momentum_; }
E
eclipsess 已提交
375

376
  const bool &IsTest() const { return is_test_; }
E
eclipsess 已提交
377

W
wangliu 已提交
378
  const string &DataFormat() const { return data_format_; }
E
eclipsess 已提交
379

朔-望's avatar
朔-望 已提交
380
 private:
381 382 383 384 385 386 387 388 389
  Tensor *input_x_;
  Tensor *output_y_;
  Tensor *input_bias_;
  Tensor *input_mean_;
  Tensor *input_scale_;
  Tensor *input_variance_;
  float epsilon_;
  float momentum_;
  bool is_test_;
W
wangliu 已提交
390
  string data_format_;
E
eclipsess 已提交
391
};
392
class PoolParam : public OpParam {
朔-望's avatar
朔-望 已提交
393
 public:
394 395 396 397 398 399
  PoolParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
            const framework::AttributeMap &attrs,
            const framework::Scope &scope) {
    input_ = InputXFrom<framework::Tensor>(inputs, scope);

    output_ = OutFrom<framework::Tensor>(outputs, scope);
W
wangliu 已提交
400 401 402 403
    pooling_type_ = GetAttr<string>("pooling_type", attrs);
    ksize_ = GetAttr<vector<int>>("ksize", attrs);
    strides_ = GetAttr<vector<int>>("strides", attrs);
    paddings_ = GetAttr<vector<int>>("paddings", attrs);
404 405 406
    ceil_mode_ = GetAttr<bool>("ceil_mode", attrs);
    gloabal_pooling_ = GetAttr<bool>("global_pooling", attrs);
  }
407

408
  const Tensor *Input() const { return input_; }
409

410
  Tensor *Output() const { return output_; }
411

W
wangliu 已提交
412
  const string &PoolingType() const { return pooling_type_; }
413

W
wangliu 已提交
414
  const vector<int> &Ksize() const { return ksize_; }
415

W
wangliu 已提交
416
  const vector<int> &Strides() const { return strides_; }
417

W
wangliu 已提交
418
  const vector<int> &Paddings() const { return paddings_; }
419

420
  bool isCeilMode() const { return ceil_mode_; }
421

422
  bool isGlobalPooling() const { return gloabal_pooling_; }
423

朔-望's avatar
朔-望 已提交
424
 private:
425 426
  Tensor *input_;
  Tensor *output_;
W
wangliu 已提交
427 428 429 430
  string pooling_type_;
  vector<int> ksize_;
  vector<int> strides_;
  vector<int> paddings_;
431 432
  bool ceil_mode_;
  bool gloabal_pooling_ = false;
433 434
};

E
eclipsess 已提交
435 436 437 438 439 440 441 442 443
class PriorBoxParam : public OpParam {
 public:
  PriorBoxParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
                const framework::AttributeMap &attrs,
                const framework::Scope &scope) {
    input_ = InputFrom<framework::Tensor>(inputs, scope);
    input_image_ = InputImageFrom<framework::Tensor>(inputs, scope);
    output_boxes_ = OutputBoxesFrom<framework::Tensor>(outputs, scope);
    output_variances_ = OutputVariancesFrom<framework::Tensor>(outputs, scope);
W
wangliu 已提交
444 445 446 447
    min_sizes_ = GetAttr<vector<float>>("min_sizes", attrs);
    max_sizes_ = GetAttr<vector<float>>("max_sizes", attrs);
    aspect_ratios_ = GetAttr<vector<float>>("aspect_ratios", attrs);
    variances_ = GetAttr<vector<float>>("variances", attrs);
E
eclipsess 已提交
448 449 450 451 452 453 454 455 456 457 458 459 460 461
    flip_ = GetAttr<bool>("flip", attrs);
    clip_ = GetAttr<bool>("clip", attrs);
    step_w_ = GetAttr<float>("step_w", attrs);
    step_h_ = GetAttr<float>("step_h", attrs);
    offset_ = GetAttr<float>("offset", attrs);
  }
  const Tensor *Input() const { return input_; }

  const Tensor *InputImage() const { return input_image_; }

  Tensor *OutputBoxes() const { return output_boxes_; }

  Tensor *OutputVariances() const { return output_variances_; }

W
wangliu 已提交
462
  const vector<float> &MinSizes() const { return min_sizes_; }
E
eclipsess 已提交
463

W
wangliu 已提交
464
  const vector<float> &MaxSizes() const { return max_sizes_; }
E
eclipsess 已提交
465

W
wangliu 已提交
466
  const vector<float> &AspectRatios() const { return aspect_ratios_; }
E
eclipsess 已提交
467

W
wangliu 已提交
468
  const vector<float> &Variances() const { return variances_; }
E
eclipsess 已提交
469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484

  const bool &Flip() const { return flip_; }

  const bool &Clip() const { return clip_; }

  const float &StepW() const { return step_w_; }

  const float &StepH() const { return step_h_; }

  const float &Offset() const { return offset_; }

 private:
  Tensor *input_;
  Tensor *input_image_;
  Tensor *output_boxes_;
  Tensor *output_variances_;
W
wangliu 已提交
485 486 487 488
  vector<float> min_sizes_;
  vector<float> max_sizes_;
  vector<float> aspect_ratios_;
  vector<float> variances_;
E
eclipsess 已提交
489 490 491 492 493 494
  bool flip_;
  bool clip_;
  float step_w_;
  float step_h_;
  float offset_;
};
E
eclipsess 已提交
495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523

class BoxCoderParam : public OpParam {
 public:
  BoxCoderParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
                const framework::AttributeMap &attrs,
                const framework::Scope &scope) {
    input_priorbox_ = InputPriorBoxFrom<framework::Tensor>(inputs, scope);
    input_priorboxvar_ = InputPriorBoxVarFrom<framework::Tensor>(inputs, scope);
    input_targetbox_ = InputTargetBoxFrom<framework::Tensor>(inputs, scope);
    output_box_ = OutputBoxFrom<framework::Tensor>(outputs, scope);
    code_type_ = GetAttr<std::string>("code_type", attrs);
  }
  const Tensor *InputPriorBox() const { return input_priorbox_; }

  const Tensor *InputPriorBoxVar() const { return input_priorboxvar_; }

  const Tensor *InputTargetBox() const { return input_targetbox_; }

  Tensor *OutputBox() const { return output_box_; }

  const std::string &CodeType() const { return code_type_; }

 private:
  Tensor *input_priorbox_;
  Tensor *input_priorboxvar_;
  Tensor *input_targetbox_;
  Tensor *output_box_;
  std::string code_type_;
};
W
wangliu 已提交
524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539

class SoftmaxParam : public OpParam {
 public:
  SoftmaxParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
               const framework::AttributeMap &attrs,
               const framework::Scope &scope) {
    input_x_ = InputXFrom<framework::Tensor>(inputs, scope);
    out_ = OutFrom<framework::Tensor>(outputs, scope);
  }
  const Tensor *InputX() const { return input_x_; }
  Tensor *Out() const { return out_; }

 private:
  Tensor *input_x_;
  Tensor *out_;
};
E
eclipsess 已提交
540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584
class MultiClassNMSParam : public OpParam {
 public:
  MultiClassNMSParam(const VariableNameMap &inputs,
                     const VariableNameMap &outputs, const AttributeMap &attrs,
                     const Scope &scope) {
    input_bboxes_ = InputBBoxesFrom<Tensor>(inputs, scope);
    input_scores_ = InputScoresFrom<Tensor>(inputs, scope);
    out_ = OutFrom<Tensor>(outputs, scope);
    background_label_ = GetAttr<int>("background_label", attrs);
    nms_top_k_ = GetAttr<int>("nms_top_k", attrs);
    keep_top_k_ = GetAttr<int>("keep_top_k", attrs);
    nms_threshold_ = GetAttr<float>("nms_threshold", attrs);
    nms_eta_ = GetAttr<float>("nms_eta", attrs);
    score_threshold_ = GetAttr<float>("score_threshold", attrs);
  }

  const Tensor *InputBBoxes() const { return input_bboxes_; }

  const Tensor *InputScores() const { return input_scores_; }

  Tensor *Out() const { return out_; }

  const int &BackGroundLabel() const { return background_label_; }

  const int &NMSTopK() const { return nms_top_k_; }

  const int &KeepTopK() const { return keep_top_k_; }

  const float &NMSThreshold() const { return nms_threshold_; }

  const float &NMSEta() const { return nms_eta_; }

  const float &ScoreThreshold() const { return score_threshold_; }

 private:
  Tensor *input_bboxes_;
  Tensor *input_scores_;
  Tensor *out_;
  int background_label_;
  int nms_top_k_;
  int keep_top_k_;
  float nms_threshold_;
  float nms_eta_;
  float score_threshold_;
};
W
wangliu 已提交
585

朔-望's avatar
朔-望 已提交
586 587
}  // namespace operators
}  // namespace paddle_mobile