提交 e9794214 编写于 作者: X xutianbing

Address further comments.

上级 8560ce69
......@@ -20,23 +20,27 @@ limitations under the License. */
namespace paddle {
const SequenceArg& BufferArg::sequence() const {
// CHECK_EQ(bufferType_, TENSOR_SEQUENCE_DATA);
CHECK_EQ(bufferType_, TENSOR_SEQUENCE_DATA);
return dynamic_cast<const SequenceArg&>(*this);
}
const SparseMatrixArg& BufferArg::sparse() const {
// CHECK_EQ(bufferType_, TENSOR_SPARSE);
CHECK_EQ(bufferType_, TENSOR_SPARSE);
return dynamic_cast<const SparseMatrixArg&>(*this);
}
SparseMatrixArg::SparseMatrixArg(const CpuSparseMatrix& sparse, ArgType argType)
: BufferArg(sparse, argType),
row_(reinterpret_cast<void*>(sparse.getRows()), VALUE_TYPE_INT32),
col_(reinterpret_cast<void*>(sparse.getCols()), VALUE_TYPE_INT32) {}
col_(reinterpret_cast<void*>(sparse.getCols()), VALUE_TYPE_INT32) {
bufferType_ = TENSOR_SPARSE;
}
SparseMatrixArg::SparseMatrixArg(const GpuSparseMatrix& sparse, ArgType argType)
: BufferArg(sparse, argType),
row_(reinterpret_cast<void*>(sparse.getRows()), VALUE_TYPE_INT32),
col_(reinterpret_cast<void*>(sparse.getCols()), VALUE_TYPE_INT32) {}
col_(reinterpret_cast<void*>(sparse.getCols()), VALUE_TYPE_INT32) {
bufferType_ = TENSOR_SPARSE;
}
} // namespace paddle
......@@ -23,10 +23,11 @@ limitations under the License. */
namespace paddle {
enum BufferType {
TENSOR_NORMAL = 0,
TENSOR_SEQUENCE_ID = 1,
TENSOR_SEQUENCE_DATA = 2,
TENSOR_SPARSE = 3
TENSOR_UNKNOWN = 0,
TENSOR_NORMAL = 1,
TENSOR_SEQUENCE_ID = 2,
TENSOR_SEQUENCE_DATA = 3,
TENSOR_SPARSE = 4
};
enum SparseDataType {
......@@ -86,6 +87,7 @@ public:
valueType_(DataType<real>::value),
shape_(2),
argType_(argType) {
bufferType_ = TENSOR_NORMAL;
shape_.setDim(0, matrix.getHeight());
shape_.setDim(1, matrix.getWidth());
}
......@@ -98,6 +100,7 @@ public:
valueType_(DataType<real>::value),
shape_(shape),
argType_(argType) {
bufferType_ = TENSOR_NORMAL;
CHECK_EQ(matrix.getElementCnt(), shape.getElements());
}
......@@ -107,6 +110,7 @@ public:
valueType_(DataType<real>::value),
shape_(1),
argType_(argType) {
bufferType_ = TENSOR_NORMAL;
shape_.setDim(0, vector.getSize());
}
......@@ -116,6 +120,7 @@ public:
valueType_(VALUE_TYPE_INT32),
shape_(1),
argType_(argType) {
bufferType_ = TENSOR_NORMAL;
shape_.setDim(0, vector.getSize());
}
......@@ -150,6 +155,8 @@ public:
ValueType valueType() const { return valueType_; }
BufferType bufferType() const { return bufferType_; }
const TensorShape& shape() const { return shape_; }
bool isSparse() const { return (TENSOR_SPARSE == bufferType_); }
bool isSequenceArg() const { return TENSOR_SEQUENCE_DATA == bufferType_; }
const SequenceArg& sequence() const;
const SparseMatrixArg& sparse() const;
......@@ -158,8 +165,8 @@ protected:
void* buf_;
ValueType valueType_;
TensorShape shape_;
BufferType bufferType_;
ArgType argType_ = UNSPECIFIED;
BufferType bufferType_{TENSOR_UNKNOWN};
ArgType argType_{UNSPECIFIED};
// leading dimensions. The size is dims_.size()
// Dims lds_;
};
......@@ -174,11 +181,13 @@ public:
const TensorShape& shape,
ArgType argType = UNSPECIFIED)
: BufferArg(buf, VALUE_TYPE_INT32, shape, argType) {
bufferType_ = TENSOR_SEQUENCE_ID;
CHECK_EQ(shape_.ndims(), (size_t)1);
numSeqs_ = shape_[0] - 1;
}
SequenceIdArg(const IVector& vector) : BufferArg(vector) {
bufferType_ = TENSOR_SEQUENCE_ID;
numSeqs_ = shape_[0] - 1;
}
......@@ -199,12 +208,16 @@ public:
const SequenceIdArg& startPositions,
ArgType argType = UNSPECIFIED)
: BufferArg(buf, valueType, shape, argType),
startPositions_(startPositions) {}
startPositions_(startPositions) {
bufferType_ = TENSOR_SEQUENCE_DATA;
}
SequenceArg(const Matrix& matrix,
const IVector& vector,
ArgType argType = UNSPECIFIED)
: BufferArg(matrix, argType), startPositions_(vector) {}
: BufferArg(matrix, argType), startPositions_(vector) {
bufferType_ = TENSOR_SEQUENCE_DATA;
}
~SequenceArg() {}
......@@ -236,6 +249,7 @@ public:
nnz_(nnz),
format_(format),
type_(type) {
bufferType_ = TENSOR_SPARSE;
CHECK((valueType == VALUE_TYPE_FLOAT) || (valueType == VALUE_TYPE_DOUBLE));
CHECK_EQ(shape_.ndims(), (size_t)2);
CHECK_EQ(row_.shape().ndims(), (size_t)1);
......
......@@ -74,9 +74,9 @@ void ContextProjectionForward<DEVICE_TYPE_CPU>(CpuMatrix& out_mat,
/**
* Paddle Function for Context Projection Forward.
* Calculate the output sequence after context projection.
* Calculate the output layer value sequence after context projection.
*
* What is Context Projection?
* What is Context Projection for a sequence?
* For example, assumed input (x) has 4 words and the dimension of each word
* representation is 2. If we use zero to pad instead of learned weight to pad,
* and the context_lenth is 3, the output (y) is:
......@@ -92,12 +92,11 @@ void ContextProjectionForward<DEVICE_TYPE_CPU>(CpuMatrix& out_mat,
* c1, c2, d1, d2, 0, 0]
* @endcode
*
* \param outputs[0].matrix output value, n * (d * l)
* \param outputs[0].vector input sequence, n * 1
* \param inputs[0].matrix input value, n * d
* \param inputs[0].vector input sequence, n * 1
* \param inputs[1].matrix input weight, pad * d
* \param inputs[1].vector input sequence, n * 1
* \param outputs[0].matrix output layer value, n * (d * l)
* \param outputs[0].vector start position sequence, n * 1
* \param inputs[0].matrix input layer value, n * d
* \param inputs[0].vector start position sequence, n * 1
* \param inputs[1].matrix input layer weight, pad * d
*/
template <DeviceType Device>
class ContextProjectionForwardFunc : public FunctionBase {
......@@ -111,37 +110,35 @@ public:
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK(1 == inputs.size() || 2 == inputs.size());
CHECK_EQ((size_t)1, outputs.size());
CHECK(inputs[0].isSequenceArg() && outputs[0].isSequenceArg())
<< "SequenceArg required here";
const auto val_seqs = dynamic_cast<const SequenceArg&>(inputs[0]);
const auto w_seqs = inputs.size() <= 1
? nullptr
: dynamic_cast<const SequenceArg*>(&inputs[1]);
auto out_seqs = dynamic_cast<const SequenceArg&>(outputs[0]);
auto out_seq = dynamic_cast<const SequenceArg&>(outputs[0]);
CHECK(out_seqs.data() && val_seqs.data() &&
CHECK(out_seq.data() && val_seqs.data() &&
val_seqs.getSequenceIds().data());
CHECK_EQ(out_seqs.shape().ndims(), (size_t)2);
CHECK_EQ(out_seq.shape().ndims(), (size_t)2);
CHECK_EQ(val_seqs.shape().ndims(), (size_t)2);
CHECK_EQ(val_seqs.getSequenceIds().shape().ndims(), (size_t)1);
if (w_seqs) {
CHECK_EQ(w_seqs->shape().ndims(), (size_t)2);
CHECK_EQ(w_seqs->getSequenceIds().shape().ndims(), (size_t)1);
if (2 == inputs.size()) {
CHECK_EQ(inputs[1].shape().ndims(), (size_t)2);
}
/// dim of output = dim of input * context_length
CHECK_EQ(out_seqs.shape()[1], val_seqs.shape()[1] * context_length_);
CHECK_EQ(out_seq.shape()[1], val_seqs.shape()[1] * context_length_);
/// input and output has the same batch_size
CHECK_EQ(val_seqs.shape()[0], out_seqs.shape()[0]);
CHECK_EQ(val_seqs.shape()[0], out_seq.shape()[0]);
/// dim of input == dim of weight
if (w_seqs) {
CHECK_EQ(val_seqs.shape()[1], w_seqs->shape()[1]);
if (2 == inputs.size()) {
CHECK_EQ(val_seqs.shape()[1], inputs[1].shape()[1]);
}
CHECK_EQ(out_seqs.getArgType(), ADD_TO);
auto out_mat = out_seqs.matrix<Device>();
CHECK_EQ(out_seq.getArgType(), ADD_TO);
auto out_mat = out_seq.matrix<Device>();
const auto in_mat = val_seqs.matrix<Device>();
const auto w_mat =
w_seqs ? w_seqs->matrix<Device>()
: typename Tensor<real, Device>::Matrix(nullptr, 0, 0);
(2 == inputs.size())
? inputs[1].matrix<Device>()
: typename Tensor<real, Device>::Matrix(nullptr, 0, 0);
const auto seq_vec = val_seqs.getSequenceIds().vector<int, Device>();
ContextProjectionForward<Device>(out_mat,
in_mat,
......@@ -221,10 +218,11 @@ void ContextProjectionBackward<DEVICE_TYPE_CPU>(const CpuMatrix& out_grad_mat,
* Context Projection Backward Function.
* Update the weight gradient and input layer gradient with backprop
*
* \param inputs[0].seq input sequence.
* \param inputs[0].matrix output layer grad.
* \param outputs[0] input layer grad.
* \param outputs[1] weight grad.
* \param inputs[0].matrix output layer grad, n * (d * l)
* \param inputs[0].vector start position sequence, n * 1
* \param outputs[0].matrix input layer grad, n * d
* \param outputs[0].vector start position sequence, n * 1
* \param outputs[1] weight grad, pad * d
*/
template <DeviceType Device>
class ContextProjectionBackwardFunc : public FunctionBase {
......@@ -240,30 +238,31 @@ public:
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ((size_t)1, inputs.size());
CHECK_EQ((size_t)2, outputs.size());
const auto seq_arg = dynamic_cast<const SequenceArg&>(inputs[0]);
CHECK(seq_arg.data() && inputs[0].data());
CHECK_EQ(seq_arg.shape().ndims(), (size_t)2);
CHECK_EQ(seq_arg.getSequenceIds().shape().ndims(), (size_t)1);
CHECK_EQ(outputs[0].shape().ndims(), (size_t)2);
CHECK(inputs[0].isSequenceArg() && outputs[0].isSequenceArg())
<< "SequenceArg required here";
const auto in_seq = dynamic_cast<const SequenceArg&>(inputs[0]);
auto out_seq = dynamic_cast<const SequenceArg&>(outputs[0]);
CHECK(in_seq.data() && in_seq.getSequenceIds().data());
CHECK_EQ(in_seq.shape().ndims(), (size_t)2);
CHECK_EQ(in_seq.getSequenceIds().shape().ndims(), (size_t)1);
CHECK_EQ(out_seq.shape().ndims(), (size_t)2);
CHECK_EQ(out_seq.getSequenceIds().shape().ndims(), (size_t)1);
CHECK_EQ(outputs[1].shape().ndims(), (size_t)2);
/// dim of input grad == dim of weight
CHECK_EQ(outputs[0].shape()[1], outputs[1].shape()[1]);
CHECK_EQ(out_seq.shape()[1], outputs[1].shape()[1]);
/// input and output grad has the same batch_size
CHECK_EQ(outputs[0].shape()[0], seq_arg.shape()[0]);
/// dim of output val = dim of input grad * context_length
CHECK_EQ(seq_arg.shape()[1], outputs[0].shape()[1] * context_length_);
CHECK_EQ(outputs[0].getArgType(), ADD_TO);
CHECK_EQ(out_seq.shape()[0], in_seq.shape()[0]);
/// dim of output grad = dim of input grad * context_length
CHECK_EQ(in_seq.shape()[1], out_seq.shape()[1] * context_length_);
CHECK_EQ(out_seq.getArgType(), ADD_TO);
CHECK_EQ(outputs[1].getArgType(), ADD_TO);
const auto seq_vec = seq_arg.getSequenceIds().vector<int, Device>();
const auto out_grad_mat = seq_arg.matrix<Device>();
const auto seq_vec = in_seq.getSequenceIds().vector<int, Device>();
const auto out_grad_mat = in_seq.matrix<Device>();
auto in_grad_mat =
!outputs[0].data()
? typename Tensor<real, Device>::Matrix(nullptr, 0, 0)
: outputs[0].matrix<Device>();
!out_seq.data() ? typename Tensor<real, Device>::Matrix(nullptr, 0, 0)
: out_seq.matrix<Device>();
auto w_grad_mat = !outputs[1].data()
? typename Tensor<real, Device>::Matrix(nullptr, 0, 0)
: outputs[1].matrix<Device>();
......@@ -287,9 +286,15 @@ private:
};
/**
* \param inputs[0].matrix input grad, n*d
* \param inputs[0].vector input sequence, n*1
* \param outputs[0] output grad, n*(d*l)
* Context Projection Backward Data Function
* Update input layer grad
* input: sequence of output layer grad
* output: sequence of input layer grad
*
* \param outputs[0].matrix input layer grad, n * d
* \param outputs[0].vector start position sequence, n * 1
* \param inputs[0].matrix output layer grad, n * (d * l)
* \param inputs[0].vector start positon sequence, n * 1
*/
template <DeviceType Device>
class ContextProjectionBackwardDataFunc : public FunctionBase {
......@@ -302,19 +307,24 @@ public:
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(1, static_cast<int>(inputs.size()));
CHECK_EQ(1, static_cast<int>(outputs.size()));
const auto in_seqs = dynamic_cast<const SequenceArg&>(inputs[0]);
CHECK(in_seqs.data() && outputs[0].data() &&
in_seqs.getSequenceIds().data());
CHECK_EQ(static_cast<int>(outputs[0].shape().ndims()), 2);
CHECK_EQ(static_cast<int>(in_seqs.shape().ndims()), 2);
CHECK_EQ(static_cast<int>(in_seqs.getSequenceIds().shape().ndims()), 1);
CHECK_EQ(outputs[0].shape().ndims(),
in_seqs.shape().ndims() * context_length_);
CHECK(inputs[0].isSequenceArg() && outputs[0].isSequenceArg())
<< "SequenceArg required here";
const auto in_seq = dynamic_cast<const SequenceArg&>(inputs[0]);
const auto out_seq = dynamic_cast<const SequenceArg&>(outputs[0]);
CHECK(in_seq.data() && out_seq.data() && in_seq.getSequenceIds().data());
CHECK_EQ(static_cast<int>(out_seq.shape().ndims()), 2);
CHECK_EQ(static_cast<int>(in_seq.shape().ndims()), 2);
CHECK_EQ(static_cast<int>(in_seq.getSequenceIds().shape().ndims()), 1);
/// output layer grad dim == input layer grad dim * context_length_
CHECK_EQ(in_seq.shape().ndims(), out_seq.shape().ndims() * context_length_);
/// input and output has the same batch_size
CHECK_EQ(in_seqs.shape()[0], outputs[0].shape()[0]);
const auto out_grad_mat = outputs[0].matrix<Device>();
auto in_grad_mat = in_seqs.matrix<Device>();
const auto seq_vec = in_seqs.getSequenceIds().vector<int, Device>();
CHECK_EQ(in_seq.shape()[0], out_seq.shape()[0]);
CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO);
const auto out_grad_mat = in_seq.matrix<Device>();
const auto seq_vec = in_seq.getSequenceIds().vector<int, Device>();
auto in_grad_mat = out_seq.matrix<Device>();
ContextProjectionBackwardData<Device>(
out_grad_mat, in_grad_mat, seq_vec, context_length_, context_start_);
......@@ -326,9 +336,14 @@ private:
};
/**
* \param inputs[0].matrix weight grad, pad * d
* \param inputs[0].vecotr input sequence, n * 1
* \param outputs[0] output grad, n * (d * l)
* Context Projection Backward Weight Function
* Update weight grad by backprop
* input: sequence of output layer grad
* output: weight grad
*
* \param outputs[0] weight grad, pad * d
* \param inputs[0].matrix output layer grad, n * (d * l)
* \param inputs[0].vecotr start positon sequence, n * 1
*/
template <DeviceType Device>
class ContextProjectionBackwardWeightFunc : public FunctionBase {
......@@ -343,18 +358,20 @@ public:
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(1, static_cast<int>(inputs.size()));
CHECK_EQ(1, static_cast<int>(outputs.size()));
const auto in_seqs = dynamic_cast<const SequenceArg&>(inputs[0]);
CHECK(in_seqs.data() && in_seqs.getSequenceIds().data() &&
outputs[0].data());
CHECK(inputs[0].isSequenceArg()) << "SequenceArg required here";
const auto in_seq = dynamic_cast<const SequenceArg&>(inputs[0]);
CHECK(in_seq.data() && in_seq.getSequenceIds().data() && outputs[0].data());
CHECK_EQ(static_cast<int>(outputs[0].shape().ndims()), 2);
CHECK_EQ(static_cast<int>(in_seqs.shape().ndims()), 2);
CHECK_EQ(static_cast<int>(in_seqs.getSequenceIds().shape().ndims()), 1);
CHECK_EQ(in_seqs.shape()[0], outputs[0].shape()[0]);
CHECK_EQ(outputs[0].shape()[1], in_seqs.shape()[1] * context_length_);
const auto out_grad_mat = outputs[0].matrix<Device>();
auto w_grad_mat = inputs[0].matrix<Device>();
const auto seq_vec = in_seqs.getSequenceIds().vector<int, Device>();
CHECK_EQ(static_cast<int>(in_seq.shape().ndims()), 2);
CHECK_EQ(static_cast<int>(in_seq.getSequenceIds().shape().ndims()), 1);
CHECK_EQ(in_seq.shape()[0], outputs[0].shape()[0]);
/// output layer grad dim == weight dim * context_length_
CHECK_EQ(in_seq.shape()[1], outputs[0].shape()[1] * context_length_);
CHECK_EQ(outputs[0].getArgType(), ADD_TO);
const auto seq_vec = in_seq.getSequenceIds().vector<int, Device>();
const auto out_grad_mat = in_seq.matrix<Device>();
auto w_grad_mat = outputs[0].matrix<Device>();
ContextProjectionBackwardWeight<Device>(out_grad_mat,
w_grad_mat,
seq_vec,
......
......@@ -123,7 +123,7 @@ void testMatrixProjectionBackward(int context_start,
BufferArgs cpu_inputs;
BufferArgs cpu_outputs;
cpu_inputs.addArg(cpu_out_grad, *cpu_seq);
cpu_outputs.addArg(cpu_in_grad, ADD_TO);
cpu_outputs.addArg(cpu_in_grad, *cpu_seq, ADD_TO);
cpu_outputs.addArg(
cpu_w_grad ? *cpu_w_grad : CpuMatrix(nullptr, 0, input_dim), ADD_TO);
......@@ -132,7 +132,7 @@ void testMatrixProjectionBackward(int context_start,
BufferArgs gpu_inputs;
BufferArgs gpu_outputs;
gpu_inputs.addArg(gpu_out_grad, *gpu_seq);
gpu_outputs.addArg(gpu_in_grad, ADD_TO);
gpu_outputs.addArg(gpu_in_grad, *gpu_seq, ADD_TO);
gpu_outputs.addArg(
gpu_w_grad ? *gpu_w_grad : GpuMatrix(nullptr, 0, input_dim), ADD_TO);
......
......@@ -169,6 +169,7 @@ void ContextProjection::backward(const UpdateCallback& callback) {
outputs.addArg(
CpuMatrix(
in_->grad ? in_->grad->getData() : nullptr, batch_size, input_dim),
*in_->sequenceStartPositions->getVector(useGpu_),
ADD_TO);
outputs.addArg(CpuMatrix(w_ptr ? w_ptr->getData() : nullptr,
w_ptr ? w_ptr->getHeight() : 0,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册