提交 2a20fdc1 编写于 作者: H hedaoyuan

Change BufferArgPtr to BufferArg*

上级 bff19f57
......@@ -79,15 +79,18 @@ FuncConfig& FuncConfig::set<bool>(const std::string& key, bool v) {
void BufferArgs::addArg(const Matrix& arg,
const TensorShape& shape,
ArgType argType) {
args_.push_back(std::make_shared<BufferArg>(arg, shape, argType));
_args_.push_back(new BufferArg(arg, shape, argType));
addArg(*_args_.back());
}
void BufferArgs::addArg(const CpuSparseMatrix& arg, ArgType argType) {
args_.push_back(std::make_shared<SparseMatrixArg>(arg, argType));
_args_.push_back(new SparseMatrixArg(arg, argType));
addArg(*_args_.back());
}
void BufferArgs::addArg(const GpuSparseMatrix& arg, ArgType argType) {
args_.push_back(std::make_shared<SparseMatrixArg>(arg, argType));
_args_.push_back(new SparseMatrixArg(arg, argType));
addArg(*_args_.back());
}
ClassRegistrar<FunctionBase> FunctionBase::funcRegistrar_;
......
......@@ -50,10 +50,25 @@ protected:
* Argument type for Function::calc().
* A BufferArgs contains a set of BufferArg,
* because Function can have multiple inputs and outputs.
*
* addArg() with Matix object used to adapt Layer Argument.
* Will create a BufferArg object in addArg(),
* and free in destructor of BufferArgs.
*
* addArg() with BufferArg object, just save BufferArg object address,
* and the caller needs to guarantee the validity of the BufferArg object
* in the BufferArgs life time.
*/
class BufferArgs {
public:
BufferArgs() {}
~BufferArgs() {
for (auto arg : _args_) {
delete arg;
}
}
size_t size() const { return args_.size(); }
// add argument into BufferArgs
......@@ -62,7 +77,8 @@ public:
// For outputs, the argType needs to be specified as ASSIGN_TO or ADD_TO.
template <typename Tensor>
void addArg(const Tensor& arg, ArgType argType = UNSPECIFIED) {
args_.push_back(std::make_shared<BufferArg>(arg, argType));
_args_.push_back(new BufferArg(arg, argType));
addArg(*_args_.back());
}
// Add arg into BufferArgs and reshape the arg.
......@@ -83,14 +99,27 @@ public:
return *args_[num];
}
void addArg(BufferArg& arg) { args_.push_back(&arg); }
void addArg(SequenceIdArg& arg) { args_.push_back(&arg); }
void addArg(SequenceArg& arg) { args_.push_back(&arg); }
void addArg(SparseMatrixArg& arg) { args_.push_back(&arg); }
private:
std::vector<BufferArgPtr> args_;
std::vector<BufferArg*> args_;
// The BufferArg object is constructed and freed by BufferArgs.
std::vector<BufferArg*> _args_;
};
/**
* \brief Base class for Function.
* The basic Function implementation requires override init and calc interfaces.
*
* The caller needs to ensure the validity of the arguments
* during Function execution.
*
* Function inputs are readonly, Function outputs have two modes: ASSIGN_TO
* and ADD_TO.
* If output.getArgType() == ASSIGN_TO, this is assign mode, and the calculation
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册