diff --git a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h index 46da8e615169254e96275c865ecab532d2c6a614..35612905f8569d9fd5beebd0783d118e9f265bf0 100644 --- a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h +++ b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h @@ -24,6 +24,7 @@ #include #include +#include "paddle/fluid/framework/data_device_transform.h" #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" @@ -421,6 +422,13 @@ class TensorRTEngineOp : public framework::OperatorBase { // convert input and copy to TRT engine's buffer auto &t = inference::analysis::GetFromScope(scope, x); + // check the input_tensor + if (!platform::is_gpu_place(t.place())) { + framework::Tensor out; + platform::CUDAPlace dst_place; + framework::TransDataDevice(t, dst_place, &out); + t.ShareDataWith(out); + } auto t_shape = framework::vectorize(t.dims()); const int bind_index = engine->engine()->getBindingIndex(x.c_str()); PADDLE_ENFORCE_LT(