// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "lite/kernels/cuda/conv_compute.h" #include #include "lite/core/op_registry.h" namespace paddle { namespace lite { namespace kernels { namespace cuda { inline int ConvOutputSize(int input_size, int filter_size, int dilation, int pad_left, int pad_right, int stride) { const int dkernel = dilation * (filter_size - 1) + 1; int output_size = (input_size + pad_left + pad_right - dkernel) / stride + 1; CHECK_GT_OR_FALSE(output_size, 0); return output_size; } void ConvCompute::PrepareForRun() { auto& param = this->Param(); auto& ctx = this->ctx_->template As(); conv_impl_.reset(new lite::cuda::math::CudnnConv2D); conv_impl_->init(param, &ctx); } void ConvCompute::Run() { auto& param = this->Param(); conv_impl_->run(param); } template void ConvComputeInt8::PrepareForRun() { auto& param = this->Param(); const auto in_dims = param.x->dims(); const auto filter_dims = param.filter->dims(); std::vector output_shape({in_dims[0]}); auto paddings = *param.paddings; auto dilations = *param.dilations; for (size_t i = 0; i < param.strides.size(); ++i) { output_shape.push_back(ConvOutputSize(in_dims[i + 1], filter_dims[i + 1], dilations[i], paddings[2 * i], paddings[2 * i + 1], param.strides[i])); } output_shape.push_back(filter_dims[0]); param.output->Resize(lite::DDim(output_shape)); auto& ctx = this->ctx_->template As(); conv_impl_.reset(new lite::cuda::math::CudnnConv2DInt8); conv_impl_->init(param, &ctx); } template void ConvComputeInt8::Run() { auto& param = this->Param(); const auto in_dims = param.x->dims(); const auto filter_dims = param.filter->dims(); std::vector output_shape({in_dims[0]}); auto paddings = *param.paddings; auto dilations = *param.dilations; for (size_t i = 0; i < param.strides.size(); ++i) { output_shape.push_back(ConvOutputSize(in_dims[i + 1], filter_dims[i + 1], dilations[i], paddings[2 * i], paddings[2 * i + 1], param.strides[i])); } output_shape.push_back(filter_dims[0]); param.output->Resize(lite::DDim(output_shape)); conv_impl_->run(param); } template class ConvComputeInt8; template class ConvComputeInt8; } // namespace cuda } // namespace kernels } // namespace lite } // namespace paddle REGISTER_LITE_KERNEL( conv2d, kCUDA, kFloat, kNCHW, paddle::lite::kernels::cuda::ConvCompute, def) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNCHW))}) .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat))}) .BindInput("Filter", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNCHW))}) .BindOutput("Output", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNCHW))}) .Finalize(); REGISTER_LITE_KERNEL(depthwise_conv2d, kCUDA, kFloat, kNCHW, paddle::lite::kernels::cuda::ConvCompute, def) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNCHW))}) .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat))}) .BindInput("Filter", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNCHW))}) .BindOutput("Output", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNCHW))}) .Finalize(); REGISTER_LITE_KERNEL( conv2d, kCUDA, kInt8, kNHWC, paddle::lite::kernels::cuda::ConvComputeInt8, fp32_out) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kInt8), DATALAYOUT(kNHWC))}) .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat))}) .BindInput("Filter", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kInt8), DATALAYOUT(kNHWC))}) .BindOutput("Output", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNHWC))}) .Finalize(); REGISTER_LITE_KERNEL( conv2d, kCUDA, kInt8, kNHWC, paddle::lite::kernels::cuda::ConvComputeInt8, int8_out) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kInt8), DATALAYOUT(kNHWC))}) .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat))}) .BindInput("Filter", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kInt8), DATALAYOUT(kNHWC))}) .BindOutput("Output", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kInt8), DATALAYOUT(kNHWC))}) .Finalize();