提交 940f5dbc 编写于 作者: N nhzlx

modify the tensorrt engine op to adapt to chage

上级 82527696
...@@ -53,13 +53,14 @@ nvinfer1::Dims Vec2TRT_Dims(const std::vector<int64_t> &shape) { ...@@ -53,13 +53,14 @@ nvinfer1::Dims Vec2TRT_Dims(const std::vector<int64_t> &shape) {
PADDLE_ENFORCE_LE(shape.size(), 4UL, PADDLE_ENFORCE_LE(shape.size(), 4UL,
"TensorRT' tensor input requires at most 4 dimensions"); "TensorRT' tensor input requires at most 4 dimensions");
// We should delete the batch size here.
switch (shape.size()) { switch (shape.size()) {
case 2: case 2:
return nvinfer1::Dims2(shape[0], shape[1]); return nvinfer1::Dims2(1, shape[1]);
case 3: case 3:
return nvinfer1::Dims3(shape[0], shape[1], shape[2]); return nvinfer1::Dims3(1, shape[1], shape[2]);
case 4: case 4:
return nvinfer1::Dims4(shape[0], shape[1], shape[2], shape[3]); return nvinfer1::Dims4(1, shape[1], shape[2], shape[3]);
default: default:
return nvinfer1::Dims(); return nvinfer1::Dims();
} }
......
...@@ -95,16 +95,14 @@ class TensorRTEngineKernel : public framework::OpKernel<T> { ...@@ -95,16 +95,14 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_NOT_NULL(fluid_v, "no output variable called %s", y); PADDLE_ENFORCE_NOT_NULL(fluid_v, "no output variable called %s", y);
auto* fluid_t = fluid_v->GetMutable<framework::LoDTensor>(); auto* fluid_t = fluid_v->GetMutable<framework::LoDTensor>();
fluid_t->Resize(framework::make_ddim(ddim)); fluid_t->Resize(framework::make_ddim(ddim));
auto size = inference::analysis::AccuDims(dims.d, dims.nbDims);
if (platform::is_cpu_place(fluid_t->place())) { if (platform::is_cpu_place(fluid_t->place())) {
// TODO(Superjomn) change this float to dtype size. // TODO(Superjomn) change this float to dtype size.
engine->GetOutputInCPU( engine->GetOutputInCPU(
y, fluid_t->mutable_data<float>(platform::CPUPlace()), y, fluid_t->mutable_data<float>(platform::CPUPlace()));
size * sizeof(float));
} else { } else {
engine->GetOutputInGPU( engine->GetOutputInGPU(
y, fluid_t->mutable_data<float>(platform::CUDAPlace()), y, fluid_t->mutable_data<float>(platform::CUDAPlace()));
size * sizeof(float));
} }
} }
......
...@@ -64,36 +64,37 @@ TEST(TensorRTEngineOp, manual) { ...@@ -64,36 +64,37 @@ TEST(TensorRTEngineOp, manual) {
LOG(INFO) << "create block desc"; LOG(INFO) << "create block desc";
framework::BlockDesc block_desc(&program, block_); framework::BlockDesc block_desc(&program, block_);
LOG(INFO) << "create mul op"; LOG(INFO) << "create fc op";
auto* mul = block_desc.AppendOp(); auto* fc0 = block_desc.AppendOp();
mul->SetType("mul"); fc0->SetType("mul");
mul->SetInput("X", std::vector<std::string>({"x"})); // 2 x 4 fc0->SetInput("X", std::vector<std::string>({"x"})); // 4 x 1 x 1
mul->SetInput("Y", std::vector<std::string>({"y"})); // 4 x 6 fc0->SetInput("Y", std::vector<std::string>({"y"})); // 4 x 6
mul->SetOutput("Out", std::vector<std::string>({"z"})); // 2 x 6 fc0->SetOutput("Out", std::vector<std::string>({"z"})); // 6 x 1 x 1
LOG(INFO) << "create fc op"; LOG(INFO) << "create fc op";
auto* fc = block_desc.AppendOp(); auto* fc1 = block_desc.AppendOp();
fc->SetType("mul"); fc1->SetType("mul");
fc->SetInput("X", std::vector<std::string>({"z"})); fc1->SetInput("X", std::vector<std::string>({"z"}));
fc->SetInput("Y", std::vector<std::string>({"y0"})); // 6 x 8 fc1->SetInput("Y", std::vector<std::string>({"y0"})); // 6 x 8
fc->SetOutput("Out", std::vector<std::string>({"z0"})); // 2 x 8 fc1->SetOutput("Out", std::vector<std::string>({"z0"})); // 8 x 1 x 1
// Set inputs' variable shape in BlockDesc // Set inputs' variable shape in BlockDesc
AddTensorToBlockDesc(block_, "x", std::vector<int64_t>({2, 4})); // the batch size is 2, so the dims of 'x' is {2, 4, 1, 1}
AddTensorToBlockDesc(block_, "x", std::vector<int64_t>({2, 4, 1, 1}));
AddTensorToBlockDesc(block_, "y", std::vector<int64_t>({4, 6})); AddTensorToBlockDesc(block_, "y", std::vector<int64_t>({4, 6}));
AddTensorToBlockDesc(block_, "y0", std::vector<int64_t>({6, 8})); AddTensorToBlockDesc(block_, "y0", std::vector<int64_t>({6, 8}));
AddTensorToBlockDesc(block_, "z", std::vector<int64_t>({2, 6})); AddTensorToBlockDesc(block_, "z", std::vector<int64_t>({2, 6}));
// It is wired, need to copy manually. // It is wired, need to copy manually.
*block_->add_ops() = *mul->Proto(); *block_->add_ops() = *fc0->Proto();
*block_->add_ops() = *fc->Proto(); *block_->add_ops() = *fc1->Proto();
ASSERT_EQ(block_->ops_size(), 2); ASSERT_EQ(block_->ops_size(), 2);
LOG(INFO) << "create tensorrt desc"; LOG(INFO) << "create tensorrt desc";
framework::OpDesc engine_op_desc(nullptr); framework::OpDesc engine_op_desc(nullptr);
engine_op_desc.SetType("tensorrt_engine"); engine_op_desc.SetType("tensorrt_engine");
engine_op_desc.SetInput("Xs", std::vector<std::string>({"x", "y", "y0"})); engine_op_desc.SetInput("Xs", std::vector<std::string>({"x"}));
engine_op_desc.SetOutput("Ys", std::vector<std::string>({"z0"})); engine_op_desc.SetOutput("Ys", std::vector<std::string>({"z0"}));
SetAttr<std::string>(engine_op_desc.Proto(), "subgraph", SetAttr<std::string>(engine_op_desc.Proto(), "subgraph",
block_->SerializeAsString()); block_->SerializeAsString());
...@@ -208,4 +209,3 @@ TEST(TensorRTEngineOp, fc) { Execute(40, 28, 28); } ...@@ -208,4 +209,3 @@ TEST(TensorRTEngineOp, fc) { Execute(40, 28, 28); }
} // namespace paddle } // namespace paddle
USE_TRT_CONVERTER(mul) USE_TRT_CONVERTER(mul)
USE_TRT_CONVERTER(fc)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册