From ea4d08dab6b76a685e9277f28daf6b594912b97f Mon Sep 17 00:00:00 2001 From: xutianbing Date: Thu, 29 Dec 2016 15:37:00 -0800 Subject: [PATCH] update interface of context projection functions, Tensor -> Matrix/Vector --- paddle/function/ContextProjectionOp.cpp | 185 ++++++++++++-------- paddle/function/ContextProjectionOp.h | 45 ++--- paddle/function/ContextProjectionOpGpu.cu | 130 ++++++-------- paddle/function/ContextProjectionOpTest.cpp | 3 +- paddle/function/Function.h | 13 ++ paddle/gserver/layers/ContextProjection.cpp | 12 +- 6 files changed, 207 insertions(+), 181 deletions(-) diff --git a/paddle/function/ContextProjectionOp.cpp b/paddle/function/ContextProjectionOp.cpp index 40852e1ab4..bd367a859e 100644 --- a/paddle/function/ContextProjectionOp.cpp +++ b/paddle/function/ContextProjectionOp.cpp @@ -19,35 +19,17 @@ limitations under the License. */ namespace paddle { template <> -void ContextProjectionForward(Tensor& output, - const Tensor& input, - const Tensor& weight, - const Tensor& sequence, +void ContextProjectionForward(CpuMatrix* out_mat, + const CpuMatrix* input_mat, + const CpuMatrix* weight_mat, + const CpuIVector& seq_vec, size_t context_length, int context_start, - size_t begin_pad, - bool is_padding) { - CHECK(output.getData() && input.getData() && sequence.getData()); - CHECK_EQ(output.dims_.size(), 2); - CHECK_EQ(input.dims_.size(), 2); - CHECK_EQ(weight.dims_.size(), 2); - CHECK_EQ(sequence.dims_.size(), 1); - - auto out_mat = std::make_shared( - output.getData(), output.dims_[0], output.dims_[1]); - const auto in_mat = std::make_shared( - input.getData(), input.dims_[0], input.dims_[1]); - const auto weight_mat = - !weight.getData() - ? nullptr - : std::make_shared( - weight.getData(), weight.dims_[0], weight.dims_[1]); - CpuIVector seq_vec(sequence.dims_[0], - reinterpret_cast(sequence.getData())); - CHECK_EQ(out_mat->getWidth(), in_mat->getWidth() * context_length); - + size_t begin_pad) { const int* starts = seq_vec.getData(); const size_t num_sequences = seq_vec.getSize() - 1; + auto w_mat = const_cast(weight_mat); + auto in_mat = const_cast(input_mat); for (size_t i = 0; i < num_sequences; ++i) { for (size_t j = 0; j < context_length; ++j) { int begin = starts[i] + context_start + j; @@ -58,8 +40,8 @@ void ContextProjectionForward(Tensor& output, int64_t pad_size = std::min(starts[i] - begin, starts[i + 1] - starts[i]); MatrixPtr mat = out_mat->subMatrix(starts[i], pad_size); - if (is_padding && weight_mat) { - MatrixPtr sub = weight_mat->subMatrix(j, pad_size); + if (w_mat) { + MatrixPtr sub = w_mat->subMatrix(j, pad_size); mat->addAtOffset(*sub, j * in_mat->getWidth()); } dst_begin = starts[i] + pad_size; @@ -69,8 +51,8 @@ void ContextProjectionForward(Tensor& output, int64_t pad_size = std::min(end - starts[i + 1], starts[i + 1] - starts[i]); MatrixPtr mat = out_mat->subMatrix(starts[i + 1] - pad_size, pad_size); - if (is_padding && weight_mat) { - MatrixPtr sub = weight_mat->subMatrix( + if (w_mat) { + MatrixPtr sub = w_mat->subMatrix( begin_pad + context_start + j - pad_size, pad_size); mat->addAtOffset(*sub, j * in_mat->getWidth()); } @@ -98,7 +80,6 @@ public: context_length_ = config.get("context_length"); context_start_ = config.get("context_start"); begin_pad_ = config.get("begin_pad"); - is_padding_ = config.get("is_padding"); } void calc(const Arguments& inputs, @@ -108,59 +89,58 @@ public: CHECK_EQ(1, outputs.size()); CHECK_EQ(0, inouts.size()); - ContextProjectionForward((Tensor&)outputs[0], - inputs[0], - inputs[1], - inputs[2], + CHECK(outputs[0].getData() && inputs[0].getData() && inputs[2].getData()); + CHECK_EQ(outputs[0].dims_.size(), 2); + CHECK_EQ(inputs[0].dims_.size(), 2); + CHECK_EQ(inputs[1].dims_.size(), 2); + CHECK_EQ(inputs[2].dims_.size(), 1); + /// dim of output = dim of input * context_length + CHECK_EQ(outputs[0].dims_[1], inputs[0].dims_[1] * context_length_); + /// dim of input == dim of weight + CHECK_EQ(inputs[0].dims_[1], inputs[1].dims_[1]); + /// input and output has the same batch_size + CHECK_EQ(inputs[0].dims_[0], outputs[0].dims_[0]); + + auto out_mat = std::make_shared::type>( + outputs[0].getData(), outputs[0].dims_[0], outputs[0].dims_[1]); + const auto in_mat = std::make_shared::type>( + inputs[0].getData(), inputs[0].dims_[0], inputs[0].dims_[1]); + const auto w_mat = + !inputs[1].getData() + ? nullptr + : std::make_shared::type>( + inputs[1].getData(), inputs[1].dims_[0], inputs[1].dims_[1]); + typename SequenceT::type seq_vec( + inputs[2].dims_[0], reinterpret_cast(inputs[2].getData())); + + ContextProjectionForward(out_mat.get(), + in_mat.get(), + w_mat.get(), + seq_vec, context_length_, context_start_, - begin_pad_, - is_padding_); + begin_pad_); } private: size_t context_length_; int context_start_; size_t begin_pad_; - bool is_padding_; }; template <> -void ContextProjectionBackward(Tensor& out_grad, - Tensor& in_grad, - Tensor& w_grad, - const Tensor& sequence, +void ContextProjectionBackward(CpuMatrix* out_grad_mat, + CpuMatrix* in_grad_mat, + CpuMatrix* w_grad_mat, + const CpuIVector& seq_vec, size_t context_length, int context_start, size_t begin_pad, bool is_padding, size_t total_pad) { - CHECK(out_grad.getData() && sequence.getData()); - CHECK_EQ(out_grad.dims_.size(), 2); - CHECK_EQ(in_grad.dims_.size(), 2); - CHECK_EQ(w_grad.dims_.size(), 2); - CHECK_EQ(sequence.dims_.size(), 1); - - auto out_grad_mat = std::make_shared( - out_grad.getData(), out_grad.dims_[0], out_grad.dims_[1]); - const auto in_grad_mat = - !in_grad.getData() - ? nullptr - : std::make_shared( - in_grad.getData(), in_grad.dims_[0], in_grad.dims_[1]); - const auto w_grad_mat = - !w_grad.getData() - ? nullptr - : std::make_shared( - w_grad.getData(), w_grad.dims_[0], w_grad.dims_[1]); - CpuIVector seq_vec(sequence.dims_[0], - reinterpret_cast(sequence.getData())); - CHECK_EQ(out_grad_mat->getWidth(), in_grad_mat->getWidth() * context_length); - + CHECK(out_grad_mat); size_t input_dim = in_grad_mat ? in_grad_mat->getWidth() : w_grad_mat ? w_grad_mat->getWidth() : 0; - CHECK_EQ(out_grad_mat->getWidth(), input_dim * context_length); - const int* starts = seq_vec.getData(); size_t num_sequences = seq_vec.getSize() - 1; for (size_t i = 0; i < num_sequences; ++i) { @@ -226,10 +206,38 @@ public: CHECK_EQ(1, outputs.size()); CHECK_EQ(0, inouts.size()); - ContextProjectionBackward((Tensor&)outputs[0], - (Tensor&)inputs[0], - (Tensor&)inputs[1], - inputs[2], + CHECK(outputs[0].getData() && inputs[2].getData()); + CHECK_EQ(outputs[0].dims_.size(), 2); + CHECK_EQ(inputs[0].dims_.size(), 2); + CHECK_EQ(inputs[1].dims_.size(), 2); + CHECK_EQ(inputs[2].dims_.size(), 1); + + /// dim of input == dim of weight + CHECK_EQ(inputs[0].dims_[1], inputs[1].dims_[1]); + /// input and output has the same batch_size + CHECK_EQ(inputs[0].dims_[0], outputs[0].dims_[0]); + /// dim of output = dim of input * context_length + CHECK_EQ(outputs[0].dims_[1], inputs[0].dims_[1] * context_length_); + + auto out_grad_mat = std::make_shared::type>( + outputs[0].getData(), outputs[0].dims_[0], outputs[0].dims_[1]); + auto in_grad_mat = + !inputs[0].getData() + ? nullptr + : std::make_shared::type>( + inputs[0].getData(), inputs[0].dims_[0], inputs[0].dims_[1]); + auto w_grad_mat = + !inputs[1].getData() + ? nullptr + : std::make_shared::type>( + inputs[1].getData(), inputs[1].dims_[0], inputs[1].dims_[1]); + typename SequenceT::type seq_vec( + inputs[2].dims_[0], reinterpret_cast(inputs[2].getData())); + + ContextProjectionBackward(out_grad_mat.get(), + in_grad_mat ? in_grad_mat.get() : nullptr, + w_grad_mat ? w_grad_mat.get() : nullptr, + seq_vec, context_length_, context_start_, begin_pad_, @@ -264,10 +272,24 @@ public: CHECK_EQ(2, inputs.size()); CHECK_EQ(1, outputs.size()); CHECK_EQ(0, inouts.size()); + CHECK(inputs[0].getData() && outputs[0].getData() && inputs[1].getData()); + CHECK_EQ(outputs[0].dims_.size(), 2); + CHECK_EQ(inputs[0].dims_.size(), 2); + CHECK_EQ(inputs[1].dims_.size(), 1); + CHECK_EQ(outputs[0].dims_[1], inputs[0].dims_[1] * context_length_); + /// input and output has the same batch_size + CHECK_EQ(inputs[0].dims_[0], outputs[0].dims_[0]); - ContextProjectionBackwardData((Tensor&)outputs[0], - (Tensor&)inputs[0], - inputs[1], + auto out_grad_mat = std::make_shared::type>( + outputs[0].getData(), outputs[0].dims_[0], outputs[0].dims_[1]); + const auto in_grad_mat = std::make_shared::type>( + inputs[0].getData(), inputs[0].dims_[0], inputs[0].dims_[1]); + typename SequenceT::type seq_vec( + inputs[1].dims_[0], reinterpret_cast(inputs[1].getData())); + + ContextProjectionBackwardData(out_grad_mat.get(), + in_grad_mat.get(), + seq_vec, context_length_, context_start_); } @@ -299,9 +321,22 @@ public: CHECK_EQ(1, outputs.size()); CHECK_EQ(0, inouts.size()); - ContextProjectionBackwardWeight((Tensor&)outputs[0], - (Tensor&)inputs[0], - inputs[1], + CHECK(inputs[0].getData() && outputs[0].getData() && inputs[1].getData()); + CHECK_EQ(outputs[0].dims_.size(), 2); + CHECK_EQ(inputs[0].dims_.size(), 2); + CHECK_EQ(inputs[1].dims_.size(), 1); + CHECK_EQ(outputs[0].dims_[1], inputs[0].dims_[1] * context_length_); + + auto out_grad_mat = std::make_shared::type>( + outputs[0].getData(), outputs[0].dims_[0], outputs[0].dims_[1]); + auto w_grad_mat = std::make_shared::type>( + inputs[0].getData(), inputs[0].dims_[0], inputs[0].dims_[1]); + typename SequenceT::type seq_vec( + inputs[1].dims_[0], reinterpret_cast(inputs[1].getData())); + + ContextProjectionBackwardWeight(out_grad_mat.get(), + w_grad_mat.get(), + seq_vec, context_length_, context_start_, total_pad_, diff --git a/paddle/function/ContextProjectionOp.h b/paddle/function/ContextProjectionOp.h index e0f1beb496..93eb050fde 100644 --- a/paddle/function/ContextProjectionOp.h +++ b/paddle/function/ContextProjectionOp.h @@ -32,14 +32,13 @@ namespace paddle { * */ template -void ContextProjectionForward(Tensor& output, - const Tensor& input, - const Tensor& weight, - const Tensor& sequence, +void ContextProjectionForward(typename MatrixT::type* output, + const typename MatrixT::type* input, + const typename MatrixT::type* weight, + const typename SequenceT::type& sequence, size_t context_length, int context_start, - size_t begin_pad, - bool is_padding); + size_t begin_pad); /** * \brief Context Projection Backward. @@ -55,10 +54,10 @@ void ContextProjectionForward(Tensor& output, * */ template -void ContextProjectionBackward(Tensor& out_grad, - Tensor& in_grad, - Tensor& w_grad, - const Tensor& sequence, +void ContextProjectionBackward(typename MatrixT::type* out_grad, + typename MatrixT::type* in_grad, + typename MatrixT::type* w_grad, + const typename SequenceT::type& seq_vec, size_t context_length, int context_start, size_t begin_pad, @@ -66,19 +65,21 @@ void ContextProjectionBackward(Tensor& out_grad, size_t total_pad); template -void ContextProjectionBackwardData(Tensor& out_grad, - Tensor& in_grad, - const Tensor& sequence, - size_t context_length, - int context_start); +void ContextProjectionBackwardData( + typename MatrixT::type* out_grad, + typename MatrixT::type* in_grad, + const typename SequenceT::type& sequence, + size_t context_length, + int context_start); template -void ContextProjectionBackwardWeight(Tensor& out_grad, - Tensor& w_grad, - const Tensor& sequence, - size_t context_length, - int context_start, - size_t total_pad, - size_t begin_pad); +void ContextProjectionBackwardWeight( + typename MatrixT::type* out_grad, + typename MatrixT::type* w_grad, + const typename SequenceT::type& seq_vec, + size_t context_length, + int context_start, + size_t total_pad, + size_t begin_pad); } // namespace paddle diff --git a/paddle/function/ContextProjectionOpGpu.cu b/paddle/function/ContextProjectionOpGpu.cu index 1e5916002c..7c4ebacdbf 100644 --- a/paddle/function/ContextProjectionOpGpu.cu +++ b/paddle/function/ContextProjectionOpGpu.cu @@ -75,18 +75,16 @@ __global__ void KeContextProjectionForward(const real* input, void hl_context_projection_forward(const real* input, const int* sequence, - real* weight, + const real* weight, real* output, int num_sequences, int input_dim, int context_length, int context_start, - int begin_pad, - bool is_padding) { + int begin_pad) { CHECK_NOTNULL(input); CHECK_NOTNULL(sequence); CHECK_NOTNULL(output); - CHECK(!is_padding || weight); int block_size = 128; int blocks_x = num_sequences; @@ -94,7 +92,7 @@ void hl_context_projection_forward(const real* input, dim3 threads(block_size, 1); dim3 grid(blocks_x, blocks_y); - if (is_padding) { + if (weight) { KeContextProjectionForward<<< grid, threads, 0, STREAM_DEFAULT >>> (input, sequence, weight, output, input_dim, context_length, context_start, begin_pad); @@ -107,31 +105,23 @@ void hl_context_projection_forward(const real* input, } template <> -void ContextProjectionForward(Tensor& output, - const Tensor& input, - const Tensor& weight, - const Tensor& sequence, +void ContextProjectionForward(GpuMatrix* output, + const GpuMatrix* input, + const GpuMatrix* weight, + const GpuIVector& sequence, size_t context_length, int context_start, - size_t begin_pad, - bool is_padding) { - CHECK(output.getData() && input.getData() && sequence.getData()); - CHECK_EQ(output.dims_.size(), 2); - CHECK_EQ(input.dims_.size(), 2); - CHECK_EQ(weight.dims_.size(), 2); - CHECK_EQ(sequence.dims_.size(), 1); - CHECK_EQ(output.dims_[1], input.dims_[1] * context_length); - - hl_context_projection_forward(input.getData(), - reinterpret_cast(sequence.getData()), - weight.getData(), - output.getData(), - sequence.dims_[0] - 1, - input.dims_[1], + size_t begin_pad) { + CHECK(input && output); + hl_context_projection_forward(input->getData(), + sequence.getData(), + weight ? weight->getData() : nullptr, + output->getData(), + sequence.getSize() - 1, + input->getWidth(), context_length, context_start, - begin_pad, - is_padding); + begin_pad); } __global__ void KeContextProjectionBackwardData(real* out_grad, @@ -200,22 +190,17 @@ void hl_context_projection_backward_data(real* out_grad, } template <> -void ContextProjectionBackwardData(Tensor& out_grad, - Tensor& in_grad, - const Tensor& sequence, - size_t context_length, - int context_start) { - CHECK(in_grad.getData() && out_grad.getData() && sequence.getData()); - CHECK_EQ(out_grad.dims_.size(), 2); - CHECK_EQ(in_grad.dims_.size(), 2); - CHECK_EQ(sequence.dims_.size(), 1); - CHECK_EQ(out_grad.dims_[1], in_grad.dims_[1] * context_length); - - hl_context_projection_backward_data(out_grad.getData(), - reinterpret_cast(sequence.getData()), - in_grad.getData(), - sequence.dims_[0] - 1, - in_grad.dims_[1], +void ContextProjectionBackwardData(GpuMatrix* out_grad, + GpuMatrix* in_grad, + const GpuIVector& sequence, + size_t context_length, + int context_start) { + CHECK(in_grad && out_grad); + hl_context_projection_backward_data(out_grad->getData(), + sequence.getData(), + in_grad->getData(), + sequence.getSize() - 1, + in_grad->getWidth(), context_length, context_start); } @@ -320,24 +305,20 @@ void hl_context_projection_backward_weight(real* out_grad, } template <> -void ContextProjectionBackwardWeight(Tensor& out_grad, - Tensor& w_grad, - const Tensor& sequence, - size_t context_length, - int context_start, - size_t total_pad, - size_t begin_pad) { - CHECK(w_grad.getData() && out_grad.getData() && sequence.getData()); - CHECK_EQ(out_grad.dims_.size(), 2); - CHECK_EQ(w_grad.dims_.size(), 2); - CHECK_EQ(sequence.dims_.size(), 1); - CHECK_EQ(out_grad.dims_[1], w_grad.dims_[1] * context_length); - - hl_context_projection_backward_weight(out_grad.getData(), - reinterpret_cast(sequence.getData()), - w_grad.getData(), - sequence.dims_[0] - 1, - w_grad.dims_[1], +void ContextProjectionBackwardWeight( + GpuMatrix* out_grad, + GpuMatrix* w_grad, + const GpuIVector& seq_vec, + size_t context_length, + int context_start, + size_t total_pad, + size_t begin_pad) { + CHECK(out_grad && w_grad); + hl_context_projection_backward_weight(out_grad->getData(), + seq_vec.getData(), + w_grad->getData(), + seq_vec.getSize() - 1, + w_grad->getWidth(), total_pad, context_length, context_start, @@ -345,24 +326,27 @@ void ContextProjectionBackwardWeight(Tensor& out_grad, } template <> -void ContextProjectionBackward(Tensor& out_grad, - Tensor& in_grad, - Tensor& w_grad, - const Tensor& sequence, - size_t context_length, - int context_start, - size_t begin_pad, - bool is_padding, - size_t total_pad) { - if (in_grad.getData()) { - ContextProjectionBackwardData(out_grad, +void ContextProjectionBackward(GpuMatrix* out_grad, + GpuMatrix* in_grad, + GpuMatrix* w_grad, + const GpuIVector& sequence, + size_t context_length, + int context_start, + size_t begin_pad, + bool is_padding, + size_t total_pad) { + CHECK(out_grad); + if (in_grad) { + ContextProjectionBackwardData( + out_grad, in_grad, sequence, context_length, context_start); } - if (is_padding && w_grad.getData()) { - ContextProjectionBackwardWeight(out_grad, + if (is_padding && w_grad) { + ContextProjectionBackwardWeight( + out_grad, w_grad, sequence, context_length, diff --git a/paddle/function/ContextProjectionOpTest.cpp b/paddle/function/ContextProjectionOpTest.cpp index 372fc21cf1..359428fc03 100644 --- a/paddle/function/ContextProjectionOpTest.cpp +++ b/paddle/function/ContextProjectionOpTest.cpp @@ -32,8 +32,7 @@ void testMatrixProjectionForward(int context_start, FuncConfig() .set("context_length", context_length) .set("context_start", context_start) - .set("begin_pad", std::max(0, -context_start)) - .set("is_padding", is_padding)); + .set("begin_pad", std::max(0, -context_start))); CpuMatrix cpu_in(batch_size, input_dim); cpu_in.randomizeUniform(); diff --git a/paddle/function/Function.h b/paddle/function/Function.h index 210eba1301..9e8cbb8e48 100644 --- a/paddle/function/Function.h +++ b/paddle/function/Function.h @@ -40,6 +40,19 @@ struct MatrixT { using type = GpuMatrix; }; +template +struct SequenceT; + +template <> +struct SequenceT { + using type = CpuIVector; +}; + +template <> +struct SequenceT { + using type = GpuIVector; +}; + typedef std::vector Dims; class Tensor { diff --git a/paddle/gserver/layers/ContextProjection.cpp b/paddle/gserver/layers/ContextProjection.cpp index 37e951a1e3..e947b2b9ec 100644 --- a/paddle/gserver/layers/ContextProjection.cpp +++ b/paddle/gserver/layers/ContextProjection.cpp @@ -53,8 +53,7 @@ bool ContextProjection::init() { FuncConfig() .set("context_length", context_length) .set("context_start", context_start) - .set("begin_pad", beginPad_) - .set("is_padding", is_padding)); + .set("begin_pad", beginPad_)); createFunction(backward_, "ContextProjectionBackward", FuncConfig() @@ -112,7 +111,7 @@ void ContextProjection::forward() { size_t dim = out_->value->getWidth(); CHECK_EQ(dim, input_dim * config_.context_length()); size_t batch_size = in_->value->getHeight(); - CHECK_EQ(batch_size, out_->value->getHeight()); + CHECK_EQ(forward_.size(), 1) << "Only one forward function here"; REGISTER_TIMER_INFO("ContextProjectionForward", getName().c_str()); bool is_padding = config_.trainable_padding(); @@ -120,12 +119,6 @@ void ContextProjection::forward() { auto w_ptr = state_ ? state_.get() : is_padding ? weight_->getW().get() : nullptr; auto start_pos = in_->sequenceStartPositions; - /// if use state_ as weight_, w_ptr already has mem, so padding true - forward_[0]->init(FuncConfig() - .set("context_length", config_.context_length()) - .set("context_start", config_.context_start()) - .set("begin_pad", beginPad_) - .set("is_padding", state_ ? true : is_padding)); forward_[0]->calc({Tensor(in_->value->getData(), Dims{batch_size, input_dim}), Tensor(w_ptr ? w_ptr->getData() : nullptr, Dims{w_ptr ? w_ptr->getHeight() : 0, input_dim}), @@ -161,6 +154,7 @@ void ContextProjection::backward(const UpdateCallback& callback) { CHECK_EQ(dim, input_dim * config_.context_length()); size_t batch_size = in_->value->getHeight(); CHECK_EQ(batch_size, out_->value->getHeight()); + CHECK_EQ(backward_.size(), 1) << "Only one backward function here"; REGISTER_TIMER_INFO("ContextProjectionBackward", getName().c_str()); bool is_padding = config_.trainable_padding(); -- GitLab