tensorrt_engine_op.h 4.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

#pragma once

#ifdef PADDLE_WITH_CUDA

#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
Y
Yan Chunwei 已提交
22
#include "paddle/fluid/inference/tensorrt/engine.h"
23 24 25 26

namespace paddle {
namespace operators {

Y
Yan Chunwei 已提交
27 28 29
using inference::Singleton;
using inference::tensorrt::TRT_EngineManager;

30 31 32 33 34 35 36 37 38
class TensorRTEngineOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
  void InferShape(framework::InferShapeContext* ctx) const override {}

  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
39
    auto input0 = ctx.Inputs("Xs").front();
40
    framework::OpKernelType kt = framework::OpKernelType(
41 42 43 44
        framework::ToDataType(ctx.scope()
                                  .FindVar(input0)
                                  ->GetMutable<framework::LoDTensor>()
                                  ->type()),
45 46 47 48 49 50 51 52 53
        platform::CPUPlace());
    return kt;
  }
};

template <typename DeviceContext, typename T>
class TensorRTEngineKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
Y
Yan Chunwei 已提交
54 55
    auto engine_name = context.Attr<std::string>("engine_uniq_key");
    if (!Singleton<TRT_EngineManager>::Global().HasEngine(engine_name)) {
56 57
      Prepare(context);
    }
Y
Yan Chunwei 已提交
58
    auto* engine = Singleton<TRT_EngineManager>::Global().Get(engine_name);
59 60 61
    auto input_names = context.op().Inputs("Xs");
    PADDLE_ENFORCE(!input_names.empty(), "should pass more than one inputs");
    // Try to determine a batch_size
62 63 64
    auto& tensor0 = inference::analysis::GetFromScope<framework::LoDTensor>(
        context.scope(), input_names.front());
    int batch_size = tensor0.dims()[0];
Y
Yan Chunwei 已提交
65
    PADDLE_ENFORCE_LE(batch_size, context.Attr<int>("max_batch"));
66 67 68 69

    // Convert input tensor from fluid to engine.
    for (const auto& x : context.Inputs("Xs")) {
      // convert input and copy to TRT engine's buffer
70 71
      auto& t = inference::analysis::GetFromScope<framework::LoDTensor>(
          context.scope(), x);
72
      if (platform::is_cpu_place(t.place())) {
Y
Yan Chunwei 已提交
73 74
        engine->SetInputFromCPU(x, static_cast<const void*>(t.data<void>()),
                                t.memory_size());
75
      } else {
Y
Yan Chunwei 已提交
76 77
        engine->SetInputFromGPU(x, static_cast<const void*>(t.data<void>()),
                                t.memory_size());
78 79 80 81
      }
    }
    // Execute the engine.
    PADDLE_ENFORCE_GT(batch_size, 0);
Y
Yan Chunwei 已提交
82
    engine->Execute(batch_size);
83 84 85
    // Convert output tensor from engine to fluid
    for (const auto& y : context.Outputs("Ys")) {
      // convert output and copy to fluid.
Y
Yan Chunwei 已提交
86
      nvinfer1::ITensor* trt_t = engine->GetITensor(y);
87 88 89 90 91 92 93 94 95 96
      auto dims = trt_t->getDimensions();
      // Use the output ITensor's dims to reshape the Fluid Tensor.
      std::vector<int> ddim(dims.d, dims.d + dims.nbDims);

      auto* fluid_v = context.scope().FindVar(y);
      PADDLE_ENFORCE_NOT_NULL(fluid_v, "no output variable called %s", y);
      auto* fluid_t = fluid_v->GetMutable<framework::LoDTensor>();
      fluid_t->Resize(framework::make_ddim(ddim));
      auto size = inference::analysis::AccuDims(dims.d, dims.nbDims);
      if (platform::is_cpu_place(fluid_t->place())) {
97
        // TODO(Superjomn) change this float to dtype size.
Y
Yan Chunwei 已提交
98
        engine->GetOutputInCPU(
99 100
            y, fluid_t->mutable_data<float>(platform::CPUPlace()),
            size * sizeof(float));
101
      } else {
Y
Yan Chunwei 已提交
102
        engine->GetOutputInGPU(
103 104
            y, fluid_t->mutable_data<float>(platform::CUDAPlace()),
            size * sizeof(float));
105 106
      }
    }
107

Y
Yan Chunwei 已提交
108
    cudaStreamSynchronize(*engine->stream());
109 110 111 112 113 114 115 116 117 118 119
  }

 protected:
  // Build the engine.
  void Prepare(const framework::ExecutionContext& context) const;
};

}  // namespace operators
}  // namespace paddle

#endif  // PADDLE_WITH_CUDA