diff --git a/tests/pytest/concurrent_inquiry.py b/tests/pytest/concurrent_inquiry.py index e832c9a74e1c8b6c42a882a59931bff6d481f445..d1f180373bbe0585bd6e01b224e64359e0449e77 100644 --- a/tests/pytest/concurrent_inquiry.py +++ b/tests/pytest/concurrent_inquiry.py @@ -40,7 +40,7 @@ class ConcurrentInquiry: # stableNum = 2,subtableNum = 1000,insertRows = 100): def __init__(self,ts,host,user,password,dbname, stb_prefix,subtb_prefix,n_Therads,r_Therads,probabilities,loop, - stableNum ,subtableNum ,insertRows ,mix_table): + stableNum ,subtableNum ,insertRows ,mix_table, replay): self.n_numOfTherads = n_Therads self.r_numOfTherads = r_Therads self.ts=ts @@ -65,6 +65,7 @@ class ConcurrentInquiry: self.mix_table = mix_table self.max_ts = datetime.datetime.now() self.min_ts = datetime.datetime.now() - datetime.timedelta(days=5) + self.replay = replay def SetThreadsNum(self,num): self.numOfTherads=num @@ -412,7 +413,7 @@ class ConcurrentInquiry: ) cl = conn.cursor() cl.execute("use %s;" % self.dbname) - + fo = open('bak_sql_n_%d'%threadID,'w+') print("Thread %d: starting" % threadID) loop = self.loop while loop: @@ -423,6 +424,7 @@ class ConcurrentInquiry: else: sql=self.gen_query_join() print("sql is ",sql) + fo.write(sql+'\n') start = time.time() cl.execute(sql) cl.fetchall() @@ -438,13 +440,49 @@ class ConcurrentInquiry: exit(-1) loop -= 1 if loop == 0: break - + fo.close() cl.close() conn.close() print("Thread %d: finishing" % threadID) + + def query_thread_nr(self,threadID): #使用原生python接口进行重放 + host = self.host + user = self.user + password = self.password + conn = taos.connect( + host, + user, + password, + ) + cl = conn.cursor() + cl.execute("use %s;" % self.dbname) + replay_sql = [] + with open('bak_sql_n_%d'%threadID,'r') as f: + replay_sql = f.readlines() + print("Replay Thread %d: starting" % threadID) + for sql in replay_sql: + try: + print("sql is ",sql) + start = time.time() + cl.execute(sql) + cl.fetchall() + end = time.time() + print("time cost :",end-start) + except Exception as e: + print('-'*40) + print( + "Failure thread%d, sql: %s \nexception: %s" % + (threadID, str(sql),str(e))) + err_uec='Unable to establish connection' + if err_uec in str(e) and loop >0: + exit(-1) + cl.close() + conn.close() + print("Replay Thread %d: finishing" % threadID) def query_thread_r(self,threadID): #使用rest接口查询 print("Thread %d: starting" % threadID) + fo = open('bak_sql_r_%d'%threadID,'w+') loop = self.loop while loop: try: @@ -453,6 +491,7 @@ class ConcurrentInquiry: else: sql=self.gen_query_join() print("sql is ",sql) + fo.write(sql+'\n') start = time.time() self.rest_query(sql) end = time.time() @@ -467,20 +506,53 @@ class ConcurrentInquiry: exit(-1) loop -= 1 if loop == 0: break - - print("Thread %d: finishing" % threadID) + fo.close() + print("Thread %d: finishing" % threadID) + + def query_thread_rr(self,threadID): #使用rest接口重放 + print("Replay Thread %d: starting" % threadID) + replay_sql = [] + with open('bak_sql_r_%d'%threadID,'r') as f: + replay_sql = f.readlines() + + for sql in replay_sql: + try: + print("sql is ",sql) + start = time.time() + self.rest_query(sql) + end = time.time() + print("time cost :",end-start) + except Exception as e: + print('-'*40) + print( + "Failure thread%d, sql: %s \nexception: %s" % + (threadID, str(sql),str(e))) + err_uec='Unable to establish connection' + if err_uec in str(e) and loop >0: + exit(-1) + print("Replay Thread %d: finishing" % threadID) def run(self): print(self.n_numOfTherads,self.r_numOfTherads) threads = [] - for i in range(self.n_numOfTherads): - thread = threading.Thread(target=self.query_thread_n, args=(i,)) - threads.append(thread) - thread.start() - for i in range(self.r_numOfTherads): - thread = threading.Thread(target=self.query_thread_r, args=(i,)) - threads.append(thread) - thread.start() + if self.replay: #whether replay + for i in range(self.n_numOfTherads): + thread = threading.Thread(target=self.query_thread_nr, args=(i,)) + threads.append(thread) + thread.start() + for i in range(self.r_numOfTherads): + thread = threading.Thread(target=self.query_thread_rr, args=(i,)) + threads.append(thread) + thread.start() + else: + for i in range(self.n_numOfTherads): + thread = threading.Thread(target=self.query_thread_n, args=(i,)) + threads.append(thread) + thread.start() + for i in range(self.r_numOfTherads): + thread = threading.Thread(target=self.query_thread_r, args=(i,)) + threads.append(thread) + thread.start() parser = argparse.ArgumentParser() parser.add_argument( @@ -595,13 +667,20 @@ parser.add_argument( default=0, type=int, help='0:stable & substable ,1:subtable ,2:stable (default: 0)') +parser.add_argument( + '-R', + '--replay', + action='store', + default=0, + type=int, + help='0:not replay ,1:replay (default: 0)') args = parser.parse_args() q = ConcurrentInquiry( args.ts,args.host_name,args.user,args.password,args.db_name, args.stb_name_prefix,args.subtb_name_prefix,args.number_of_native_threads,args.number_of_rest_threads, args.probabilities,args.loop_per_thread,args.number_of_stables,args.number_of_tables ,args.number_of_records, - args.mix_stable_subtable ) + args.mix_stable_subtable, args.replay ) if args.create_table: q.gen_data()