diff --git a/paddle/function/ContextProjectionOp.cpp b/paddle/function/ContextProjectionOp.cpp index 177708d00f83c75c5e3ae6124a318610a914ee4b..ec697a381f97fad93930341c1e21c3ede0e0c015 100644 --- a/paddle/function/ContextProjectionOp.cpp +++ b/paddle/function/ContextProjectionOp.cpp @@ -74,7 +74,7 @@ void ContextProjectionForward(CpuMatrix& out_mat, /** * Paddle Function for Context Projection Forward. - * Calculate the value for the output layer with context projection. + * Calculate the output sequence after context projection. * * What is Context Projection? * For example, assumed input (x) has 4 words and the dimension of each word @@ -92,10 +92,12 @@ void ContextProjectionForward(CpuMatrix& out_mat, * c1, c2, d1, d2, 0, 0] * @endcode * - * \param outputs[0] output value. - * \param inputs[0] input value. - * \param inputs[1] input weight. - * \param inputs[2] input sequence. + * \param outputs[0].matrix output value, n * (d * l) + * \param outputs[0].vector input sequence, n * 1 + * \param inputs[0].matrix input value, n * d + * \param inputs[0].vector input sequence, n * 1 + * \param inputs[1].matrix input weight, pad * d + * \param inputs[1].vector input sequence, n * 1 */ template class ContextProjectionForwardFunc : public FunctionBase { @@ -107,28 +109,40 @@ public: } void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { - CHECK_EQ((size_t)3, inputs.size()); + CHECK(1 == inputs.size() || 2 == inputs.size()); CHECK_EQ((size_t)1, outputs.size()); - CHECK(outputs[0].data() && inputs[0].data() && inputs[2].data()); - CHECK_EQ(outputs[0].shape().ndims(), (size_t)2); - CHECK_EQ(inputs[0].shape().ndims(), (size_t)2); - CHECK_EQ(inputs[1].shape().ndims(), (size_t)2); - CHECK_EQ(inputs[2].shape().ndims(), (size_t)1); + const auto val_seqs = dynamic_cast(inputs[0]); + const auto w_seqs = inputs.size() <= 1 + ? nullptr + : dynamic_cast(&inputs[1]); + auto out_seqs = dynamic_cast(outputs[0]); + + CHECK(out_seqs.data() && val_seqs.data() && + val_seqs.getSequenceIds().data()); + CHECK_EQ(out_seqs.shape().ndims(), (size_t)2); + CHECK_EQ(val_seqs.shape().ndims(), (size_t)2); + CHECK_EQ(val_seqs.getSequenceIds().shape().ndims(), (size_t)1); + if (w_seqs) { + CHECK_EQ(w_seqs->shape().ndims(), (size_t)2); + CHECK_EQ(w_seqs->getSequenceIds().shape().ndims(), (size_t)1); + } /// dim of output = dim of input * context_length - CHECK_EQ(outputs[0].shape()[1], inputs[0].shape()[1] * context_length_); - /// dim of input == dim of weight - CHECK_EQ(inputs[0].shape()[1], inputs[1].shape()[1]); + CHECK_EQ(out_seqs.shape()[1], val_seqs.shape()[1] * context_length_); /// input and output has the same batch_size - CHECK_EQ(inputs[0].shape()[0], outputs[0].shape()[0]); + CHECK_EQ(val_seqs.shape()[0], out_seqs.shape()[0]); + /// dim of input == dim of weight + if (w_seqs) { + CHECK_EQ(val_seqs.shape()[1], w_seqs->shape()[1]); + } - CHECK_EQ(outputs[0].getArgType(), ADD_TO); - auto out_mat = outputs[0].matrix(); - const auto in_mat = inputs[0].matrix(); + CHECK_EQ(out_seqs.getArgType(), ADD_TO); + auto out_mat = out_seqs.matrix(); + const auto in_mat = val_seqs.matrix(); const auto w_mat = - !inputs[1].data() ? typename Tensor::Matrix(nullptr, 0, 0) - : inputs[1].matrix(); - const auto seq_vec = inputs[2].vector(); + w_seqs ? w_seqs->matrix() + : typename Tensor::Matrix(nullptr, 0, 0); + const auto seq_vec = val_seqs.getSequenceIds().vector(); ContextProjectionForward(out_mat, in_mat, w_mat, @@ -227,25 +241,25 @@ public: CHECK_EQ((size_t)1, inputs.size()); CHECK_EQ((size_t)2, outputs.size()); - const auto seqArg = dynamic_cast(inputs[0]); - CHECK(seqArg.data() && inputs[0].data()); - CHECK_EQ(seqArg.shape().ndims(), (size_t)2); - CHECK_EQ(seqArg.getSequenceIds().shape().ndims(), (size_t)1); + const auto seq_arg = dynamic_cast(inputs[0]); + CHECK(seq_arg.data() && inputs[0].data()); + CHECK_EQ(seq_arg.shape().ndims(), (size_t)2); + CHECK_EQ(seq_arg.getSequenceIds().shape().ndims(), (size_t)1); CHECK_EQ(outputs[0].shape().ndims(), (size_t)2); CHECK_EQ(outputs[1].shape().ndims(), (size_t)2); /// dim of input grad == dim of weight CHECK_EQ(outputs[0].shape()[1], outputs[1].shape()[1]); /// input and output grad has the same batch_size - CHECK_EQ(outputs[0].shape()[0], seqArg.shape()[0]); + CHECK_EQ(outputs[0].shape()[0], seq_arg.shape()[0]); /// dim of output val = dim of input grad * context_length - CHECK_EQ(seqArg.shape()[1], outputs[0].shape()[1] * context_length_); + CHECK_EQ(seq_arg.shape()[1], outputs[0].shape()[1] * context_length_); CHECK_EQ(outputs[0].getArgType(), ADD_TO); CHECK_EQ(outputs[1].getArgType(), ADD_TO); - const auto seq_vec = seqArg.getSequenceIds().vector(); - const auto out_grad_mat = seqArg.matrix(); + const auto seq_vec = seq_arg.getSequenceIds().vector(); + const auto out_grad_mat = seq_arg.matrix(); auto in_grad_mat = !outputs[0].data() ? typename Tensor::Matrix(nullptr, 0, 0) @@ -272,6 +286,91 @@ private: size_t total_pad_; }; +/** + * \param inputs[0].matrix input grad, n*d + * \param inputs[0].vector input sequence, n*1 + * \param outputs[0] output grad, n*(d*l) + */ +template +class ContextProjectionBackwardDataFunc : public FunctionBase { +public: + void init(const FuncConfig& config) override { + context_length_ = config.get("context_length"); + context_start_ = config.get("context_start"); + } + + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(1, static_cast(inputs.size())); + CHECK_EQ(1, static_cast(outputs.size())); + const auto in_seqs = dynamic_cast(inputs[0]); + CHECK(in_seqs.data() && outputs[0].data() && + in_seqs.getSequenceIds().data()); + CHECK_EQ(static_cast(outputs[0].shape().ndims()), 2); + CHECK_EQ(static_cast(in_seqs.shape().ndims()), 2); + CHECK_EQ(static_cast(in_seqs.getSequenceIds().shape().ndims()), 1); + CHECK_EQ(outputs[0].shape().ndims(), + in_seqs.shape().ndims() * context_length_); + /// input and output has the same batch_size + CHECK_EQ(in_seqs.shape()[0], outputs[0].shape()[0]); + const auto out_grad_mat = outputs[0].matrix(); + auto in_grad_mat = in_seqs.matrix(); + const auto seq_vec = in_seqs.getSequenceIds().vector(); + + ContextProjectionBackwardData( + out_grad_mat, in_grad_mat, seq_vec, context_length_, context_start_); + } + +private: + size_t context_length_; + int context_start_; +}; + +/** + * \param inputs[0].matrix weight grad, pad * d + * \param inputs[0].vecotr input sequence, n * 1 + * \param outputs[0] output grad, n * (d * l) + */ +template +class ContextProjectionBackwardWeightFunc : public FunctionBase { +public: + void init(const FuncConfig& config) override { + context_length_ = config.get("context_length"); + context_start_ = config.get("context_start"); + begin_pad_ = config.get("begin_pad"); + total_pad_ = config.get("total_pad"); + } + + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(1, static_cast(inputs.size())); + CHECK_EQ(1, static_cast(outputs.size())); + + const auto in_seqs = dynamic_cast(inputs[0]); + CHECK(in_seqs.data() && in_seqs.getSequenceIds().data() && + outputs[0].data()); + CHECK_EQ(static_cast(outputs[0].shape().ndims()), 2); + CHECK_EQ(static_cast(in_seqs.shape().ndims()), 2); + CHECK_EQ(static_cast(in_seqs.getSequenceIds().shape().ndims()), 1); + CHECK_EQ(in_seqs.shape()[0], outputs[0].shape()[0]); + CHECK_EQ(outputs[0].shape()[1], in_seqs.shape()[1] * context_length_); + const auto out_grad_mat = outputs[0].matrix(); + auto w_grad_mat = inputs[0].matrix(); + const auto seq_vec = in_seqs.getSequenceIds().vector(); + ContextProjectionBackwardWeight(out_grad_mat, + w_grad_mat, + seq_vec, + context_length_, + context_start_, + total_pad_, + begin_pad_); + } + +private: + size_t context_length_; + int context_start_; + size_t begin_pad_; + size_t total_pad_; +}; + REGISTER_TYPED_FUNC(ContextProjectionForward, CPU, ContextProjectionForwardFunc); @@ -285,5 +384,11 @@ REGISTER_TYPED_FUNC(ContextProjectionForward, REGISTER_TYPED_FUNC(ContextProjectionBackward, GPU, ContextProjectionBackwardFunc); +REGISTER_TYPED_FUNC(ContextProjectionBackwardData, + GPU, + ContextProjectionBackwardDataFunc); +REGISTER_TYPED_FUNC(ContextProjectionBackwardWeight, + GPU, + ContextProjectionBackwardWeightFunc); #endif } // namespace paddle diff --git a/paddle/function/ContextProjectionOpTest.cpp b/paddle/function/ContextProjectionOpTest.cpp index 50ca2040050cada9cda9c2de4fc24412addf297e..bd0c06c5f64a8807847f7ebeaa7255c795688a14 100644 --- a/paddle/function/ContextProjectionOpTest.cpp +++ b/paddle/function/ContextProjectionOpTest.cpp @@ -58,21 +58,21 @@ void testMatrixProjectionForward(int context_start, BufferArgs cpu_inputs; BufferArgs cpu_outputs; - cpu_inputs.addArg(cpu_in); - cpu_inputs.addArg(cpu_weight ? *cpu_weight - : CpuMatrix(nullptr, 0, input_dim)); - cpu_inputs.addArg(*cpu_seq); - cpu_outputs.addArg(cpu_out, ADD_TO); + cpu_inputs.addArg(cpu_in, *cpu_seq); + if (cpu_weight) { + cpu_inputs.addArg(*cpu_weight, *cpu_seq); + } + cpu_outputs.addArg(cpu_out, *cpu_seq, ADD_TO); compare.getCpuFunction()->calc(cpu_inputs, cpu_outputs); BufferArgs gpu_inputs; BufferArgs gpu_outputs; - gpu_inputs.addArg(gpu_in); - gpu_inputs.addArg(gpu_weight ? *gpu_weight - : GpuMatrix(nullptr, 0, input_dim)); - gpu_inputs.addArg(*gpu_seq); - gpu_outputs.addArg(gpu_out, ADD_TO); + gpu_inputs.addArg(gpu_in, *gpu_seq); + if (gpu_weight) { + gpu_inputs.addArg(*gpu_weight, *gpu_seq); + } + gpu_outputs.addArg(gpu_out, *gpu_seq, ADD_TO); compare.getGpuFunction()->calc(gpu_inputs, gpu_outputs); diff --git a/paddle/gserver/layers/ContextProjection.cpp b/paddle/gserver/layers/ContextProjection.cpp index 17fd36ef563c1bd40b3368ba4cead5b041ca40d3..edcef17ad4705fdf7a2221806ae3e928fa318e4c 100644 --- a/paddle/gserver/layers/ContextProjection.cpp +++ b/paddle/gserver/layers/ContextProjection.cpp @@ -118,16 +118,15 @@ void ContextProjection::forward() { /// first use state_, otherwise use weight_(padding false === w nullptr) auto w_ptr = state_ ? state_.get() : is_padding ? weight_->getW().get() : nullptr; - auto start_pos = in_->sequenceStartPositions; - + const auto start_pos = in_->sequenceStartPositions->getVector(useGpu_); BufferArgs inputs; BufferArgs outputs; - inputs.addArg(*in_->value); - inputs.addArg(CpuMatrix(w_ptr ? w_ptr->getData() : nullptr, - w_ptr ? w_ptr->getHeight() : 0, - input_dim)); - inputs.addArg(*in_->sequenceStartPositions->getVector(useGpu_)); - outputs.addArg(*out_->value, ADD_TO); + inputs.addArg(*in_->value, *start_pos); + if (w_ptr) { + inputs.addArg(CpuMatrix(w_ptr->getData(), w_ptr->getHeight(), input_dim), + *start_pos); + } + outputs.addArg(*out_->value, *start_pos, ADD_TO); forward_[0]->calc(inputs, outputs); if (state_ && config_.context_start() < 0) {