From 4c750920f0b2dc350200099f4ba1a19d2a25bad7 Mon Sep 17 00:00:00 2001 From: Jiangtao Hu Date: Mon, 16 Mar 2020 14:14:01 -0700 Subject: [PATCH] planning: add an inference demo for one planning demo pytorch model. --- modules/planning/tools/BUILD | 9 ++++ modules/planning/tools/inference_demo.cc | 62 ++++++++++++++++++++++++ 2 files changed, 71 insertions(+) create mode 100644 modules/planning/tools/inference_demo.cc diff --git a/modules/planning/tools/BUILD b/modules/planning/tools/BUILD index 7b0623ee25..5ab8cd34f6 100644 --- a/modules/planning/tools/BUILD +++ b/modules/planning/tools/BUILD @@ -14,4 +14,13 @@ cc_binary( ], ) +cc_binary( + name = "inference_demo", + srcs = ["inference_demo.cc"], + deps = [ + "@com_github_gflags_gflags//:gflags", + "//third_party:libtorch", + ], +) + cpplint() diff --git a/modules/planning/tools/inference_demo.cc b/modules/planning/tools/inference_demo.cc new file mode 100644 index 0000000000..985c6c9174 --- /dev/null +++ b/modules/planning/tools/inference_demo.cc @@ -0,0 +1,62 @@ +/****************************************************************************** + * Copyright 2019 The Apollo Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *****************************************************************************/ + +#include + +#include "torch/script.h" +#include "torch/torch.h" + +DEFINE_string(model_file, + "/apollo/modules/planning/tools/planning_demo_model.pt", + "pytorch model file."); + +int main(int argc, char **argv) { + google::ParseCommandLineFlags(&argc, &argv, true); + + torch::jit::script::Module model; + torch::Device device(torch::kCPU); + + torch::set_num_threads(1); + model = torch::jit::load(FLAGS_model_file, device); + + std::vector torch_inputs; + + int input_dim = 2 * 3 * 224 * 224 + 2 * 14; + std::vector feature_values(input_dim, 0.5); + + std::vector inputs; + std::vector tuple; + tuple.push_back(torch::zeros({2, 3, 224, 224})); + tuple.push_back(torch::zeros({2, 14})); + inputs.push_back(torch::ivalue::Tuple::create(tuple)); + + auto torch_output = model.forward(inputs); + std::cout << torch_output << std::endl; + std::cout << "isDoubleList:" << torch_output.isDoubleList() << std::endl; + std::cout << "isTensorList:" << torch_output.isTensorList() << std::endl; + std::cout << "isTensor:" << torch_output.isTensor() << std::endl; + auto torch_output_tensor = torch_output.toTensor(); + std::cout << "tensor dim:" << torch_output_tensor.dim() << std::endl; + std::cout << "tensor sizes:" << torch_output_tensor.sizes() << std::endl; + std::cout << "tensor toString:" << torch_output_tensor.toString() + << std::endl; + std::cout << "tensor [0,0,0] element:" << torch_output_tensor[0][0][0] + << std::endl; + std::cout << "tensor [0,0,1] element:" << torch_output_tensor[0][0][1] + << std::endl; + + return 0; +} -- GitLab