提交 e9794214 编写于 作者: X xutianbing

Address further comments.

上级 8560ce69
...@@ -20,23 +20,27 @@ limitations under the License. */ ...@@ -20,23 +20,27 @@ limitations under the License. */
namespace paddle { namespace paddle {
const SequenceArg& BufferArg::sequence() const { const SequenceArg& BufferArg::sequence() const {
// CHECK_EQ(bufferType_, TENSOR_SEQUENCE_DATA); CHECK_EQ(bufferType_, TENSOR_SEQUENCE_DATA);
return dynamic_cast<const SequenceArg&>(*this); return dynamic_cast<const SequenceArg&>(*this);
} }
const SparseMatrixArg& BufferArg::sparse() const { const SparseMatrixArg& BufferArg::sparse() const {
// CHECK_EQ(bufferType_, TENSOR_SPARSE); CHECK_EQ(bufferType_, TENSOR_SPARSE);
return dynamic_cast<const SparseMatrixArg&>(*this); return dynamic_cast<const SparseMatrixArg&>(*this);
} }
SparseMatrixArg::SparseMatrixArg(const CpuSparseMatrix& sparse, ArgType argType) SparseMatrixArg::SparseMatrixArg(const CpuSparseMatrix& sparse, ArgType argType)
: BufferArg(sparse, argType), : BufferArg(sparse, argType),
row_(reinterpret_cast<void*>(sparse.getRows()), VALUE_TYPE_INT32), row_(reinterpret_cast<void*>(sparse.getRows()), VALUE_TYPE_INT32),
col_(reinterpret_cast<void*>(sparse.getCols()), VALUE_TYPE_INT32) {} col_(reinterpret_cast<void*>(sparse.getCols()), VALUE_TYPE_INT32) {
bufferType_ = TENSOR_SPARSE;
}
SparseMatrixArg::SparseMatrixArg(const GpuSparseMatrix& sparse, ArgType argType) SparseMatrixArg::SparseMatrixArg(const GpuSparseMatrix& sparse, ArgType argType)
: BufferArg(sparse, argType), : BufferArg(sparse, argType),
row_(reinterpret_cast<void*>(sparse.getRows()), VALUE_TYPE_INT32), row_(reinterpret_cast<void*>(sparse.getRows()), VALUE_TYPE_INT32),
col_(reinterpret_cast<void*>(sparse.getCols()), VALUE_TYPE_INT32) {} col_(reinterpret_cast<void*>(sparse.getCols()), VALUE_TYPE_INT32) {
bufferType_ = TENSOR_SPARSE;
}
} // namespace paddle } // namespace paddle
...@@ -23,10 +23,11 @@ limitations under the License. */ ...@@ -23,10 +23,11 @@ limitations under the License. */
namespace paddle { namespace paddle {
enum BufferType { enum BufferType {
TENSOR_NORMAL = 0, TENSOR_UNKNOWN = 0,
TENSOR_SEQUENCE_ID = 1, TENSOR_NORMAL = 1,
TENSOR_SEQUENCE_DATA = 2, TENSOR_SEQUENCE_ID = 2,
TENSOR_SPARSE = 3 TENSOR_SEQUENCE_DATA = 3,
TENSOR_SPARSE = 4
}; };
enum SparseDataType { enum SparseDataType {
...@@ -86,6 +87,7 @@ public: ...@@ -86,6 +87,7 @@ public:
valueType_(DataType<real>::value), valueType_(DataType<real>::value),
shape_(2), shape_(2),
argType_(argType) { argType_(argType) {
bufferType_ = TENSOR_NORMAL;
shape_.setDim(0, matrix.getHeight()); shape_.setDim(0, matrix.getHeight());
shape_.setDim(1, matrix.getWidth()); shape_.setDim(1, matrix.getWidth());
} }
...@@ -98,6 +100,7 @@ public: ...@@ -98,6 +100,7 @@ public:
valueType_(DataType<real>::value), valueType_(DataType<real>::value),
shape_(shape), shape_(shape),
argType_(argType) { argType_(argType) {
bufferType_ = TENSOR_NORMAL;
CHECK_EQ(matrix.getElementCnt(), shape.getElements()); CHECK_EQ(matrix.getElementCnt(), shape.getElements());
} }
...@@ -107,6 +110,7 @@ public: ...@@ -107,6 +110,7 @@ public:
valueType_(DataType<real>::value), valueType_(DataType<real>::value),
shape_(1), shape_(1),
argType_(argType) { argType_(argType) {
bufferType_ = TENSOR_NORMAL;
shape_.setDim(0, vector.getSize()); shape_.setDim(0, vector.getSize());
} }
...@@ -116,6 +120,7 @@ public: ...@@ -116,6 +120,7 @@ public:
valueType_(VALUE_TYPE_INT32), valueType_(VALUE_TYPE_INT32),
shape_(1), shape_(1),
argType_(argType) { argType_(argType) {
bufferType_ = TENSOR_NORMAL;
shape_.setDim(0, vector.getSize()); shape_.setDim(0, vector.getSize());
} }
...@@ -150,6 +155,8 @@ public: ...@@ -150,6 +155,8 @@ public:
ValueType valueType() const { return valueType_; } ValueType valueType() const { return valueType_; }
BufferType bufferType() const { return bufferType_; } BufferType bufferType() const { return bufferType_; }
const TensorShape& shape() const { return shape_; } const TensorShape& shape() const { return shape_; }
bool isSparse() const { return (TENSOR_SPARSE == bufferType_); }
bool isSequenceArg() const { return TENSOR_SEQUENCE_DATA == bufferType_; }
const SequenceArg& sequence() const; const SequenceArg& sequence() const;
const SparseMatrixArg& sparse() const; const SparseMatrixArg& sparse() const;
...@@ -158,8 +165,8 @@ protected: ...@@ -158,8 +165,8 @@ protected:
void* buf_; void* buf_;
ValueType valueType_; ValueType valueType_;
TensorShape shape_; TensorShape shape_;
BufferType bufferType_; BufferType bufferType_{TENSOR_UNKNOWN};
ArgType argType_ = UNSPECIFIED; ArgType argType_{UNSPECIFIED};
// leading dimensions. The size is dims_.size() // leading dimensions. The size is dims_.size()
// Dims lds_; // Dims lds_;
}; };
...@@ -174,11 +181,13 @@ public: ...@@ -174,11 +181,13 @@ public:
const TensorShape& shape, const TensorShape& shape,
ArgType argType = UNSPECIFIED) ArgType argType = UNSPECIFIED)
: BufferArg(buf, VALUE_TYPE_INT32, shape, argType) { : BufferArg(buf, VALUE_TYPE_INT32, shape, argType) {
bufferType_ = TENSOR_SEQUENCE_ID;
CHECK_EQ(shape_.ndims(), (size_t)1); CHECK_EQ(shape_.ndims(), (size_t)1);
numSeqs_ = shape_[0] - 1; numSeqs_ = shape_[0] - 1;
} }
SequenceIdArg(const IVector& vector) : BufferArg(vector) { SequenceIdArg(const IVector& vector) : BufferArg(vector) {
bufferType_ = TENSOR_SEQUENCE_ID;
numSeqs_ = shape_[0] - 1; numSeqs_ = shape_[0] - 1;
} }
...@@ -199,12 +208,16 @@ public: ...@@ -199,12 +208,16 @@ public:
const SequenceIdArg& startPositions, const SequenceIdArg& startPositions,
ArgType argType = UNSPECIFIED) ArgType argType = UNSPECIFIED)
: BufferArg(buf, valueType, shape, argType), : BufferArg(buf, valueType, shape, argType),
startPositions_(startPositions) {} startPositions_(startPositions) {
bufferType_ = TENSOR_SEQUENCE_DATA;
}
SequenceArg(const Matrix& matrix, SequenceArg(const Matrix& matrix,
const IVector& vector, const IVector& vector,
ArgType argType = UNSPECIFIED) ArgType argType = UNSPECIFIED)
: BufferArg(matrix, argType), startPositions_(vector) {} : BufferArg(matrix, argType), startPositions_(vector) {
bufferType_ = TENSOR_SEQUENCE_DATA;
}
~SequenceArg() {} ~SequenceArg() {}
...@@ -236,6 +249,7 @@ public: ...@@ -236,6 +249,7 @@ public:
nnz_(nnz), nnz_(nnz),
format_(format), format_(format),
type_(type) { type_(type) {
bufferType_ = TENSOR_SPARSE;
CHECK((valueType == VALUE_TYPE_FLOAT) || (valueType == VALUE_TYPE_DOUBLE)); CHECK((valueType == VALUE_TYPE_FLOAT) || (valueType == VALUE_TYPE_DOUBLE));
CHECK_EQ(shape_.ndims(), (size_t)2); CHECK_EQ(shape_.ndims(), (size_t)2);
CHECK_EQ(row_.shape().ndims(), (size_t)1); CHECK_EQ(row_.shape().ndims(), (size_t)1);
......
...@@ -74,9 +74,9 @@ void ContextProjectionForward<DEVICE_TYPE_CPU>(CpuMatrix& out_mat, ...@@ -74,9 +74,9 @@ void ContextProjectionForward<DEVICE_TYPE_CPU>(CpuMatrix& out_mat,
/** /**
* Paddle Function for Context Projection Forward. * Paddle Function for Context Projection Forward.
* Calculate the output sequence after context projection. * Calculate the output layer value sequence after context projection.
* *
* What is Context Projection? * What is Context Projection for a sequence?
* For example, assumed input (x) has 4 words and the dimension of each word * For example, assumed input (x) has 4 words and the dimension of each word
* representation is 2. If we use zero to pad instead of learned weight to pad, * representation is 2. If we use zero to pad instead of learned weight to pad,
* and the context_lenth is 3, the output (y) is: * and the context_lenth is 3, the output (y) is:
...@@ -92,12 +92,11 @@ void ContextProjectionForward<DEVICE_TYPE_CPU>(CpuMatrix& out_mat, ...@@ -92,12 +92,11 @@ void ContextProjectionForward<DEVICE_TYPE_CPU>(CpuMatrix& out_mat,
* c1, c2, d1, d2, 0, 0] * c1, c2, d1, d2, 0, 0]
* @endcode * @endcode
* *
* \param outputs[0].matrix output value, n * (d * l) * \param outputs[0].matrix output layer value, n * (d * l)
* \param outputs[0].vector input sequence, n * 1 * \param outputs[0].vector start position sequence, n * 1
* \param inputs[0].matrix input value, n * d * \param inputs[0].matrix input layer value, n * d
* \param inputs[0].vector input sequence, n * 1 * \param inputs[0].vector start position sequence, n * 1
* \param inputs[1].matrix input weight, pad * d * \param inputs[1].matrix input layer weight, pad * d
* \param inputs[1].vector input sequence, n * 1
*/ */
template <DeviceType Device> template <DeviceType Device>
class ContextProjectionForwardFunc : public FunctionBase { class ContextProjectionForwardFunc : public FunctionBase {
...@@ -111,37 +110,35 @@ public: ...@@ -111,37 +110,35 @@ public:
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK(1 == inputs.size() || 2 == inputs.size()); CHECK(1 == inputs.size() || 2 == inputs.size());
CHECK_EQ((size_t)1, outputs.size()); CHECK_EQ((size_t)1, outputs.size());
CHECK(inputs[0].isSequenceArg() && outputs[0].isSequenceArg())
<< "SequenceArg required here";
const auto val_seqs = dynamic_cast<const SequenceArg&>(inputs[0]); const auto val_seqs = dynamic_cast<const SequenceArg&>(inputs[0]);
const auto w_seqs = inputs.size() <= 1 auto out_seq = dynamic_cast<const SequenceArg&>(outputs[0]);
? nullptr
: dynamic_cast<const SequenceArg*>(&inputs[1]);
auto out_seqs = dynamic_cast<const SequenceArg&>(outputs[0]);
CHECK(out_seqs.data() && val_seqs.data() && CHECK(out_seq.data() && val_seqs.data() &&
val_seqs.getSequenceIds().data()); val_seqs.getSequenceIds().data());
CHECK_EQ(out_seqs.shape().ndims(), (size_t)2); CHECK_EQ(out_seq.shape().ndims(), (size_t)2);
CHECK_EQ(val_seqs.shape().ndims(), (size_t)2); CHECK_EQ(val_seqs.shape().ndims(), (size_t)2);
CHECK_EQ(val_seqs.getSequenceIds().shape().ndims(), (size_t)1); CHECK_EQ(val_seqs.getSequenceIds().shape().ndims(), (size_t)1);
if (w_seqs) { if (2 == inputs.size()) {
CHECK_EQ(w_seqs->shape().ndims(), (size_t)2); CHECK_EQ(inputs[1].shape().ndims(), (size_t)2);
CHECK_EQ(w_seqs->getSequenceIds().shape().ndims(), (size_t)1);
} }
/// dim of output = dim of input * context_length /// dim of output = dim of input * context_length
CHECK_EQ(out_seqs.shape()[1], val_seqs.shape()[1] * context_length_); CHECK_EQ(out_seq.shape()[1], val_seqs.shape()[1] * context_length_);
/// input and output has the same batch_size /// input and output has the same batch_size
CHECK_EQ(val_seqs.shape()[0], out_seqs.shape()[0]); CHECK_EQ(val_seqs.shape()[0], out_seq.shape()[0]);
/// dim of input == dim of weight /// dim of input == dim of weight
if (w_seqs) { if (2 == inputs.size()) {
CHECK_EQ(val_seqs.shape()[1], w_seqs->shape()[1]); CHECK_EQ(val_seqs.shape()[1], inputs[1].shape()[1]);
} }
CHECK_EQ(out_seqs.getArgType(), ADD_TO); CHECK_EQ(out_seq.getArgType(), ADD_TO);
auto out_mat = out_seqs.matrix<Device>(); auto out_mat = out_seq.matrix<Device>();
const auto in_mat = val_seqs.matrix<Device>(); const auto in_mat = val_seqs.matrix<Device>();
const auto w_mat = const auto w_mat =
w_seqs ? w_seqs->matrix<Device>() (2 == inputs.size())
: typename Tensor<real, Device>::Matrix(nullptr, 0, 0); ? inputs[1].matrix<Device>()
: typename Tensor<real, Device>::Matrix(nullptr, 0, 0);
const auto seq_vec = val_seqs.getSequenceIds().vector<int, Device>(); const auto seq_vec = val_seqs.getSequenceIds().vector<int, Device>();
ContextProjectionForward<Device>(out_mat, ContextProjectionForward<Device>(out_mat,
in_mat, in_mat,
...@@ -221,10 +218,11 @@ void ContextProjectionBackward<DEVICE_TYPE_CPU>(const CpuMatrix& out_grad_mat, ...@@ -221,10 +218,11 @@ void ContextProjectionBackward<DEVICE_TYPE_CPU>(const CpuMatrix& out_grad_mat,
* Context Projection Backward Function. * Context Projection Backward Function.
* Update the weight gradient and input layer gradient with backprop * Update the weight gradient and input layer gradient with backprop
* *
* \param inputs[0].seq input sequence. * \param inputs[0].matrix output layer grad, n * (d * l)
* \param inputs[0].matrix output layer grad. * \param inputs[0].vector start position sequence, n * 1
* \param outputs[0] input layer grad. * \param outputs[0].matrix input layer grad, n * d
* \param outputs[1] weight grad. * \param outputs[0].vector start position sequence, n * 1
* \param outputs[1] weight grad, pad * d
*/ */
template <DeviceType Device> template <DeviceType Device>
class ContextProjectionBackwardFunc : public FunctionBase { class ContextProjectionBackwardFunc : public FunctionBase {
...@@ -240,30 +238,31 @@ public: ...@@ -240,30 +238,31 @@ public:
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ((size_t)1, inputs.size()); CHECK_EQ((size_t)1, inputs.size());
CHECK_EQ((size_t)2, outputs.size()); CHECK_EQ((size_t)2, outputs.size());
CHECK(inputs[0].isSequenceArg() && outputs[0].isSequenceArg())
const auto seq_arg = dynamic_cast<const SequenceArg&>(inputs[0]); << "SequenceArg required here";
CHECK(seq_arg.data() && inputs[0].data()); const auto in_seq = dynamic_cast<const SequenceArg&>(inputs[0]);
CHECK_EQ(seq_arg.shape().ndims(), (size_t)2); auto out_seq = dynamic_cast<const SequenceArg&>(outputs[0]);
CHECK_EQ(seq_arg.getSequenceIds().shape().ndims(), (size_t)1); CHECK(in_seq.data() && in_seq.getSequenceIds().data());
CHECK_EQ(outputs[0].shape().ndims(), (size_t)2); CHECK_EQ(in_seq.shape().ndims(), (size_t)2);
CHECK_EQ(in_seq.getSequenceIds().shape().ndims(), (size_t)1);
CHECK_EQ(out_seq.shape().ndims(), (size_t)2);
CHECK_EQ(out_seq.getSequenceIds().shape().ndims(), (size_t)1);
CHECK_EQ(outputs[1].shape().ndims(), (size_t)2); CHECK_EQ(outputs[1].shape().ndims(), (size_t)2);
/// dim of input grad == dim of weight /// dim of input grad == dim of weight
CHECK_EQ(outputs[0].shape()[1], outputs[1].shape()[1]); CHECK_EQ(out_seq.shape()[1], outputs[1].shape()[1]);
/// input and output grad has the same batch_size /// input and output grad has the same batch_size
CHECK_EQ(outputs[0].shape()[0], seq_arg.shape()[0]); CHECK_EQ(out_seq.shape()[0], in_seq.shape()[0]);
/// dim of output val = dim of input grad * context_length /// dim of output grad = dim of input grad * context_length
CHECK_EQ(seq_arg.shape()[1], outputs[0].shape()[1] * context_length_); CHECK_EQ(in_seq.shape()[1], out_seq.shape()[1] * context_length_);
CHECK_EQ(out_seq.getArgType(), ADD_TO);
CHECK_EQ(outputs[0].getArgType(), ADD_TO);
CHECK_EQ(outputs[1].getArgType(), ADD_TO); CHECK_EQ(outputs[1].getArgType(), ADD_TO);
const auto seq_vec = seq_arg.getSequenceIds().vector<int, Device>(); const auto seq_vec = in_seq.getSequenceIds().vector<int, Device>();
const auto out_grad_mat = seq_arg.matrix<Device>(); const auto out_grad_mat = in_seq.matrix<Device>();
auto in_grad_mat = auto in_grad_mat =
!outputs[0].data() !out_seq.data() ? typename Tensor<real, Device>::Matrix(nullptr, 0, 0)
? typename Tensor<real, Device>::Matrix(nullptr, 0, 0) : out_seq.matrix<Device>();
: outputs[0].matrix<Device>();
auto w_grad_mat = !outputs[1].data() auto w_grad_mat = !outputs[1].data()
? typename Tensor<real, Device>::Matrix(nullptr, 0, 0) ? typename Tensor<real, Device>::Matrix(nullptr, 0, 0)
: outputs[1].matrix<Device>(); : outputs[1].matrix<Device>();
...@@ -287,9 +286,15 @@ private: ...@@ -287,9 +286,15 @@ private:
}; };
/** /**
* \param inputs[0].matrix input grad, n*d * Context Projection Backward Data Function
* \param inputs[0].vector input sequence, n*1 * Update input layer grad
* \param outputs[0] output grad, n*(d*l) * input: sequence of output layer grad
* output: sequence of input layer grad
*
* \param outputs[0].matrix input layer grad, n * d
* \param outputs[0].vector start position sequence, n * 1
* \param inputs[0].matrix output layer grad, n * (d * l)
* \param inputs[0].vector start positon sequence, n * 1
*/ */
template <DeviceType Device> template <DeviceType Device>
class ContextProjectionBackwardDataFunc : public FunctionBase { class ContextProjectionBackwardDataFunc : public FunctionBase {
...@@ -302,19 +307,24 @@ public: ...@@ -302,19 +307,24 @@ public:
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(1, static_cast<int>(inputs.size())); CHECK_EQ(1, static_cast<int>(inputs.size()));
CHECK_EQ(1, static_cast<int>(outputs.size())); CHECK_EQ(1, static_cast<int>(outputs.size()));
const auto in_seqs = dynamic_cast<const SequenceArg&>(inputs[0]); CHECK(inputs[0].isSequenceArg() && outputs[0].isSequenceArg())
CHECK(in_seqs.data() && outputs[0].data() && << "SequenceArg required here";
in_seqs.getSequenceIds().data()); const auto in_seq = dynamic_cast<const SequenceArg&>(inputs[0]);
CHECK_EQ(static_cast<int>(outputs[0].shape().ndims()), 2); const auto out_seq = dynamic_cast<const SequenceArg&>(outputs[0]);
CHECK_EQ(static_cast<int>(in_seqs.shape().ndims()), 2);
CHECK_EQ(static_cast<int>(in_seqs.getSequenceIds().shape().ndims()), 1); CHECK(in_seq.data() && out_seq.data() && in_seq.getSequenceIds().data());
CHECK_EQ(outputs[0].shape().ndims(), CHECK_EQ(static_cast<int>(out_seq.shape().ndims()), 2);
in_seqs.shape().ndims() * context_length_); CHECK_EQ(static_cast<int>(in_seq.shape().ndims()), 2);
CHECK_EQ(static_cast<int>(in_seq.getSequenceIds().shape().ndims()), 1);
/// output layer grad dim == input layer grad dim * context_length_
CHECK_EQ(in_seq.shape().ndims(), out_seq.shape().ndims() * context_length_);
/// input and output has the same batch_size /// input and output has the same batch_size
CHECK_EQ(in_seqs.shape()[0], outputs[0].shape()[0]); CHECK_EQ(in_seq.shape()[0], out_seq.shape()[0]);
const auto out_grad_mat = outputs[0].matrix<Device>(); CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO);
auto in_grad_mat = in_seqs.matrix<Device>();
const auto seq_vec = in_seqs.getSequenceIds().vector<int, Device>(); const auto out_grad_mat = in_seq.matrix<Device>();
const auto seq_vec = in_seq.getSequenceIds().vector<int, Device>();
auto in_grad_mat = out_seq.matrix<Device>();
ContextProjectionBackwardData<Device>( ContextProjectionBackwardData<Device>(
out_grad_mat, in_grad_mat, seq_vec, context_length_, context_start_); out_grad_mat, in_grad_mat, seq_vec, context_length_, context_start_);
...@@ -326,9 +336,14 @@ private: ...@@ -326,9 +336,14 @@ private:
}; };
/** /**
* \param inputs[0].matrix weight grad, pad * d * Context Projection Backward Weight Function
* \param inputs[0].vecotr input sequence, n * 1 * Update weight grad by backprop
* \param outputs[0] output grad, n * (d * l) * input: sequence of output layer grad
* output: weight grad
*
* \param outputs[0] weight grad, pad * d
* \param inputs[0].matrix output layer grad, n * (d * l)
* \param inputs[0].vecotr start positon sequence, n * 1
*/ */
template <DeviceType Device> template <DeviceType Device>
class ContextProjectionBackwardWeightFunc : public FunctionBase { class ContextProjectionBackwardWeightFunc : public FunctionBase {
...@@ -343,18 +358,20 @@ public: ...@@ -343,18 +358,20 @@ public:
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(1, static_cast<int>(inputs.size())); CHECK_EQ(1, static_cast<int>(inputs.size()));
CHECK_EQ(1, static_cast<int>(outputs.size())); CHECK_EQ(1, static_cast<int>(outputs.size()));
CHECK(inputs[0].isSequenceArg()) << "SequenceArg required here";
const auto in_seqs = dynamic_cast<const SequenceArg&>(inputs[0]); const auto in_seq = dynamic_cast<const SequenceArg&>(inputs[0]);
CHECK(in_seqs.data() && in_seqs.getSequenceIds().data() && CHECK(in_seq.data() && in_seq.getSequenceIds().data() && outputs[0].data());
outputs[0].data());
CHECK_EQ(static_cast<int>(outputs[0].shape().ndims()), 2); CHECK_EQ(static_cast<int>(outputs[0].shape().ndims()), 2);
CHECK_EQ(static_cast<int>(in_seqs.shape().ndims()), 2); CHECK_EQ(static_cast<int>(in_seq.shape().ndims()), 2);
CHECK_EQ(static_cast<int>(in_seqs.getSequenceIds().shape().ndims()), 1); CHECK_EQ(static_cast<int>(in_seq.getSequenceIds().shape().ndims()), 1);
CHECK_EQ(in_seqs.shape()[0], outputs[0].shape()[0]); CHECK_EQ(in_seq.shape()[0], outputs[0].shape()[0]);
CHECK_EQ(outputs[0].shape()[1], in_seqs.shape()[1] * context_length_); /// output layer grad dim == weight dim * context_length_
const auto out_grad_mat = outputs[0].matrix<Device>(); CHECK_EQ(in_seq.shape()[1], outputs[0].shape()[1] * context_length_);
auto w_grad_mat = inputs[0].matrix<Device>(); CHECK_EQ(outputs[0].getArgType(), ADD_TO);
const auto seq_vec = in_seqs.getSequenceIds().vector<int, Device>();
const auto seq_vec = in_seq.getSequenceIds().vector<int, Device>();
const auto out_grad_mat = in_seq.matrix<Device>();
auto w_grad_mat = outputs[0].matrix<Device>();
ContextProjectionBackwardWeight<Device>(out_grad_mat, ContextProjectionBackwardWeight<Device>(out_grad_mat,
w_grad_mat, w_grad_mat,
seq_vec, seq_vec,
......
...@@ -123,7 +123,7 @@ void testMatrixProjectionBackward(int context_start, ...@@ -123,7 +123,7 @@ void testMatrixProjectionBackward(int context_start,
BufferArgs cpu_inputs; BufferArgs cpu_inputs;
BufferArgs cpu_outputs; BufferArgs cpu_outputs;
cpu_inputs.addArg(cpu_out_grad, *cpu_seq); cpu_inputs.addArg(cpu_out_grad, *cpu_seq);
cpu_outputs.addArg(cpu_in_grad, ADD_TO); cpu_outputs.addArg(cpu_in_grad, *cpu_seq, ADD_TO);
cpu_outputs.addArg( cpu_outputs.addArg(
cpu_w_grad ? *cpu_w_grad : CpuMatrix(nullptr, 0, input_dim), ADD_TO); cpu_w_grad ? *cpu_w_grad : CpuMatrix(nullptr, 0, input_dim), ADD_TO);
...@@ -132,7 +132,7 @@ void testMatrixProjectionBackward(int context_start, ...@@ -132,7 +132,7 @@ void testMatrixProjectionBackward(int context_start,
BufferArgs gpu_inputs; BufferArgs gpu_inputs;
BufferArgs gpu_outputs; BufferArgs gpu_outputs;
gpu_inputs.addArg(gpu_out_grad, *gpu_seq); gpu_inputs.addArg(gpu_out_grad, *gpu_seq);
gpu_outputs.addArg(gpu_in_grad, ADD_TO); gpu_outputs.addArg(gpu_in_grad, *gpu_seq, ADD_TO);
gpu_outputs.addArg( gpu_outputs.addArg(
gpu_w_grad ? *gpu_w_grad : GpuMatrix(nullptr, 0, input_dim), ADD_TO); gpu_w_grad ? *gpu_w_grad : GpuMatrix(nullptr, 0, input_dim), ADD_TO);
......
...@@ -169,6 +169,7 @@ void ContextProjection::backward(const UpdateCallback& callback) { ...@@ -169,6 +169,7 @@ void ContextProjection::backward(const UpdateCallback& callback) {
outputs.addArg( outputs.addArg(
CpuMatrix( CpuMatrix(
in_->grad ? in_->grad->getData() : nullptr, batch_size, input_dim), in_->grad ? in_->grad->getData() : nullptr, batch_size, input_dim),
*in_->sequenceStartPositions->getVector(useGpu_),
ADD_TO); ADD_TO);
outputs.addArg(CpuMatrix(w_ptr ? w_ptr->getData() : nullptr, outputs.addArg(CpuMatrix(w_ptr ? w_ptr->getData() : nullptr,
w_ptr ? w_ptr->getHeight() : 0, w_ptr ? w_ptr->getHeight() : 0,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册