From e9794214cbca438b1b467d614c6398ec09ab1d0b Mon Sep 17 00:00:00 2001 From: xutianbing Date: Thu, 12 Jan 2017 13:26:10 -0800 Subject: [PATCH] Address further comments. --- paddle/function/BufferArg.cpp | 12 +- paddle/function/BufferArg.h | 30 +++- paddle/function/ContextProjectionOp.cpp | 169 +++++++++++--------- paddle/function/ContextProjectionOpTest.cpp | 4 +- paddle/gserver/layers/ContextProjection.cpp | 1 + 5 files changed, 126 insertions(+), 90 deletions(-) diff --git a/paddle/function/BufferArg.cpp b/paddle/function/BufferArg.cpp index fde48a73b6..5d595deb12 100644 --- a/paddle/function/BufferArg.cpp +++ b/paddle/function/BufferArg.cpp @@ -20,23 +20,27 @@ limitations under the License. */ namespace paddle { const SequenceArg& BufferArg::sequence() const { - // CHECK_EQ(bufferType_, TENSOR_SEQUENCE_DATA); + CHECK_EQ(bufferType_, TENSOR_SEQUENCE_DATA); return dynamic_cast(*this); } const SparseMatrixArg& BufferArg::sparse() const { - // CHECK_EQ(bufferType_, TENSOR_SPARSE); + CHECK_EQ(bufferType_, TENSOR_SPARSE); return dynamic_cast(*this); } SparseMatrixArg::SparseMatrixArg(const CpuSparseMatrix& sparse, ArgType argType) : BufferArg(sparse, argType), row_(reinterpret_cast(sparse.getRows()), VALUE_TYPE_INT32), - col_(reinterpret_cast(sparse.getCols()), VALUE_TYPE_INT32) {} + col_(reinterpret_cast(sparse.getCols()), VALUE_TYPE_INT32) { + bufferType_ = TENSOR_SPARSE; +} SparseMatrixArg::SparseMatrixArg(const GpuSparseMatrix& sparse, ArgType argType) : BufferArg(sparse, argType), row_(reinterpret_cast(sparse.getRows()), VALUE_TYPE_INT32), - col_(reinterpret_cast(sparse.getCols()), VALUE_TYPE_INT32) {} + col_(reinterpret_cast(sparse.getCols()), VALUE_TYPE_INT32) { + bufferType_ = TENSOR_SPARSE; +} } // namespace paddle diff --git a/paddle/function/BufferArg.h b/paddle/function/BufferArg.h index f3a4350e12..440a924a7a 100644 --- a/paddle/function/BufferArg.h +++ b/paddle/function/BufferArg.h @@ -23,10 +23,11 @@ limitations under the License. */ namespace paddle { enum BufferType { - TENSOR_NORMAL = 0, - TENSOR_SEQUENCE_ID = 1, - TENSOR_SEQUENCE_DATA = 2, - TENSOR_SPARSE = 3 + TENSOR_UNKNOWN = 0, + TENSOR_NORMAL = 1, + TENSOR_SEQUENCE_ID = 2, + TENSOR_SEQUENCE_DATA = 3, + TENSOR_SPARSE = 4 }; enum SparseDataType { @@ -86,6 +87,7 @@ public: valueType_(DataType::value), shape_(2), argType_(argType) { + bufferType_ = TENSOR_NORMAL; shape_.setDim(0, matrix.getHeight()); shape_.setDim(1, matrix.getWidth()); } @@ -98,6 +100,7 @@ public: valueType_(DataType::value), shape_(shape), argType_(argType) { + bufferType_ = TENSOR_NORMAL; CHECK_EQ(matrix.getElementCnt(), shape.getElements()); } @@ -107,6 +110,7 @@ public: valueType_(DataType::value), shape_(1), argType_(argType) { + bufferType_ = TENSOR_NORMAL; shape_.setDim(0, vector.getSize()); } @@ -116,6 +120,7 @@ public: valueType_(VALUE_TYPE_INT32), shape_(1), argType_(argType) { + bufferType_ = TENSOR_NORMAL; shape_.setDim(0, vector.getSize()); } @@ -150,6 +155,8 @@ public: ValueType valueType() const { return valueType_; } BufferType bufferType() const { return bufferType_; } const TensorShape& shape() const { return shape_; } + bool isSparse() const { return (TENSOR_SPARSE == bufferType_); } + bool isSequenceArg() const { return TENSOR_SEQUENCE_DATA == bufferType_; } const SequenceArg& sequence() const; const SparseMatrixArg& sparse() const; @@ -158,8 +165,8 @@ protected: void* buf_; ValueType valueType_; TensorShape shape_; - BufferType bufferType_; - ArgType argType_ = UNSPECIFIED; + BufferType bufferType_{TENSOR_UNKNOWN}; + ArgType argType_{UNSPECIFIED}; // leading dimensions. The size is dims_.size() // Dims lds_; }; @@ -174,11 +181,13 @@ public: const TensorShape& shape, ArgType argType = UNSPECIFIED) : BufferArg(buf, VALUE_TYPE_INT32, shape, argType) { + bufferType_ = TENSOR_SEQUENCE_ID; CHECK_EQ(shape_.ndims(), (size_t)1); numSeqs_ = shape_[0] - 1; } SequenceIdArg(const IVector& vector) : BufferArg(vector) { + bufferType_ = TENSOR_SEQUENCE_ID; numSeqs_ = shape_[0] - 1; } @@ -199,12 +208,16 @@ public: const SequenceIdArg& startPositions, ArgType argType = UNSPECIFIED) : BufferArg(buf, valueType, shape, argType), - startPositions_(startPositions) {} + startPositions_(startPositions) { + bufferType_ = TENSOR_SEQUENCE_DATA; + } SequenceArg(const Matrix& matrix, const IVector& vector, ArgType argType = UNSPECIFIED) - : BufferArg(matrix, argType), startPositions_(vector) {} + : BufferArg(matrix, argType), startPositions_(vector) { + bufferType_ = TENSOR_SEQUENCE_DATA; + } ~SequenceArg() {} @@ -236,6 +249,7 @@ public: nnz_(nnz), format_(format), type_(type) { + bufferType_ = TENSOR_SPARSE; CHECK((valueType == VALUE_TYPE_FLOAT) || (valueType == VALUE_TYPE_DOUBLE)); CHECK_EQ(shape_.ndims(), (size_t)2); CHECK_EQ(row_.shape().ndims(), (size_t)1); diff --git a/paddle/function/ContextProjectionOp.cpp b/paddle/function/ContextProjectionOp.cpp index ec697a381f..2ef53cd6d9 100644 --- a/paddle/function/ContextProjectionOp.cpp +++ b/paddle/function/ContextProjectionOp.cpp @@ -74,9 +74,9 @@ void ContextProjectionForward(CpuMatrix& out_mat, /** * Paddle Function for Context Projection Forward. - * Calculate the output sequence after context projection. + * Calculate the output layer value sequence after context projection. * - * What is Context Projection? + * What is Context Projection for a sequence? * For example, assumed input (x) has 4 words and the dimension of each word * representation is 2. If we use zero to pad instead of learned weight to pad, * and the context_lenth is 3, the output (y) is: @@ -92,12 +92,11 @@ void ContextProjectionForward(CpuMatrix& out_mat, * c1, c2, d1, d2, 0, 0] * @endcode * - * \param outputs[0].matrix output value, n * (d * l) - * \param outputs[0].vector input sequence, n * 1 - * \param inputs[0].matrix input value, n * d - * \param inputs[0].vector input sequence, n * 1 - * \param inputs[1].matrix input weight, pad * d - * \param inputs[1].vector input sequence, n * 1 + * \param outputs[0].matrix output layer value, n * (d * l) + * \param outputs[0].vector start position sequence, n * 1 + * \param inputs[0].matrix input layer value, n * d + * \param inputs[0].vector start position sequence, n * 1 + * \param inputs[1].matrix input layer weight, pad * d */ template class ContextProjectionForwardFunc : public FunctionBase { @@ -111,37 +110,35 @@ public: void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { CHECK(1 == inputs.size() || 2 == inputs.size()); CHECK_EQ((size_t)1, outputs.size()); - + CHECK(inputs[0].isSequenceArg() && outputs[0].isSequenceArg()) + << "SequenceArg required here"; const auto val_seqs = dynamic_cast(inputs[0]); - const auto w_seqs = inputs.size() <= 1 - ? nullptr - : dynamic_cast(&inputs[1]); - auto out_seqs = dynamic_cast(outputs[0]); + auto out_seq = dynamic_cast(outputs[0]); - CHECK(out_seqs.data() && val_seqs.data() && + CHECK(out_seq.data() && val_seqs.data() && val_seqs.getSequenceIds().data()); - CHECK_EQ(out_seqs.shape().ndims(), (size_t)2); + CHECK_EQ(out_seq.shape().ndims(), (size_t)2); CHECK_EQ(val_seqs.shape().ndims(), (size_t)2); CHECK_EQ(val_seqs.getSequenceIds().shape().ndims(), (size_t)1); - if (w_seqs) { - CHECK_EQ(w_seqs->shape().ndims(), (size_t)2); - CHECK_EQ(w_seqs->getSequenceIds().shape().ndims(), (size_t)1); + if (2 == inputs.size()) { + CHECK_EQ(inputs[1].shape().ndims(), (size_t)2); } /// dim of output = dim of input * context_length - CHECK_EQ(out_seqs.shape()[1], val_seqs.shape()[1] * context_length_); + CHECK_EQ(out_seq.shape()[1], val_seqs.shape()[1] * context_length_); /// input and output has the same batch_size - CHECK_EQ(val_seqs.shape()[0], out_seqs.shape()[0]); + CHECK_EQ(val_seqs.shape()[0], out_seq.shape()[0]); /// dim of input == dim of weight - if (w_seqs) { - CHECK_EQ(val_seqs.shape()[1], w_seqs->shape()[1]); + if (2 == inputs.size()) { + CHECK_EQ(val_seqs.shape()[1], inputs[1].shape()[1]); } - CHECK_EQ(out_seqs.getArgType(), ADD_TO); - auto out_mat = out_seqs.matrix(); + CHECK_EQ(out_seq.getArgType(), ADD_TO); + auto out_mat = out_seq.matrix(); const auto in_mat = val_seqs.matrix(); const auto w_mat = - w_seqs ? w_seqs->matrix() - : typename Tensor::Matrix(nullptr, 0, 0); + (2 == inputs.size()) + ? inputs[1].matrix() + : typename Tensor::Matrix(nullptr, 0, 0); const auto seq_vec = val_seqs.getSequenceIds().vector(); ContextProjectionForward(out_mat, in_mat, @@ -221,10 +218,11 @@ void ContextProjectionBackward(const CpuMatrix& out_grad_mat, * Context Projection Backward Function. * Update the weight gradient and input layer gradient with backprop * - * \param inputs[0].seq input sequence. - * \param inputs[0].matrix output layer grad. - * \param outputs[0] input layer grad. - * \param outputs[1] weight grad. + * \param inputs[0].matrix output layer grad, n * (d * l) + * \param inputs[0].vector start position sequence, n * 1 + * \param outputs[0].matrix input layer grad, n * d + * \param outputs[0].vector start position sequence, n * 1 + * \param outputs[1] weight grad, pad * d */ template class ContextProjectionBackwardFunc : public FunctionBase { @@ -240,30 +238,31 @@ public: void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { CHECK_EQ((size_t)1, inputs.size()); CHECK_EQ((size_t)2, outputs.size()); - - const auto seq_arg = dynamic_cast(inputs[0]); - CHECK(seq_arg.data() && inputs[0].data()); - CHECK_EQ(seq_arg.shape().ndims(), (size_t)2); - CHECK_EQ(seq_arg.getSequenceIds().shape().ndims(), (size_t)1); - CHECK_EQ(outputs[0].shape().ndims(), (size_t)2); + CHECK(inputs[0].isSequenceArg() && outputs[0].isSequenceArg()) + << "SequenceArg required here"; + const auto in_seq = dynamic_cast(inputs[0]); + auto out_seq = dynamic_cast(outputs[0]); + CHECK(in_seq.data() && in_seq.getSequenceIds().data()); + CHECK_EQ(in_seq.shape().ndims(), (size_t)2); + CHECK_EQ(in_seq.getSequenceIds().shape().ndims(), (size_t)1); + CHECK_EQ(out_seq.shape().ndims(), (size_t)2); + CHECK_EQ(out_seq.getSequenceIds().shape().ndims(), (size_t)1); CHECK_EQ(outputs[1].shape().ndims(), (size_t)2); /// dim of input grad == dim of weight - CHECK_EQ(outputs[0].shape()[1], outputs[1].shape()[1]); + CHECK_EQ(out_seq.shape()[1], outputs[1].shape()[1]); /// input and output grad has the same batch_size - CHECK_EQ(outputs[0].shape()[0], seq_arg.shape()[0]); - /// dim of output val = dim of input grad * context_length - CHECK_EQ(seq_arg.shape()[1], outputs[0].shape()[1] * context_length_); - - CHECK_EQ(outputs[0].getArgType(), ADD_TO); + CHECK_EQ(out_seq.shape()[0], in_seq.shape()[0]); + /// dim of output grad = dim of input grad * context_length + CHECK_EQ(in_seq.shape()[1], out_seq.shape()[1] * context_length_); + CHECK_EQ(out_seq.getArgType(), ADD_TO); CHECK_EQ(outputs[1].getArgType(), ADD_TO); - const auto seq_vec = seq_arg.getSequenceIds().vector(); - const auto out_grad_mat = seq_arg.matrix(); + const auto seq_vec = in_seq.getSequenceIds().vector(); + const auto out_grad_mat = in_seq.matrix(); auto in_grad_mat = - !outputs[0].data() - ? typename Tensor::Matrix(nullptr, 0, 0) - : outputs[0].matrix(); + !out_seq.data() ? typename Tensor::Matrix(nullptr, 0, 0) + : out_seq.matrix(); auto w_grad_mat = !outputs[1].data() ? typename Tensor::Matrix(nullptr, 0, 0) : outputs[1].matrix(); @@ -287,9 +286,15 @@ private: }; /** - * \param inputs[0].matrix input grad, n*d - * \param inputs[0].vector input sequence, n*1 - * \param outputs[0] output grad, n*(d*l) + * Context Projection Backward Data Function + * Update input layer grad + * input: sequence of output layer grad + * output: sequence of input layer grad + * + * \param outputs[0].matrix input layer grad, n * d + * \param outputs[0].vector start position sequence, n * 1 + * \param inputs[0].matrix output layer grad, n * (d * l) + * \param inputs[0].vector start positon sequence, n * 1 */ template class ContextProjectionBackwardDataFunc : public FunctionBase { @@ -302,19 +307,24 @@ public: void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { CHECK_EQ(1, static_cast(inputs.size())); CHECK_EQ(1, static_cast(outputs.size())); - const auto in_seqs = dynamic_cast(inputs[0]); - CHECK(in_seqs.data() && outputs[0].data() && - in_seqs.getSequenceIds().data()); - CHECK_EQ(static_cast(outputs[0].shape().ndims()), 2); - CHECK_EQ(static_cast(in_seqs.shape().ndims()), 2); - CHECK_EQ(static_cast(in_seqs.getSequenceIds().shape().ndims()), 1); - CHECK_EQ(outputs[0].shape().ndims(), - in_seqs.shape().ndims() * context_length_); + CHECK(inputs[0].isSequenceArg() && outputs[0].isSequenceArg()) + << "SequenceArg required here"; + const auto in_seq = dynamic_cast(inputs[0]); + const auto out_seq = dynamic_cast(outputs[0]); + + CHECK(in_seq.data() && out_seq.data() && in_seq.getSequenceIds().data()); + CHECK_EQ(static_cast(out_seq.shape().ndims()), 2); + CHECK_EQ(static_cast(in_seq.shape().ndims()), 2); + CHECK_EQ(static_cast(in_seq.getSequenceIds().shape().ndims()), 1); + /// output layer grad dim == input layer grad dim * context_length_ + CHECK_EQ(in_seq.shape().ndims(), out_seq.shape().ndims() * context_length_); /// input and output has the same batch_size - CHECK_EQ(in_seqs.shape()[0], outputs[0].shape()[0]); - const auto out_grad_mat = outputs[0].matrix(); - auto in_grad_mat = in_seqs.matrix(); - const auto seq_vec = in_seqs.getSequenceIds().vector(); + CHECK_EQ(in_seq.shape()[0], out_seq.shape()[0]); + CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO); + + const auto out_grad_mat = in_seq.matrix(); + const auto seq_vec = in_seq.getSequenceIds().vector(); + auto in_grad_mat = out_seq.matrix(); ContextProjectionBackwardData( out_grad_mat, in_grad_mat, seq_vec, context_length_, context_start_); @@ -326,9 +336,14 @@ private: }; /** - * \param inputs[0].matrix weight grad, pad * d - * \param inputs[0].vecotr input sequence, n * 1 - * \param outputs[0] output grad, n * (d * l) + * Context Projection Backward Weight Function + * Update weight grad by backprop + * input: sequence of output layer grad + * output: weight grad + * + * \param outputs[0] weight grad, pad * d + * \param inputs[0].matrix output layer grad, n * (d * l) + * \param inputs[0].vecotr start positon sequence, n * 1 */ template class ContextProjectionBackwardWeightFunc : public FunctionBase { @@ -343,18 +358,20 @@ public: void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { CHECK_EQ(1, static_cast(inputs.size())); CHECK_EQ(1, static_cast(outputs.size())); - - const auto in_seqs = dynamic_cast(inputs[0]); - CHECK(in_seqs.data() && in_seqs.getSequenceIds().data() && - outputs[0].data()); + CHECK(inputs[0].isSequenceArg()) << "SequenceArg required here"; + const auto in_seq = dynamic_cast(inputs[0]); + CHECK(in_seq.data() && in_seq.getSequenceIds().data() && outputs[0].data()); CHECK_EQ(static_cast(outputs[0].shape().ndims()), 2); - CHECK_EQ(static_cast(in_seqs.shape().ndims()), 2); - CHECK_EQ(static_cast(in_seqs.getSequenceIds().shape().ndims()), 1); - CHECK_EQ(in_seqs.shape()[0], outputs[0].shape()[0]); - CHECK_EQ(outputs[0].shape()[1], in_seqs.shape()[1] * context_length_); - const auto out_grad_mat = outputs[0].matrix(); - auto w_grad_mat = inputs[0].matrix(); - const auto seq_vec = in_seqs.getSequenceIds().vector(); + CHECK_EQ(static_cast(in_seq.shape().ndims()), 2); + CHECK_EQ(static_cast(in_seq.getSequenceIds().shape().ndims()), 1); + CHECK_EQ(in_seq.shape()[0], outputs[0].shape()[0]); + /// output layer grad dim == weight dim * context_length_ + CHECK_EQ(in_seq.shape()[1], outputs[0].shape()[1] * context_length_); + CHECK_EQ(outputs[0].getArgType(), ADD_TO); + + const auto seq_vec = in_seq.getSequenceIds().vector(); + const auto out_grad_mat = in_seq.matrix(); + auto w_grad_mat = outputs[0].matrix(); ContextProjectionBackwardWeight(out_grad_mat, w_grad_mat, seq_vec, diff --git a/paddle/function/ContextProjectionOpTest.cpp b/paddle/function/ContextProjectionOpTest.cpp index bd0c06c5f6..c9db2ff800 100644 --- a/paddle/function/ContextProjectionOpTest.cpp +++ b/paddle/function/ContextProjectionOpTest.cpp @@ -123,7 +123,7 @@ void testMatrixProjectionBackward(int context_start, BufferArgs cpu_inputs; BufferArgs cpu_outputs; cpu_inputs.addArg(cpu_out_grad, *cpu_seq); - cpu_outputs.addArg(cpu_in_grad, ADD_TO); + cpu_outputs.addArg(cpu_in_grad, *cpu_seq, ADD_TO); cpu_outputs.addArg( cpu_w_grad ? *cpu_w_grad : CpuMatrix(nullptr, 0, input_dim), ADD_TO); @@ -132,7 +132,7 @@ void testMatrixProjectionBackward(int context_start, BufferArgs gpu_inputs; BufferArgs gpu_outputs; gpu_inputs.addArg(gpu_out_grad, *gpu_seq); - gpu_outputs.addArg(gpu_in_grad, ADD_TO); + gpu_outputs.addArg(gpu_in_grad, *gpu_seq, ADD_TO); gpu_outputs.addArg( gpu_w_grad ? *gpu_w_grad : GpuMatrix(nullptr, 0, input_dim), ADD_TO); diff --git a/paddle/gserver/layers/ContextProjection.cpp b/paddle/gserver/layers/ContextProjection.cpp index edcef17ad4..d7042af1c2 100644 --- a/paddle/gserver/layers/ContextProjection.cpp +++ b/paddle/gserver/layers/ContextProjection.cpp @@ -169,6 +169,7 @@ void ContextProjection::backward(const UpdateCallback& callback) { outputs.addArg( CpuMatrix( in_->grad ? in_->grad->getData() : nullptr, batch_size, input_dim), + *in_->sequenceStartPositions->getVector(useGpu_), ADD_TO); outputs.addArg(CpuMatrix(w_ptr ? w_ptr->getData() : nullptr, w_ptr ? w_ptr->getHeight() : 0, -- GitLab