未验证 提交 192be07b 编写于 作者: W Wilber 提交者: GitHub

optimize sequence_mask. test=develop (#4120)

上级 9e38adc8
...@@ -37,6 +37,40 @@ __global__ void SequenceMaskKernel(T* dst, ...@@ -37,6 +37,40 @@ __global__ void SequenceMaskKernel(T* dst,
} }
} }
template <typename T>
__global__ void VecMaxKernel(const T* in_data, T* out, const int count) {
extern __shared__ T cache[];
int i = blockDim.x * blockIdx.x + threadIdx.x;
int cache_index = threadIdx.x;
T tmp = -1;
while (i < count) {
if (in_data[i] > tmp) {
tmp = in_data[i];
}
i += blockDim.x * gridDim.x;
}
cache[cache_index] = tmp;
__syncthreads();
// perform parallel reduction, blockDim.x must be 2^n
int ib = blockDim.x / 2;
while (ib != 0) {
if (cache_index < ib && cache[cache_index + ib] > cache[cache_index]) {
cache[cache_index] = cache[cache_index + ib];
}
__syncthreads();
ib /= 2;
}
if (cache_index == 0) {
out[blockIdx.x] = cache[0];
}
}
template <typename T, PrecisionType Ptype> template <typename T, PrecisionType Ptype>
void SequenceMaskCompute<T, Ptype>::Run() { void SequenceMaskCompute<T, Ptype>::Run() {
auto& param = this->template Param<param_t>(); auto& param = this->template Param<param_t>();
...@@ -57,11 +91,34 @@ void SequenceMaskCompute<T, Ptype>::Run() { ...@@ -57,11 +91,34 @@ void SequenceMaskCompute<T, Ptype>::Run() {
} }
if (maxlen < 0) { if (maxlen < 0) {
maxlen = static_cast<int>( // choose algorithm according to magic_num.
thrust::reduce(thrust::device_pointer_cast(x_data), const int magic_num = 256;
thrust::device_pointer_cast(x_data) + x->numel(), std::vector<int64_t> h_max_data;
static_cast<int64_t>(0), if (x->numel() < magic_num) {
thrust::maximum<int64_t>())); h_max_data.resize(x->numel());
TargetWrapperCuda::MemcpySync(h_max_data.data(),
x_data,
x->numel() * sizeof(int64_t),
IoDirection::DtoH);
} else {
const int threads = 256;
const int blocks = (x->numel() + threads - 1) / threads;
max_tensor_.Resize({blocks});
auto* max_data = max_tensor_.mutable_data<int64_t>(TARGET(kCUDA));
VecMaxKernel<
int64_t><<<blocks, threads, threads * sizeof(int64_t), stream>>>(
x_data, max_data, x->numel());
h_max_data.resize(blocks);
TargetWrapperCuda::MemcpyAsync(h_max_data.data(),
max_data,
sizeof(int64_t) * blocks,
IoDirection::DtoH,
stream);
TargetWrapperCuda::StreamSync(stream);
}
auto maxlen_iterator =
std::max_element(h_max_data.begin(), h_max_data.end());
maxlen = h_max_data[std::distance(h_max_data.begin(), maxlen_iterator)];
} }
auto y_dim = x->dims().Vectorize(); auto y_dim = x->dims().Vectorize();
......
...@@ -28,6 +28,9 @@ class SequenceMaskCompute : public KernelLite<TARGET(kCUDA), Ptype> { ...@@ -28,6 +28,9 @@ class SequenceMaskCompute : public KernelLite<TARGET(kCUDA), Ptype> {
void Run() override; void Run() override;
virtual ~SequenceMaskCompute() = default; virtual ~SequenceMaskCompute() = default;
private:
lite::Tensor max_tensor_;
}; };
} // namespace cuda } // namespace cuda
......
...@@ -50,9 +50,15 @@ bool FcOpLite::CheckShape() const { ...@@ -50,9 +50,15 @@ bool FcOpLite::CheckShape() const {
bool FcOpLite::InferShapeImpl() const { bool FcOpLite::InferShapeImpl() const {
const auto& input_dims = param_.input->dims(); const auto& input_dims = param_.input->dims();
const auto& w_dims = param_.w_dims; int64_t w_dims_1;
if (param_.w_dims.empty()) {
const auto& w_dims = param_.w->dims();
w_dims_1 = param_.padding_weights ? w_dims[1] - 4 : w_dims[1];
} else {
const auto& w_dims = param_.w_dims;
w_dims_1 = param_.padding_weights ? w_dims[1] - 4 : w_dims[1];
}
int in_num_col_dims = param_.in_num_col_dims; int in_num_col_dims = param_.in_num_col_dims;
int64_t w_dims_1 = param_.padding_weights ? w_dims[1] - 4 : w_dims[1];
// Set output dims // Set output dims
std::vector<DDim::value_type> output_dims(in_num_col_dims + 1); std::vector<DDim::value_type> output_dims(in_num_col_dims + 1);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册