diff --git a/paddle/function/BufferArg.h b/paddle/function/BufferArg.h index 9fcda7a878aaddc47ad3f3e3b1a064cf8b5d2049..52494afed3b8554e9960f4ceccb86c9d94a6e91d 100644 --- a/paddle/function/BufferArg.h +++ b/paddle/function/BufferArg.h @@ -46,32 +46,6 @@ class SequenceArg; class SparseMatrixArg; typedef std::shared_ptr BufferArgPtr; -class BufferArgs { -public: - BufferArgs() {} - size_t size() const { return args_.size(); } - - // add argument into BufferArgss - template - void addArg(const Tensor& arg) { - args_.push_back(std::make_shared(arg)); - } - - void addArg(const Matrix& arg, const TensorShape& shape); - - void addArg(const CpuSparseMatrix& arg); - void addArg(const GpuSparseMatrix& arg); - - // get argument - const BufferArg& operator[](size_t num) const { - CHECK_LT(num, args_.size()); - return *args_[num]; - } - -private: - std::vector args_; -}; - // an array of arbitrary dimensions class BufferArg { public: diff --git a/paddle/function/Function.h b/paddle/function/Function.h index 024575b4f7bcdbd059338958b727a898d66cf1dc..27ebe808aaf446771afb931e3a4611519cf340f0 100644 --- a/paddle/function/Function.h +++ b/paddle/function/Function.h @@ -22,6 +22,11 @@ limitations under the License. */ namespace paddle { +/** + * Function Configuration. + * The argument type of Function::init. + * Follow-up will consider moving this data structure to Proto inside. + */ class FuncConfig { public: union value { @@ -41,6 +46,43 @@ protected: std::map valueMap_; }; +/** + * Argument type for Function::calc(). + * A BufferArgs contains a set of BufferArg, + * because Function can have multiple inputs, outputs and inouts. + */ +class BufferArgs { +public: + BufferArgs() {} + size_t size() const { return args_.size(); } + + // add argument into BufferArgss + template + void addArg(const Tensor& arg) { + args_.push_back(std::make_shared(arg)); + } + + void addArg(const Matrix& arg, const TensorShape& shape); + + void addArg(const CpuSparseMatrix& arg); + void addArg(const GpuSparseMatrix& arg); + + // get argument + const BufferArg& operator[](size_t num) const { + CHECK_LT(num, args_.size()); + return *args_[num]; + } + +private: + std::vector args_; +}; + +/** + * Base class for Function. + * The basic Function implementation requires override init and calc interfaces. + * Need to pay attention to the inouts argument. For the input argument + * that will be modified, it needs to be passed through inouts. + */ class FunctionBase { public: virtual ~FunctionBase() {}