From a6b1c4c12b767f1d421b5e5a829d901738d5209e Mon Sep 17 00:00:00 2001 From: Wilber Date: Thu, 6 Jan 2022 13:53:01 +0800 Subject: [PATCH] fix:transform the data from cpu to gpu when trt is used (#37427) (#38745) Co-authored-by: feng_shuai --- paddle/fluid/operators/tensorrt/tensorrt_engine_op.h | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h index 46da8e6151..35612905f8 100644 --- a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h +++ b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h @@ -24,6 +24,7 @@ #include #include +#include "paddle/fluid/framework/data_device_transform.h" #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" @@ -421,6 +422,13 @@ class TensorRTEngineOp : public framework::OperatorBase { // convert input and copy to TRT engine's buffer auto &t = inference::analysis::GetFromScope(scope, x); + // check the input_tensor + if (!platform::is_gpu_place(t.place())) { + framework::Tensor out; + platform::CUDAPlace dst_place; + framework::TransDataDevice(t, dst_place, &out); + t.ShareDataWith(out); + } auto t_shape = framework::vectorize(t.dims()); const int bind_index = engine->engine()->getBindingIndex(x.c_str()); PADDLE_ENFORCE_LT( -- GitLab