// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include "paddle/fluid/inference/capi/c_api.h" #include "paddle/fluid/inference/capi/c_api_internal.h" using paddle::ConvertToPaddleDType; using paddle::ConvertToPDDataType; using paddle::ConvertToACPrecision; extern "C" { bool PD_PredictorRun(const PD_AnalysisConfig* config, PD_Tensor* inputs, int in_size, PD_Tensor* output_data, int** out_size, int batch_size) { auto predictor = paddle::CreatePaddlePredictor(config->config); std::vector in; for (int i = 0; i < in_size; ++i) { in.emplace_back(inputs->tensor); } std::vector out; if (predictor->Run(in, &out, batch_size)) { int osize = out.size(); for (int i = 0; i < osize; ++i) { output_data[i].tensor = out[i]; } *out_size = &osize; return true; } return false; } bool PD_PredictorZeroCopyRun(const PD_AnalysisConfig* config, PD_ZeroCopyData* inputs, int in_size, PD_ZeroCopyData* output, int** out_size) { auto predictor = paddle::CreatePaddlePredictor(config->config); auto input_names = predictor->GetInputNames(); PADDLE_ENFORCE_EQ( input_names.size(), in_size, "The number of input and the number of model's input must match. "); for (int i = 0; i < in_size; ++i) { auto input_t = predictor->GetInputTensor(inputs[i].name); std::vector tensor_shape; tensor_shape.assign(inputs[i].shape, inputs[i].shape + inputs[i].shape_size); input_t->Reshape(tensor_shape); switch (inputs[i].dtype) { case PD_FLOAT32: input_t->copy_from_cpu(static_cast(inputs[i].data)); break; case PD_INT32: input_t->copy_from_cpu(static_cast(inputs[i].data)); break; case PD_INT64: input_t->copy_from_cpu(static_cast(inputs[i].data)); break; case PD_UINT8: input_t->copy_from_cpu(static_cast(inputs[i].data)); break; default: CHECK(false) << "Unsupport data type."; break; } } CHECK(predictor->ZeroCopyRun()); auto output_names = predictor->GetOutputNames(); int osize = output_names.size(); *out_size = &osize; output = new PD_ZeroCopyData[osize]; for (int i = 0; i < osize; ++i) { LOG(INFO) << 1; output[i].name = new char[output_names[i].length() + 1]; snprintf(output[i].name, output_names[i].length() + 1, "%s", output_names[i].c_str()); auto output_t = predictor->GetOutputTensor(output_names[i]); output[i].dtype = ConvertToPDDataType(output_t->type()); std::vector output_shape = output_t->shape(); output[i].shape = new int[output_shape.size()]; output[i].shape = output_shape.data(); output[i].shape_size = output_shape.size(); switch (output[i].dtype) { case PD_FLOAT32: { std::vector out_data; int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); out_data.resize(out_num); output_t->copy_to_cpu(out_data.data()); output[i].data = static_cast(out_data.data()); } break; case PD_INT32: { std::vector out_data; int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); out_data.resize(out_num); output_t->copy_to_cpu(out_data.data()); output[i].data = static_cast(out_data.data()); } break; case PD_INT64: { std::vector out_data; int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); out_data.resize(out_num); output_t->copy_to_cpu(out_data.data()); output[i].data = static_cast(out_data.data()); } break; case PD_UINT8: { std::vector out_data; int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); out_data.resize(out_num); output_t->copy_to_cpu(out_data.data()); output[i].data = static_cast(out_data.data()); } break; default: CHECK(false) << "Unsupport data type."; break; } } return true; } } // extern "C"