conv_compute.cc 6.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/kernels/cuda/conv_compute.h"
Z
Zhaolong Xing 已提交
16
#include <vector>
17 18 19 20 21 22 23
#include "lite/core/op_registry.h"

namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {

Z
Zhaolong Xing 已提交
24 25 26 27 28 29 30 31 32
inline int ConvOutputSize(
    int input_size, int filter_size, int dilation, int padding, int stride) {
  const int dkernel = dilation * (filter_size - 1) + 1;
  int output_size = (input_size + 2 * padding - dkernel) / stride + 1;
  CHECK_GT_OR_FALSE(output_size, 0);

  return output_size;
}

33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
void ConvCompute::PrepareForRun() {
  auto& param = this->Param<param_t>();
  auto& ctx = this->ctx_->template As<CUDAContext>();
  conv_impl_.reset(new lite::cuda::math::CudnnConv2D<PRECISION(kFloat)>);
  conv_impl_->init(param, &ctx);
}

void ConvCompute::Run() {
  auto& param = this->Param<param_t>();
  conv_impl_->run(param);
}

template <PrecisionType Ptype_out>
void ConvComputeInt8<Ptype_out>::PrepareForRun() {
  auto& param = this->Param<param_t>();
Z
Zhaolong Xing 已提交
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62

  const auto in_dims = param.x->dims();
  const auto filter_dims = param.filter->dims();
  std::vector<int64_t> output_shape({in_dims[0]});

  for (size_t i = 0; i < param.strides.size(); ++i) {
    output_shape.push_back(ConvOutputSize(in_dims[i + 1],
                                          filter_dims[i + 1],
                                          param.dilations[i],
                                          param.paddings[i],
                                          param.strides[i]));
  }
  output_shape.push_back(filter_dims[0]);
  param.output->Resize(lite::DDim(output_shape));

63 64 65 66 67 68 69 70
  auto& ctx = this->ctx_->template As<CUDAContext>();
  conv_impl_.reset(new lite::cuda::math::CudnnConv2DInt8<Ptype_out>);
  conv_impl_->init(param, &ctx);
}

template <PrecisionType Ptype_out>
void ConvComputeInt8<Ptype_out>::Run() {
  auto& param = this->Param<param_t>();
Z
Zhaolong Xing 已提交
71 72 73 74 75 76 77 78 79 80 81 82 83 84
  const auto in_dims = param.x->dims();
  const auto filter_dims = param.filter->dims();
  std::vector<int64_t> output_shape({in_dims[0]});

  for (size_t i = 0; i < param.strides.size(); ++i) {
    output_shape.push_back(ConvOutputSize(in_dims[i + 1],
                                          filter_dims[i + 1],
                                          param.dilations[i],
                                          param.paddings[i],
                                          param.strides[i]));
  }
  output_shape.push_back(filter_dims[0]);
  param.output->Resize(lite::DDim(output_shape));

85 86 87 88 89 90 91 92 93 94 95 96 97
  conv_impl_->run(param);
}

template class ConvComputeInt8<PRECISION(kInt8)>;
template class ConvComputeInt8<PRECISION(kFloat)>;

}  // namespace cuda
}  // namespace kernels
}  // namespace lite
}  // namespace paddle

REGISTER_LITE_KERNEL(
    conv2d, kCUDA, kFloat, kNCHW, paddle::lite::kernels::cuda::ConvCompute, def)
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
    .BindInput("Input",
               {LiteType::GetTensorTy(TARGET(kCUDA),
                                      PRECISION(kFloat),
                                      DATALAYOUT(kNCHW))})
    .BindInput("Bias",
               {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat))})
    .BindInput("Filter",
               {LiteType::GetTensorTy(TARGET(kCUDA),
                                      PRECISION(kFloat),
                                      DATALAYOUT(kNCHW))})
    .BindOutput("Output",
                {LiteType::GetTensorTy(TARGET(kCUDA),
                                       PRECISION(kFloat),
                                       DATALAYOUT(kNCHW))})
    .Finalize();

REGISTER_LITE_KERNEL(depthwise_conv2d,
                     kCUDA,
                     kFloat,
                     kNCHW,
                     paddle::lite::kernels::cuda::ConvCompute,
                     def)
120 121 122 123 124 125 126 127 128 129 130 131 132 133
    .BindInput("Input",
               {LiteType::GetTensorTy(TARGET(kCUDA),
                                      PRECISION(kFloat),
                                      DATALAYOUT(kNCHW))})
    .BindInput("Bias",
               {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat))})
    .BindInput("Filter",
               {LiteType::GetTensorTy(TARGET(kCUDA),
                                      PRECISION(kFloat),
                                      DATALAYOUT(kNCHW))})
    .BindOutput("Output",
                {LiteType::GetTensorTy(TARGET(kCUDA),
                                       PRECISION(kFloat),
                                       DATALAYOUT(kNCHW))})
134 135 136 137 138 139 140 141 142 143
    .Finalize();

REGISTER_LITE_KERNEL(
    conv2d,
    kCUDA,
    kInt8,
    kNHWC,
    paddle::lite::kernels::cuda::ConvComputeInt8<PRECISION(kFloat)>,
    fp32_out)
    .BindInput("Input",
144 145 146
               {LiteType::GetTensorTy(TARGET(kCUDA),
                                      PRECISION(kInt8),
                                      DATALAYOUT(kNHWC))})
147 148 149
    .BindInput("Bias",
               {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat))})
    .BindInput("Filter",
150 151 152
               {LiteType::GetTensorTy(TARGET(kCUDA),
                                      PRECISION(kInt8),
                                      DATALAYOUT(kNHWC))})
153
    .BindOutput("Output",
154 155 156
                {LiteType::GetTensorTy(TARGET(kCUDA),
                                       PRECISION(kFloat),
                                       DATALAYOUT(kNHWC))})
157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180
    .Finalize();

REGISTER_LITE_KERNEL(
    conv2d,
    kCUDA,
    kInt8,
    kNHWC,
    paddle::lite::kernels::cuda::ConvComputeInt8<PRECISION(kInt8)>,
    int8_out)
    .BindInput("Input",
               {LiteType::GetTensorTy(TARGET(kCUDA),
                                      PRECISION(kInt8),
                                      DATALAYOUT(kNHWC))})
    .BindInput("Bias",
               {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat))})
    .BindInput("Filter",
               {LiteType::GetTensorTy(TARGET(kCUDA),
                                      PRECISION(kInt8),
                                      DATALAYOUT(kNHWC))})
    .BindOutput("Output",
                {LiteType::GetTensorTy(TARGET(kCUDA),
                                       PRECISION(kInt8),
                                       DATALAYOUT(kNHWC))})
    .Finalize();