feed_compute.cc 4.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/kernels/cuda/feed_compute.h"
#include "lite/core/op_registry.h"

namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {

23 24 25
template <typename T, PrecisionType Ptype>
void FeedCompute<T, Ptype>::Run() {
  auto& param = this->template Param<param_t>();
26 27 28 29 30 31
  auto& ctx = this->ctx_->template As<CUDAContext>();
  auto stream = ctx.exec_stream();
  VLOG(4) << "feed_list.size: " << param.feed_list->size();
  const lite::Tensor& feed_item = (*param.feed_list)[param.col];

  int num = static_cast<int>(feed_item.numel());
32
  auto input = feed_item.data<T>();
33
  param.out->Resize(feed_item.dims());
34
  auto output = param.out->template mutable_data<T>(TARGET(kCUDA));
35 36 37
  VLOG(4) << "col: " << param.col << " num:" << num;

  TargetW::MemcpyAsync(
38
      output, input, num * sizeof(T), IoDirection::HtoD, stream);
39 40 41 42 43 44 45
}

}  // namespace cuda
}  // namespace kernels
}  // namespace lite
}  // namespace paddle

46 47 48 49 50 51
typedef paddle::lite::kernels::cuda::FeedCompute<float, PRECISION(kFloat)>
    FeedFp32;

typedef paddle::lite::kernels::cuda::FeedCompute<int64_t, PRECISION(kInt64)>
    FeedInt64;

52 53 54
typedef paddle::lite::kernels::cuda::FeedCompute<int32_t, PRECISION(kInt32)>
    FeedInt32;

55
REGISTER_LITE_KERNEL(feed, kCUDA, kFloat, kNCHW, FeedFp32, nchw)
56
    .BindInput("X",
Z
Zhaolong Xing 已提交
57
               {LiteType::GetTensorTy(TARGET(kHost),
58 59 60 61 62 63 64 65
                                      PRECISION(kFloat),
                                      DATALAYOUT(kNCHW))})
    .BindOutput("Out",
                {LiteType::GetTensorTy(TARGET(kCUDA),
                                       PRECISION(kFloat),
                                       DATALAYOUT(kNCHW))})
    .Finalize();

66
REGISTER_LITE_KERNEL(feed, kCUDA, kFloat, kNHWC, FeedFp32, nhwc)
67
    .BindInput("X",
Z
Zhaolong Xing 已提交
68
               {LiteType::GetTensorTy(TARGET(kHost),
69 70 71 72 73 74 75
                                      PRECISION(kFloat),
                                      DATALAYOUT(kNHWC))})
    .BindOutput("Out",
                {LiteType::GetTensorTy(TARGET(kCUDA),
                                       PRECISION(kFloat),
                                       DATALAYOUT(kNHWC))})
    .Finalize();
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97

REGISTER_LITE_KERNEL(feed, kCUDA, kInt64, kNCHW, FeedInt64, nchw)
    .BindInput("X",
               {LiteType::GetTensorTy(TARGET(kHost),
                                      PRECISION(kInt64),
                                      DATALAYOUT(kNCHW))})
    .BindOutput("Out",
                {LiteType::GetTensorTy(TARGET(kCUDA),
                                       PRECISION(kInt64),
                                       DATALAYOUT(kNCHW))})
    .Finalize();

REGISTER_LITE_KERNEL(feed, kCUDA, kInt64, kNHWC, FeedInt64, nhwc)
    .BindInput("X",
               {LiteType::GetTensorTy(TARGET(kHost),
                                      PRECISION(kInt64),
                                      DATALAYOUT(kNHWC))})
    .BindOutput("Out",
                {LiteType::GetTensorTy(TARGET(kCUDA),
                                       PRECISION(kInt64),
                                       DATALAYOUT(kNHWC))})
    .Finalize();
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119

REGISTER_LITE_KERNEL(feed, kCUDA, kInt32, kNCHW, FeedInt32, nchw)
    .BindInput("X",
               {LiteType::GetTensorTy(TARGET(kHost),
                                      PRECISION(kAny),
                                      DATALAYOUT(kNCHW))})
    .BindOutput("Out",
                {LiteType::GetTensorTy(TARGET(kCUDA),
                                       PRECISION(kInt32),
                                       DATALAYOUT(kNCHW))})
    .Finalize();

REGISTER_LITE_KERNEL(feed, kCUDA, kInt32, kNHWC, FeedInt32, nhwc)
    .BindInput("X",
               {LiteType::GetTensorTy(TARGET(kHost),
                                      PRECISION(kAny),
                                      DATALAYOUT(kNHWC))})
    .BindOutput("Out",
                {LiteType::GetTensorTy(TARGET(kCUDA),
                                       PRECISION(kInt32),
                                       DATALAYOUT(kNHWC))})
    .Finalize();