diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 85e92630557dd0c93b02e186b271db6d9a2767d8..41ed774bf8431fbd0d58da02c32b7fc64db0c942 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -107,7 +107,7 @@ paddle.fluid.layers.sequence_conv ArgSpec(args=['input', 'num_filters', 'filter_ paddle.fluid.layers.conv2d ArgSpec(args=['input', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'use_mkldnn', 'act', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, True, False, None, None)) paddle.fluid.layers.conv3d ArgSpec(args=['input', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'use_mkldnn', 'act', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, True, False, None, None)) paddle.fluid.layers.sequence_pool ArgSpec(args=['input', 'pool_type'], varargs=None, keywords=None, defaults=None) -paddle.fluid.layers.sequence_softmax ArgSpec(args=['input', 'param_attr', 'bias_attr', 'use_cudnn'], varargs=None, keywords=None, defaults=(None, None, True)) +paddle.fluid.layers.sequence_softmax ArgSpec(args=['input', 'param_attr', 'bias_attr', 'use_cudnn'], varargs=None, keywords=None, defaults=(None, None, False)) paddle.fluid.layers.softmax ArgSpec(args=['input', 'param_attr', 'bias_attr', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(None, None, True, None)) paddle.fluid.layers.pool2d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'use_mkldnn', 'name'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, False, None)) paddle.fluid.layers.pool3d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'use_mkldnn', 'name'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, False, None)) diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index ccb7fa1f8cce8cc757038904bce762af3b5ff30b..9c67df7bdfb2c4e5d1c9fe60676c412ab11b4fa5 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -252,12 +252,12 @@ endif() op_library(cross_entropy_op DEPS cross_entropy) if(WITH_GPU) op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax cub) + op_library(sequence_softmax_op DEPS cub) else() op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax) endif() op_library(softmax_op DEPS softmax) -op_library(sequence_softmax_op DEPS softmax) if (WITH_GPU AND TENSORRT_FOUND) op_library(tensorrt_engine_op DEPS tensorrt_engine tensorrt_converter) file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(tensorrt_engine);\n") diff --git a/paddle/fluid/operators/sequence_softmax_cudnn_op.cu.cc b/paddle/fluid/operators/sequence_softmax_cudnn_op.cu.cc index 7aca9f7111956dba63e2ceee10077879fe092bdf..585363958696fa0d8ed1ffdc7b6fdaab26349b08 100644 --- a/paddle/fluid/operators/sequence_softmax_cudnn_op.cu.cc +++ b/paddle/fluid/operators/sequence_softmax_cudnn_op.cu.cc @@ -29,8 +29,8 @@ class SequenceSoftmaxCUDNNKernel : public framework::OpKernel { auto* x = ctx.Input("X"); auto* out = ctx.Output("Out"); - auto lod = x->lod(); - auto dims = x->dims(); + auto& lod = x->lod(); + auto& dims = x->dims(); const size_t level = lod.size() - 1; PADDLE_ENFORCE_EQ(dims[0], static_cast(lod[level].back()), @@ -71,7 +71,7 @@ class SequenceSoftmaxGradCUDNNKernel : public framework::OpKernel { if (x_grad) { x_grad->set_lod(x->lod()); } - auto lod = x->lod(); + auto& lod = x->lod(); const size_t level = lod.size() - 1; x_grad->mutable_data(ctx.GetPlace()); diff --git a/paddle/fluid/operators/sequence_softmax_op.cu b/paddle/fluid/operators/sequence_softmax_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..e94ceaa170131e8bce7d1574b27f0baeaa8d1ffc --- /dev/null +++ b/paddle/fluid/operators/sequence_softmax_op.cu @@ -0,0 +1,171 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include // NOLINT +#include "paddle/fluid/operators/sequence_softmax_op.h" + +namespace paddle { +namespace operators { + +using LoDTensor = framework::LoDTensor; + +__device__ __forceinline__ float real_exp(float x) { return expf(x); } +__device__ __forceinline__ double real_exp(double x) { return exp(x); } + +template +using BlockReduce = cub::BlockReduce; + +template +using BlockReduceTempStorage = typename BlockReduce::TempStorage; + +template +__global__ void sequence_softmax_kernel(const T *in_data, const size_t *ref_lod, + const size_t src_hight, T *out_data) { + __shared__ BlockReduceTempStorage temp_storage; + __shared__ T shared_max_data; + __shared__ T shared_sum_data; + + for (int i = blockIdx.x; i < src_hight; i += gridDim.x) { + size_t start = ref_lod[i]; + size_t span = ref_lod[i + 1] - start; + + // Find the max ele + T max_ele = -FLT_MAX; + for (int tid = threadIdx.x; tid < span; tid += blockDim.x) { + T ele = in_data[start + tid]; + max_ele = max_ele > ele ? max_ele : ele; + } + max_ele = + BlockReduce(temp_storage).Reduce(max_ele, cub::Max()); + if (threadIdx.x == 0) { + shared_max_data = max_ele; + } + __syncthreads(); + + // sum + T sum_data = 0; + for (int tid = threadIdx.x; tid < span; tid += blockDim.x) { + T ele = in_data[start + tid]; + sum_data += real_exp(ele - shared_max_data); + } + sum_data = + BlockReduce(temp_storage).Reduce(sum_data, cub::Sum()); + if (threadIdx.x == 0) { + shared_sum_data = sum_data; + } + __syncthreads(); + + // get final resit + for (int tid = threadIdx.x; tid < span; tid += blockDim.x) { + T ele = in_data[start + tid]; + ele = real_exp(ele - shared_max_data) / shared_sum_data; + out_data[start + tid] = ele; + } + } +} + +template +__global__ void sequence_softmax_grad_kernel(const T *softmax_grad_data, + const T *softmax_data, + const size_t *ref_lod, + const size_t src_hight, + T *dx_data) { + __shared__ BlockReduceTempStorage temp_storage; + __shared__ T shared_data; + + for (int i = blockIdx.x; i < src_hight; i += gridDim.x) { + size_t start = ref_lod[i]; + size_t span = ref_lod[i + 1] - start; + + T result = 0; + for (int tid = threadIdx.x; tid < span; tid += blockDim.x) { + size_t idx = start + tid; + T s_g_d = softmax_grad_data[idx]; + T s_d = softmax_data[idx]; + result += s_g_d * s_d; + } + result = BlockReduce(temp_storage).Reduce(result, cub::Sum()); + if (threadIdx.x == 0) { + shared_data = result; + } + __syncthreads(); + + for (int tid = threadIdx.x; tid < span; tid += blockDim.x) { + size_t idx = start + tid; + T s_g_d = softmax_grad_data[idx]; + T s_d = softmax_data[idx]; + dx_data[idx] = (s_g_d - shared_data) * s_d; + } + } +} + +template +struct SequenceSoftmaxFunctor { + void operator()(const platform::CUDADeviceContext &context, + const LoDTensor &x, + const framework::Vector &ref_lod, /*referenced lod*/ + LoDTensor *out) { + int hight = ref_lod.size() - 1; + + const int kThreadsPerBlock = 32; + int thread_x = kThreadsPerBlock; + int max_threads = context.GetMaxPhysicalThreadCount(); + int max_blocks = std::max(max_threads / kThreadsPerBlock, 1); + + dim3 block_size(thread_x); + dim3 grid_size(max_blocks); + sequence_softmax_kernel< + T, kThreadsPerBlock><<>>( + x.data(), ref_lod.CUDAData(context.GetPlace()), hight, + out->mutable_data(context.GetPlace())); + } +}; + +template +struct SequenceSoftmaxGradFunctor { + void operator()(const platform::CUDADeviceContext &context, + const LoDTensor &dout, const LoDTensor &out, + const framework::Vector &ref_lod, /*referenced lod*/ + LoDTensor *dx) { + size_t hight = ref_lod.size() - 1; + + const int kThreadsPerBlock = 32; + int thread_x = kThreadsPerBlock; + int max_threads = context.GetMaxPhysicalThreadCount(); + int max_blocks = std::max(max_threads / kThreadsPerBlock, 1); + + dim3 block_size(thread_x); + dim3 grid_size(max_blocks); + + sequence_softmax_grad_kernel< + T, kThreadsPerBlock><<>>( + dout.data(), out.data(), ref_lod.CUDAData(context.GetPlace()), + hight, dx->mutable_data(context.GetPlace())); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + sequence_softmax, + ops::SequenceSoftmaxKernel, + ops::SequenceSoftmaxKernel); +REGISTER_OP_CUDA_KERNEL( + sequence_softmax_grad, + ops::SequenceSoftmaxGradKernel, + ops::SequenceSoftmaxGradKernel); diff --git a/paddle/fluid/operators/sequence_softmax_op.cu.cc b/paddle/fluid/operators/sequence_softmax_op.cu.cc deleted file mode 100644 index 397df75415691e4f53bc399cd1868c3e37bc9110..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/sequence_softmax_op.cu.cc +++ /dev/null @@ -1,26 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/sequence_softmax_op.h" - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - sequence_softmax, - ops::SequenceSoftmaxKernel, - ops::SequenceSoftmaxKernel); -REGISTER_OP_CUDA_KERNEL( - sequence_softmax_grad, - ops::SequenceSoftmaxGradKernel, - ops::SequenceSoftmaxGradKernel); diff --git a/paddle/fluid/operators/sequence_softmax_op.h b/paddle/fluid/operators/sequence_softmax_op.h index bca564e16f9951519eefe25126aadebb4c1326b6..ed49e9471458cbca2d4760d966ef30033f292778 100644 --- a/paddle/fluid/operators/sequence_softmax_op.h +++ b/paddle/fluid/operators/sequence_softmax_op.h @@ -15,7 +15,6 @@ limitations under the License. */ #pragma once #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/softmax.h" namespace paddle { namespace operators { @@ -23,12 +22,76 @@ namespace operators { using Tensor = framework::Tensor; using LoDTensor = framework::LoDTensor; +template +struct SequenceSoftmaxFunctor { + void operator()( + const DeviceContext &ctx, const LoDTensor &x, + const framework::Vector &ref_lod, /*expand referenced lod*/ + LoDTensor *out); +}; + +template +struct SequenceSoftmaxGradFunctor { + void operator()(const DeviceContext &ctx, const LoDTensor &dout, + const LoDTensor &out, + const framework::Vector &ref_lod, /*referenced lod*/ + LoDTensor *dx); +}; + +template +struct SequenceSoftmaxFunctor { + void operator()(const platform::CPUDeviceContext &ctx, const LoDTensor &x, + const framework::Vector &ref_lod, /*referenced lod*/ + LoDTensor *out) { + size_t hight = ref_lod.size() - 1; + const T *in_data = x.data(); + T *out_data = out->mutable_data(ctx.GetPlace()); + for (size_t i = 0; i < hight; ++i) { + size_t span = ref_lod[i + 1] - ref_lod[i]; + T result = 0; + for (size_t j = 0; j < span; ++j) { + result += exp(in_data[ref_lod[i] + j]); + } + for (size_t j = 0; j < span; ++j) { + out_data[ref_lod[i] + j] = exp(in_data[ref_lod[i] + j]) / result; + } + } + } +}; + +template +struct SequenceSoftmaxGradFunctor { + void operator()(const platform::CPUDeviceContext &ctx, const LoDTensor &dout, + const LoDTensor &out, + const framework::Vector &ref_lod, /*referenced lod*/ + LoDTensor *dx) { + size_t hight = ref_lod.size() - 1; + + const T *softmax_grad_data = dout.data(); + const T *softmax = out.data(); + T *dx_data = dx->mutable_data(ctx.GetPlace()); + + for (size_t i = 0; i < hight; ++i) { + size_t span = ref_lod[i + 1] - ref_lod[i]; + T result = 0; + for (size_t j = 0; j < span; ++j) { + result += softmax_grad_data[ref_lod[i] + j] * softmax[ref_lod[i] + j]; + } + + for (size_t j = 0; j < span; ++j) { + dx_data[ref_lod[i] + j] = (softmax_grad_data[ref_lod[i] + j] - result) * + softmax[ref_lod[i] + j]; + } + } + } +}; + template class SequenceSoftmaxKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* x = ctx.Input("X"); - auto* out = ctx.Output("Out"); + void Compute(const framework::ExecutionContext &ctx) const override { + auto *x = ctx.Input("X"); + auto *out = ctx.Output("Out"); auto lod = x->lod(); auto dims = x->dims(); @@ -42,55 +105,33 @@ class SequenceSoftmaxKernel : public framework::OpKernel { "SequenceSoftmaxOp should be 1."); out->mutable_data(ctx.GetPlace()); - for (int i = 0; i < static_cast(lod[level].size()) - 1; ++i) { - int start_pos = static_cast(lod[level][i]); - int end_pos = static_cast(lod[level][i + 1]); - Tensor x_i = x->Slice(start_pos, end_pos); - Tensor out_i = out->Slice(start_pos, end_pos); - - // Reshape from (end_pos - start_pos) x 1UL to 1UL x (end_pos - start_pos) - framework::DDim dims_i = framework::make_ddim({1UL, end_pos - start_pos}); - x_i.Resize(dims_i); - out_i.Resize(dims_i); - math::SoftmaxFunctor()( - ctx.template device_context(), &x_i, &out_i); - } + + SequenceSoftmaxFunctor seq_softmax_functor; + seq_softmax_functor(ctx.template device_context(), *x, + lod[level], out); } }; template class SequenceSoftmaxGradKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* out = ctx.Input("Out"); - auto* out_grad = ctx.Input(framework::GradVarName("Out")); - auto* x = ctx.Input("X"); - auto* x_grad = ctx.Output(framework::GradVarName("X")); - if (x_grad) { - x_grad->set_lod(x->lod()); + void Compute(const framework::ExecutionContext &ctx) const override { + auto *out = ctx.Input("Out"); + auto *out_grad = ctx.Input(framework::GradVarName("Out")); + auto *x = ctx.Input("X"); + auto *x_grad = ctx.Output(framework::GradVarName("X")); + if (!x_grad) { + return; } + x_grad->set_lod(x->lod()); auto lod = x->lod(); const size_t level = lod.size() - 1; - x_grad->mutable_data(ctx.GetPlace()); - for (int i = 0; i < static_cast(lod[level].size()) - 1; ++i) { - int start_pos = static_cast(lod[level][i]); - int end_pos = static_cast(lod[level][i + 1]); - - Tensor out_i = out->Slice(start_pos, end_pos); - Tensor out_grad_i = out_grad->Slice(start_pos, end_pos); - Tensor x_grad_i = x_grad->Slice(start_pos, end_pos); - - // Reshape from (end_pos - start_pos) x 1UL to 1UL x (end_pos - start_pos) - framework::DDim dims_i = framework::make_ddim({1UL, end_pos - start_pos}); - out_i.Resize(dims_i); - out_grad_i.Resize(dims_i); - x_grad_i.Resize(dims_i); - math::SoftmaxGradFunctor()( - ctx.template device_context(), &out_i, &out_grad_i, - &x_grad_i); - } + + SequenceSoftmaxGradFunctor seq_softmax_grad_functor; + seq_softmax_grad_functor(ctx.template device_context(), + *out_grad, *out, lod[level], x_grad); } }; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index c6de22f996184c7f07b22b6255829b5a65aad32a..a1a966be2c059e448dabceeeeac052502360c6ef 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -1275,7 +1275,7 @@ def sequence_conv(input, return helper.append_activation(pre_act) -def sequence_softmax(input, param_attr=None, bias_attr=None, use_cudnn=True): +def sequence_softmax(input, param_attr=None, bias_attr=None, use_cudnn=False): """ This function computes the softmax activation among all time-steps for each sequence. The dimension of each time-step should be 1. Thus, the shape of @@ -1298,7 +1298,7 @@ def sequence_softmax(input, param_attr=None, bias_attr=None, use_cudnn=True): bias_attr (ParamAttr|None): attributes for bias param_attr (ParamAttr|None): attributes for parameter use_cudnn (bool): Use cudnn kernel or not, it is valid only when the cudnn \ - library is installed. Default: True + library is installed. Default: False Returns: Variable: output of sequence_softmax