提交 57e25211 编写于 作者: H hedaoyuan

BufferArg add ArgType and Function remove inouts

上级 d35ef9de
...@@ -38,16 +38,40 @@ enum SparseDataType { ...@@ -38,16 +38,40 @@ enum SparseDataType {
enum SparseDataFormat { SPARSE_CSR_FORMAT = 0, SPARSE_CSC_FORMAT = 1 }; enum SparseDataFormat { SPARSE_CSR_FORMAT = 0, SPARSE_CSC_FORMAT = 1 };
/**
* BufferArg used as the argument type for Function.
*/
class BufferArg; class BufferArg;
class SequenceArg; class SequenceArg;
class SparseMatrixArg; class SparseMatrixArg;
typedef std::shared_ptr<BufferArg> BufferArgPtr; typedef std::shared_ptr<BufferArg> BufferArgPtr;
// an array of arbitrary dimensions /**
* \brief BufferArg used as the argument type of Function.
*
* The arguments of the Paddle Function have four Buffer types.
* 1. BufferArg for a dense Buffer of any dimension.
* 2. SequenceIdArg for a Buffer of sequence start positions.
* 3. SequenceArg for a Buffer of sequence data.
* 4. SparseMatrixArg for a Buffer of sparse matrix.
*
* There is an ArgType property for the BufferArg used as Function Output.
* Whether the result of the Function calculation is assigned to the
* output Buffer or added to the output Buffer is determined by the
* argType_ property of the output BufferArg.
*/
class BufferArg { class BufferArg {
public:
// ArgType is only used by output BufferArg.
// For input argument, argType_ is ignored.
// For output argument, need to set the argType_ of the BufferArg.
enum ArgType {
UNSPECIFIED = 0,
ASSIGN_TO = 1,
ADD_TO = 2,
};
void setArgType(ArgType argType) { argType_ = argType; }
ArgType getArgType() const { return argType_; }
public: public:
BufferArg(void* buf, ValueType valueType, const TensorShape& shape) BufferArg(void* buf, ValueType valueType, const TensorShape& shape)
: buf_(buf), valueType_(valueType), shape_(shape) {} : buf_(buf), valueType_(valueType), shape_(shape) {}
...@@ -56,7 +80,8 @@ public: ...@@ -56,7 +80,8 @@ public:
: buf_(buf), valueType_(valueType) {} : buf_(buf), valueType_(valueType) {}
BufferArg(const Matrix& matrix) BufferArg(const Matrix& matrix)
: buf_(reinterpret_cast<void*>(matrix.getData())), : buf_(
const_cast<void*>(reinterpret_cast<const void*>(matrix.getData()))),
valueType_(DataType<real>::value), valueType_(DataType<real>::value),
shape_(2) { shape_(2) {
shape_.setDim(0, matrix.getHeight()); shape_.setDim(0, matrix.getHeight());
...@@ -64,21 +89,24 @@ public: ...@@ -64,21 +89,24 @@ public:
} }
BufferArg(const Matrix& matrix, const TensorShape& shape) BufferArg(const Matrix& matrix, const TensorShape& shape)
: buf_(reinterpret_cast<void*>(matrix.getData())), : buf_(
const_cast<void*>(reinterpret_cast<const void*>(matrix.getData()))),
valueType_(DataType<real>::value), valueType_(DataType<real>::value),
shape_(shape) { shape_(shape) {
CHECK_EQ(matrix.getElementCnt(), shape.getElements()); CHECK_EQ(matrix.getElementCnt(), shape.getElements());
} }
BufferArg(const Vector& vector) BufferArg(const Vector& vector)
: buf_(reinterpret_cast<void*>(vector.getData())), : buf_(
const_cast<void*>(reinterpret_cast<const void*>(vector.getData()))),
valueType_(DataType<real>::value), valueType_(DataType<real>::value),
shape_(1) { shape_(1) {
shape_.setDim(0, vector.getSize()); shape_.setDim(0, vector.getSize());
} }
BufferArg(const IVector& vector) BufferArg(const IVector& vector)
: buf_(reinterpret_cast<void*>(vector.getData())), : buf_(
const_cast<void*>(reinterpret_cast<const void*>(vector.getData()))),
valueType_(VALUE_TYPE_INT32), valueType_(VALUE_TYPE_INT32),
shape_(1) { shape_(1) {
shape_.setDim(0, vector.getSize()); shape_.setDim(0, vector.getSize());
...@@ -124,6 +152,7 @@ protected: ...@@ -124,6 +152,7 @@ protected:
ValueType valueType_; ValueType valueType_;
TensorShape shape_; TensorShape shape_;
BufferType bufferType_; BufferType bufferType_;
ArgType argType_ = UNSPECIFIED;
// leading dimensions. The size is dims_.size() // leading dimensions. The size is dims_.size()
// Dims lds_; // Dims lds_;
}; };
......
...@@ -56,12 +56,18 @@ public: ...@@ -56,12 +56,18 @@ public:
BufferArgs() {} BufferArgs() {}
size_t size() const { return args_.size(); } size_t size() const { return args_.size(); }
// add argument into BufferArgss // add argument into BufferArgs
// Tensor can be Matrix, Vector, IVector.
template <typename Tensor> template <typename Tensor>
void addArg(const Tensor& arg) { void addArg(const Tensor& arg) {
args_.push_back(std::make_shared<BufferArg>(arg)); args_.push_back(std::make_shared<BufferArg>(arg));
} }
// Add arg into BufferArgs and reshape the arg.
//
// For example, arg represents an image buffer,
// but Matrix can only represent a two-dimensional Tensor.
// So need an extra argument to describe the shape of the image buffer.
void addArg(const Matrix& arg, const TensorShape& shape); void addArg(const Matrix& arg, const TensorShape& shape);
void addArg(const CpuSparseMatrix& arg); void addArg(const CpuSparseMatrix& arg);
...@@ -78,10 +84,20 @@ private: ...@@ -78,10 +84,20 @@ private:
}; };
/** /**
* Base class for Function. * \brief Base class for Function.
* The basic Function implementation requires override init and calc interfaces. * The basic Function implementation requires override init and calc interfaces.
* Need to pay attention to the inouts argument. For the input argument *
* that will be modified, it needs to be passed through inouts. * Function inputs are readonly, Function outputs have two modes: ASSIGN_TO
* and ADD_TO.
* If output.getArgType() == ASSIGN_TO, this is assign mode, and the calculation
* result of Function assigned to the output BufferArg.
* If output.getArgType() == ADD_TO, this is add mode, and the calculation
* result of Function need added to the output BufferArg.
*
* For example:
* ASSIGN_TO: output = Function(inputs)
* ADD_TO: output += Function(inputs)
* If Function has more than one output, each output can have different modes.
*/ */
class FunctionBase { class FunctionBase {
public: public:
...@@ -89,9 +105,7 @@ public: ...@@ -89,9 +105,7 @@ public:
virtual void init(const FuncConfig& config) {} virtual void init(const FuncConfig& config) {}
virtual void calc(const BufferArgs& inputs, virtual void calc(const BufferArgs& inputs, const BufferArgs& outputs) {}
const BufferArgs& outputs,
const BufferArgs& inouts) {}
static ClassRegistrar<FunctionBase> funcRegistrar_; static ClassRegistrar<FunctionBase> funcRegistrar_;
}; };
......
...@@ -35,7 +35,7 @@ void FunctionApi<DEVICE_TYPE_GPU>(GpuMatrix& output, const GpuMatrix& input) { ...@@ -35,7 +35,7 @@ void FunctionApi<DEVICE_TYPE_GPU>(GpuMatrix& output, const GpuMatrix& input) {
template <DeviceType DType> template <DeviceType DType>
void Function(const BufferArgs& arguments) { void Function(const BufferArgs& arguments) {
auto input = arguments[0].matrix<DType>(); const auto input = arguments[0].matrix<DType>();
auto output = arguments[1].matrix<DType>(); auto output = arguments[1].matrix<DType>();
FunctionApi<DType>(output, input); FunctionApi<DType>(output, input);
} }
......
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册