io_converter.cc 4.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

L
Luo Tao 已提交
15
#include "paddle/fluid/inference/tensorrt/convert/io_converter.h"
16 17 18 19 20 21 22 23 24 25
#include <cuda.h>
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace inference {
namespace tensorrt {

using platform::is_gpu_place;
using platform::is_cpu_place;

26
class DefaultIOConverter : public EngineIOConverter {
27
 public:
28
  DefaultIOConverter() {}
29 30 31
  // NOTE out is GPU memory.
  virtual void operator()(const LoDTensor& in, void* out,
                          size_t max_size) override {
S
Shang Zhizhou 已提交
32 33 34 35 36 37 38
    PADDLE_ENFORCE_NOT_NULL(out,
                            platform::errors::InvalidArgument(
                                "The input param 'out' must not be nullptr."));
    PADDLE_ENFORCE_NOT_NULL(stream_,
                            platform::errors::PreconditionNotMet(
                                "You should set up stream_ by SetStream() "
                                "before you call the operator()."));
39
    const auto& place = in.place();
40
    size_t size = in.memory_size();
S
Shang Zhizhou 已提交
41 42 43 44 45 46
    PADDLE_ENFORCE_LE(
        size, max_size,
        platform::errors::InvalidArgument(
            "The input Tensor in's memory_size shoule be less than or equal to "
            "the input max_size. But in's memory_size = %u, max_size = %u.",
            size, max_size));
47
    if (is_cpu_place(place)) {
S
Shang Zhizhou 已提交
48 49
      PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemcpyAsync(
          out, in.data<float>(), size, cudaMemcpyHostToDevice, *stream_));
50
    } else if (is_gpu_place(place)) {
S
Shang Zhizhou 已提交
51 52 53 54 55
      PADDLE_ENFORCE_EQ(
          0, cudaMemcpyAsync(out, in.data<float>(), size,
                             cudaMemcpyDeviceToDevice, *stream_),
          platform::errors::External(
              "cudaMemcpyAsync(cudaMemcpyDeviceToDevice) error."));
56
    } else {
S
Shang Zhizhou 已提交
57
      PADDLE_THROW(platform::errors::NotFound("Unknown device for converter"));
58 59 60 61 62 63
    }
    cudaStreamSynchronize(*stream_);
  }
  // NOTE in is GPU memory.
  virtual void operator()(const void* in, LoDTensor* out,
                          size_t max_size) override {
S
Shang Zhizhou 已提交
64 65 66 67 68 69 70
    PADDLE_ENFORCE_NOT_NULL(in,
                            platform::errors::InvalidArgument(
                                "The input param 'in' must not be nullptr."));
    PADDLE_ENFORCE_NOT_NULL(stream_,
                            platform::errors::PreconditionNotMet(
                                "You should set up stream_ by SetStream() "
                                "before you call the operator()."));
71 72
    const auto& place = out->place();
    size_t size = out->memory_size();
S
Shang Zhizhou 已提交
73 74 75 76 77 78 79
    PADDLE_ENFORCE_LE(
        size, max_size,
        platform::errors::InvalidArgument(
            "The input Tensor out's memory_size shoule be less than or equal "
            "to the input max_size. "
            "But out's memory_size = %u, max_size = %u.",
            size, max_size));
80 81
    if (is_cpu_place(place)) {
      PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(out->data<float>(), in, size,
S
Shang Zhizhou 已提交
82 83 84
                                           cudaMemcpyDeviceToHost, *stream_),
                        platform::errors::External(
                            "cudaMemcpyAsync(cudaMemcpyDeviceToHost) error."));
85
    } else if (is_gpu_place(place)) {
S
Shang Zhizhou 已提交
86 87 88 89 90
      PADDLE_ENFORCE_EQ(
          0, cudaMemcpyAsync(out->data<float>(), in, size,
                             cudaMemcpyDeviceToDevice, *stream_),
          platform::errors::External(
              "cudaMemcpyAsync(cudaMemcpyDeviceToDevice) error."));
91
    } else {
S
Shang Zhizhou 已提交
92
      PADDLE_THROW(platform::errors::NotFound("Unknown device for converter"));
93 94 95 96 97
    }
    cudaStreamSynchronize(*stream_);
  }
};

98
// fluid LodTensor <-> tensorrt ITensor
99
REGISTER_TENSORRT_IO_CONVERTER(default, DefaultIOConverter);
100 101 102 103

}  // namespace tensorrt
}  // namespace inference
}  // namespace paddle