op_param.h 32.2 KB
Newer Older
W
wangliu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
朔-望's avatar
朔-望 已提交
14

15
#pragma once
朔-望's avatar
朔-望 已提交
16

E
eclipsess 已提交
17
#include <string>
W
wangliu 已提交
18
#include <vector>
L
liuruilong 已提交
19
#include "common/log.h"
朔-望's avatar
朔-望 已提交
20 21 22 23 24 25 26
#include "common/type_define.h"
#include "framework/lod_tensor.h"
#include "framework/scope.h"
#include "framework/tensor.h"
#include "framework/variable.h"

namespace paddle_mobile {
朔-望's avatar
朔-望 已提交
27 28
namespace operators {

W
wangliu 已提交
29 30 31 32 33 34 35
using framework::Attribute;
using framework::AttributeMap;
using framework::LoDTensor;
using framework::Scope;
using framework::Tensor;
using std::string;
using std::vector;
朔-望's avatar
朔-望 已提交
36

L
liuruilong 已提交
37
class OpParam {
朔-望's avatar
朔-望 已提交
38
 protected:
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
  template <typename T>
  static T *InputFrom(const VariableNameMap &inputs, const Scope &scope) {
    return GetVarValue<T>("Input", inputs, scope);
  }

  template <typename T>
  static T *InputXFrom(const VariableNameMap &inputs, const Scope &scope) {
    return GetVarValue<T>("X", inputs, scope);
  }

  template <typename T>
  static T *InputYFrom(const VariableNameMap &inputs, const Scope &scope) {
    return GetVarValue<T>("Y", inputs, scope);
  }

E
eclipsess 已提交
54 55 56 57 58
  template <typename T>
  static T *InputZFrom(const VariableNameMap &inputs, const Scope &scope) {
    return GetVarValue<T>("Z", inputs, scope);
  }

59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
  template <typename T>
  static T *InputBiasFrom(const VariableNameMap &inputs, const Scope &scope) {
    return GetVarValue<T>("Bias", inputs, scope);
  }
  template <typename T>
  static T *InputVarianceFrom(const VariableNameMap &inputs,
                              const Scope &scope) {
    return GetVarValue<T>("Variance", inputs, scope);
  }
  template <typename T>
  static T *InputMeanFrom(const VariableNameMap &inputs, const Scope &scope) {
    return GetVarValue<T>("Mean", inputs, scope);
  }
  template <typename T>
  static T *InputScaleFrom(const VariableNameMap &inputs, const Scope &scope) {
    return GetVarValue<T>("Scale", inputs, scope);
  }
E
eclipsess 已提交
76 77 78 79
  template <typename T>
  static T *InputImageFrom(const VariableNameMap &inputs, const Scope &scope) {
    return GetVarValue<T>("Image", inputs, scope);
  }
E
eclipsess 已提交
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
  template <typename T>
  static T *InputPriorBoxFrom(const VariableNameMap &inputs,
                              const Scope &scope) {
    return GetVarValue<T>("PriorBox", inputs, scope);
  }
  template <typename T>
  static T *InputPriorBoxVarFrom(const VariableNameMap &inputs,
                                 const Scope &scope) {
    return GetVarValue<T>("PriorBoxVar", inputs, scope);
  }
  // LoDTensor but now use Tensor
  template <typename T>
  static T *InputTargetBoxFrom(const VariableNameMap &inputs,
                               const Scope &scope) {
    return GetVarValue<T>("TargetBox", inputs, scope);
  }
96

E
eclipsess 已提交
97 98 99 100 101 102 103 104 105 106
  template <typename T>
  static T *InputBBoxesFrom(const VariableNameMap &inputs, const Scope &scope) {
    return GetVarValue<T>("BBoxes", inputs, scope);
  }

  template <typename T>
  static T *InputScoresFrom(const VariableNameMap &inputs, const Scope &scope) {
    return GetVarValue<T>("Scores", inputs, scope);
  }

E
eclipsess 已提交
107 108 109 110
  template <typename T>
  static T *InputShapeFrom(const VariableNameMap &inputs, const Scope &scope) {
    return GetVarValue<T>("Shape", inputs, scope);
  }
E
eclipsess 已提交
111

112
  template <typename T>
W
wangliu 已提交
113 114
  static vector<T *> InputMultiFrom(const VariableNameMap &inputs,
                                    const Scope &scope) {
115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
    return GetMultiVarValue<T>("X", inputs, scope);
  }

  template <typename T>
  static T *OutputFrom(const VariableNameMap &outputs, const Scope &scope) {
    return GetVarValue<T>("Output", outputs, scope);
  }

  template <typename T>
  static T *OutFrom(const VariableNameMap &outputs, const Scope &scope) {
    return GetVarValue<T>("Out", outputs, scope);
  }

  template <typename T>
  static T *OutputYFrom(const VariableNameMap &outputs, const Scope &scope) {
    return GetVarValue<T>("Y", outputs, scope);
  }

E
eclipsess 已提交
133 134 135 136 137 138
  template <typename T>
  static T *OutputBoxesFrom(const VariableNameMap &outputs,
                            const Scope &scope) {
    return GetVarValue<T>("Boxes", outputs, scope);
  }

E
eclipsess 已提交
139 140 141 142 143
  template <typename T>
  static T *OutputBoxFrom(const VariableNameMap &outputs, const Scope &scope) {
    return GetVarValue<T>("OutputBox", outputs, scope);
  }

E
eclipsess 已提交
144 145 146 147 148 149
  template <typename T>
  static T *OutputVariancesFrom(const VariableNameMap &outputs,
                                const Scope &scope) {
    return GetVarValue<T>("Variances", outputs, scope);
  }

150 151 152 153 154 155 156 157 158 159 160
  template <typename T>
  static T *MidOutFrom(const VariableNameMap &outputs, const Scope &scope) {
    return GetVarValue<T>("MidOut", outputs, scope);
  }

  template <typename T>
  static T *FilterFrom(const VariableNameMap &inputs, const Scope &scope) {
    return GetVarValue<T>("Filter", inputs, scope);
  }

  template <typename T>
W
wangliu 已提交
161
  static const T GetAttr(const string &key, const AttributeMap &map) {
162 163 164 165
    return ((Attribute)map.at(key)).Get<T>();
  }

  template <typename T>
W
wangliu 已提交
166
  static T *GetVarValue(const string &key, const VariableNameMap &var_map,
167
                        const Scope &scope) {
W
wangliu 已提交
168 169
    PADDLE_MOBILE_ENFORCE(var_map.count(key) > 0,
                          "%s is not contained in var_map", key.c_str())
170 171 172 173 174 175
    auto var_vec = var_map.at(key);
    if (!var_vec.empty()) {
      auto var = scope.FindVar(var_vec[0]);
      return var->GetMutable<T>();
    } else {
      return nullptr;
朔-望's avatar
朔-望 已提交
176
    }
177
  }
朔-望's avatar
朔-望 已提交
178

179
  template <typename T>
W
wangliu 已提交
180 181 182
  static vector<T *> GetMultiVarValue(const string &key,
                                      const VariableNameMap &var_map,
                                      const Scope &scope) {
183 184
    auto var_vecs = var_map.at(key);
    assert(var_vecs.size() > 1);
W
wangliu 已提交
185
    vector<T *> var_res;
186 187 188
    for (auto &var_vec : var_vecs) {
      auto var = scope.FindVar(var_vec);
      var_res.push_back(var->GetMutable<T>());
朔-望's avatar
朔-望 已提交
189
    }
190 191
    return var_res;
  }
朔-望's avatar
朔-望 已提交
192 193
};

L
liuruilong 已提交
194
#ifdef CONV_OP
朔-望's avatar
朔-望 已提交
195
class ConvParam : OpParam {
朔-望's avatar
朔-望 已提交
196
 public:
197
  ConvParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
198
            const AttributeMap &attrs, const Scope &scope) {
W
wangliu 已提交
199
    filter_ = FilterFrom<LoDTensor>(inputs, scope);
W
wangliu 已提交
200 201
    input_ = InputFrom<LoDTensor>(inputs, scope);
    output_ = OutputFrom<LoDTensor>(outputs, scope);
W
wangliu 已提交
202 203 204
    strides_ = GetAttr<vector<int>>("strides", attrs);
    paddings_ = GetAttr<vector<int>>("paddings", attrs);
    dilations_ = GetAttr<vector<int>>("dilations", attrs);
205 206
    groups = GetAttr<int>("groups", attrs);
  }
朔-望's avatar
朔-望 已提交
207

208
  const Tensor *Input() const { return input_; }
朔-望's avatar
朔-望 已提交
209

E
eclipsess 已提交
210
  const Tensor *Filter() const { return filter_; }
朔-望's avatar
朔-望 已提交
211

212
  Tensor *Output() const { return output_; }
朔-望's avatar
朔-望 已提交
213

W
wangliu 已提交
214
  const vector<int> &Strides() const { return strides_; }
朔-望's avatar
朔-望 已提交
215

W
wangliu 已提交
216
  const vector<int> &Paddings() const { return paddings_; }
朔-望's avatar
朔-望 已提交
217

W
wangliu 已提交
218
  const vector<int> &Dilations() const { return dilations_; }
朔-望's avatar
朔-望 已提交
219

220
  const int &Groups() const { return groups; }
朔-望's avatar
朔-望 已提交
221

朔-望's avatar
朔-望 已提交
222
 private:
223 224
  Tensor *input_;
  Tensor *output_;
E
eclipsess 已提交
225
  Tensor *filter_;
W
wangliu 已提交
226 227 228
  vector<int> strides_;
  vector<int> paddings_;
  vector<int> dilations_;
229
  int groups;
朔-望's avatar
朔-望 已提交
230 231 232
};

Print &operator<<(Print &printer, const ConvParam &conv_param);
L
liuruilong 已提交
233
#endif
朔-望's avatar
朔-望 已提交
234

L
liuruilong 已提交
235
#ifdef ELEMENTWISEADD_OP
朔-望's avatar
朔-望 已提交
236
class ElementwiseAddParam : OpParam {
朔-望's avatar
朔-望 已提交
237
 public:
238
  ElementwiseAddParam(const VariableNameMap &inputs,
239 240 241 242 243
                      const VariableNameMap &outputs, const AttributeMap &attrs,
                      const Scope &scope) {
    input_x_ = InputXFrom<LoDTensor>(inputs, scope);
    input_y_ = InputYFrom<LoDTensor>(inputs, scope);
    out_ = OutFrom<LoDTensor>(outputs, scope);
244 245 246 247 248 249 250 251 252 253 254
    axis_ = GetAttr<int>("axis", attrs);
  }

  const Tensor *InputX() const { return input_x_; }

  const Tensor *InputY() const { return input_y_; }

  Tensor *Out() const { return out_; }

  const int &Axis() const { return axis_; }

朔-望's avatar
朔-望 已提交
255
 private:
256 257 258 259
  Tensor *input_x_;
  Tensor *input_y_;
  Tensor *out_;
  int axis_;
朔-望's avatar
朔-望 已提交
260 261
};

L
liuruilong 已提交
262 263 264
#endif

#ifdef MUL_OP
朔-望's avatar
朔-望 已提交
265
class MulParam : OpParam {
朔-望's avatar
朔-望 已提交
266
 public:
267
  MulParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
268 269 270 271
           const AttributeMap &attrs, const Scope &scope) {
    input_x_ = InputXFrom<LoDTensor>(inputs, scope);
    input_y_ = InputYFrom<LoDTensor>(inputs, scope);
    out_ = OutFrom<LoDTensor>(outputs, scope);
272 273 274
    x_num_col_dims_ = GetAttr<int>("x_num_col_dims", attrs);
    y_num_col_dims_ = GetAttr<int>("y_num_col_dims", attrs);
  }
朔-望's avatar
朔-望 已提交
275

276
  const Tensor *InputX() const { return input_x_; }
朔-望's avatar
朔-望 已提交
277

278
  const Tensor *InputY() const { return input_y_; }
朔-望's avatar
朔-望 已提交
279

280
  Tensor *Out() const { return out_; }
朔-望's avatar
朔-望 已提交
281

282
  const int &XNumColDims() const { return x_num_col_dims_; }
朔-望's avatar
朔-望 已提交
283

284
  const int &YNumColDims() const { return y_num_col_dims_; }
朔-望's avatar
朔-望 已提交
285

朔-望's avatar
朔-望 已提交
286
 private:
287 288 289 290 291
  Tensor *input_x_;
  Tensor *input_y_;
  Tensor *out_;
  int x_num_col_dims_;
  int y_num_col_dims_;
朔-望's avatar
朔-望 已提交
292
};
L
liuruilong 已提交
293
#endif
朔-望's avatar
朔-望 已提交
294

L
liuruilong 已提交
295
#ifdef CONCAT_OP
朔-望's avatar
朔-望 已提交
296
class ConcatParam : public OpParam {
朔-望's avatar
朔-望 已提交
297
 public:
298
  ConcatParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
299
              const AttributeMap &attrs, const Scope &scope) {
W
wangliu 已提交
300
    inputs_ = InputMultiFrom<LoDTensor>(inputs, scope);
301
    out_ = OutFrom<LoDTensor>(outputs, scope);
302 303
    axis_ = GetAttr<int>("axis", attrs);
  }
朔-望's avatar
朔-望 已提交
304

W
wangliu 已提交
305
  vector<LoDTensor *> Inputs() const { return inputs_; }
朔-望's avatar
朔-望 已提交
306

307
  Tensor *Out() const { return out_; }
朔-望's avatar
朔-望 已提交
308

309
  const int &Axis() const { return axis_; }
朔-望's avatar
朔-望 已提交
310

朔-望's avatar
朔-望 已提交
311
 private:
W
wangliu 已提交
312
  vector<LoDTensor *> inputs_;
313 314
  Tensor *out_;
  int axis_;
朔-望's avatar
朔-望 已提交
315
};
L
liuruilong 已提交
316
#endif
朔-望's avatar
朔-望 已提交
317

L
liuruilong 已提交
318
#ifdef LRN_OP
E
eclipsess 已提交
319
class LrnParam : public OpParam {
朔-望's avatar
朔-望 已提交
320
 public:
321
  LrnParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
322 323 324 325
           const AttributeMap &attrs, const Scope &scope) {
    input_x_ = InputXFrom<LoDTensor>(inputs, scope);
    out_ = OutFrom<LoDTensor>(outputs, scope);
    mid_out_ = MidOutFrom<LoDTensor>(outputs, scope);
326 327 328 329
    n_ = GetAttr<int>("n", attrs);
    alpha_ = GetAttr<float>("alpha", attrs);
    beta_ = GetAttr<float>("beta", attrs);
    k_ = GetAttr<float>("k", attrs);
W
wangliu 已提交
330
    data_format_ = GetAttr<string>("data_format", attrs);
331
  }
E
eclipsess 已提交
332

333
  const Tensor *InputX() const { return input_x_; }
E
eclipsess 已提交
334

335
  Tensor *Out() const { return out_; }
E
eclipsess 已提交
336

337
  Tensor *MidOut() const { return mid_out_; }
E
eclipsess 已提交
338

339
  const int &N() const { return n_; }
E
eclipsess 已提交
340

341
  const float &Alpha() const { return alpha_; }
E
eclipsess 已提交
342

343
  const float &Beta() const { return beta_; }
E
eclipsess 已提交
344

345
  const float &K() const { return k_; }
E
eclipsess 已提交
346

W
wangliu 已提交
347
  const string &DataFormat() const { return data_format_; }
E
eclipsess 已提交
348

朔-望's avatar
朔-望 已提交
349
 private:
350 351 352 353 354 355 356
  Tensor *input_x_;
  Tensor *out_;
  Tensor *mid_out_;
  int n_;
  float alpha_;
  float beta_;
  float k_;
W
wangliu 已提交
357
  string data_format_;
E
eclipsess 已提交
358
};
L
liuruilong 已提交
359 360 361
#endif

#ifdef BATCHNORM_OP
E
eclipsess 已提交
362
class BatchNormParam : OpParam {
朔-望's avatar
朔-望 已提交
363
 public:
364
  BatchNormParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
365 366 367 368 369 370 371
                 const AttributeMap &attrs, const Scope &scope) {
    input_x_ = InputXFrom<LoDTensor>(inputs, scope);
    output_y_ = OutputYFrom<LoDTensor>(outputs, scope);
    input_bias_ = InputBiasFrom<LoDTensor>(inputs, scope);
    input_mean_ = InputMeanFrom<LoDTensor>(inputs, scope);
    input_scale_ = InputScaleFrom<LoDTensor>(inputs, scope);
    input_variance_ = InputVarianceFrom<LoDTensor>(inputs, scope);
372 373
    epsilon_ = GetAttr<float>("epsilon", attrs);
    momentum_ = GetAttr<float>("momentum", attrs);
L
liuruilong 已提交
374
//    is_test_ = GetAttr<bool>("is_test", attrs);
375
  }
E
eclipsess 已提交
376

377
  const Tensor *InputX() const { return input_x_; }
E
eclipsess 已提交
378

379
  Tensor *OutputY() const { return output_y_; }
E
eclipsess 已提交
380

381
  const Tensor *InputBias() const { return input_bias_; }
E
eclipsess 已提交
382

383
  const Tensor *InputMean() const { return input_mean_; }
E
eclipsess 已提交
384

385
  const Tensor *InputScale() const { return input_scale_; }
E
eclipsess 已提交
386

387
  const Tensor *InputVariance() const { return input_variance_; }
E
eclipsess 已提交
388

389
  const float &Epsilon() const { return epsilon_; }
E
eclipsess 已提交
390

391
  const float &Momentum() const { return momentum_; }
E
eclipsess 已提交
392

393
  const bool &IsTest() const { return is_test_; }
E
eclipsess 已提交
394

W
wangliu 已提交
395
  const string &DataFormat() const { return data_format_; }
E
eclipsess 已提交
396

朔-望's avatar
朔-望 已提交
397
 private:
398 399 400 401 402 403 404 405 406
  Tensor *input_x_;
  Tensor *output_y_;
  Tensor *input_bias_;
  Tensor *input_mean_;
  Tensor *input_scale_;
  Tensor *input_variance_;
  float epsilon_;
  float momentum_;
  bool is_test_;
W
wangliu 已提交
407
  string data_format_;
E
eclipsess 已提交
408
};
L
liuruilong 已提交
409 410 411
#endif

#ifdef POOL_OP
412
class PoolParam : public OpParam {
朔-望's avatar
朔-望 已提交
413
 public:
414
  PoolParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
415 416
            const AttributeMap &attrs, const Scope &scope) {
    input_ = InputXFrom<LoDTensor>(inputs, scope);
417

418
    output_ = OutFrom<LoDTensor>(outputs, scope);
W
wangliu 已提交
419 420 421 422
    pooling_type_ = GetAttr<string>("pooling_type", attrs);
    ksize_ = GetAttr<vector<int>>("ksize", attrs);
    strides_ = GetAttr<vector<int>>("strides", attrs);
    paddings_ = GetAttr<vector<int>>("paddings", attrs);
423 424 425
    ceil_mode_ = GetAttr<bool>("ceil_mode", attrs);
    gloabal_pooling_ = GetAttr<bool>("global_pooling", attrs);
  }
426

427
  const Tensor *Input() const { return input_; }
428

429
  Tensor *Output() const { return output_; }
430

W
wangliu 已提交
431
  const string &PoolingType() const { return pooling_type_; }
432

W
wangliu 已提交
433
  const vector<int> &Ksize() const { return ksize_; }
434

W
wangliu 已提交
435
  const vector<int> &Strides() const { return strides_; }
436

W
wangliu 已提交
437
  const vector<int> &Paddings() const { return paddings_; }
438

439
  bool isCeilMode() const { return ceil_mode_; }
440

441
  bool isGlobalPooling() const { return gloabal_pooling_; }
442

朔-望's avatar
朔-望 已提交
443
 private:
444 445
  Tensor *input_;
  Tensor *output_;
W
wangliu 已提交
446 447 448 449
  string pooling_type_;
  vector<int> ksize_;
  vector<int> strides_;
  vector<int> paddings_;
450 451
  bool ceil_mode_;
  bool gloabal_pooling_ = false;
452 453
};

L
liuruilong 已提交
454 455 456
#endif

#ifdef PRIORBOX_OP
E
eclipsess 已提交
457 458 459
class PriorBoxParam : public OpParam {
 public:
  PriorBoxParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
460 461 462 463 464
                const AttributeMap &attrs, const Scope &scope) {
    input_ = InputFrom<LoDTensor>(inputs, scope);
    input_image_ = InputImageFrom<LoDTensor>(inputs, scope);
    output_boxes_ = OutputBoxesFrom<LoDTensor>(outputs, scope);
    output_variances_ = OutputVariancesFrom<LoDTensor>(outputs, scope);
W
wangliu 已提交
465 466 467 468
    min_sizes_ = GetAttr<vector<float>>("min_sizes", attrs);
    max_sizes_ = GetAttr<vector<float>>("max_sizes", attrs);
    aspect_ratios_ = GetAttr<vector<float>>("aspect_ratios", attrs);
    variances_ = GetAttr<vector<float>>("variances", attrs);
E
eclipsess 已提交
469 470 471 472 473 474 475 476 477 478 479 480 481 482
    flip_ = GetAttr<bool>("flip", attrs);
    clip_ = GetAttr<bool>("clip", attrs);
    step_w_ = GetAttr<float>("step_w", attrs);
    step_h_ = GetAttr<float>("step_h", attrs);
    offset_ = GetAttr<float>("offset", attrs);
  }
  const Tensor *Input() const { return input_; }

  const Tensor *InputImage() const { return input_image_; }

  Tensor *OutputBoxes() const { return output_boxes_; }

  Tensor *OutputVariances() const { return output_variances_; }

W
wangliu 已提交
483
  const vector<float> &MinSizes() const { return min_sizes_; }
E
eclipsess 已提交
484

W
wangliu 已提交
485
  const vector<float> &MaxSizes() const { return max_sizes_; }
E
eclipsess 已提交
486

W
wangliu 已提交
487
  const vector<float> &AspectRatios() const { return aspect_ratios_; }
E
eclipsess 已提交
488

W
wangliu 已提交
489
  const vector<float> &Variances() const { return variances_; }
E
eclipsess 已提交
490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505

  const bool &Flip() const { return flip_; }

  const bool &Clip() const { return clip_; }

  const float &StepW() const { return step_w_; }

  const float &StepH() const { return step_h_; }

  const float &Offset() const { return offset_; }

 private:
  Tensor *input_;
  Tensor *input_image_;
  Tensor *output_boxes_;
  Tensor *output_variances_;
W
wangliu 已提交
506 507 508 509
  vector<float> min_sizes_;
  vector<float> max_sizes_;
  vector<float> aspect_ratios_;
  vector<float> variances_;
E
eclipsess 已提交
510 511 512 513 514 515
  bool flip_;
  bool clip_;
  float step_w_;
  float step_h_;
  float offset_;
};
L
liuruilong 已提交
516
#endif
E
eclipsess 已提交
517

L
liuruilong 已提交
518
#ifdef BOXCODER_OP
E
eclipsess 已提交
519 520 521
class BoxCoderParam : public OpParam {
 public:
  BoxCoderParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
522 523 524 525 526
                const AttributeMap &attrs, const Scope &scope) {
    input_priorbox_ = InputPriorBoxFrom<LoDTensor>(inputs, scope);
    input_priorboxvar_ = InputPriorBoxVarFrom<LoDTensor>(inputs, scope);
    input_targetbox_ = InputTargetBoxFrom<LoDTensor>(inputs, scope);
    output_box_ = OutputBoxFrom<LoDTensor>(outputs, scope);
E
eclipsess 已提交
527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545
    code_type_ = GetAttr<std::string>("code_type", attrs);
  }
  const Tensor *InputPriorBox() const { return input_priorbox_; }

  const Tensor *InputPriorBoxVar() const { return input_priorboxvar_; }

  const Tensor *InputTargetBox() const { return input_targetbox_; }

  Tensor *OutputBox() const { return output_box_; }

  const std::string &CodeType() const { return code_type_; }

 private:
  Tensor *input_priorbox_;
  Tensor *input_priorboxvar_;
  Tensor *input_targetbox_;
  Tensor *output_box_;
  std::string code_type_;
};
L
liuruilong 已提交
546
#endif
W
wangliu 已提交
547

L
liuruilong 已提交
548
#ifdef SOFTMAX_OP
W
wangliu 已提交
549 550 551
class SoftmaxParam : public OpParam {
 public:
  SoftmaxParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
552 553 554
               const AttributeMap &attrs, const Scope &scope) {
    input_x_ = InputXFrom<LoDTensor>(inputs, scope);
    out_ = OutFrom<LoDTensor>(outputs, scope);
W
wangliu 已提交
555 556 557 558 559 560 561 562
  }
  const Tensor *InputX() const { return input_x_; }
  Tensor *Out() const { return out_; }

 private:
  Tensor *input_x_;
  Tensor *out_;
};
L
liuruilong 已提交
563
#endif
W
wangliu 已提交
564

L
liuruilong 已提交
565
#ifdef SIGMOID_OP
W
wangliu 已提交
566 567 568
class SigmoidParam : public OpParam {
 public:
  SigmoidParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
569 570 571
               const AttributeMap &attrs, const Scope &scope) {
    input_x_ = InputXFrom<LoDTensor>(inputs, scope);
    out_ = OutFrom<LoDTensor>(outputs, scope);
W
wangliu 已提交
572 573 574 575 576 577 578 579
  }
  const Tensor *InputX() const { return input_x_; }
  Tensor *Out() const { return out_; }

 private:
  Tensor *input_x_;
  Tensor *out_;
};
L
liuruilong 已提交
580 581 582
#endif

#ifdef MULTICLASSNMS_OP
E
eclipsess 已提交
583 584 585 586 587
class MultiClassNMSParam : public OpParam {
 public:
  MultiClassNMSParam(const VariableNameMap &inputs,
                     const VariableNameMap &outputs, const AttributeMap &attrs,
                     const Scope &scope) {
W
wangliu 已提交
588 589 590
    input_bboxes_ = InputBBoxesFrom<LoDTensor>(inputs, scope);
    input_scores_ = InputScoresFrom<LoDTensor>(inputs, scope);
    out_ = OutFrom<LoDTensor>(outputs, scope);
E
eclipsess 已提交
591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627
    background_label_ = GetAttr<int>("background_label", attrs);
    nms_top_k_ = GetAttr<int>("nms_top_k", attrs);
    keep_top_k_ = GetAttr<int>("keep_top_k", attrs);
    nms_threshold_ = GetAttr<float>("nms_threshold", attrs);
    nms_eta_ = GetAttr<float>("nms_eta", attrs);
    score_threshold_ = GetAttr<float>("score_threshold", attrs);
  }

  const Tensor *InputBBoxes() const { return input_bboxes_; }

  const Tensor *InputScores() const { return input_scores_; }

  Tensor *Out() const { return out_; }

  const int &BackGroundLabel() const { return background_label_; }

  const int &NMSTopK() const { return nms_top_k_; }

  const int &KeepTopK() const { return keep_top_k_; }

  const float &NMSThreshold() const { return nms_threshold_; }

  const float &NMSEta() const { return nms_eta_; }

  const float &ScoreThreshold() const { return score_threshold_; }

 private:
  Tensor *input_bboxes_;
  Tensor *input_scores_;
  Tensor *out_;
  int background_label_;
  int nms_top_k_;
  int keep_top_k_;
  float nms_threshold_;
  float nms_eta_;
  float score_threshold_;
};
L
liuruilong 已提交
628
#endif
W
wangliu 已提交
629

L
liuruilong 已提交
630 631 632
class FeedParam : public OpParam {
 public:
  FeedParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
633 634 635
            const AttributeMap &attrs, Scope &scope) {
    input_x_ = InputXFrom<LoDTensor>(inputs, scope);
    out_ = OutFrom<LoDTensor>(outputs, scope);
W
wangliu 已提交
636 637
    auto var = scope.Var("batch_size");
    batch_size = var->GetValue<int>();
L
liuruilong 已提交
638 639 640
  }
  const Tensor *InputX() const { return input_x_; }
  Tensor *Out() const { return out_; }
W
wangliu 已提交
641
  const int BatchSize() const { return batch_size; }
L
liuruilong 已提交
642

L
liuruilong 已提交
643 644 645
 private:
  Tensor *input_x_;
  Tensor *out_;
W
wangliu 已提交
646
  int batch_size;
L
liuruilong 已提交
647 648 649 650 651
};

class FetchParam : public OpParam {
 public:
  FetchParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
652 653 654
             const AttributeMap &attrs, const Scope &scope) {
    input_x_ = InputXFrom<LoDTensor>(inputs, scope);
    out_ = OutFrom<LoDTensor>(outputs, scope);
L
liuruilong 已提交
655 656 657
  }
  const Tensor *InputX() const { return input_x_; }
  Tensor *Out() const { return out_; }
L
liuruilong 已提交
658

L
liuruilong 已提交
659 660 661 662 663
 private:
  Tensor *input_x_;
  Tensor *out_;
};

L
liuruilong 已提交
664
#ifdef TRANSPOSE_OP
E
eclipsess 已提交
665 666 667 668
class TransposeParam : public OpParam {
 public:
  TransposeParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
                 const AttributeMap &attrs, const Scope &scope) {
W
wangliu 已提交
669 670
    input_x_ = InputXFrom<LoDTensor>(inputs, scope);
    out_ = OutFrom<LoDTensor>(outputs, scope);
E
eclipsess 已提交
671 672 673 674 675 676 677 678 679 680 681 682 683 684
    axis_ = GetAttr<vector<int>>("axis", attrs);
  }

  const Tensor *InputX() const { return input_x_; }

  Tensor *Out() const { return out_; }

  const vector<int> &Axis() const { return axis_; }

 private:
  Tensor *input_x_;
  Tensor *out_;
  vector<int> axis_;
};
L
liuruilong 已提交
685
#endif
E
eclipsess 已提交
686

L
liuruilong 已提交
687
#ifdef RESHAPE_OP
E
eclipsess 已提交
688 689 690 691
class ReshapeParam : public OpParam {
 public:
  ReshapeParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
               const AttributeMap &attrs, const Scope &scope) {
W
wangliu 已提交
692 693 694
    input_x_ = InputXFrom<LoDTensor>(inputs, scope);
    input_shape_ = InputShapeFrom<LoDTensor>(inputs, scope);
    out_ = OutFrom<LoDTensor>(outputs, scope);
E
eclipsess 已提交
695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715
    shape_ = GetAttr<vector<int>>("shape", attrs);
    inplace_ = GetAttr<bool>("inplace", attrs);
  }

  const Tensor *InputX() const { return input_x_; }

  const Tensor *InputShape() const { return input_shape_; }

  Tensor *Out() const { return out_; }

  const vector<int> &Shape() const { return shape_; }

  const bool &Inplace() const { return inplace_; }

 private:
  Tensor *input_x_;
  Tensor *input_shape_;
  Tensor *out_;
  vector<int> shape_;
  bool inplace_;
};
L
liuruilong 已提交
716
#endif
E
eclipsess 已提交
717

T
Tian 已提交
718
#ifdef SCALE_OP
I
itminner 已提交
719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754
class ScaleParam : public OpParam {
 public:
  ScaleParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
             const AttributeMap &attrs, const Scope &scope) {
    input_x_ = InputXFrom<LoDTensor>(inputs, scope);
    input_bias_ = InputBiasFrom<framework::LoDTensor>(inputs, scope);
    out_ = OutFrom<LoDTensor>(outputs, scope);
    inplace_ = GetAttr<bool>("inplace", attrs);
    has_bias_ = GetAttr<bool>("has_bias", attrs);
    scales_ = GetAttr<vector<float>>("scales", attrs);
    biases_ = GetAttr<vector<float>>("biases", attrs);
  }

  const Tensor *InputX() const { return input_x_; }

  const Tensor *InputBias() const { return input_bias_; }

  Tensor *Out() const { return out_; }

  const bool &Inplace() const { return inplace_; }

  const bool &HasBias() const { return has_bias_; }

  const vector<float> &Scales() const { return scales_; }

  const vector<float> &Biases() const { return biases_; }

 private:
  Tensor *input_x_;
  Tensor *input_bias_;
  Tensor *out_;
  bool inplace_;
  bool has_bias_;
  vector<float> scales_;
  vector<float> biases_;
};
T
Tian 已提交
755 756 757
#endif

#ifdef SLICE_OP
I
itminner 已提交
758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789
class SliceParam : public OpParam {
 public:
  SliceParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
             const AttributeMap &attrs, const Scope &scope) {
    input_x_ = InputXFrom<LoDTensor>(inputs, scope);
    input_shape_ = InputShapeFrom<LoDTensor>(inputs, scope);
    out_ = OutFrom<LoDTensor>(outputs, scope);
    axis_ = GetAttr<int>("axis", attrs);
    slice_points_ = GetAttr<vector<int>>("slice_points", attrs);
    inplace_ = GetAttr<bool>("inplace", attrs);
  }

  const Tensor *InputX() const { return input_x_; }

  const Tensor *InputShape() const { return input_shape_; }

  Tensor *Out() const { return out_; }

  const int &Axis() const { return axis_; }

  const vector<int> &SlicePoints() const { return slice_points_; }

  const bool &Inplace() const { return inplace_; }

 private:
  Tensor *input_x_;
  Tensor *input_shape_;
  Tensor *out_;
  int axis_;
  vector<int> slice_points_;
  bool inplace_;
};
T
Tian 已提交
790 791 792 793
#endif

#ifdef RESIZE_OP
class ResizeParam : public OpParam {
I
itminner 已提交
794 795 796 797 798 799 800 801 802 803 804 805
 public:
  ResizeParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
              const AttributeMap &attrs, const Scope &scope) {
    input_x_ = InputXFrom<LoDTensor>(inputs, scope);
    input_shape_ = InputShapeFrom<LoDTensor>(inputs, scope);
    out_ = OutFrom<LoDTensor>(outputs, scope);
    is_pyramid_test_ = GetAttr<bool>("is_pyramid_test", attrs);
    height_ = GetAttr<int>("height", attrs);
    width_ = GetAttr<int>("width", attrs);
    out_height_scale_ = GetAttr<float>("out_height_scale", attrs);
    out_width_scale_ = GetAttr<float>("out_width_scale", attrs);
  }
T
Tian 已提交
806

I
itminner 已提交
807
  const Tensor *InputX() const { return input_x_; }
T
Tian 已提交
808

I
itminner 已提交
809
  const Tensor *InputShape() const { return input_shape_; }
T
Tian 已提交
810

I
itminner 已提交
811
  Tensor *Out() const { return out_; }
T
Tian 已提交
812

I
itminner 已提交
813
  const bool &IsPyramidTest() const { return is_pyramid_test_; }
T
Tian 已提交
814

I
itminner 已提交
815
  const int &Height() const { return height_; }
T
Tian 已提交
816

I
itminner 已提交
817
  const int &Width() const { return width_; }
T
Tian 已提交
818

I
itminner 已提交
819
  const float &OutHeightScale() const { return out_height_scale_; }
T
Tian 已提交
820

I
itminner 已提交
821
  const float &OutWidthScale() const { return out_width_scale_; }
T
Tian 已提交
822

I
itminner 已提交
823 824 825 826 827 828 829 830 831
 private:
  Tensor *input_x_;
  Tensor *input_shape_;
  Tensor *out_;
  bool is_pyramid_test_;
  int height_;
  int width_;
  float out_height_scale_;
  float out_width_scale_;
T
Tian 已提交
832 833 834
};
#endif

L
liuruilong 已提交
835
#ifdef RELU_OP
L
liuruilong 已提交
836 837 838
/*
 * @b op 层实例化好这个 param 传递给 kernel 层使用
 * */
E
eclipsess 已提交
839 840 841 842
class ReluParam : public OpParam {
 public:
  ReluParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
            const AttributeMap &attrs, const Scope &scope) {
W
wangliu 已提交
843 844
    input_x_ = InputXFrom<LoDTensor>(inputs, scope);
    out_ = OutFrom<LoDTensor>(outputs, scope);
E
eclipsess 已提交
845 846 847 848 849 850 851 852 853 854
  }

  const Tensor *InputX() const { return input_x_; }

  Tensor *Out() const { return out_; }

 private:
  Tensor *input_x_;
  Tensor *out_;
};
L
liuruilong 已提交
855
#endif
E
eclipsess 已提交
856

T
Tian 已提交
857 858
#ifdef PRELU_OP
class PReluParam : public OpParam {
I
itminner 已提交
859 860 861 862 863 864 865
 public:
  PReluParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
             const AttributeMap &attrs, const Scope &scope) {
    input_x_ = InputXFrom<LoDTensor>(inputs, scope);
    out_ = OutFrom<LoDTensor>(outputs, scope);
    slopes_ = GetAttr<vector<float>>("slopes", attrs);
  }
T
Tian 已提交
866

I
itminner 已提交
867 868 869
  const Tensor *InputX() const { return input_x_; }
  Tensor *Out() const { return out_; }
  const vector<float> &Slopes() const { return slopes_; }
T
Tian 已提交
870

I
itminner 已提交
871 872 873 874
 private:
  Tensor *input_x_;
  Tensor *out_;
  vector<float> slopes_;
T
Tian 已提交
875 876 877
};
#endif

L
liuruilong 已提交
878
#ifdef FUSION_FC_OP
L
liuruilong 已提交
879
class FusionFcParam : public OpParam {
E
eclipsess 已提交
880
 public:
L
liuruilong 已提交
881
  FusionFcParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
L
liuruilong 已提交
882
                const AttributeMap &attrs, const Scope &scope) {
E
eclipsess 已提交
883 884 885 886
    input_x_ = InputXFrom<LoDTensor>(inputs, scope);
    input_y_ = InputYFrom<LoDTensor>(inputs, scope);
    input_z_ = InputZFrom<LoDTensor>(inputs, scope);
    out_ = OutFrom<LoDTensor>(outputs, scope);
E
eclipsess 已提交
887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913
    x_num_col_dims_ = GetAttr<int>("x_num_col_dims", attrs);
    y_num_col_dims_ = GetAttr<int>("y_num_col_dims", attrs);
    axis_ = GetAttr<int>("axis", attrs);
  }
  const Tensor *InputX() const { return input_x_; }

  const Tensor *InputY() const { return input_y_; }

  const Tensor *InputZ() const { return input_z_; }

  Tensor *Out() const { return out_; }

  const int &XNumColDims() const { return x_num_col_dims_; }

  const int &YNumColDims() const { return y_num_col_dims_; }

  const int &Axis() const { return axis_; }

 private:
  Tensor *input_x_;
  Tensor *input_y_;
  Tensor *input_z_;
  Tensor *out_;
  int x_num_col_dims_;
  int y_num_col_dims_;
  int axis_;
};
L
liuruilong 已提交
914
#endif
E
eclipsess 已提交
915

W
wangliu 已提交
916
#ifdef FUSION_CONVADD_OP
L
liuruilong 已提交
917
class FusionConvAddParam : public OpParam {
W
wangliu 已提交
918
 public:
L
liuruilong 已提交
919
  FusionConvAddParam(const VariableNameMap &inputs,
L
liuruilong 已提交
920 921
                     const VariableNameMap &outputs, const AttributeMap &attrs,
                     const Scope &scope) {
W
wangliu 已提交
922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949
    bias_ = InputYFrom<LoDTensor>(inputs, scope);
    axis_ = GetAttr<int>("axis", attrs);
    filter_ = FilterFrom<LoDTensor>(inputs, scope);
    input_ = InputFrom<LoDTensor>(inputs, scope);
    output_ = OutFrom<LoDTensor>(outputs, scope);
    strides_ = GetAttr<vector<int>>("strides", attrs);
    paddings_ = GetAttr<vector<int>>("paddings", attrs);
    dilations_ = GetAttr<vector<int>>("dilations", attrs);
    groups = GetAttr<int>("groups", attrs);
  }
  Tensor *Bias() const { return bias_; }

  const int &Axis() const { return axis_; }

  const Tensor *Input() const { return input_; }

  const Tensor *Filter() const { return filter_; }

  Tensor *Output() const { return output_; }

  const vector<int> &Strides() const { return strides_; }

  const vector<int> &Paddings() const { return paddings_; }

  const vector<int> &Dilations() const { return dilations_; }

  const int &Groups() const { return groups; }

L
liuruilong 已提交
950
 protected:
W
wangliu 已提交
951 952 953 954 955 956 957 958 959 960 961
  Tensor *bias_;
  int axis_;
  Tensor *input_;
  Tensor *output_;
  Tensor *filter_;
  vector<int> strides_;
  vector<int> paddings_;
  vector<int> dilations_;
  int groups;
};

L
liuruilong 已提交
962
Print &operator<<(Print &printer, const FusionConvAddParam &conv_param);
W
wangliu 已提交
963 964
#endif

L
liuruilong 已提交
965
#ifdef FUSION_CONVADD_RELU_OP
L
liuruilong 已提交
966
class FusionConvAddReluParam : public FusionConvAddParam {
L
liuruilong 已提交
967
 public:
L
liuruilong 已提交
968
  FusionConvAddReluParam(const VariableNameMap &inputs,
L
liuruilong 已提交
969 970
                         const VariableNameMap &outputs,
                         const AttributeMap &attrs, const Scope &scope)
L
liuruilong 已提交
971
      : FusionConvAddParam(inputs, outputs, attrs, scope) {}
L
liuruilong 已提交
972 973 974
};
#endif

E
eclipsess 已提交
975 976 977 978 979 980 981 982 983 984 985 986 987 988 989
#ifdef FUSION_CONVADDBNRELU_OP
class FusionConvAddBNReluParam : public OpParam {
 public:
  FusionConvAddBNReluParam(const VariableNameMap &inputs,
                           const VariableNameMap &outputs,
                           const AttributeMap &attrs, const Scope &scope) {
    bias_ = InputYFrom<LoDTensor>(inputs, scope);
    axis_ = GetAttr<int>("axis", attrs);
    filter_ = FilterFrom<LoDTensor>(inputs, scope);
    input_ = InputFrom<LoDTensor>(inputs, scope);
    output_ = OutFrom<LoDTensor>(outputs, scope);
    strides_ = GetAttr<vector<int>>("strides", attrs);
    paddings_ = GetAttr<vector<int>>("paddings", attrs);
    dilations_ = GetAttr<vector<int>>("dilations", attrs);
    groups = GetAttr<int>("groups", attrs);
990 991 992 993
    input_bias_ = InputBiasFrom<LoDTensor>(inputs, scope);
    input_mean_ = InputMeanFrom<LoDTensor>(inputs, scope);
    input_scale_ = InputScaleFrom<LoDTensor>(inputs, scope);
    input_variance_ = InputVarianceFrom<LoDTensor>(inputs, scope);
E
eclipsess 已提交
994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060
    epsilon_ = GetAttr<float>("epsilon", attrs);
    momentum_ = GetAttr<float>("momentum", attrs);
    is_test_ = GetAttr<bool>("is_test", attrs);
  }
  Tensor *Bias() const { return bias_; }

  const int &Axis() const { return axis_; }

  const Tensor *Input() const { return input_; }

  const Tensor *Filter() const { return filter_; }

  Tensor *Output() const { return output_; }

  const vector<int> &Strides() const { return strides_; }

  const vector<int> &Paddings() const { return paddings_; }

  const vector<int> &Dilations() const { return dilations_; }

  const int &Groups() const { return groups; }

  const Tensor *InputBias() const { return input_bias_; }

  const Tensor *InputMean() const { return input_mean_; }

  const Tensor *InputScale() const { return input_scale_; }

  const Tensor *InputVariance() const { return input_variance_; }

  const float &Epsilon() const { return epsilon_; }

  const float &Momentum() const { return momentum_; }

  const bool &IsTest() const { return is_test_; }

  void SetNewScale(Tensor *new_scale) { new_scale_ = new_scale; }

  void SetNewBias(Tensor *new_bias) { new_bias_ = new_bias; }

  const Tensor *NewScale() const { return new_scale_; }

  const Tensor *NewBias() const { return new_bias_; }

 protected:
  Tensor *bias_;
  int axis_;
  Tensor *input_;
  Tensor *output_;
  Tensor *filter_;
  vector<int> strides_;
  vector<int> paddings_;
  vector<int> dilations_;
  int groups;
  Tensor *input_bias_;
  Tensor *input_mean_;
  Tensor *input_scale_;
  Tensor *input_variance_;
  float epsilon_;
  float momentum_;
  bool is_test_;
  Tensor *new_bias_;
  Tensor *new_scale_;
};

Print &operator<<(Print &printer, const FusionConvAddParam &conv_param);
#endif
Y
Yao,kun 已提交
1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091

#ifdef IM2SEQUENCE_OP
class Im2SequenceParam : public OpParam {
 public:
  Im2SequenceParam(const VariableNameMap &inputs,
                   const VariableNameMap &outputs, const AttributeMap &attrs,
                   const Scope &scope) {
    input_x_ = InputXFrom<LoDTensor>(inputs, scope);
    out_ = OutFrom<LoDTensor>(outputs, scope);
    kernels_ = GetAttr<vector<int>>("kernels", attrs);
    strides_ = GetAttr<vector<int>>("strides", attrs);
    paddings_ = GetAttr<vector<int>>("paddings", attrs);
  }

  const Tensor *Input() const { return input_x_; }

  Tensor *Output() const { return out_; }

  const vector<int> &Kernels() const { return kernels_; }

  const vector<int> &Strides() const { return strides_; }

  const vector<int> &Paddings() const { return paddings_; }

 private:
  Tensor *input_x_;
  Tensor *out_;
  vector<int> kernels_;
  vector<int> strides_;
  vector<int> paddings_;
};
1092
#endif
Y
Yao,kun 已提交
1093

1094
#ifdef DROPOUT_OP
Y
Yao,kun 已提交
1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110
class DropoutParam : public OpParam {
 public:
  DropoutParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
               const AttributeMap &attrs, const Scope &scope) {
    input_x_ = InputXFrom<LoDTensor>(inputs, scope);
    out_ = OutFrom<LoDTensor>(outputs, scope);
  }

  const Tensor *InputX() const { return input_x_; }

  Tensor *Out() const { return out_; }

 private:
  Tensor *input_x_;
  Tensor *out_;
};
1111
#endif
Y
Yao,kun 已提交
1112

朔-望's avatar
朔-望 已提交
1113 1114
}  // namespace operators
}  // namespace paddle_mobile