diff --git a/demo-serving/op/ctr_prediction_op.cpp b/demo-serving/op/ctr_prediction_op.cpp index a904562b6b303134d5198fbbe01ad2cb79c4ba97..483b52252e357559d82df2d5f694ab83a08cfa3b 100644 --- a/demo-serving/op/ctr_prediction_op.cpp +++ b/demo-serving/op/ctr_prediction_op.cpp @@ -23,6 +23,9 @@ #include "predictor/framework/kv_manager.h" #include "predictor/framework/memory.h" +// Flag where enable profiling mode +DECLARE_bool(enable_ctr_profiling); + namespace baidu { namespace paddle_serving { namespace serving { @@ -46,6 +49,11 @@ const int CTR_PREDICTION_DENSE_SLOT_ID = 26; const int CTR_PREDICTION_DENSE_DIM = 13; const int CTR_PREDICTION_EMBEDDING_SIZE = 10; +bthread::Mutex CTRPredictionOp::mutex_; +int64_t CTRPredictionOp::cube_time_us_ = 0; +int32_t CTRPredictionOp::cube_req_num_ = 0; +int32_t CTRPredictionOp::cube_req_key_num_ = 0; + void fill_response_with_message(Response *response, int err_code, std::string err_msg) { @@ -135,7 +143,28 @@ int CTRPredictionOp::inference() { return 0; } else if (kvinfo->sparse_param_service_type == configure::EngineDesc::REMOTE) { - int ret = cube->seek(table_name, keys, &values); + struct timeval start; + struct timeval end; + + int ret; + + if (FLAGS_enable_ctr_profiling) { + gettimeofday(&start, NULL); + ret = cube->seek(table_name, keys, &values); + gettimeofday(&end, NULL); + uint64_t usec = + end.tv_sec * 1e6 + end.tv_usec - start.tv_sec * 1e6 - start.tv_usec; + + // Statistics + mutex_.lock(); + cube_time_us_ += usec; + ++cube_req_num_; + cube_req_key_num_ += keys.size(); + mutex_.unlock(); + } else { + ret = cube->seek(table_name, keys, &values); + } + if (ret != 0) { fill_response_with_message(res, -1, "Query cube for embeddings error"); LOG(ERROR) << "Query cube for embeddings error"; diff --git a/demo-serving/op/ctr_prediction_op.h b/demo-serving/op/ctr_prediction_op.h index a12cccab68c06c2238e7205b90b095318b28f3f0..4cf60022ad065158c28445962b6db1ce66e0b2dd 100644 --- a/demo-serving/op/ctr_prediction_op.h +++ b/demo-serving/op/ctr_prediction_op.h @@ -55,6 +55,7 @@ static const char* CTR_PREDICTION_MODEL_NAME = "ctr_prediction"; * and modifications we made * */ + class CTRPredictionOp : public baidu::paddle_serving::predictor::OpWithChannel< baidu::paddle_serving::predictor::ctr_prediction::Response> { @@ -64,6 +65,12 @@ class CTRPredictionOp DECLARE_OP(CTRPredictionOp); int inference(); + + private: + static bthread::Mutex mutex_; + static int64_t cube_time_us_; + static int32_t cube_req_num_; + static int32_t cube_req_key_num_; }; } // namespace serving diff --git a/predictor/common/constant.h b/predictor/common/constant.h index da44103eb8e6d064a642520bb90dd2c9df293889..72509c8d9187f817cf4dd0dfef1bff06370ce537 100644 --- a/predictor/common/constant.h +++ b/predictor/common/constant.h @@ -40,8 +40,6 @@ DECLARE_int32(reload_interval_s); DECLARE_bool(enable_model_toolkit); DECLARE_string(enable_protocol_list); DECLARE_bool(enable_cube); -DECLARE_string(cube_config_path); -DECLARE_string(cube_config_file); // STATIC Variables extern const char* START_OP_NAME; diff --git a/predictor/src/pdserving.cpp b/predictor/src/pdserving.cpp index a86b39abac7bd007a8fd401bd9a0b8aaaa5c5114..28247bce41cd2eb429c5c0d317453733a2a26f73 100644 --- a/predictor/src/pdserving.cpp +++ b/predictor/src/pdserving.cpp @@ -69,6 +69,9 @@ static bvar::PassiveStatus s_predictor_revision( DEFINE_bool(V, false, "print version, bool"); DEFINE_bool(g, false, "user defined gflag path"); +DEFINE_bool(enable_ctr_profiling, + false, + "Enable profiling in CTR prediction demo"); DECLARE_string(flagfile); namespace bthread {