From 590ecba30576f166579ac49e4fa042af9e191fd1 Mon Sep 17 00:00:00 2001 From: xutianbing Date: Tue, 27 Dec 2016 15:40:35 -0800 Subject: [PATCH] add ContextProjectionBackward, ContextProjectionBackwardData, ContextProjectionBackwardWeightw --- paddle/function/CMakeLists.txt | 1 - paddle/function/context_projection_op.cpp | 197 +++++++++++++++- paddle/function/context_projection_op.h | 46 +++- paddle/function/context_projection_op_gpu.cu | 210 ++++++++++++++++++ .../function/context_projection_op_test.cpp | 100 ++++++++- 5 files changed, 548 insertions(+), 6 deletions(-) diff --git a/paddle/function/CMakeLists.txt b/paddle/function/CMakeLists.txt index f70ae9959b0..0b3126155d0 100644 --- a/paddle/function/CMakeLists.txt +++ b/paddle/function/CMakeLists.txt @@ -19,7 +19,6 @@ if(WITH_TESTING) add_simple_unittest(CrossMapNormalOpTest) add_unittest(ContextProjectionOpTest ContextProjectionOpTest.cpp - ContextProjectionOpGpu.cu ../gserver/tests/TestUtil.cpp) endif() endif() diff --git a/paddle/function/context_projection_op.cpp b/paddle/function/context_projection_op.cpp index 75c41eed141..a6a85fb6a46 100644 --- a/paddle/function/context_projection_op.cpp +++ b/paddle/function/context_projection_op.cpp @@ -41,7 +41,7 @@ void ContextProjectionForward(Tensor& output, !weight.getData() ? nullptr : std::make_shared( - weight.getData(), weight.dims_[0], input.dims_[1]); + weight.getData(), weight.dims_[0], weight.dims_[1]); CpuIVector seq_vec(sequence.dims_[0], reinterpret_cast(sequence.getData())); CHECK_EQ(out_mat->getWidth(), in_mat->getWidth() * context_length); @@ -125,12 +125,207 @@ private: bool is_padding_; }; +template <> +void ContextProjectionBackward(Tensor& out_grad, + const Tensor& in_grad, + const Tensor& w_grad, + const Tensor& sequence, + size_t context_length, + int context_start, + size_t begin_pad, + bool is_padding) { + CHECK(out_grad.getData() && sequence.getData()); + CHECK_EQ(out_grad.dims_.size(), 2); + CHECK_EQ(in_grad.dims_.size(), 2); + CHECK_EQ(w_grad.dims_.size(), 2); + CHECK_EQ(sequence.dims_.size(), 1); + + auto out_grad_mat = std::make_shared( + out_grad.getData(), out_grad.dims_[0], out_grad.dims_[1]); + const auto in_grad_mat = + !in_grad.getData() + ? nullptr + : std::make_shared( + in_grad.getData(), in_grad.dims_[0], in_grad.dims_[1]); + const auto w_grad_mat = + !w_grad.getData() + ? nullptr + : std::make_shared( + w_grad.getData(), w_grad.dims_[0], w_grad.dims_[1]); + CpuIVector seq_vec(sequence.dims_[0], + reinterpret_cast(sequence.getData())); + CHECK_EQ(out_grad_mat->getWidth(), in_grad_mat->getWidth() * context_length); + + size_t input_dim = in_grad_mat ? in_grad_mat->getWidth() + : w_grad_mat ? w_grad_mat->getWidth() : 0; + CHECK_EQ(out_grad_mat->getWidth(), input_dim * context_length); + + const int* starts = seq_vec.getData(); + size_t num_sequences = seq_vec.getSize() - 1; + for (size_t i = 0; i < num_sequences; ++i) { + for (size_t j = 0; j < context_length; ++j) { + int begin = starts[i] + context_start + j; + int end = starts[i + 1] + context_start + j; + int dst_begin = starts[i]; + int dst_end = starts[i + 1]; + if (begin < starts[i]) { + int64_t pad_size = + std::min(starts[i] - begin, starts[i + 1] - starts[i]); + if (is_padding && w_grad_mat) { + MatrixPtr mat = out_grad_mat->subMatrix(starts[i], pad_size); + MatrixPtr sub = w_grad_mat->subMatrix(j, pad_size); + sub->addAtOffset(*mat, j * input_dim); + } + dst_begin = starts[i] + pad_size; + begin = starts[i]; + } + if (end > starts[i + 1]) { + int64_t pad_size = + std::min(end - starts[i + 1], starts[i + 1] - starts[i]); + if (is_padding && w_grad_mat) { + MatrixPtr mat = + out_grad_mat->subMatrix(starts[i + 1] - pad_size, pad_size); + MatrixPtr sub = w_grad_mat->subMatrix( + begin_pad + context_start + j - pad_size, pad_size); + sub->addAtOffset(*mat, j * input_dim); + } + dst_end = starts[i + 1] - pad_size; + end = starts[i + 1]; + } + if (end <= begin) continue; + if (!in_grad_mat) continue; + MatrixPtr src = in_grad_mat->subMatrix(begin, end - begin); + MatrixPtr dst = out_grad_mat->subMatrix(dst_begin, dst_end - dst_begin); + src->addAtOffset(*dst, j * input_dim); + } + } +} + +/** + * \param inputs[0] input value. + * \param inputs[1] input weight. + * \param inputs[2] input sequence. + * \param outputs[0] output value. + */ +template +class ContextProjectionBackwardFunc : public FunctionBase { +public: + void init(const FuncConfig& config) override { + context_length_ = config.get("context_length"); + context_start_ = config.get("context_start"); + begin_pad_ = config.get("begin_pad"); + is_padding_ = config.get("is_padding"); + } + + void calc(const Arguments& inputs, + const Arguments& outputs, + const Arguments& inouts) override { + CHECK_EQ(3, inputs.size()); + CHECK_EQ(1, outputs.size()); + CHECK_EQ(0, inouts.size()); + + ContextProjectionBackward((Tensor&)outputs[0], + inputs[0], + inputs[1], + inputs[2], + context_length_, + context_start_, + begin_pad_, + is_padding_); + } + +private: + size_t context_length_; + int context_start_; + size_t begin_pad_; + bool is_padding_; +}; + +/** + * \param inputs[0] input grad. + * \param inputs[1] input sequence. + * \param outputs[0] output grad. + */ +template +class ContextProjectionBackwardDataFunc : public FunctionBase { +public: + void init(const FuncConfig& config) override { + context_length_ = config.get("context_length"); + context_start_ = config.get("context_start"); + } + + void calc(const Arguments& inputs, + const Arguments& outputs, + const Arguments& inouts) override { + CHECK_EQ(2, inputs.size()); + CHECK_EQ(1, outputs.size()); + CHECK_EQ(0, inouts.size()); + + ContextProjectionBackwardData((Tensor&)outputs[0], + (Tensor&)inputs[0], + inputs[1], + context_length_, + context_start_); + } + +private: + size_t context_length_; + int context_start_; +}; + +/** + * \param inputs[0] weight grad. + * \param inputs[1] input sequence. + * \param outputs[0] output grad. + */ +template +class ContextProjectionBackwardWeightFunc : public FunctionBase { +public: + void init(const FuncConfig& config) override { + context_length_ = config.get("context_length"); + context_start_ = config.get("context_start"); + begin_pad_ = config.get("begin_pad"); + total_pad_ = config.get("total_pad"); + } + + void calc(const Arguments& inputs, + const Arguments& outputs, + const Arguments& inouts) override { + CHECK_EQ(2, inputs.size()); + CHECK_EQ(1, outputs.size()); + CHECK_EQ(0, inouts.size()); + + ContextProjectionBackwardWeight((Tensor&)outputs[0], + (Tensor&)inputs[0], + inputs[1], + context_length_, + context_start_, + total_pad_, + begin_pad_); + } + +private: + size_t context_length_; + int context_start_; + size_t begin_pad_; + size_t total_pad_; +}; + REGISTER_TYPED_FUNC(ContextProjectionForward, CPU, ContextProjectionForwardFunc); +REGISTER_TYPED_FUNC(ContextProjectionBackward, + CPU, + ContextProjectionBackwardFunc); #ifndef PADDLE_ONLY_CPU REGISTER_TYPED_FUNC(ContextProjectionForward, GPU, ContextProjectionForwardFunc); +REGISTER_TYPED_FUNC(ContextProjectionBackwardData, + GPU, + ContextProjectionBackwardDataFunc); +REGISTER_TYPED_FUNC(ContextProjectionBackwardWeight, + GPU, + ContextProjectionBackwardWeightFunc); #endif } // namespace paddle diff --git a/paddle/function/context_projection_op.h b/paddle/function/context_projection_op.h index bdc5071bc62..5f4e0761db3 100644 --- a/paddle/function/context_projection_op.h +++ b/paddle/function/context_projection_op.h @@ -25,9 +25,10 @@ namespace paddle { * \param[in] input input data. * \param[in] weight input weight. * \param[in] sequence input data. - * \param[in] context_length consecutive rows for concatenation. - * \param[in] begin_pad context start position. - * \param[in] is_padding whether padding 0 or not. + * \param[in] context_length consecutive rows for concatenation. + * \param[in] context_start context start position. + * \param[in] begin_pad begining pad position. + * \param[in] is_padding whether padding 0 or not. * */ template @@ -40,4 +41,43 @@ void ContextProjectionForward(Tensor& output, size_t begin_pad, bool is_padding); +/** + * \brief Context Projection Backward. + * + * \param[out] outputs output gradient. + * \param[in] input input gradient. + * \param[in] weight input weight gradient. + * \param[in] sequence input data. + * \param[in] context_length consecutive rows for concatenation. + * \param[in] context_start context start position. + * \param[in] begin_pad begining pad position. + * \param[in] is_padding whether padding 0 or not. + * + */ +template +void ContextProjectionBackward(Tensor& out_grad, + const Tensor& in_grad, + const Tensor& w_grad, + const Tensor& sequence, + size_t context_length, + int context_start, + size_t begin_pad, + bool is_padding); + +template +void ContextProjectionBackwardData(Tensor& out_grad, + Tensor& in_grad, + const Tensor& sequence, + size_t context_length, + int context_start); + +template +void ContextProjectionBackwardWeight(Tensor& out_grad, + Tensor& w_grad, + const Tensor& sequence, + size_t context_length, + int context_start, + size_t total_pad, + size_t begin_pad); + } // namespace paddle diff --git a/paddle/function/context_projection_op_gpu.cu b/paddle/function/context_projection_op_gpu.cu index 4e7958164b0..fdea433d07e 100644 --- a/paddle/function/context_projection_op_gpu.cu +++ b/paddle/function/context_projection_op_gpu.cu @@ -134,4 +134,214 @@ void ContextProjectionForward(Tensor& output, is_padding); } +__global__ void KeContextProjectionBackwardData(real* out_grad, + const int* sequence, + real* in_grad, + int input_dim, + int context_length, + int context_start) { + int idx = threadIdx.x; + int block_size = blockDim.x; + int sequenceId = blockIdx.x; + int seq_start = sequence[sequenceId]; + int seq_end = sequence[sequenceId+1]; + real value = 0; + + int instances = seq_end - seq_start + context_length - 1; + out_grad += seq_start * input_dim * context_length; + in_grad += seq_start * input_dim; + for (int k = 0; k <= input_dim / block_size; k++) { + if (idx < input_dim) { + for (int i = 0; i < instances; i++) { + if ((i + context_start) < 0) { + continue; + } else if ((i + context_start) >= (seq_end - seq_start)) { + continue; + } else { + // value = 0; + value = in_grad[(i + context_start) * input_dim + idx]; + } + + int outx = (i - context_length) < 0 ? i : (context_length - 1); + int outy = (i - context_length) < 0 ? 0 : (i - (context_length - 1)); + real* output_r = + out_grad + outy * input_dim * context_length + outx * input_dim; + for (int j = outy; j < seq_end - seq_start; j++) { + value += output_r[idx]; + if (j - outy == outx) break; + output_r += (context_length - 1) * input_dim; + } + in_grad[(i + context_start) * input_dim + idx] = value; + } + } + idx += block_size; + } +} + +void hl_context_projection_backward_data(real* out_grad, + const int* sequence, + real* input_grad, + int num_sequences, + int input_dim, + int context_length, + int context_start) { + CHECK_NOTNULL(out_grad); + CHECK_NOTNULL(sequence); + CHECK_NOTNULL(input_grad); + + int block_size = 128; + int blocks_x = num_sequences; + int blocks_y = 1; + dim3 threads(block_size, 1); + dim3 grid(blocks_x, blocks_y); + KeContextProjectionBackwardData<<< grid, threads, 0, STREAM_DEFAULT >>> + (out_grad, sequence, input_grad, input_dim, context_length, context_start); + CHECK_SYNC("hl_context_projection_backward_data failed"); +} + +template <> +void ContextProjectionBackwardData(Tensor& out_grad, + Tensor& in_grad, + const Tensor& sequence, + size_t context_length, + int context_start) { + CHECK(in_grad.getData() && out_grad.getData() && sequence.getData()); + CHECK_EQ(out_grad.dims_.size(), 2); + CHECK_EQ(in_grad.dims_.size(), 2); + CHECK_EQ(sequence.dims_.size(), 1); + CHECK_EQ(out_grad.dims_[1], in_grad.dims_[1] * context_length); + + hl_context_projection_backward_data(out_grad.getData(), + reinterpret_cast(sequence.getData()), + in_grad.getData(), + sequence.dims_[0] - 1, + in_grad.dims_[1], + context_length, + context_start); +} + +template +__global__ void KeContextProjectionBackwardWeight(real* out_grad, + const int* sequence, + real* w_grad, + int num_sequences, + int w_dim, + int context_length, + int context_start, + int begin_pad) { + __shared__ real sum_s[THREADS_Y][THREADS_X]; + int pad_of_block = (w_dim + THREADS_X - 1) / THREADS_X; + const int idx = threadIdx.x; + const int idy = threadIdx.y; + int padId = blockIdx.x / pad_of_block; + int weight_idx = idx + THREADS_X * (blockIdx.x % pad_of_block); + int instanceId; + real value = 0; + real* output_r; + + sum_s[idy][idx] = 0.0f; + if (weight_idx < w_dim) { + for (int seqId = idy; seqId < num_sequences; seqId += THREADS_Y) { + int seq_start = sequence[seqId]; + int seq_end = sequence[seqId+1]; + output_r = out_grad + seq_start * w_dim * context_length; + + if (context_start < 0) { + if (padId + context_start < 0) { + instanceId = padId; + } else { + // begin_pad > 0; + instanceId = (padId - begin_pad) + + (seq_end - seq_start) - context_start; + } + } else { + if (padId + (seq_end - seq_start) < context_start) { + continue; + } else { + // begin_pad == 0; + instanceId = padId + (seq_end - seq_start) - context_start; + } + } + + int outx = (instanceId - context_length) < 0 ? + instanceId : (context_length - 1); + int outy = (instanceId - context_length) < 0 ? + 0 : (instanceId - (context_length - 1)); + output_r += outy * w_dim * context_length + outx * w_dim; + for (int j = outy; j < seq_end - seq_start; j++) { + value += output_r[weight_idx]; + if (j - outy == outx) break; + output_r += (context_length - 1) * w_dim; + } + } + sum_s[idy][idx] = value; + } + __syncthreads(); + + for (int stride = THREADS_Y/2; stride > 0; stride = stride/2) { + if (idy < stride) { + sum_s[idy][idx] += sum_s[idy + stride][idx]; + } + __syncthreads(); + } + __syncthreads(); + + if (weight_idx < w_dim) { + if (idy == 0) { + w_grad[padId * w_dim + weight_idx] += sum_s[0][idx]; + } + } +} + +void hl_context_projection_backward_weight(real* out_grad, + const int* sequence, + real* w_grad, + int num_sequences, + int w_dim, + size_t total_pad, + int context_length, + int context_start, + int begin_pad) { + CHECK_NOTNULL(out_grad); + CHECK_NOTNULL(sequence); + CHECK_NOTNULL(w_grad); + + int threads_x = 32; + int threads_y = 32; + int blocks_x = total_pad * ((w_dim + threads_x - 1) / threads_x); + dim3 threads(threads_x, threads_y); + dim3 grid(blocks_x, 1); + + KeContextProjectionBackwardWeight<32, 32> + <<< grid, threads, 0, STREAM_DEFAULT >>> + (out_grad, sequence, w_grad, num_sequences, w_dim, + context_length, context_start, begin_pad); + CHECK_SYNC("hl_context_projection_backward_weight failed"); +} + +template <> +void ContextProjectionBackwardWeight(Tensor& out_grad, + Tensor& w_grad, + const Tensor& sequence, + size_t context_length, + int context_start, + size_t total_pad, + size_t begin_pad) { + CHECK(w_grad.getData() && out_grad.getData()); + CHECK_EQ(out_grad.dims_.size(), 2); + CHECK_EQ(w_grad.dims_.size(), 2); + CHECK_EQ(sequence.dims_.size(), 1); + CHECK_EQ(out_grad.dims_[1], w_grad.dims_[1] * context_length); + + hl_context_projection_backward_weight(out_grad.getData(), + reinterpret_cast(sequence.getData()), + w_grad.getData(), + sequence.dims_[0] - 1, + w_grad.dims_[1], + total_pad, + context_length, + context_start, + begin_pad); +} + } // namespace paddle diff --git a/paddle/function/context_projection_op_test.cpp b/paddle/function/context_projection_op_test.cpp index 98784471ae9..997bcb1bd2d 100644 --- a/paddle/function/context_projection_op_test.cpp +++ b/paddle/function/context_projection_op_test.cpp @@ -77,7 +77,100 @@ void testMatrixProjectionForward(int context_start, autotest::TensorCheckEqual(cpu_out, gpu_out); } -TEST(ContextProjectionForward, projection) { +void testMatrixProjectionBackward(int context_start, + int context_length, + bool is_padding, + size_t batch_size, + size_t input_dim) { + size_t pad = std::max(0, -context_start) + + std::max(0, (int)(context_start + context_length - 1)); + if (pad == 0) is_padding = false; + + std::shared_ptr cpu_func( + FunctionBase::funcRegistrar_.createByType( + "ContextProjectionBackward-CPU")); + FuncConfig cpu_config; + cpu_config.set("context_length", context_length) + .set("context_start", context_start) + .set("begin_pad", std::max(0, -context_start)) + .set("is_padding", is_padding); + cpu_func->init(cpu_config); + + std::shared_ptr gpu_data_func( + FunctionBase::funcRegistrar_.createByType( + "ContextProjectionBackwardData-GPU")); + FuncConfig gpu_data_config; + gpu_data_config.set("context_length", context_length) + .set("context_start", context_start); + gpu_data_func->init(gpu_data_config); + + std::shared_ptr gpu_w_func( + FunctionBase::funcRegistrar_.createByType( + "ContextProjectionBackwardWeight-GPU")); + FuncConfig gpu_w_config; + gpu_w_config.set("context_length", context_length) + .set("context_start", context_start) + .set("begin_pad", std::max(0, -context_start)) + .set("total_pad", pad); + gpu_w_func->init(gpu_w_config); + + CpuMatrix cpu_in_grad(batch_size, input_dim); + cpu_in_grad.randomizeUniform(); + GpuMatrix gpu_in_grad(batch_size, input_dim); + gpu_in_grad.copyFrom(cpu_in_grad); + + CpuMatrix cpu_out_grad(batch_size, input_dim * context_length); + cpu_out_grad.randomizeUniform(); + GpuMatrix gpu_out_grad(batch_size, input_dim * context_length); + gpu_out_grad.copyFrom(cpu_out_grad); + + IVectorPtr cpu_seq; + generateSequenceStartPositions(batch_size, cpu_seq); + IVectorPtr gpu_seq = IVector::create(cpu_seq->getSize(), true); + gpu_seq->copyFrom(*cpu_seq); + + auto cpu_w_grad = + is_padding ? std::make_shared(pad, input_dim) : nullptr; + auto gpu_w_grad = + is_padding ? std::make_shared(pad, input_dim) : nullptr; + if (is_padding) { + cpu_w_grad->randomizeUniform(); + gpu_w_grad->copyFrom(*cpu_w_grad); + } + + cpu_func->calc({Tensor(cpu_in_grad.getData(), Dims{batch_size, input_dim}), + Tensor(cpu_w_grad ? cpu_w_grad->getData() : nullptr, + Dims{pad, input_dim}), + Tensor(reinterpret_cast(cpu_seq->getData()), + Dims{cpu_seq->getSize()})}, + {Tensor(cpu_out_grad.getData(), + Dims{batch_size, input_dim * context_length})}, + {}); + + gpu_data_func->calc( + {Tensor(gpu_in_grad.getData(), Dims{batch_size, input_dim}), + Tensor(reinterpret_cast(gpu_seq->getData()), + Dims{gpu_seq->getSize()})}, + {Tensor(gpu_out_grad.getData(), + Dims{batch_size, input_dim * context_length})}, + {}); + + if (is_padding && gpu_w_grad) { + gpu_w_func->calc({Tensor(gpu_w_grad->getData(), Dims{pad, input_dim}), + Tensor(reinterpret_cast(gpu_seq->getData()), + Dims{gpu_seq->getSize()})}, + {Tensor(gpu_out_grad.getData(), + Dims{batch_size, input_dim * context_length})}, + {}); + } + + autotest::TensorCheckErr(cpu_in_grad, gpu_in_grad); + if (is_padding) { + autotest::TensorCheckErr(*cpu_w_grad, *gpu_w_grad); + } +} + +TEST(ContextProjection, projection) { for (auto context_start : {-5, -3, -1, 0, 3}) { for (auto context_length : {1, 2, 5, 7}) { for (auto trainable_padding : {false, true}) { @@ -93,6 +186,11 @@ TEST(ContextProjectionForward, projection) { trainable_padding, batch_size, input_dim); + testMatrixProjectionBackward(context_start, + context_length, + trainable_padding, + batch_size, + input_dim); } } } -- GitLab