Function.h 6.2 KB
Newer Older
H
hedaoyuan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <map>
#include <vector>
H
hedaoyuan 已提交
19
#include "BufferArg.h"
H
hedaoyuan 已提交
20
#include "paddle/math/Matrix.h"
21
#include "paddle/utils/Any.h"
H
hedaoyuan 已提交
22
#include "paddle/utils/ClassRegistrar.h"
H
hedaoyuan 已提交
23 24 25

namespace paddle {

H
hedaoyuan 已提交
26 27 28 29
/**
 * Function Configuration.
 * The argument type of Function::init.
 */
H
hedaoyuan 已提交
30 31 32
class FuncConfig {
public:
  template <typename T>
33 34 35
  T get(const std::string& key) const {
    return any_cast<T>(valueMap_[key]);
  }
H
hedaoyuan 已提交
36 37

  template <typename T>
38 39 40 41
  FuncConfig& set(const std::string& key, T v) {
    valueMap_[key] = any(v);
    return *this;
  }
H
hedaoyuan 已提交
42 43

protected:
44
  mutable std::unordered_map<std::string, any> valueMap_;
H
hedaoyuan 已提交
45 46
};

H
hedaoyuan 已提交
47 48 49
/**
 * Argument type for Function::calc().
 * A BufferArgs contains a set of BufferArg,
50
 * because Function can have multiple inputs and outputs.
H
hedaoyuan 已提交
51 52 53 54 55 56 57 58
 *
 * addArg() with Matix object used to adapt Layer Argument.
 * Will create a BufferArg object in addArg(),
 * and free in destructor of BufferArgs.
 *
 * addArg() with BufferArg object, just save BufferArg object address,
 * and the caller needs to guarantee the validity of the BufferArg object
 * in the BufferArgs life time.
H
hedaoyuan 已提交
59 60 61 62
 */
class BufferArgs {
public:
  BufferArgs() {}
H
hedaoyuan 已提交
63 64 65 66 67 68 69

  ~BufferArgs() {
    for (auto arg : _args_) {
      delete arg;
    }
  }

H
hedaoyuan 已提交
70 71
  size_t size() const { return args_.size(); }

72 73
  // add argument into BufferArgs
  // Tensor can be Matrix, Vector, IVector.
74 75
  // For inputs, do not need argType.
  // For outputs, the argType needs to be specified as ASSIGN_TO or ADD_TO.
H
hedaoyuan 已提交
76 77 78 79 80 81 82 83 84 85 86
  void addArg(const Matrix& arg, ArgType argType = UNSPECIFIED) {
    _args_.push_back(new BufferArg(arg, argType));
    addArg(*_args_.back());
  }

  void addArg(const Vector& arg, ArgType argType = UNSPECIFIED) {
    _args_.push_back(new BufferArg(arg, argType));
    addArg(*_args_.back());
  }

  void addArg(const IVector& arg, ArgType argType = UNSPECIFIED) {
H
hedaoyuan 已提交
87 88
    _args_.push_back(new BufferArg(arg, argType));
    addArg(*_args_.back());
H
hedaoyuan 已提交
89 90
  }

91 92 93 94 95
  // Add arg into BufferArgs and reshape the arg.
  //
  // For example, arg represents an image buffer,
  // but Matrix can only represent a two-dimensional Tensor.
  // So need an extra argument to describe the shape of the image buffer.
96 97 98
  void addArg(const Matrix& arg,
              const TensorShape& shape,
              ArgType argType = UNSPECIFIED);
H
hedaoyuan 已提交
99

100 101
  void addArg(const CpuSparseMatrix& arg, ArgType argType = UNSPECIFIED);
  void addArg(const GpuSparseMatrix& arg, ArgType argType = UNSPECIFIED);
H
hedaoyuan 已提交
102

X
xutianbing 已提交
103 104 105 106
  void addArg(const Matrix& matrix,
              const IVector& vector,
              ArgType argType = UNSPECIFIED);

H
hedaoyuan 已提交
107 108 109 110 111 112
  // get argument
  const BufferArg& operator[](size_t num) const {
    CHECK_LT(num, args_.size());
    return *args_[num];
  }

H
hedaoyuan 已提交
113 114 115 116 117 118 119 120
  void addArg(BufferArg& arg) { args_.push_back(&arg); }

  void addArg(SequenceIdArg& arg) { args_.push_back(&arg); }

  void addArg(SequenceArg& arg) { args_.push_back(&arg); }

  void addArg(SparseMatrixArg& arg) { args_.push_back(&arg); }

H
hedaoyuan 已提交
121
private:
H
hedaoyuan 已提交
122 123 124
  std::vector<BufferArg*> args_;
  // The BufferArg object is constructed and freed by BufferArgs.
  std::vector<BufferArg*> _args_;
H
hedaoyuan 已提交
125 126 127
};

/**
128
 * \brief Base class for Function.
H
hedaoyuan 已提交
129
 * The basic Function implementation requires override init and calc interfaces.
130
 *
H
hedaoyuan 已提交
131 132 133
 * The caller needs to ensure the validity of the arguments
 * during Function execution.
 *
134 135 136 137 138 139 140 141 142 143 144
 * Function inputs are readonly, Function outputs have two modes: ASSIGN_TO
 * and ADD_TO.
 * If output.getArgType() == ASSIGN_TO, this is assign mode, and the calculation
 * result of Function assigned to the output BufferArg.
 * If output.getArgType() == ADD_TO, this is add mode, and the calculation
 * result of Function need added to the output BufferArg.
 *
 * For example:
 * ASSIGN_TO: output = Function(inputs)
 * ADD_TO: output += Function(inputs)
 * If Function has more than one output, each output can have different modes.
H
hedaoyuan 已提交
145
 */
H
hedaoyuan 已提交
146 147 148 149 150 151
class FunctionBase {
public:
  virtual ~FunctionBase() {}

  virtual void init(const FuncConfig& config) {}

152
  virtual void calc(const BufferArgs& inputs, const BufferArgs& outputs) {}
H
hedaoyuan 已提交
153

H
hedaoyuan 已提交
154 155 156
  // This member function is used to check whether the BufferType and shape of
  // the inputs and outputs arguments of the Function are correct.
  // General calc function which will call this check to do arguments check.
H
hedaoyuan 已提交
157
  // And before the calc called, the caller can also check their own arguments.
H
hedaoyuan 已提交
158 159
  virtual void check(const BufferArgs& inputs, const BufferArgs& outputs) {}

H
hedaoyuan 已提交
160 161 162
  // Calculate the number of floating-point operations of this Function.
  // The inputs and outputs arguments do not need to contain the actual data,
  // only the shape.
H
hedaoyuan 已提交
163 164 165
  // And some Functions have the same input and output shapes,
  // so you may not need to enter the complete number of arguments.
  // But entering the full arguments is always correct for this interface.
H
hedaoyuan 已提交
166 167 168 169
  virtual size_t ops(const BufferArgs& inputs, const BufferArgs& outputs) {
    return 0;
  }

H
hedaoyuan 已提交
170 171 172 173
  int getNumInputs() const { return numInputs_; }

  int getNumOutputs() const { return numOutputs_; }

H
hedaoyuan 已提交
174
  static ClassRegistrar<FunctionBase> funcRegistrar_;
H
hedaoyuan 已提交
175 176 177 178 179 180 181 182 183

protected:
  // numInputs_ and numOutputs_ represents the maximum
  // input and output supported by Function.
  // Some functions are optimized for input and output,
  // so when comparing the number of arguments, for these functions
  // inputs.size() <= numInputs_ or outputs.size() <= numOutputs_
  size_t numInputs_;
  size_t numOutputs_;
H
hedaoyuan 已提交
184 185 186 187
};

#define FUNC_NAME(typeName, deviceName) #typeName "-" #deviceName

H
hedaoyuan 已提交
188 189 190 191 192
#define REGISTER_TYPED_FUNC(typeName, deviceName, className)   \
  static InitFunction __reg_type_##typeName##deviceName([]() { \
    FunctionBase::funcRegistrar_                               \
        .registerClass<className<DEVICE_TYPE_##deviceName>>(   \
            FUNC_NAME(typeName, deviceName));                  \
H
hedaoyuan 已提交
193 194 195
  })

}  // namespace paddle