diff --git a/core/general-client/example/simple_client.cpp b/core/general-client/example/simple_client.cpp index fc89146731fd9ddae0fd1e2d0b0a06b301dbcdbf..cd60018880e7cb99a128b7346999e7daee1b3412 100644 --- a/core/general-client/example/simple_client.cpp +++ b/core/general-client/example/simple_client.cpp @@ -27,6 +27,92 @@ using baidu::paddle_serving::client::PredictorOutputs; DEFINE_string(server_port, "127.0.0.1:9292", ""); DEFINE_string(client_conf, "serving_client_conf.prototxt", ""); DEFINE_string(test_type, "brpc", ""); +DEFINE_string(sample_type, "fit_a_line", ""); + +namespace { +int prepare_fit_a_line(PredictorInputs& input, std::vector& fetch_name) { + std::vector float_feed = {0.0137f, -0.1136f, 0.2553f, -0.0692f, + 0.0582f, -0.0727f, -0.1583f, -0.0584f, + 0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f}; + std::vector float_shape = {1, 13}; + std::string feed_name = "x"; + fetch_name = {"price"}; + std::vector lod; + input.add_float_data(float_feed, feed_name, float_shape, lod); + return 0; +} + +int prepare_bert(PredictorInputs& input, std::vector& fetch_name) { + float input_mask[] = { + 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}; + long position_ids[] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + long input_ids[] = { + 101, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + long segment_ids[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + { + std::vector float_feed(std::begin(input_mask), std::end(input_mask)); + std::vector float_shape = {1, 128, 1}; + std::string feed_name = "input_mask"; + std::vector lod; + input.add_float_data(float_feed, feed_name, float_shape, lod); + } + { + std::vector feed(std::begin(position_ids), std::end(position_ids)); + std::vector shape = {1, 128, 1}; + std::string feed_name = "position_ids"; + std::vector lod; + input.add_int64_data(feed, feed_name, shape, lod); + } + { + std::vector feed(std::begin(input_ids), std::end(input_ids)); + std::vector shape = {1, 128, 1}; + std::string feed_name = "input_ids"; + std::vector lod; + input.add_int64_data(feed, feed_name, shape, lod); + } + { + std::vector feed(std::begin(segment_ids), std::end(segment_ids)); + std::vector shape = {1, 128, 1}; + std::string feed_name = "segment_ids"; + std::vector lod; + input.add_int64_data(feed, feed_name, shape, lod); + } + + fetch_name = {"pooled_output"}; + return 0; +} +} // namespace int main(int argc, char* argv[]) { @@ -34,9 +120,11 @@ int main(int argc, char* argv[]) { std::string url = FLAGS_server_port; std::string conf = FLAGS_client_conf; std::string test_type = FLAGS_test_type; + std::string sample_type = FLAGS_sample_type; LOG(INFO) << "url = " << url << ";" << "client_conf = " << conf << ";" - << "type = " << test_type; + << "test_type = " << test_type + << "sample_type = " << sample_type; std::unique_ptr client; if (test_type == "brpc") { client.reset(new ServingBrpcClient()); @@ -50,18 +138,19 @@ int main(int argc, char* argv[]) { return 0; } - std::vector float_feed = {0.0137f, -0.1136f, 0.2553f, -0.0692f, - 0.0582f, -0.0727f, -0.1583f, -0.0584f, - 0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f}; - std::vector float_shape = {1, 13}; - std::string feed_name = "x"; - std::vector fetch_name = {"price"}; - std::vector lod; - PredictorInputs input; PredictorOutputs output; + std::vector fetch_name; - input.add_float_data(float_feed, feed_name, float_shape, lod); + if (sample_type == "fit_a_line") { + prepare_fit_a_line(input, fetch_name); + } + else if (sample_type == "bert") { + prepare_bert(input, fetch_name); + } + else { + prepare_fit_a_line(input, fetch_name); + } if (client->predict(input, output, fetch_name, 0) != 0) { LOG(ERROR) << "Failed to predict!";