op_param.h 12.1 KB
Newer Older
朔-望's avatar
朔-望 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserved.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
==============================================================================*/

19
#pragma once
朔-望's avatar
朔-望 已提交
20

L
liuruilong 已提交
21
#include "common/log.h"
朔-望's avatar
朔-望 已提交
22 23 24 25 26 27 28
#include "common/type_define.h"
#include "framework/lod_tensor.h"
#include "framework/scope.h"
#include "framework/tensor.h"
#include "framework/variable.h"

namespace paddle_mobile {
朔-望's avatar
朔-望 已提交
29 30 31 32 33
namespace operators {

using namespace framework;

class OpParam : PaddleMobileObject {
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
public:
protected:
  template <typename T>
  static T *InputFrom(const VariableNameMap &inputs, const Scope &scope) {
    return GetVarValue<T>("Input", inputs, scope);
  }

  template <typename T>
  static T *InputXFrom(const VariableNameMap &inputs, const Scope &scope) {
    return GetVarValue<T>("X", inputs, scope);
  }

  template <typename T>
  static T *InputYFrom(const VariableNameMap &inputs, const Scope &scope) {
    return GetVarValue<T>("Y", inputs, scope);
  }

  template <typename T>
  static T *InputBiasFrom(const VariableNameMap &inputs, const Scope &scope) {
    return GetVarValue<T>("Bias", inputs, scope);
  }
  template <typename T>
  static T *InputVarianceFrom(const VariableNameMap &inputs,
                              const Scope &scope) {
    return GetVarValue<T>("Variance", inputs, scope);
  }
  template <typename T>
  static T *InputMeanFrom(const VariableNameMap &inputs, const Scope &scope) {
    return GetVarValue<T>("Mean", inputs, scope);
  }
  template <typename T>
  static T *InputScaleFrom(const VariableNameMap &inputs, const Scope &scope) {
    return GetVarValue<T>("Scale", inputs, scope);
  }

  template <typename T>
  static std::vector<T *> InputMultiFrom(const VariableNameMap &inputs,
                                         const Scope &scope) {
    return GetMultiVarValue<T>("X", inputs, scope);
  }

  template <typename T>
  static T *OutputFrom(const VariableNameMap &outputs, const Scope &scope) {
    return GetVarValue<T>("Output", outputs, scope);
  }

  template <typename T>
  static T *OutFrom(const VariableNameMap &outputs, const Scope &scope) {
    return GetVarValue<T>("Out", outputs, scope);
  }

  template <typename T>
  static T *OutputYFrom(const VariableNameMap &outputs, const Scope &scope) {
    return GetVarValue<T>("Y", outputs, scope);
  }

  template <typename T>
  static T *MidOutFrom(const VariableNameMap &outputs, const Scope &scope) {
    return GetVarValue<T>("MidOut", outputs, scope);
  }

  template <typename T>
  static T *FilterFrom(const VariableNameMap &inputs, const Scope &scope) {
    return GetVarValue<T>("Filter", inputs, scope);
  }

  template <typename T>
  static const T GetAttr(const std::string &key, const AttributeMap &map) {
    return ((Attribute)map.at(key)).Get<T>();
  }

  template <typename T>
  static T *GetVarValue(const std::string &key, const VariableNameMap &var_map,
                        const Scope &scope) {
    auto var_vec = var_map.at(key);
    if (!var_vec.empty()) {
      //      std::cout << " get var value -- " << var_vec[0] <<
      //      std::endl;
      auto var = scope.FindVar(var_vec[0]);
      return var->GetMutable<T>();
    } else {
      return nullptr;
朔-望's avatar
朔-望 已提交
116
    }
117
  }
朔-望's avatar
朔-望 已提交
118

119 120 121
  template <typename T>
  static std::vector<T *> GetMultiVarValue(const std::string &key,
                                           const VariableNameMap &var_map,
朔-望's avatar
朔-望 已提交
122
                                           const Scope &scope) {
123 124 125 126 127 128
    auto var_vecs = var_map.at(key);
    assert(var_vecs.size() > 1);
    std::vector<T *> var_res;
    for (auto &var_vec : var_vecs) {
      auto var = scope.FindVar(var_vec);
      var_res.push_back(var->GetMutable<T>());
朔-望's avatar
朔-望 已提交
129
    }
130 131
    return var_res;
  }
朔-望's avatar
朔-望 已提交
132 133 134
};

class ConvParam : OpParam {
135 136 137 138 139 140 141 142 143 144 145 146
public:
  ConvParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
            const framework::AttributeMap &attrs,
            const framework::Scope &scope) {
    filter_ = FilterFrom<framework::LoDTensor>(inputs, scope);
    input_ = InputFrom<framework::Tensor>(inputs, scope);
    output_ = OutputFrom<framework::Tensor>(outputs, scope);
    strides_ = GetAttr<std::vector<int>>("strides", attrs);
    paddings_ = GetAttr<std::vector<int>>("paddings", attrs);
    dilations_ = GetAttr<std::vector<int>>("dilations", attrs);
    groups = GetAttr<int>("groups", attrs);
  }
朔-望's avatar
朔-望 已提交
147

148
  const Tensor *Input() const { return input_; }
朔-望's avatar
朔-望 已提交
149

150
  const LoDTensor *Filter() const { return filter_; }
朔-望's avatar
朔-望 已提交
151

152
  Tensor *Output() const { return output_; }
朔-望's avatar
朔-望 已提交
153

154
  const std::vector<int> &Strides() const { return strides_; }
朔-望's avatar
朔-望 已提交
155

156
  const std::vector<int> &Paddings() const { return paddings_; }
朔-望's avatar
朔-望 已提交
157

158
  const std::vector<int> &Dilations() const { return dilations_; }
朔-望's avatar
朔-望 已提交
159

160
  const int &Groups() const { return groups; }
朔-望's avatar
朔-望 已提交
161

162 163 164 165 166 167 168 169
private:
  Tensor *input_;
  Tensor *output_;
  LoDTensor *filter_;
  std::vector<int> strides_;
  std::vector<int> paddings_;
  std::vector<int> dilations_;
  int groups;
朔-望's avatar
朔-望 已提交
170 171 172 173 174
};

Print &operator<<(Print &printer, const ConvParam &conv_param);

class ElementwiseAddParam : OpParam {
175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198
public:
  ElementwiseAddParam(const VariableNameMap &inputs,
                      const VariableNameMap &outputs,
                      const framework::AttributeMap &attrs,
                      const framework::Scope &scope) {
    input_x_ = InputXFrom<framework::Tensor>(inputs, scope);
    input_y_ = InputYFrom<framework::Tensor>(inputs, scope);
    out_ = OutFrom<framework::Tensor>(outputs, scope);
    axis_ = GetAttr<int>("axis", attrs);
  }

  const Tensor *InputX() const { return input_x_; }

  const Tensor *InputY() const { return input_y_; }

  Tensor *Out() const { return out_; }

  const int &Axis() const { return axis_; }

private:
  Tensor *input_x_;
  Tensor *input_y_;
  Tensor *out_;
  int axis_;
朔-望's avatar
朔-望 已提交
199 200 201
};

class MulParam : OpParam {
202 203 204 205 206 207 208 209 210 211
public:
  MulParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
           const framework::AttributeMap &attrs,
           const framework::Scope &scope) {
    input_x_ = InputXFrom<framework::Tensor>(inputs, scope);
    input_y_ = InputYFrom<framework::Tensor>(inputs, scope);
    out_ = OutFrom<framework::Tensor>(outputs, scope);
    x_num_col_dims_ = GetAttr<int>("x_num_col_dims", attrs);
    y_num_col_dims_ = GetAttr<int>("y_num_col_dims", attrs);
  }
朔-望's avatar
朔-望 已提交
212

213
  const Tensor *InputX() const { return input_x_; }
朔-望's avatar
朔-望 已提交
214

215
  const Tensor *InputY() const { return input_y_; }
朔-望's avatar
朔-望 已提交
216

217
  Tensor *Out() const { return out_; }
朔-望's avatar
朔-望 已提交
218

219
  const int &XNumColDims() const { return x_num_col_dims_; }
朔-望's avatar
朔-望 已提交
220

221
  const int &YNumColDims() const { return y_num_col_dims_; }
朔-望's avatar
朔-望 已提交
222

223 224 225 226 227 228
private:
  Tensor *input_x_;
  Tensor *input_y_;
  Tensor *out_;
  int x_num_col_dims_;
  int y_num_col_dims_;
朔-望's avatar
朔-望 已提交
229 230 231
};

class ConcatParam : public OpParam {
232 233 234 235 236 237 238 239
public:
  ConcatParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
              const framework::AttributeMap &attrs,
              const framework::Scope &scope) {
    inputs_ = InputMultiFrom<framework::Tensor>(inputs, scope);
    out_ = OutFrom<framework::Tensor>(outputs, scope);
    axis_ = GetAttr<int>("axis", attrs);
  }
朔-望's avatar
朔-望 已提交
240

241
  std::vector<Tensor *> Inputs() const { return inputs_; }
朔-望's avatar
朔-望 已提交
242

243
  Tensor *Out() const { return out_; }
朔-望's avatar
朔-望 已提交
244

245
  const int &Axis() const { return axis_; }
朔-望's avatar
朔-望 已提交
246

247 248 249 250
private:
  std::vector<Tensor *> inputs_;
  Tensor *out_;
  int axis_;
朔-望's avatar
朔-望 已提交
251 252
};

E
eclipsess 已提交
253
class LrnParam : public OpParam {
254 255 256 257 258 259 260 261 262 263 264 265 266
public:
  LrnParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
           const framework::AttributeMap &attrs,
           const framework::Scope &scope) {
    input_x_ = InputXFrom<framework::Tensor>(inputs, scope);
    out_ = OutFrom<framework::Tensor>(outputs, scope);
    mid_out_ = MidOutFrom<framework::Tensor>(outputs, scope);
    n_ = GetAttr<int>("n", attrs);
    alpha_ = GetAttr<float>("alpha", attrs);
    beta_ = GetAttr<float>("beta", attrs);
    k_ = GetAttr<float>("k", attrs);
    data_format_ = GetAttr<std::string>("data_format", attrs);
  }
E
eclipsess 已提交
267

268
  const Tensor *InputX() const { return input_x_; }
E
eclipsess 已提交
269

270
  Tensor *Out() const { return out_; }
E
eclipsess 已提交
271

272
  Tensor *MidOut() const { return mid_out_; }
E
eclipsess 已提交
273

274
  const int &N() const { return n_; }
E
eclipsess 已提交
275

276
  const float &Alpha() const { return alpha_; }
E
eclipsess 已提交
277

278
  const float &Beta() const { return beta_; }
E
eclipsess 已提交
279

280
  const float &K() const { return k_; }
E
eclipsess 已提交
281

282
  const std::string &DataFormat() const { return data_format_; }
E
eclipsess 已提交
283

284 285 286 287 288 289 290 291 292
private:
  Tensor *input_x_;
  Tensor *out_;
  Tensor *mid_out_;
  int n_;
  float alpha_;
  float beta_;
  float k_;
  std::string data_format_;
E
eclipsess 已提交
293
};
E
eclipsess 已提交
294
class BatchNormParam : OpParam {
295 296 297 298 299 300 301 302 303 304 305 306 307 308
public:
  BatchNormParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
                 const framework::AttributeMap &attrs,
                 const framework::Scope &scope) {
    input_x_ = InputXFrom<framework::Tensor>(inputs, scope);
    output_y_ = OutputYFrom<framework::Tensor>(outputs, scope);
    input_bias_ = InputBiasFrom<framework::Tensor>(inputs, scope);
    input_mean_ = InputMeanFrom<framework::Tensor>(inputs, scope);
    input_scale_ = InputScaleFrom<framework::Tensor>(inputs, scope);
    input_variance_ = InputVarianceFrom<framework::Tensor>(inputs, scope);
    epsilon_ = GetAttr<float>("epsilon", attrs);
    momentum_ = GetAttr<float>("momentum", attrs);
    is_test_ = GetAttr<bool>("is_test", attrs);
  }
E
eclipsess 已提交
309

310
  const Tensor *InputX() const { return input_x_; }
E
eclipsess 已提交
311

312
  Tensor *OutputY() const { return output_y_; }
E
eclipsess 已提交
313

314
  const Tensor *InputBias() const { return input_bias_; }
E
eclipsess 已提交
315

316
  const Tensor *InputMean() const { return input_mean_; }
E
eclipsess 已提交
317

318
  const Tensor *InputScale() const { return input_scale_; }
E
eclipsess 已提交
319

320
  const Tensor *InputVariance() const { return input_variance_; }
E
eclipsess 已提交
321

322
  const float &Epsilon() const { return epsilon_; }
E
eclipsess 已提交
323

324
  const float &Momentum() const { return momentum_; }
E
eclipsess 已提交
325

326
  const bool &IsTest() const { return is_test_; }
E
eclipsess 已提交
327

328
  const std::string &DataFormat() const { return data_format_; }
E
eclipsess 已提交
329

330 331 332 333 334 335 336 337 338 339 340
private:
  Tensor *input_x_;
  Tensor *output_y_;
  Tensor *input_bias_;
  Tensor *input_mean_;
  Tensor *input_scale_;
  Tensor *input_variance_;
  float epsilon_;
  float momentum_;
  bool is_test_;
  std::string data_format_;
E
eclipsess 已提交
341
};
342
class PoolParam : public OpParam {
343 344 345 346 347 348 349 350 351 352 353 354 355 356
public:
  PoolParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
            const framework::AttributeMap &attrs,
            const framework::Scope &scope) {
    input_ = InputXFrom<framework::Tensor>(inputs, scope);

    output_ = OutFrom<framework::Tensor>(outputs, scope);
    pooling_type_ = GetAttr<std::string>("pooling_type", attrs);
    ksize_ = GetAttr<std::vector<int>>("ksize", attrs);
    strides_ = GetAttr<std::vector<int>>("strides", attrs);
    paddings_ = GetAttr<std::vector<int>>("paddings", attrs);
    ceil_mode_ = GetAttr<bool>("ceil_mode", attrs);
    gloabal_pooling_ = GetAttr<bool>("global_pooling", attrs);
  }
357

358
  const Tensor *Input() const { return input_; }
359

360
  Tensor *Output() const { return output_; }
361

362
  const std::string &PoolingType() const { return pooling_type_; }
363

364
  const std::vector<int> &Ksize() const { return ksize_; }
365

366
  const std::vector<int> &Strides() const { return strides_; }
367

368
  const std::vector<int> &Paddings() const { return paddings_; }
369

370
  bool isCeilMode() const { return ceil_mode_; }
371

372
  bool isGlobalPooling() const { return gloabal_pooling_; }
373

374 375 376 377 378 379 380 381 382
private:
  Tensor *input_;
  Tensor *output_;
  std::string pooling_type_;
  std::vector<int> ksize_;
  std::vector<int> strides_;
  std::vector<int> paddings_;
  bool ceil_mode_;
  bool gloabal_pooling_ = false;
383 384
};

朔-望's avatar
朔-望 已提交
385
} // namespace operators
L
liuruilong 已提交
386
} // namespace paddle_mobile