diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 2569535c257c3210c239b69cd464ae59a8f4747c..2412ebd82a02c872e73fd310c56221309441f630 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -149,7 +149,7 @@ op_library(sequence_pool_op DEPS sequence_pooling) op_library(lstm_op DEPS sequence2batch lstm_compute) op_library(gru_op DEPS sequence2batch gru_compute) op_library(recurrent_op DEPS executor) -op_library(warpctc_op DEPS dynload_warpctc sequence_padding math_function) +op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale math_function) op_library(cos_sim_op DEPS cos_sim_functor) op_library(parallel_do_op DEPS executor) diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index fd59eef7d650b48feae68c89be54ec4e48cbcc7e..c607704efac86982c8c22e462381aaab488a9b69 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -13,6 +13,7 @@ if(WITH_GPU) nv_library(context_project SRCS context_project.cc context_project.cu DEPS device_context math_function) nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context tensor) nv_library(sequence_padding SRCS sequence_padding.cc sequence_padding.cu DEPS lod_tensor device_context) + nv_library(sequence_scale SRCS sequence_scale.cc sequence_scale.cu DEPS lod_tensor device_context) nv_library(lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context activation_functions) nv_library(maxouting SRCS maxouting.cc maxouting.cu DEPS device_context) nv_library(unpooling SRCS unpooling.cc unpooling.cu DEPS device_context) @@ -29,6 +30,7 @@ else() cc_library(context_project SRCS context_project.cc DEPS device_context math_function) cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context tensor) cc_library(sequence_padding SRCS sequence_padding.cc DEPS lod_tensor device_context) + cc_library(sequence_scale SRCS sequence_scale.cc DEPS lod_tensor device_context) cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions) cc_library(maxouting SRCS maxouting.cc DEPS device_context) cc_library(unpooling SRCS unpooling.cc DEPS device_context) diff --git a/paddle/operators/math/sequence_scale.cc b/paddle/operators/math/sequence_scale.cc new file mode 100644 index 0000000000000000000000000000000000000000..7e439e9a2cebaa5d494b185fd878e293a6895e45 --- /dev/null +++ b/paddle/operators/math/sequence_scale.cc @@ -0,0 +1,46 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/sequence_scale.h" + +namespace paddle { +namespace operators { +namespace math { + +template +class ScaleLoDTensorFunctor { + public: + void operator()(const platform::CPUDeviceContext& context, + framework::LoDTensor& seq, const T* scales) { + const size_t level = 0; + auto lod = seq.lod(); + const size_t num_seq = lod[level].size() - 1; + size_t seq_width = seq.dims()[1]; + framework::LoD abs_offset_lod = framework::ToAbsOffset(lod); + + T* seq_data = seq.mutable_data(context.GetPlace()); + for (size_t i = 0; i < num_seq; ++i) { + for (size_t j = lod[level][i] * seq_width; + j < lod[level][i + 1] * seq_width; ++j) { + seq_data[j] *= scales[i]; + } + } + } +}; + +template class ScaleLoDTensorFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/sequence_scale.cu b/paddle/operators/math/sequence_scale.cu new file mode 100644 index 0000000000000000000000000000000000000000..ceaabd8e0fd81c927fbd4333c0aa7954b8da8513 --- /dev/null +++ b/paddle/operators/math/sequence_scale.cu @@ -0,0 +1,57 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/sequence_scale.h" +#include "paddle/platform/cuda_helper.h" + +namespace paddle { +namespace operators { +namespace math { + +using platform::PADDLE_CUDA_NUM_THREADS; + +template +__global__ void SequenceScaleKernel(T* seq, size_t* lod, const T* scales, + const size_t seq_width) { + for (int i = threadIdx.x; + i < (lod[blockIdx.x + 1] - lod[blockIdx.x]) * seq_width; + i += BlockSize) { + int idx = lod[blockIdx.x] * seq_width + i; + seq[idx] *= scales[blockIdx.x]; + } +} + +template +class ScaleLoDTensorFunctor { + public: + void operator()(const platform::CUDADeviceContext& context, + framework::LoDTensor& seq, const T* scales) { + const size_t level = 0; + auto lod = seq.lod(); + const size_t num_seq = lod[level].size() - 1; + const size_t seq_width = seq.numel() / seq.dims()[0]; + framework::LoD abs_offset_lod = framework::ToAbsOffset(lod); + T* seq_data = seq.mutable_data(context.GetPlace()); + + SequenceScaleKernel<<< + num_seq, PADDLE_CUDA_NUM_THREADS, 0, context.stream()>>>( + seq_data, abs_offset_lod[level].data(), scales, seq_width); + } +}; + +template class ScaleLoDTensorFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/sequence_scale.h b/paddle/operators/math/sequence_scale.h new file mode 100644 index 0000000000000000000000000000000000000000..ecd9a57c3f4d8d91bfb8933a0fd38355c227744d --- /dev/null +++ b/paddle/operators/math/sequence_scale.h @@ -0,0 +1,55 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/lod_tensor.h" +#include "paddle/platform/device_context.h" + +namespace paddle { +namespace operators { +namespace math { + +/* + * \brief Scale a sequence. + * + * All sequences will be padded to the same length and stored in a transposed + * shape. + * Example: + * Given: + * seq = (s0, s0, s0, s0; s1, s1; s2, s2, s2; s3) + * scales = (2, 3, 4, 5) + * then: + * result = (2*s0, 2*s0, 2*s0, 2*s0; 3*s1, 3*s1; 4*s2, 4*s2, 4*s2; 5*s3) + + * + * \param context Device context of this functor. + * \param seq LoDTensor which is stored in sequence format, the shape + * is [total_sequence_length, sequence_width] where + * total_sequence_length is the sum of all sequences' + * length. + * \param scales Array. The i-th sequence will be scaled by scales[i]. + * \param num_seq Number of sequence + * + */ +template +class ScaleLoDTensorFunctor { + public: + void operator()(const DeviceContext& context, framework::LoDTensor& seq, + const T* scales); +}; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/warpctc_op.h b/paddle/operators/warpctc_op.h index 41899c7fe0c3089c4fc7c160c8896dec0e3cd6dd..8aea061c00cc9614db37ed408b6d330ef707d1cf 100644 --- a/paddle/operators/warpctc_op.h +++ b/paddle/operators/warpctc_op.h @@ -17,6 +17,7 @@ limitations under the License. */ #include "paddle/framework/op_registry.h" #include "paddle/operators/math/math_function.h" #include "paddle/operators/math/sequence_padding.h" +#include "paddle/operators/math/sequence_scale.h" #include "paddle/platform/dynload/warpctc.h" namespace paddle { @@ -178,11 +179,14 @@ class WarpCTCKernel : public framework::OpKernel { T* warpctc_grad_data = warpctc_grad->mutable_data(warpctc_logits.dims(), ctx.GetPlace()); + math::SetConstant()( + ctx.template device_context(), warpctc_grad, + static_cast(0)); + // warpctc accesses labels in CPU memory Tensor warpctc_label; Copy(*label, platform::CPUPlace(), ctx.device_context(), &warpctc_label); const int* warpctc_label_data = warpctc_label.data(); - // warpctc stores loss in CPU memory Tensor warpctc_loss; T* warpctc_loss_data = @@ -206,11 +210,18 @@ class WarpCTCGradKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { auto* warpctc_grad = ctx.Input("WarpCTCGrad"); auto* logits_grad = ctx.Output(framework::GradVarName("Logits")); + const Tensor* loss_grad = ctx.Input(framework::GradVarName("Loss")); + logits_grad->mutable_data(ctx.GetPlace()); bool norm_by_times = ctx.Attr("norm_by_times"); math::UnpaddingLoDTensorFunctor()( ctx.template device_context(), *logits_grad, *warpctc_grad, norm_by_times); + + const T* loss_grad_data = loss_grad->data(); + math::ScaleLoDTensorFunctor()( + ctx.template device_context(), *logits_grad, + loss_grad_data); } }; diff --git a/python/paddle/v2/fluid/tests/test_warpctc_op.py b/python/paddle/v2/fluid/tests/test_warpctc_op.py index 272e52c982a8ee75c1ccb0398cfb5abeab28ba64..9f565676c5af1685704681758a8590ecb6f59026 100644 --- a/python/paddle/v2/fluid/tests/test_warpctc_op.py +++ b/python/paddle/v2/fluid/tests/test_warpctc_op.py @@ -17,6 +17,8 @@ import numpy as np from op_test import OpTest from test_softmax_op import stable_softmax +CUDA_BLOCK_SIZE = 512 + class CTCForward(object): def __init__(self, softmax, softmax_lod, labels, labels_lod, blank, @@ -167,47 +169,63 @@ class CTCForward(object): class TestWarpCTCOp(OpTest): + def config(self): + self.batch_size = 4 + self.num_classes = 8 + self.logits_lod = [[0, 4, 5, 8, 11]] + self.labels_lod = [[0, 3, 4, 8, 12]] + self.blank = self.num_classes - 1 + self.norm_by_times = False + def setUp(self): self.op_type = "warpctc" + self.config() - batch_size = 4 - num_classes = 8 - logits_lod = [[0, 4, 5, 8, 11]] - logits = np.random.uniform(0.1, 1.0, - [11, num_classes]).astype("float32") + logits = np.random.uniform( + 0.1, 1.0, + [self.logits_lod[0][-1], self.num_classes]).astype("float32") softmax = np.apply_along_axis(stable_softmax, 1, logits) - labels_lod = [[0, 3, 4, 8, 12]] # labels should not be blank - labels = np.random.randint(0, num_classes - 1, [12, 1], dtype="int32") - - blank = num_classes - 1 - norm_by_times = False + labels = np.random.randint( + 0, self.num_classes - 1, [self.labels_lod[0][-1], 1], dtype="int32") - ctc = CTCForward(softmax, logits_lod, labels, labels_lod, blank, - norm_by_times) + ctc = CTCForward(softmax, self.logits_lod, labels, self.labels_lod, + self.blank, self.norm_by_times) loss = ctc.forward() max_sequence_length = 0 - for i in range(batch_size): - max_sequence_length = max(max_sequence_length, - logits_lod[0][i + 1] - logits_lod[0][i]) - gradient = np.zeros( - [max_sequence_length, batch_size, num_classes], dtype="float32") + for i in range(self.batch_size): + max_sequence_length = max( + max_sequence_length, + self.logits_lod[0][i + 1] - self.logits_lod[0][i]) + self.gradient = np.zeros( + [max_sequence_length, self.batch_size, self.num_classes], + dtype="float32") self.inputs = { - "Logits": (logits, logits_lod), - "Label": (labels, labels_lod) + "Logits": (logits, self.logits_lod), + "Label": (labels, self.labels_lod) } self.outputs = {"Loss": loss} - self.attrs = {"blank": blank, "norm_by_times": norm_by_times} + self.attrs = {"blank": self.blank, "norm_by_times": self.norm_by_times} def test_check_output(self): self.check_output() + def test_check_grad(self): + self.outputs['WarpCTCGrad'] = self.gradient + self.check_grad(["Logits"], "Loss", max_relative_error=0.007) + + +class TestWarpCTCOpCase1(TestWarpCTCOp): + def config(self): + self.batch_size = 4 + self.num_classes = CUDA_BLOCK_SIZE + 2 + self.logits_lod = [[0, 4, 5, 8, 11]] + self.labels_lod = [[0, 3, 4, 8, 12]] + self.blank = 0 + self.norm_by_times = False -# def test_check_grad(self): -# self.outputs["WarpCTCGrad"] = None -# self.check_grad(["Logits"], "Loss", max_relative_error=0.01) if __name__ == "__main__": unittest.main()