提交 e7cd9ce2 编写于 作者: S ShiningZhang

add bert test example

上级 52aa1963
......@@ -27,6 +27,92 @@ using baidu::paddle_serving::client::PredictorOutputs;
DEFINE_string(server_port, "127.0.0.1:9292", "");
DEFINE_string(client_conf, "serving_client_conf.prototxt", "");
DEFINE_string(test_type, "brpc", "");
DEFINE_string(sample_type, "fit_a_line", "");
namespace {
int prepare_fit_a_line(PredictorInputs& input, std::vector<std::string>& fetch_name) {
std::vector<float> float_feed = {0.0137f, -0.1136f, 0.2553f, -0.0692f,
0.0582f, -0.0727f, -0.1583f, -0.0584f,
0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f};
std::vector<int> float_shape = {1, 13};
std::string feed_name = "x";
fetch_name = {"price"};
std::vector<int> lod;
input.add_float_data(float_feed, feed_name, float_shape, lod);
return 0;
}
int prepare_bert(PredictorInputs& input, std::vector<std::string>& fetch_name) {
float input_mask[] = {
1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
long position_ids[] = {
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
long input_ids[] = {
101, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
long segment_ids[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
{
std::vector<float> float_feed(std::begin(input_mask), std::end(input_mask));
std::vector<int> float_shape = {1, 128, 1};
std::string feed_name = "input_mask";
std::vector<int> lod;
input.add_float_data(float_feed, feed_name, float_shape, lod);
}
{
std::vector<int64_t> feed(std::begin(position_ids), std::end(position_ids));
std::vector<int> shape = {1, 128, 1};
std::string feed_name = "position_ids";
std::vector<int> lod;
input.add_int64_data(feed, feed_name, shape, lod);
}
{
std::vector<int64_t> feed(std::begin(input_ids), std::end(input_ids));
std::vector<int> shape = {1, 128, 1};
std::string feed_name = "input_ids";
std::vector<int> lod;
input.add_int64_data(feed, feed_name, shape, lod);
}
{
std::vector<int64_t> feed(std::begin(segment_ids), std::end(segment_ids));
std::vector<int> shape = {1, 128, 1};
std::string feed_name = "segment_ids";
std::vector<int> lod;
input.add_int64_data(feed, feed_name, shape, lod);
}
fetch_name = {"pooled_output"};
return 0;
}
} // namespace
int main(int argc, char* argv[]) {
......@@ -34,9 +120,11 @@ int main(int argc, char* argv[]) {
std::string url = FLAGS_server_port;
std::string conf = FLAGS_client_conf;
std::string test_type = FLAGS_test_type;
std::string sample_type = FLAGS_sample_type;
LOG(INFO) << "url = " << url << ";"
<< "client_conf = " << conf << ";"
<< "type = " << test_type;
<< "test_type = " << test_type
<< "sample_type = " << sample_type;
std::unique_ptr<ServingClient> client;
if (test_type == "brpc") {
client.reset(new ServingBrpcClient());
......@@ -50,18 +138,19 @@ int main(int argc, char* argv[]) {
return 0;
}
std::vector<float> float_feed = {0.0137f, -0.1136f, 0.2553f, -0.0692f,
0.0582f, -0.0727f, -0.1583f, -0.0584f,
0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f};
std::vector<int> float_shape = {1, 13};
std::string feed_name = "x";
std::vector<std::string> fetch_name = {"price"};
std::vector<int> lod;
PredictorInputs input;
PredictorOutputs output;
std::vector<std::string> fetch_name;
input.add_float_data(float_feed, feed_name, float_shape, lod);
if (sample_type == "fit_a_line") {
prepare_fit_a_line(input, fetch_name);
}
else if (sample_type == "bert") {
prepare_bert(input, fetch_name);
}
else {
prepare_fit_a_line(input, fetch_name);
}
if (client->predict(input, output, fetch_name, 0) != 0) {
LOG(ERROR) << "Failed to predict!";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册