提交 df9be2d4 编写于 作者: H hedaoyuan

fix CrossMapNormalFunc and ContextProjectionFunc(remove inouts argument)

上级 57e25211
...@@ -57,58 +57,67 @@ typedef std::shared_ptr<BufferArg> BufferArgPtr; ...@@ -57,58 +57,67 @@ typedef std::shared_ptr<BufferArg> BufferArgPtr;
* output Buffer or added to the output Buffer is determined by the * output Buffer or added to the output Buffer is determined by the
* argType_ property of the output BufferArg. * argType_ property of the output BufferArg.
*/ */
// ArgType is only used by output BufferArg.
// For input argument, argType_ is ignored.
// For output argument, need to set the argType_ of the BufferArg.
enum ArgType {
UNSPECIFIED = 0,
ASSIGN_TO = 1,
ADD_TO = 2,
};
class BufferArg { class BufferArg {
public: public:
// ArgType is only used by output BufferArg.
// For input argument, argType_ is ignored.
// For output argument, need to set the argType_ of the BufferArg.
enum ArgType {
UNSPECIFIED = 0,
ASSIGN_TO = 1,
ADD_TO = 2,
};
void setArgType(ArgType argType) { argType_ = argType; } void setArgType(ArgType argType) { argType_ = argType; }
ArgType getArgType() const { return argType_; } ArgType getArgType() const { return argType_; }
public: public:
BufferArg(void* buf, ValueType valueType, const TensorShape& shape) BufferArg(void* buf,
: buf_(buf), valueType_(valueType), shape_(shape) {} ValueType valueType,
const TensorShape& shape,
ArgType argType = UNSPECIFIED)
: buf_(buf), valueType_(valueType), shape_(shape), argType_(argType) {}
BufferArg(void* buf, ValueType valueType) BufferArg(void* buf, ValueType valueType)
: buf_(buf), valueType_(valueType) {} : buf_(buf), valueType_(valueType) {}
BufferArg(const Matrix& matrix) BufferArg(const Matrix& matrix, ArgType argType = UNSPECIFIED)
: buf_( : buf_(
const_cast<void*>(reinterpret_cast<const void*>(matrix.getData()))), const_cast<void*>(reinterpret_cast<const void*>(matrix.getData()))),
valueType_(DataType<real>::value), valueType_(DataType<real>::value),
shape_(2) { shape_(2),
argType_(argType) {
shape_.setDim(0, matrix.getHeight()); shape_.setDim(0, matrix.getHeight());
shape_.setDim(1, matrix.getWidth()); shape_.setDim(1, matrix.getWidth());
} }
BufferArg(const Matrix& matrix, const TensorShape& shape) BufferArg(const Matrix& matrix,
const TensorShape& shape,
ArgType argType = UNSPECIFIED)
: buf_( : buf_(
const_cast<void*>(reinterpret_cast<const void*>(matrix.getData()))), const_cast<void*>(reinterpret_cast<const void*>(matrix.getData()))),
valueType_(DataType<real>::value), valueType_(DataType<real>::value),
shape_(shape) { shape_(shape),
argType_(argType) {
CHECK_EQ(matrix.getElementCnt(), shape.getElements()); CHECK_EQ(matrix.getElementCnt(), shape.getElements());
} }
BufferArg(const Vector& vector) BufferArg(const Vector& vector, ArgType argType = UNSPECIFIED)
: buf_( : buf_(
const_cast<void*>(reinterpret_cast<const void*>(vector.getData()))), const_cast<void*>(reinterpret_cast<const void*>(vector.getData()))),
valueType_(DataType<real>::value), valueType_(DataType<real>::value),
shape_(1) { shape_(1),
argType_(argType) {
shape_.setDim(0, vector.getSize()); shape_.setDim(0, vector.getSize());
} }
BufferArg(const IVector& vector) BufferArg(const IVector& vector, ArgType argType = UNSPECIFIED)
: buf_( : buf_(
const_cast<void*>(reinterpret_cast<const void*>(vector.getData()))), const_cast<void*>(reinterpret_cast<const void*>(vector.getData()))),
valueType_(VALUE_TYPE_INT32), valueType_(VALUE_TYPE_INT32),
shape_(1) { shape_(1),
argType_(argType) {
shape_.setDim(0, vector.getSize()); shape_.setDim(0, vector.getSize());
} }
...@@ -163,8 +172,10 @@ protected: ...@@ -163,8 +172,10 @@ protected:
// if a < b then value_.buf_[a] < value_.buf_[b] // if a < b then value_.buf_[a] < value_.buf_[b]
class SequenceIdArg : public BufferArg { class SequenceIdArg : public BufferArg {
public: public:
SequenceIdArg(void* buf, const TensorShape& shape) SequenceIdArg(void* buf,
: BufferArg(buf, VALUE_TYPE_INT32, shape) { const TensorShape& shape,
ArgType argType = UNSPECIFIED)
: BufferArg(buf, VALUE_TYPE_INT32, shape, argType) {
CHECK_EQ(shape_.ndims(), 1); CHECK_EQ(shape_.ndims(), 1);
numSeqs_ = shape_[0] - 1; numSeqs_ = shape_[0] - 1;
} }
...@@ -187,11 +198,15 @@ public: ...@@ -187,11 +198,15 @@ public:
SequenceArg(void* buf, SequenceArg(void* buf,
ValueType valueType, ValueType valueType,
const TensorShape& shape, const TensorShape& shape,
const SequenceIdArg& startPositions) const SequenceIdArg& startPositions,
: BufferArg(buf, valueType, shape), startPositions_(startPositions) {} ArgType argType = UNSPECIFIED)
: BufferArg(buf, valueType, shape, argType),
startPositions_(startPositions) {}
SequenceArg(const Matrix& matrix, const IVector& vector) SequenceArg(const Matrix& matrix,
: BufferArg(matrix), startPositions_(vector) {} const IVector& vector,
ArgType argType = UNSPECIFIED)
: BufferArg(matrix, argType), startPositions_(vector) {}
~SequenceArg() {} ~SequenceArg() {}
...@@ -214,8 +229,9 @@ public: ...@@ -214,8 +229,9 @@ public:
const BufferArg& col, const BufferArg& col,
size_t nnz, size_t nnz,
SparseDataFormat format, SparseDataFormat format,
SparseDataType type) SparseDataType type,
: BufferArg(buf, valueType, shape), ArgType argType = UNSPECIFIED)
: BufferArg(buf, valueType, shape, argType),
row_(row), row_(row),
col_(col), col_(col),
nnz_(nnz), nnz_(nnz),
...@@ -232,13 +248,13 @@ public: ...@@ -232,13 +248,13 @@ public:
} }
} }
SparseMatrixArg(const CpuSparseMatrix& sparse) SparseMatrixArg(const CpuSparseMatrix& sparse, ArgType argType = UNSPECIFIED)
: BufferArg(sparse), : BufferArg(sparse, argType),
row_(reinterpret_cast<void*>(sparse.getRows()), VALUE_TYPE_INT32), row_(reinterpret_cast<void*>(sparse.getRows()), VALUE_TYPE_INT32),
col_(reinterpret_cast<void*>(sparse.getCols()), VALUE_TYPE_INT32) {} col_(reinterpret_cast<void*>(sparse.getCols()), VALUE_TYPE_INT32) {}
SparseMatrixArg(const GpuSparseMatrix& sparse) SparseMatrixArg(const GpuSparseMatrix& sparse, ArgType argType = UNSPECIFIED)
: BufferArg(sparse), : BufferArg(sparse, argType),
row_(reinterpret_cast<void*>(sparse.getRows()), VALUE_TYPE_INT32), row_(reinterpret_cast<void*>(sparse.getRows()), VALUE_TYPE_INT32),
col_(reinterpret_cast<void*>(sparse.getCols()), VALUE_TYPE_INT32) {} col_(reinterpret_cast<void*>(sparse.getCols()), VALUE_TYPE_INT32) {}
......
...@@ -84,12 +84,9 @@ public: ...@@ -84,12 +84,9 @@ public:
begin_pad_ = config.get<size_t>("begin_pad"); begin_pad_ = config.get<size_t>("begin_pad");
} }
void calc(const BufferArgs& inputs, void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
const BufferArgs& outputs,
const BufferArgs& inouts) override {
CHECK_EQ(3, inputs.size()); CHECK_EQ(3, inputs.size());
CHECK_EQ(1, outputs.size()); CHECK_EQ(1, outputs.size());
CHECK_EQ(0, inouts.size());
CHECK(outputs[0].data() && inputs[0].data() && inputs[2].data()); CHECK(outputs[0].data() && inputs[0].data() && inputs[2].data());
CHECK_EQ(outputs[0].shape().ndims(), 2); CHECK_EQ(outputs[0].shape().ndims(), 2);
...@@ -103,6 +100,7 @@ public: ...@@ -103,6 +100,7 @@ public:
/// input and output has the same batch_size /// input and output has the same batch_size
CHECK_EQ(inputs[0].shape()[0], outputs[0].shape()[0]); CHECK_EQ(inputs[0].shape()[0], outputs[0].shape()[0]);
CHECK_EQ(outputs[0].getArgType(), ADD_TO);
auto out_mat = outputs[0].matrix<Device>(); auto out_mat = outputs[0].matrix<Device>();
auto in_mat = inputs[0].matrix<Device>(); auto in_mat = inputs[0].matrix<Device>();
auto w_mat = !inputs[1].data() auto w_mat = !inputs[1].data()
...@@ -194,12 +192,9 @@ public: ...@@ -194,12 +192,9 @@ public:
total_pad_ = config.get<size_t>("total_pad"); total_pad_ = config.get<size_t>("total_pad");
} }
void calc(const BufferArgs& inputs, void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
const BufferArgs& outputs,
const BufferArgs& inouts) override {
CHECK_EQ(3, inputs.size()); CHECK_EQ(3, inputs.size());
CHECK_EQ(1, outputs.size()); CHECK_EQ(1, outputs.size());
CHECK_EQ(0, inouts.size());
CHECK(outputs[0].data() && inputs[2].data()); CHECK(outputs[0].data() && inputs[2].data());
CHECK_EQ(outputs[0].shape().ndims(), 2); CHECK_EQ(outputs[0].shape().ndims(), 2);
...@@ -214,6 +209,8 @@ public: ...@@ -214,6 +209,8 @@ public:
/// dim of output = dim of input * context_length /// dim of output = dim of input * context_length
CHECK_EQ(outputs[0].shape()[1], inputs[0].shape()[1] * context_length_); CHECK_EQ(outputs[0].shape()[1], inputs[0].shape()[1] * context_length_);
CHECK_EQ(outputs[0].getArgType(), ADD_TO);
auto out_grad_mat = outputs[0].matrix<Device>(); auto out_grad_mat = outputs[0].matrix<Device>();
auto in_grad_mat = auto in_grad_mat =
!inputs[0].data() ? typename Tensor<real, Device>::Matrix(nullptr, 0, 0) !inputs[0].data() ? typename Tensor<real, Device>::Matrix(nullptr, 0, 0)
......
...@@ -112,6 +112,8 @@ void CrossMapNormalGrad<DEVICE_TYPE_CPU>(real* inputsGrad, ...@@ -112,6 +112,8 @@ void CrossMapNormalGrad<DEVICE_TYPE_CPU>(real* inputsGrad,
} }
/** /**
* \brief {o_0, o_1} = calc(i_0)
*
* \param inputs[0] input value. * \param inputs[0] input value.
* \param outputs[0] output value. * \param outputs[0] output value.
* \param outputs[1] denoms. * \param outputs[1] denoms.
...@@ -125,17 +127,16 @@ public: ...@@ -125,17 +127,16 @@ public:
pow_ = config.get<real>("pow"); pow_ = config.get<real>("pow");
} }
void calc(const BufferArgs& inputs, void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
const BufferArgs& outputs,
const BufferArgs& inouts) override {
CHECK_EQ(1, inputs.size()); CHECK_EQ(1, inputs.size());
CHECK_EQ(2, outputs.size()); CHECK_EQ(2, outputs.size());
CHECK_EQ(0, inouts.size());
CHECK_EQ(inputs[0].shape().ndims(), 4); CHECK_EQ(inputs[0].shape().ndims(), 4);
CHECK(inputs[0].shape() == outputs[0].shape()); CHECK(inputs[0].shape() == outputs[0].shape());
CHECK(inputs[0].shape() == outputs[1].shape()); CHECK(inputs[0].shape() == outputs[1].shape());
CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO);
CHECK_EQ(outputs[1].getArgType(), ASSIGN_TO);
size_t samples = inputs[0].shape()[0]; size_t samples = inputs[0].shape()[0];
size_t channels = inputs[0].shape()[1]; size_t channels = inputs[0].shape()[1];
size_t height = inputs[0].shape()[2]; size_t height = inputs[0].shape()[2];
...@@ -160,6 +161,8 @@ private: ...@@ -160,6 +161,8 @@ private:
}; };
/** /**
* \brief {o_0} = calc(i_0, i_1, i_2, i_3)
*
* \param inputs[0] input value. * \param inputs[0] input value.
* \param inputs[1] output value. * \param inputs[1] output value.
* \param inputs[2] output grad. * \param inputs[2] output grad.
...@@ -175,12 +178,9 @@ public: ...@@ -175,12 +178,9 @@ public:
pow_ = config.get<real>("pow"); pow_ = config.get<real>("pow");
} }
void calc(const BufferArgs& inputs, void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
const BufferArgs& outputs,
const BufferArgs& inouts) override {
CHECK_EQ(4, inputs.size()); CHECK_EQ(4, inputs.size());
CHECK_EQ(1, outputs.size()); CHECK_EQ(1, outputs.size());
CHECK_EQ(0, inouts.size());
CHECK_EQ(inputs[0].shape().ndims(), 4); CHECK_EQ(inputs[0].shape().ndims(), 4);
CHECK(inputs[0].shape() == inputs[1].shape()); CHECK(inputs[0].shape() == inputs[1].shape());
...@@ -188,6 +188,9 @@ public: ...@@ -188,6 +188,9 @@ public:
CHECK(inputs[0].shape() == inputs[3].shape()); CHECK(inputs[0].shape() == inputs[3].shape());
CHECK(inputs[0].shape() == outputs[0].shape()); CHECK(inputs[0].shape() == outputs[0].shape());
// TODO(hedaoyuan): need support ASSIGN_TO mode.
CHECK_EQ(outputs[0].getArgType(), ADD_TO);
size_t samples = inputs[0].shape()[0]; size_t samples = inputs[0].shape()[0];
size_t channels = inputs[0].shape()[1]; size_t channels = inputs[0].shape()[1];
size_t height = inputs[0].shape()[2]; size_t height = inputs[0].shape()[2];
......
...@@ -72,16 +72,18 @@ FuncConfig& FuncConfig::set<bool>(const std::string& key, bool v) { ...@@ -72,16 +72,18 @@ FuncConfig& FuncConfig::set<bool>(const std::string& key, bool v) {
return *this; return *this;
} }
void BufferArgs::addArg(const Matrix& arg, const TensorShape& shape) { void BufferArgs::addArg(const Matrix& arg,
args_.push_back(std::make_shared<BufferArg>(arg, shape)); const TensorShape& shape,
ArgType argType) {
args_.push_back(std::make_shared<BufferArg>(arg, shape, argType));
} }
void BufferArgs::addArg(const CpuSparseMatrix& arg) { void BufferArgs::addArg(const CpuSparseMatrix& arg, ArgType argType) {
args_.push_back(std::make_shared<SparseMatrixArg>(arg)); args_.push_back(std::make_shared<SparseMatrixArg>(arg, argType));
} }
void BufferArgs::addArg(const GpuSparseMatrix& arg) { void BufferArgs::addArg(const GpuSparseMatrix& arg, ArgType argType) {
args_.push_back(std::make_shared<SparseMatrixArg>(arg)); args_.push_back(std::make_shared<SparseMatrixArg>(arg, argType));
} }
ClassRegistrar<FunctionBase> FunctionBase::funcRegistrar_; ClassRegistrar<FunctionBase> FunctionBase::funcRegistrar_;
......
...@@ -49,7 +49,7 @@ protected: ...@@ -49,7 +49,7 @@ protected:
/** /**
* Argument type for Function::calc(). * Argument type for Function::calc().
* A BufferArgs contains a set of BufferArg, * A BufferArgs contains a set of BufferArg,
* because Function can have multiple inputs, outputs and inouts. * because Function can have multiple inputs and outputs.
*/ */
class BufferArgs { class BufferArgs {
public: public:
...@@ -58,9 +58,11 @@ public: ...@@ -58,9 +58,11 @@ public:
// add argument into BufferArgs // add argument into BufferArgs
// Tensor can be Matrix, Vector, IVector. // Tensor can be Matrix, Vector, IVector.
// For inputs, do not need argType.
// For outputs, the argType needs to be specified as ASSIGN_TO or ADD_TO.
template <typename Tensor> template <typename Tensor>
void addArg(const Tensor& arg) { void addArg(const Tensor& arg, ArgType argType = UNSPECIFIED) {
args_.push_back(std::make_shared<BufferArg>(arg)); args_.push_back(std::make_shared<BufferArg>(arg, argType));
} }
// Add arg into BufferArgs and reshape the arg. // Add arg into BufferArgs and reshape the arg.
...@@ -68,10 +70,12 @@ public: ...@@ -68,10 +70,12 @@ public:
// For example, arg represents an image buffer, // For example, arg represents an image buffer,
// but Matrix can only represent a two-dimensional Tensor. // but Matrix can only represent a two-dimensional Tensor.
// So need an extra argument to describe the shape of the image buffer. // So need an extra argument to describe the shape of the image buffer.
void addArg(const Matrix& arg, const TensorShape& shape); void addArg(const Matrix& arg,
const TensorShape& shape,
ArgType argType = UNSPECIFIED);
void addArg(const CpuSparseMatrix& arg); void addArg(const CpuSparseMatrix& arg, ArgType argType = UNSPECIFIED);
void addArg(const GpuSparseMatrix& arg); void addArg(const GpuSparseMatrix& arg, ArgType argType = UNSPECIFIED);
// get argument // get argument
const BufferArg& operator[](size_t num) const { const BufferArg& operator[](size_t num) const {
......
...@@ -122,14 +122,13 @@ void ContextProjection::forward() { ...@@ -122,14 +122,13 @@ void ContextProjection::forward() {
BufferArgs inputs; BufferArgs inputs;
BufferArgs outputs; BufferArgs outputs;
BufferArgs inouts;
inputs.addArg(*in_->value); inputs.addArg(*in_->value);
inputs.addArg(CpuMatrix(w_ptr ? w_ptr->getData() : nullptr, inputs.addArg(CpuMatrix(w_ptr ? w_ptr->getData() : nullptr,
w_ptr ? w_ptr->getHeight() : 0, w_ptr ? w_ptr->getHeight() : 0,
input_dim)); input_dim));
inputs.addArg(*in_->sequenceStartPositions->getVector(useGpu_)); inputs.addArg(*in_->sequenceStartPositions->getVector(useGpu_));
outputs.addArg(*out_->value); outputs.addArg(*out_->value, ADD_TO);
forward_[0]->calc(inputs, outputs, inouts); forward_[0]->calc(inputs, outputs);
if (state_ && config_.context_start() < 0) { if (state_ && config_.context_start() < 0) {
CHECK_EQ(1, in_->getNumSequences()); CHECK_EQ(1, in_->getNumSequences());
...@@ -166,15 +165,14 @@ void ContextProjection::backward(const UpdateCallback& callback) { ...@@ -166,15 +165,14 @@ void ContextProjection::backward(const UpdateCallback& callback) {
BufferArgs inputs; BufferArgs inputs;
BufferArgs outputs; BufferArgs outputs;
BufferArgs inouts;
inputs.addArg(CpuMatrix( inputs.addArg(CpuMatrix(
in_->grad ? in_->grad->getData() : nullptr, batch_size, input_dim)); in_->grad ? in_->grad->getData() : nullptr, batch_size, input_dim));
inputs.addArg(CpuMatrix(w_ptr ? w_ptr->getData() : nullptr, inputs.addArg(CpuMatrix(w_ptr ? w_ptr->getData() : nullptr,
w_ptr ? w_ptr->getHeight() : 0, w_ptr ? w_ptr->getHeight() : 0,
input_dim)); input_dim));
inputs.addArg(*in_->sequenceStartPositions->getVector(useGpu_)); inputs.addArg(*in_->sequenceStartPositions->getVector(useGpu_));
outputs.addArg(*out_->grad); outputs.addArg(*out_->grad, ADD_TO);
backward_[0]->calc(inputs, outputs, inouts); backward_[0]->calc(inputs, outputs);
if (config_.trainable_padding()) { if (config_.trainable_padding()) {
weight_->getParameterPtr()->incUpdate(callback); weight_->getParameterPtr()->incUpdate(callback);
......
...@@ -59,7 +59,6 @@ bool CMRProjectionNormLayer::init(const LayerMap& layerMap, ...@@ -59,7 +59,6 @@ bool CMRProjectionNormLayer::init(const LayerMap& layerMap,
void CMRProjectionNormLayer::forward(PassType passType) { void CMRProjectionNormLayer::forward(PassType passType) {
Layer::forward(passType); Layer::forward(passType);
/* malloc memory for the output_ if necessary */ /* malloc memory for the output_ if necessary */
/* note: one sample correspond to one row */ /* note: one sample correspond to one row */
MatrixPtr input = inputLayers_[0]->getOutputValue(); MatrixPtr input = inputLayers_[0]->getOutputValue();
...@@ -67,42 +66,36 @@ void CMRProjectionNormLayer::forward(PassType passType) { ...@@ -67,42 +66,36 @@ void CMRProjectionNormLayer::forward(PassType passType) {
int size = getSize(); int size = getSize();
resetOutput(batchSize, size); resetOutput(batchSize, size);
MatrixPtr outV = getOutputValue();
Matrix::resizeOrCreate(denoms_, batchSize, size, /* trans */ false, useGpu_); Matrix::resizeOrCreate(denoms_, batchSize, size, /* trans */ false, useGpu_);
shape_ = TensorShape({batchSize, channels_, imgSizeH_, imgSizeW_}); shape_ = TensorShape({batchSize, channels_, imgSizeH_, imgSizeW_});
// prepare forward arguments
BufferArgs inputs; BufferArgs inputs;
BufferArgs outputs; BufferArgs outputs;
BufferArgs inouts; inputs.addArg(*getInputValue(0), shape_);
inputs.addArg(*input, shape_); outputs.addArg(*getOutputValue(), shape_, ASSIGN_TO);
outputs.addArg(*outV, shape_); outputs.addArg(*denoms_, shape_, ASSIGN_TO);
outputs.addArg(*denoms_, shape_);
forward_[0]->calc(inputs, outputs, inouts); forward_[0]->calc(inputs, outputs);
} }
void CMRProjectionNormLayer::backward(const UpdateCallback& callback) { void CMRProjectionNormLayer::backward(const UpdateCallback& callback) {
(void)callback; (void)callback;
if (NULL == inputLayers_[0]->getOutputGrad()) { if (NULL == getInputGrad(0)) {
return; return;
} }
/* Do derivation */
MatrixPtr preOutGrad = inputLayers_[0]->getOutputGrad();
MatrixPtr localGrad = getOutputGrad();
MatrixPtr localOutV = getOutputValue();
MatrixPtr preOutV = inputLayers_[0]->getOutputValue();
// prepare backward arguments
BufferArgs inputs; BufferArgs inputs;
BufferArgs outputs; BufferArgs outputs;
BufferArgs inouts; inputs.addArg(*getInputValue(0), shape_);
inputs.addArg(*preOutV, shape_); inputs.addArg(*getOutputValue(), shape_);
inputs.addArg(*localOutV, shape_); inputs.addArg(*getOutputGrad(), shape_);
inputs.addArg(*localGrad, shape_);
inputs.addArg(*denoms_, shape_); inputs.addArg(*denoms_, shape_);
outputs.addArg(*preOutGrad, shape_); outputs.addArg(*getInputGrad(0), shape_, ADD_TO);
backward_[0]->calc(inputs, outputs, inouts);
backward_[0]->calc(inputs, outputs);
} }
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册