diff --git a/lite/kernels/cuda/sequence_mask_compute.cu b/lite/kernels/cuda/sequence_mask_compute.cu index 8e227a6a272127f500e10775f7ed4db53660e1f8..e15c43342585a654fb9fc15a776ad5dc6f270696 100644 --- a/lite/kernels/cuda/sequence_mask_compute.cu +++ b/lite/kernels/cuda/sequence_mask_compute.cu @@ -37,6 +37,40 @@ __global__ void SequenceMaskKernel(T* dst, } } +template +__global__ void VecMaxKernel(const T* in_data, T* out, const int count) { + extern __shared__ T cache[]; + + int i = blockDim.x * blockIdx.x + threadIdx.x; + int cache_index = threadIdx.x; + T tmp = -1; + + while (i < count) { + if (in_data[i] > tmp) { + tmp = in_data[i]; + } + i += blockDim.x * gridDim.x; + } + cache[cache_index] = tmp; + + __syncthreads(); + + // perform parallel reduction, blockDim.x must be 2^n + int ib = blockDim.x / 2; + while (ib != 0) { + if (cache_index < ib && cache[cache_index + ib] > cache[cache_index]) { + cache[cache_index] = cache[cache_index + ib]; + } + + __syncthreads(); + + ib /= 2; + } + if (cache_index == 0) { + out[blockIdx.x] = cache[0]; + } +} + template void SequenceMaskCompute::Run() { auto& param = this->template Param(); @@ -57,11 +91,34 @@ void SequenceMaskCompute::Run() { } if (maxlen < 0) { - maxlen = static_cast( - thrust::reduce(thrust::device_pointer_cast(x_data), - thrust::device_pointer_cast(x_data) + x->numel(), - static_cast(0), - thrust::maximum())); + // choose algorithm according to magic_num. + const int magic_num = 256; + std::vector h_max_data; + if (x->numel() < magic_num) { + h_max_data.resize(x->numel()); + TargetWrapperCuda::MemcpySync(h_max_data.data(), + x_data, + x->numel() * sizeof(int64_t), + IoDirection::DtoH); + } else { + const int threads = 256; + const int blocks = (x->numel() + threads - 1) / threads; + max_tensor_.Resize({blocks}); + auto* max_data = max_tensor_.mutable_data(TARGET(kCUDA)); + VecMaxKernel< + int64_t><<>>( + x_data, max_data, x->numel()); + h_max_data.resize(blocks); + TargetWrapperCuda::MemcpyAsync(h_max_data.data(), + max_data, + sizeof(int64_t) * blocks, + IoDirection::DtoH, + stream); + TargetWrapperCuda::StreamSync(stream); + } + auto maxlen_iterator = + std::max_element(h_max_data.begin(), h_max_data.end()); + maxlen = h_max_data[std::distance(h_max_data.begin(), maxlen_iterator)]; } auto y_dim = x->dims().Vectorize(); diff --git a/lite/kernels/cuda/sequence_mask_compute.h b/lite/kernels/cuda/sequence_mask_compute.h index 3611587f0ce7daef1a88f5b6a916e2d30d33bcc1..965be02e0aab4f37f92848f96f1d7b36bf6e493b 100644 --- a/lite/kernels/cuda/sequence_mask_compute.h +++ b/lite/kernels/cuda/sequence_mask_compute.h @@ -28,6 +28,9 @@ class SequenceMaskCompute : public KernelLite { void Run() override; virtual ~SequenceMaskCompute() = default; + + private: + lite::Tensor max_tensor_; }; } // namespace cuda diff --git a/lite/operators/fc_op.cc b/lite/operators/fc_op.cc index e0591c5eae3bf351aa6dae5ff981e3b9c81249e0..5d60af4af075ac11b936868ed822a28e55baef6b 100644 --- a/lite/operators/fc_op.cc +++ b/lite/operators/fc_op.cc @@ -50,9 +50,15 @@ bool FcOpLite::CheckShape() const { bool FcOpLite::InferShapeImpl() const { const auto& input_dims = param_.input->dims(); - const auto& w_dims = param_.w_dims; + int64_t w_dims_1; + if (param_.w_dims.empty()) { + const auto& w_dims = param_.w->dims(); + w_dims_1 = param_.padding_weights ? w_dims[1] - 4 : w_dims[1]; + } else { + const auto& w_dims = param_.w_dims; + w_dims_1 = param_.padding_weights ? w_dims[1] - 4 : w_dims[1]; + } int in_num_col_dims = param_.in_num_col_dims; - int64_t w_dims_1 = param_.padding_weights ? w_dims[1] - 4 : w_dims[1]; // Set output dims std::vector output_dims(in_num_col_dims + 1);