diff --git a/lite/kernels/cuda/var_conv_2d_compute.cu b/lite/kernels/cuda/var_conv_2d_compute.cu index 1e42635934b67b28fca29808f484be53292d74cf..1417282dcba9751c583d69912dddbcd82ca28fe9 100644 --- a/lite/kernels/cuda/var_conv_2d_compute.cu +++ b/lite/kernels/cuda/var_conv_2d_compute.cu @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include #include #include "lite/backends/cuda/math/gemm.h" @@ -38,6 +39,32 @@ inline int ConvOutputSize(int input_size, return output_size; } +// Eliminate the effects of pad, support batch > 1. +template +__global__ void eliminate_pad_effect(dtype* src, + const int64_t* offset, + const int num_batch, + const int batch_stride, + const int num_channel, + const int channel_stride, + const int num_height, + const int height_stride, + const int num_width, + const int width_stride, + const int count) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + int thread_num = blockDim.x * gridDim.x; + for (tid = threadIdx.x + blockIdx.x * blockDim.x; tid < count; + tid += thread_num) { + int batch_id = tid / batch_stride; + int width_id = tid % num_width; + int cur_len = offset[batch_id + 1] - offset[batch_id]; + if (width_id >= cur_len) { + src[tid] = 0.; + } + } +} + void VarConv2DCompute::PrepareForRun() { auto& context = this->ctx_->template As(); auto stream = context.exec_stream(); @@ -102,6 +129,46 @@ void VarConv2DCompute::Run() { conv_param_.output->Resize({output_shape}); conv_impl_->create(conv_param_, &context); conv_impl_->run(conv_param_); + + // Avoid situations where cascading conv does not support multiple batch + // calculations + float* out_data = param.Out->mutable_data(); + const int batch_num = output_shape[1] * output_shape[2] * output_shape[3]; + std::vector lod(param.X->lod()[0].size(), 0); + for (size_t i = 0; i < param.X->lod()[0].size(); ++i) { + lod[i] = param.X->lod()[0][i]; + } + int count = std::accumulate( + output_shape.begin(), output_shape.end(), 1, std::multiplies()); + int width_stride = 1; + int height_stride = output_shape[3]; + int channel_stride = output_shape[2] * output_shape[3]; + int batch_stride = output_shape[1] * output_shape[2] * output_shape[3]; + int threads = 512; + int blocks = (count + threads - 1) / threads; + + offset_.Resize({static_cast(lod.size())}); + int64_t* d_offset = offset_.mutable_data(TARGET(kCUDA)); + TargetWrapperCuda::MemcpyAsync(d_offset, + lod.data(), + sizeof(int64_t) * lod.size(), + IoDirection::HtoD, + stream); + + eliminate_pad_effect<<>>(out_data, + d_offset, + output_shape[0], + batch_stride, + output_shape[1], + channel_stride, + output_shape[2], + height_stride, + output_shape[3], + width_stride, + count); + + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) LOG(ERROR) << cudaGetErrorString(error); } } // namespace cuda diff --git a/lite/kernels/cuda/var_conv_2d_compute.h b/lite/kernels/cuda/var_conv_2d_compute.h index 4bb61132dbb49579875fa6d3b311a80a0a177394..6f6b74e2fe41eb60acb242caffb7312cdb66595d 100644 --- a/lite/kernels/cuda/var_conv_2d_compute.h +++ b/lite/kernels/cuda/var_conv_2d_compute.h @@ -33,6 +33,7 @@ class VarConv2DCompute : public KernelLite { private: mutable operators::ConvParam conv_param_; std::unique_ptr> conv_impl_; + lite::Tensor offset_; }; } // namespace cuda diff --git a/lite/kernels/x86/var_conv_2d_compute.h b/lite/kernels/x86/var_conv_2d_compute.h index c94cb2ca2d43a138b5769653d6cad2d52d420563..7a9ba16d2ea87adb40df23e1fbe149ab805afbc8 100644 --- a/lite/kernels/x86/var_conv_2d_compute.h +++ b/lite/kernels/x86/var_conv_2d_compute.h @@ -44,6 +44,7 @@ class VarConv2DCompute : public KernelLite { // 2-D lod info. // const auto& offset_x = in_col->lod()[0]; // const auto& offset_y = in_row->lod()[0]; + CHECK_EQ(param.X->lod().size(), 3) << "input lod size should be 3!"; const auto& offset_y = param.X->lod()[1]; const auto& offset_x = param.X->lod()[2];