提交 f3fdfd94 编写于 作者: H hedaoyuan

add some comments for Function.h

上级 41c52d3b
...@@ -46,32 +46,6 @@ class SequenceArg; ...@@ -46,32 +46,6 @@ class SequenceArg;
class SparseMatrixArg; class SparseMatrixArg;
typedef std::shared_ptr<BufferArg> BufferArgPtr; typedef std::shared_ptr<BufferArg> BufferArgPtr;
class BufferArgs {
public:
BufferArgs() {}
size_t size() const { return args_.size(); }
// add argument into BufferArgss
template <typename Tensor>
void addArg(const Tensor& arg) {
args_.push_back(std::make_shared<BufferArg>(arg));
}
void addArg(const Matrix& arg, const TensorShape& shape);
void addArg(const CpuSparseMatrix& arg);
void addArg(const GpuSparseMatrix& arg);
// get argument
const BufferArg& operator[](size_t num) const {
CHECK_LT(num, args_.size());
return *args_[num];
}
private:
std::vector<BufferArgPtr> args_;
};
// an array of arbitrary dimensions // an array of arbitrary dimensions
class BufferArg { class BufferArg {
public: public:
......
...@@ -22,6 +22,11 @@ limitations under the License. */ ...@@ -22,6 +22,11 @@ limitations under the License. */
namespace paddle { namespace paddle {
/**
* Function Configuration.
* The argument type of Function::init.
* Follow-up will consider moving this data structure to Proto inside.
*/
class FuncConfig { class FuncConfig {
public: public:
union value { union value {
...@@ -41,6 +46,43 @@ protected: ...@@ -41,6 +46,43 @@ protected:
std::map<std::string, value> valueMap_; std::map<std::string, value> valueMap_;
}; };
/**
* Argument type for Function::calc().
* A BufferArgs contains a set of BufferArg,
* because Function can have multiple inputs, outputs and inouts.
*/
class BufferArgs {
public:
BufferArgs() {}
size_t size() const { return args_.size(); }
// add argument into BufferArgss
template <typename Tensor>
void addArg(const Tensor& arg) {
args_.push_back(std::make_shared<BufferArg>(arg));
}
void addArg(const Matrix& arg, const TensorShape& shape);
void addArg(const CpuSparseMatrix& arg);
void addArg(const GpuSparseMatrix& arg);
// get argument
const BufferArg& operator[](size_t num) const {
CHECK_LT(num, args_.size());
return *args_[num];
}
private:
std::vector<BufferArgPtr> args_;
};
/**
* Base class for Function.
* The basic Function implementation requires override init and calc interfaces.
* Need to pay attention to the inouts argument. For the input argument
* that will be modified, it needs to be passed through inouts.
*/
class FunctionBase { class FunctionBase {
public: public:
virtual ~FunctionBase() {} virtual ~FunctionBase() {}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册