未验证 提交 6757a315 编写于 作者: C chengduo 提交者: GitHub

[Accelerate] Refine seq_softmax_op (#13421)

* refine seq_softmax_op

* fix seq_softmax

* use cub in seq_softmax
上级 c6865954
......@@ -107,7 +107,7 @@ paddle.fluid.layers.sequence_conv ArgSpec(args=['input', 'num_filters', 'filter_
paddle.fluid.layers.conv2d ArgSpec(args=['input', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'use_mkldnn', 'act', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, True, False, None, None))
paddle.fluid.layers.conv3d ArgSpec(args=['input', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'use_mkldnn', 'act', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, True, False, None, None))
paddle.fluid.layers.sequence_pool ArgSpec(args=['input', 'pool_type'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.sequence_softmax ArgSpec(args=['input', 'param_attr', 'bias_attr', 'use_cudnn'], varargs=None, keywords=None, defaults=(None, None, True))
paddle.fluid.layers.sequence_softmax ArgSpec(args=['input', 'param_attr', 'bias_attr', 'use_cudnn'], varargs=None, keywords=None, defaults=(None, None, False))
paddle.fluid.layers.softmax ArgSpec(args=['input', 'param_attr', 'bias_attr', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(None, None, True, None))
paddle.fluid.layers.pool2d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'use_mkldnn', 'name'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, False, None))
paddle.fluid.layers.pool3d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'use_mkldnn', 'name'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, False, None))
......
......@@ -252,12 +252,12 @@ endif()
op_library(cross_entropy_op DEPS cross_entropy)
if(WITH_GPU)
op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax cub)
op_library(sequence_softmax_op DEPS cub)
else()
op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax)
endif()
op_library(softmax_op DEPS softmax)
op_library(sequence_softmax_op DEPS softmax)
if (WITH_GPU AND TENSORRT_FOUND)
op_library(tensorrt_engine_op DEPS tensorrt_engine tensorrt_converter)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(tensorrt_engine);\n")
......
......@@ -29,8 +29,8 @@ class SequenceSoftmaxCUDNNKernel : public framework::OpKernel<T> {
auto* x = ctx.Input<LoDTensor>("X");
auto* out = ctx.Output<LoDTensor>("Out");
auto lod = x->lod();
auto dims = x->dims();
auto& lod = x->lod();
auto& dims = x->dims();
const size_t level = lod.size() - 1;
PADDLE_ENFORCE_EQ(dims[0], static_cast<int64_t>(lod[level].back()),
......@@ -71,7 +71,7 @@ class SequenceSoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
if (x_grad) {
x_grad->set_lod(x->lod());
}
auto lod = x->lod();
auto& lod = x->lod();
const size_t level = lod.size() - 1;
x_grad->mutable_data<T>(ctx.GetPlace());
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <cub/cub.cuh> // NOLINT
#include "paddle/fluid/operators/sequence_softmax_op.h"
namespace paddle {
namespace operators {
using LoDTensor = framework::LoDTensor;
__device__ __forceinline__ float real_exp(float x) { return expf(x); }
__device__ __forceinline__ double real_exp(double x) { return exp(x); }
template <typename T, int BlockDim>
using BlockReduce = cub::BlockReduce<T, BlockDim>;
template <typename T, int BlockDim>
using BlockReduceTempStorage = typename BlockReduce<T, BlockDim>::TempStorage;
template <typename T, int BlockDim>
__global__ void sequence_softmax_kernel(const T *in_data, const size_t *ref_lod,
const size_t src_hight, T *out_data) {
__shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;
__shared__ T shared_max_data;
__shared__ T shared_sum_data;
for (int i = blockIdx.x; i < src_hight; i += gridDim.x) {
size_t start = ref_lod[i];
size_t span = ref_lod[i + 1] - start;
// Find the max ele
T max_ele = -FLT_MAX;
for (int tid = threadIdx.x; tid < span; tid += blockDim.x) {
T ele = in_data[start + tid];
max_ele = max_ele > ele ? max_ele : ele;
}
max_ele =
BlockReduce<T, BlockDim>(temp_storage).Reduce(max_ele, cub::Max());
if (threadIdx.x == 0) {
shared_max_data = max_ele;
}
__syncthreads();
// sum
T sum_data = 0;
for (int tid = threadIdx.x; tid < span; tid += blockDim.x) {
T ele = in_data[start + tid];
sum_data += real_exp(ele - shared_max_data);
}
sum_data =
BlockReduce<T, BlockDim>(temp_storage).Reduce(sum_data, cub::Sum());
if (threadIdx.x == 0) {
shared_sum_data = sum_data;
}
__syncthreads();
// get final resit
for (int tid = threadIdx.x; tid < span; tid += blockDim.x) {
T ele = in_data[start + tid];
ele = real_exp(ele - shared_max_data) / shared_sum_data;
out_data[start + tid] = ele;
}
}
}
template <typename T, int BlockDim>
__global__ void sequence_softmax_grad_kernel(const T *softmax_grad_data,
const T *softmax_data,
const size_t *ref_lod,
const size_t src_hight,
T *dx_data) {
__shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;
__shared__ T shared_data;
for (int i = blockIdx.x; i < src_hight; i += gridDim.x) {
size_t start = ref_lod[i];
size_t span = ref_lod[i + 1] - start;
T result = 0;
for (int tid = threadIdx.x; tid < span; tid += blockDim.x) {
size_t idx = start + tid;
T s_g_d = softmax_grad_data[idx];
T s_d = softmax_data[idx];
result += s_g_d * s_d;
}
result = BlockReduce<T, BlockDim>(temp_storage).Reduce(result, cub::Sum());
if (threadIdx.x == 0) {
shared_data = result;
}
__syncthreads();
for (int tid = threadIdx.x; tid < span; tid += blockDim.x) {
size_t idx = start + tid;
T s_g_d = softmax_grad_data[idx];
T s_d = softmax_data[idx];
dx_data[idx] = (s_g_d - shared_data) * s_d;
}
}
}
template <typename T>
struct SequenceSoftmaxFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext &context,
const LoDTensor &x,
const framework::Vector<size_t> &ref_lod, /*referenced lod*/
LoDTensor *out) {
int hight = ref_lod.size() - 1;
const int kThreadsPerBlock = 32;
int thread_x = kThreadsPerBlock;
int max_threads = context.GetMaxPhysicalThreadCount();
int max_blocks = std::max(max_threads / kThreadsPerBlock, 1);
dim3 block_size(thread_x);
dim3 grid_size(max_blocks);
sequence_softmax_kernel<
T, kThreadsPerBlock><<<grid_size, block_size, 0, context.stream()>>>(
x.data<T>(), ref_lod.CUDAData(context.GetPlace()), hight,
out->mutable_data<T>(context.GetPlace()));
}
};
template <typename T>
struct SequenceSoftmaxGradFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext &context,
const LoDTensor &dout, const LoDTensor &out,
const framework::Vector<size_t> &ref_lod, /*referenced lod*/
LoDTensor *dx) {
size_t hight = ref_lod.size() - 1;
const int kThreadsPerBlock = 32;
int thread_x = kThreadsPerBlock;
int max_threads = context.GetMaxPhysicalThreadCount();
int max_blocks = std::max(max_threads / kThreadsPerBlock, 1);
dim3 block_size(thread_x);
dim3 grid_size(max_blocks);
sequence_softmax_grad_kernel<
T, kThreadsPerBlock><<<grid_size, block_size, 0, context.stream()>>>(
dout.data<T>(), out.data<T>(), ref_lod.CUDAData(context.GetPlace()),
hight, dx->mutable_data<T>(context.GetPlace()));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
sequence_softmax,
ops::SequenceSoftmaxKernel<paddle::platform::CUDADeviceContext, float>,
ops::SequenceSoftmaxKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
sequence_softmax_grad,
ops::SequenceSoftmaxGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::SequenceSoftmaxGradKernel<paddle::platform::CUDADeviceContext,
double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/sequence_softmax_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
sequence_softmax,
ops::SequenceSoftmaxKernel<paddle::platform::CUDADeviceContext, float>,
ops::SequenceSoftmaxKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
sequence_softmax_grad,
ops::SequenceSoftmaxGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::SequenceSoftmaxGradKernel<paddle::platform::CUDADeviceContext,
double>);
......@@ -15,7 +15,6 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/softmax.h"
namespace paddle {
namespace operators {
......@@ -23,12 +22,76 @@ namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
template <typename DeviceContext, typename T>
struct SequenceSoftmaxFunctor {
void operator()(
const DeviceContext &ctx, const LoDTensor &x,
const framework::Vector<size_t> &ref_lod, /*expand referenced lod*/
LoDTensor *out);
};
template <typename DeviceContext, typename T>
struct SequenceSoftmaxGradFunctor {
void operator()(const DeviceContext &ctx, const LoDTensor &dout,
const LoDTensor &out,
const framework::Vector<size_t> &ref_lod, /*referenced lod*/
LoDTensor *dx);
};
template <typename T>
struct SequenceSoftmaxFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext &ctx, const LoDTensor &x,
const framework::Vector<size_t> &ref_lod, /*referenced lod*/
LoDTensor *out) {
size_t hight = ref_lod.size() - 1;
const T *in_data = x.data<T>();
T *out_data = out->mutable_data<T>(ctx.GetPlace());
for (size_t i = 0; i < hight; ++i) {
size_t span = ref_lod[i + 1] - ref_lod[i];
T result = 0;
for (size_t j = 0; j < span; ++j) {
result += exp(in_data[ref_lod[i] + j]);
}
for (size_t j = 0; j < span; ++j) {
out_data[ref_lod[i] + j] = exp(in_data[ref_lod[i] + j]) / result;
}
}
}
};
template <typename T>
struct SequenceSoftmaxGradFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext &ctx, const LoDTensor &dout,
const LoDTensor &out,
const framework::Vector<size_t> &ref_lod, /*referenced lod*/
LoDTensor *dx) {
size_t hight = ref_lod.size() - 1;
const T *softmax_grad_data = dout.data<T>();
const T *softmax = out.data<T>();
T *dx_data = dx->mutable_data<T>(ctx.GetPlace());
for (size_t i = 0; i < hight; ++i) {
size_t span = ref_lod[i + 1] - ref_lod[i];
T result = 0;
for (size_t j = 0; j < span; ++j) {
result += softmax_grad_data[ref_lod[i] + j] * softmax[ref_lod[i] + j];
}
for (size_t j = 0; j < span; ++j) {
dx_data[ref_lod[i] + j] = (softmax_grad_data[ref_lod[i] + j] - result) *
softmax[ref_lod[i] + j];
}
}
}
};
template <typename DeviceContext, typename T>
class SequenceSoftmaxKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<LoDTensor>("X");
auto* out = ctx.Output<LoDTensor>("Out");
void Compute(const framework::ExecutionContext &ctx) const override {
auto *x = ctx.Input<LoDTensor>("X");
auto *out = ctx.Output<LoDTensor>("Out");
auto lod = x->lod();
auto dims = x->dims();
......@@ -42,55 +105,33 @@ class SequenceSoftmaxKernel : public framework::OpKernel<T> {
"SequenceSoftmaxOp should be 1.");
out->mutable_data<T>(ctx.GetPlace());
for (int i = 0; i < static_cast<int>(lod[level].size()) - 1; ++i) {
int start_pos = static_cast<int>(lod[level][i]);
int end_pos = static_cast<int>(lod[level][i + 1]);
Tensor x_i = x->Slice(start_pos, end_pos);
Tensor out_i = out->Slice(start_pos, end_pos);
// Reshape from (end_pos - start_pos) x 1UL to 1UL x (end_pos - start_pos)
framework::DDim dims_i = framework::make_ddim({1UL, end_pos - start_pos});
x_i.Resize(dims_i);
out_i.Resize(dims_i);
math::SoftmaxFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), &x_i, &out_i);
}
SequenceSoftmaxFunctor<DeviceContext, T> seq_softmax_functor;
seq_softmax_functor(ctx.template device_context<DeviceContext>(), *x,
lod[level], out);
}
};
template <typename DeviceContext, typename T>
class SequenceSoftmaxGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* out = ctx.Input<LoDTensor>("Out");
auto* out_grad = ctx.Input<LoDTensor>(framework::GradVarName("Out"));
auto* x = ctx.Input<LoDTensor>("X");
auto* x_grad = ctx.Output<LoDTensor>(framework::GradVarName("X"));
if (x_grad) {
x_grad->set_lod(x->lod());
void Compute(const framework::ExecutionContext &ctx) const override {
auto *out = ctx.Input<LoDTensor>("Out");
auto *out_grad = ctx.Input<LoDTensor>(framework::GradVarName("Out"));
auto *x = ctx.Input<LoDTensor>("X");
auto *x_grad = ctx.Output<LoDTensor>(framework::GradVarName("X"));
if (!x_grad) {
return;
}
x_grad->set_lod(x->lod());
auto lod = x->lod();
const size_t level = lod.size() - 1;
x_grad->mutable_data<T>(ctx.GetPlace());
for (int i = 0; i < static_cast<int>(lod[level].size()) - 1; ++i) {
int start_pos = static_cast<int>(lod[level][i]);
int end_pos = static_cast<int>(lod[level][i + 1]);
Tensor out_i = out->Slice(start_pos, end_pos);
Tensor out_grad_i = out_grad->Slice(start_pos, end_pos);
Tensor x_grad_i = x_grad->Slice(start_pos, end_pos);
// Reshape from (end_pos - start_pos) x 1UL to 1UL x (end_pos - start_pos)
framework::DDim dims_i = framework::make_ddim({1UL, end_pos - start_pos});
out_i.Resize(dims_i);
out_grad_i.Resize(dims_i);
x_grad_i.Resize(dims_i);
math::SoftmaxGradFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), &out_i, &out_grad_i,
&x_grad_i);
}
SequenceSoftmaxGradFunctor<DeviceContext, T> seq_softmax_grad_functor;
seq_softmax_grad_functor(ctx.template device_context<DeviceContext>(),
*out_grad, *out, lod[level], x_grad);
}
};
......
......@@ -1275,7 +1275,7 @@ def sequence_conv(input,
return helper.append_activation(pre_act)
def sequence_softmax(input, param_attr=None, bias_attr=None, use_cudnn=True):
def sequence_softmax(input, param_attr=None, bias_attr=None, use_cudnn=False):
"""
This function computes the softmax activation among all time-steps for each
sequence. The dimension of each time-step should be 1. Thus, the shape of
......@@ -1298,7 +1298,7 @@ def sequence_softmax(input, param_attr=None, bias_attr=None, use_cudnn=True):
bias_attr (ParamAttr|None): attributes for bias
param_attr (ParamAttr|None): attributes for parameter
use_cudnn (bool): Use cudnn kernel or not, it is valid only when the cudnn \
library is installed. Default: True
library is installed. Default: False
Returns:
Variable: output of sequence_softmax
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册