提交 f08e9b08 编写于 作者: S Steven Li

now able to fan out conflicting tasks, and detect incorrect results

上级 5d14a745
...@@ -20,12 +20,15 @@ if sys.version_info[0] < 3: ...@@ -20,12 +20,15 @@ if sys.version_info[0] < 3:
import getopt import getopt
import argparse import argparse
import copy
import threading import threading
import random import random
import logging import logging
import datetime import datetime
from typing import List
from util.log import * from util.log import *
from util.dnodes import * from util.dnodes import *
from util.cases import * from util.cases import *
...@@ -42,14 +45,13 @@ def runThread(wt: WorkerThread): ...@@ -42,14 +45,13 @@ def runThread(wt: WorkerThread):
wt.run() wt.run()
class WorkerThread: class WorkerThread:
def __init__(self, pool: SteppingThreadPool, tid, dbState, def __init__(self, pool: ThreadPool, tid,
tc: ThreadCoordinator, tc: ThreadCoordinator,
# te: TaskExecutor, # te: TaskExecutor,
): # note: main thread context! ): # note: main thread context!
# self._curStep = -1 # self._curStep = -1
self._pool = pool self._pool = pool
self._tid = tid self._tid = tid
self._dbState = dbState
self._tc = tc self._tc = tc
# self.threadIdent = threading.get_ident() # self.threadIdent = threading.get_ident()
self._thread = threading.Thread(target=runThread, args=(self,)) self._thread = threading.Thread(target=runThread, args=(self,))
...@@ -82,16 +84,18 @@ class WorkerThread: ...@@ -82,16 +84,18 @@ class WorkerThread:
def _doTaskLoop(self) : def _doTaskLoop(self) :
# while self._curStep < self._pool.maxSteps: # while self._curStep < self._pool.maxSteps:
# tc = ThreadCoordinator(None) # tc = ThreadCoordinator(None)
while True: while True:
self._tc.crossStepBarrier() # shared barrier first, INCLUDING the last one tc = self._tc # Thread Coordinator, the overall master
tc.crossStepBarrier() # shared barrier first, INCLUDING the last one
logger.debug("Thread task loop exited barrier...") logger.debug("Thread task loop exited barrier...")
self.crossStepGate() # then per-thread gate, after being tapped self.crossStepGate() # then per-thread gate, after being tapped
logger.debug("Thread task loop exited step gate...") logger.debug("Thread task loop exited step gate...")
if not self._tc.isRunning(): if not self._tc.isRunning():
break break
task = self._tc.fetchTask() task = tc.fetchTask()
task.execute(self) task.execute(self)
tc.saveExecutedTask(task)
def verifyThreadSelf(self): # ensure we are called by this own thread def verifyThreadSelf(self): # ensure we are called by this own thread
if ( threading.get_ident() != self._thread.ident ): if ( threading.get_ident() != self._thread.ident ):
...@@ -129,25 +133,31 @@ class WorkerThread: ...@@ -129,25 +133,31 @@ class WorkerThread:
if ( gConfig.per_thread_db_connection ): if ( gConfig.per_thread_db_connection ):
return self._dbConn.execSql(sql) return self._dbConn.execSql(sql)
else: else:
return self._dbState.getDbConn().execSql(sql) return self._tc.getDbState.getDbConn().execSql(sql)
class ThreadCoordinator: class ThreadCoordinator:
def __init__(self, pool, wd: WorkDispatcher): def __init__(self, pool, wd: WorkDispatcher, dbState):
self._curStep = -1 # first step is 0 self._curStep = -1 # first step is 0
self._pool = pool self._pool = pool
self._wd = wd self._wd = wd
self._te = None # prepare for every new step self._te = None # prepare for every new step
self._dbState = dbState
self._executedTasks: List[Task] = [] # in a given step
self._lock = threading.RLock() # sync access for a few things
self._stepBarrier = threading.Barrier(self._pool.numThreads + 1) # one barrier for all threads self._stepBarrier = threading.Barrier(self._pool.numThreads + 1) # one barrier for all threads
def getTaskExecutor(self): def getTaskExecutor(self):
return self._te return self._te
def getDbState(self) -> DbState :
return self._dbState
def crossStepBarrier(self): def crossStepBarrier(self):
self._stepBarrier.wait() self._stepBarrier.wait()
def run(self, dbState): def run(self):
self._pool.createAndStartThreads(dbState, self) self._pool.createAndStartThreads(self)
# Coordinate all threads step by step # Coordinate all threads step by step
self._curStep = -1 # not started yet self._curStep = -1 # not started yet
...@@ -161,10 +171,14 @@ class ThreadCoordinator: ...@@ -161,10 +171,14 @@ class ThreadCoordinator:
self._stepBarrier.reset() # Other worker threads should now be at the "gate" self._stepBarrier.reset() # Other worker threads should now be at the "gate"
# At this point, all threads should be pass the overall "barrier" and before the per-thread "gate" # At this point, all threads should be pass the overall "barrier" and before the per-thread "gate"
self._dbState.transition(self._executedTasks) # at end of step, transiton the DB state
# Get ready for next step
logger.info("<-- Step {} finished".format(self._curStep)) logger.info("<-- Step {} finished".format(self._curStep))
self._curStep += 1 # we are about to get into next step. TODO: race condition here! self._curStep += 1 # we are about to get into next step. TODO: race condition here!
logger.debug("\r\n--> Step {} starts with main thread waking up".format(self._curStep)) # Now not all threads had time to go to sleep logger.debug("\r\n--> Step {} starts with main thread waking up".format(self._curStep)) # Now not all threads had time to go to sleep
# A new TE for the new step
self._te = TaskExecutor(self._curStep) self._te = TaskExecutor(self._curStep)
logger.debug("Main thread waking up at step {}, tapping worker threads".format(self._curStep)) # Now not all threads had time to go to sleep logger.debug("Main thread waking up at step {}, tapping worker threads".format(self._curStep)) # Now not all threads had time to go to sleep
...@@ -202,10 +216,19 @@ class ThreadCoordinator: ...@@ -202,10 +216,19 @@ class ThreadCoordinator:
def fetchTask(self) -> Task : def fetchTask(self) -> Task :
if ( not self.isRunning() ): # no task if ( not self.isRunning() ): # no task
raise RuntimeError("Cannot fetch task when not running") raise RuntimeError("Cannot fetch task when not running")
return self._wd.pickTask() # return self._wd.pickTask()
# Alternatively, let's ask the DbState for the appropriate task
dbState = self.getDbState()
tasks = dbState.getTasksAtState()
i = Dice.throw(len(tasks))
return copy.copy(tasks[i]) # Needs a fresh copy, to save execution results, etc.
def saveExecutedTask(self, task):
with self._lock:
self._executedTasks.append(task)
# We define a class to run a number of threads in locking steps. # We define a class to run a number of threads in locking steps.
class SteppingThreadPool: class ThreadPool:
def __init__(self, dbState, numThreads, maxSteps, funcSequencer): def __init__(self, dbState, numThreads, maxSteps, funcSequencer):
self.numThreads = numThreads self.numThreads = numThreads
self.maxSteps = maxSteps self.maxSteps = maxSteps
...@@ -215,13 +238,12 @@ class SteppingThreadPool: ...@@ -215,13 +238,12 @@ class SteppingThreadPool:
self.curStep = 0 self.curStep = 0
self.threadList = [] self.threadList = []
# self.stepGate = threading.Condition() # Gate to hold/sync all threads # self.stepGate = threading.Condition() # Gate to hold/sync all threads
# self.numWaitingThreads = 0 # self.numWaitingThreads = 0
# starting to run all the threads, in locking steps # starting to run all the threads, in locking steps
def createAndStartThreads(self, dbState, tc: ThreadCoordinator): def createAndStartThreads(self, tc: ThreadCoordinator):
for tid in range(0, self.numThreads): # Create the threads for tid in range(0, self.numThreads): # Create the threads
workerThread = WorkerThread(self, tid, dbState, tc) workerThread = WorkerThread(self, tid, tc)
self.threadList.append(workerThread) self.threadList.append(workerThread)
workerThread.start() # start, but should block immediately before step 0 workerThread.start() # start, but should block immediately before step 0
...@@ -263,9 +285,6 @@ class LinearQueue(): ...@@ -263,9 +285,6 @@ class LinearQueue():
if ( index in self.inUse ): if ( index in self.inUse ):
return False return False
# if ( index in self.inUse ):
# self.inUse.remove(index) # TODO: what about discard?
self.firstIndex += 1 self.firstIndex += 1
return index return index
...@@ -337,7 +356,8 @@ class DbConn: ...@@ -337,7 +356,8 @@ class DbConn:
# self._tdSql.prepare() # Recreate database, etc. # self._tdSql.prepare() # Recreate database, etc.
self._cursor.execute('drop database if exists db') self._cursor.execute('drop database if exists db')
self._cursor.execute('create database db') logger.debug("Resetting DB, dropped database")
# self._cursor.execute('create database db')
# self._cursor.execute('use db') # self._cursor.execute('use db')
# tdSql.execute('show databases') # tdSql.execute('show databases')
...@@ -355,16 +375,24 @@ class DbConn: ...@@ -355,16 +375,24 @@ class DbConn:
# State of the database as we believe it to be # State of the database as we believe it to be
class DbState(): class DbState():
STATE_INVALID = -1
STATE_EMPTY = 1 # nothing there, no even a DB
STATE_DB_ONLY = 2 # we have a DB, but nothing else
STATE_TABLE_ONLY = 3 # we have a table, but totally empty
STATE_HAS_DATA = 4 # we have some data in the table
def __init__(self): def __init__(self):
self.tableNumQueue = LinearQueue() self.tableNumQueue = LinearQueue()
self._lastTick = datetime.datetime(2019, 1, 1) # initial date time tick self._lastTick = datetime.datetime(2019, 1, 1) # initial date time tick
self._lastInt = 0 # next one is initial integer self._lastInt = 0 # next one is initial integer
self._lock = threading.RLock() self._lock = threading.RLock()
self._state = self.STATE_INVALID
# self.openDbServerConnection() # self.openDbServerConnection()
self._dbConn = DbConn() self._dbConn = DbConn()
self._dbConn.open() self._dbConn.open()
self._dbConn.resetDb() # drop and recreate DB self._dbConn.resetDb() # drop and recreate DB
self._state = self.STATE_EMPTY # initial state, the result of above
def getDbConn(self): def getDbConn(self):
return self._dbConn return self._dbConn
...@@ -403,12 +431,63 @@ class DbState(): ...@@ -403,12 +431,63 @@ class DbState():
def cleanUp(self): def cleanUp(self):
self._dbConn.close() self._dbConn.close()
def getTasksAtState(self):
if ( self._state == self.STATE_EMPTY ):
return [CreateDbTask(self), CreateTableTask(self)]
elif ( self._state == self.STATE_DB_ONLY ):
return [DeleteDbTask(self), CreateTableTask(self), AddDataTask(self)]
else:
raise RuntimeError("Unexpected DbState state: {}".format(self._state))
def transition(self, tasks):
if ( len(tasks) == 0 ): # before 1st step, or otherwise empty
return # do nothing
if ( self._state == self.STATE_EMPTY ):
self.assertAtMostOneSuccess(tasks, CreateDbTask) # param is class
self.assertIfExistThenSuccess(tasks, CreateDbTask)
self.assertAtMostOneSuccess(tasks, CreateTableTask)
if ( self.hasSuccess(tasks, CreateDbTask) ):
self._state = self.STATE_DB_ONLY
if ( self.hasSuccess(tasks, CreateTableTask) ):
self._state = self.STATE_TABLE_ONLY
else:
raise RuntimeError("Unexpected DbState state: {}".format(self._state))
logger.debug("New DB state is: {}".format(self._state))
def assertAtMostOneSuccess(self, tasks, cls):
sCnt = 0
for task in tasks :
if not isinstance(task, cls):
continue
if task.isSuccess():
sCnt += 1
if ( sCnt >= 2 ):
raise RuntimeError("Unexpected more than 1 success with task: {}".format(cls))
def assertIfExistThenSuccess(self, tasks, cls):
sCnt = 0
for task in tasks :
if not isinstance(task, cls):
continue
if task.isSuccess():
sCnt += 1
if ( sCnt <= 0 ):
raise RuntimeError("Unexpected zero success for task: {}".format(cls))
def hasSuccess(self, tasks, cls):
for task in tasks :
if not isinstance(task, cls):
continue
if task.isSuccess():
return True
return False
class TaskExecutor(): class TaskExecutor():
def __init__(self, curStep): def __init__(self, curStep):
self._curStep = curStep self._curStep = curStep
def execute(self, task, wt: WorkerThread): # execute a task on a thread def execute(self, task: Task, wt: WorkerThread): # execute a task on a thread
task.execute(self, wt) task.execute(wt)
def logInfo(self, msg): def logInfo(self, msg):
logger.info(" T[{}.x]: ".format(self._curStep) + msg) logger.info(" T[{}.x]: ".format(self._curStep) + msg)
...@@ -416,10 +495,13 @@ class TaskExecutor(): ...@@ -416,10 +495,13 @@ class TaskExecutor():
def logDebug(self, msg): def logDebug(self, msg):
logger.debug(" T[{}.x]: ".format(self._curStep) + msg) logger.debug(" T[{}.x]: ".format(self._curStep) + msg)
class Task(): class Task():
def __init__(self, dbState): def __init__(self, dbState):
self.dbState = dbState self.dbState = dbState
self._err = None
def isSuccess(self):
return self._err == None
def _executeInternal(self, te: TaskExecutor, wt: WorkerThread): def _executeInternal(self, te: TaskExecutor, wt: WorkerThread):
raise RuntimeError("To be implemeted by child classes") raise RuntimeError("To be implemeted by child classes")
...@@ -428,12 +510,31 @@ class Task(): ...@@ -428,12 +510,31 @@ class Task():
wt.verifyThreadSelf() wt.verifyThreadSelf()
te = wt.getTaskExecutor() te = wt.getTaskExecutor()
self._executeInternal(te, wt) # TODO: no return value? te.logDebug("[-] executing task {}...".format(self.__class__.__name__))
self._err = None
try:
self._executeInternal(te, wt) # TODO: no return value?
except taos.error.ProgrammingError as err:
te.logDebug("[=]Taos Execution exception: {0}".format(err))
self._err = err
except:
te.logDebug("[=]Unexpected exception")
raise
te.logDebug("[X] task execution completed") te.logDebug("[X] task execution completed")
def execSql(self, sql): def execSql(self, sql):
return self.dbState.execute(sql) return self.dbState.execute(sql)
class CreateDbTask(Task):
def _executeInternal(self, te: TaskExecutor, wt: WorkerThread):
wt.execSql("create database db")
class DeleteDbTask(Task):
def _executeInternal(self, te: TaskExecutor, wt: WorkerThread):
wt.execSql("drop database db")
class CreateTableTask(Task): class CreateTableTask(Task):
def _executeInternal(self, te: TaskExecutor, wt: WorkerThread): def _executeInternal(self, te: TaskExecutor, wt: WorkerThread):
tIndex = self.dbState.addTable() tIndex = self.dbState.addTable()
...@@ -487,14 +588,14 @@ class Dice(): ...@@ -487,14 +588,14 @@ class Dice():
raise RuntimeError("System RNG is not deterministic") raise RuntimeError("System RNG is not deterministic")
@classmethod @classmethod
def throw(cls, max): # get 0 to max-1 def throw(cls, stop): # get 0 to stop-1
return cls.throwRange(0, max) return cls.throwRange(0, stop)
@classmethod @classmethod
def throwRange(cls, min, max): # up to max-1 def throwRange(cls, start, stop): # up to stop-1
if ( not cls.seeded ): if ( not cls.seeded ):
raise RuntimeError("Cannot throw dice before seeding it") raise RuntimeError("Cannot throw dice before seeding it")
return random.randrange(min, max) return random.randrange(start, stop)
# Anyone needing to carry out work should simply come here # Anyone needing to carry out work should simply come here
...@@ -546,10 +647,11 @@ def main(): ...@@ -546,10 +647,11 @@ def main():
dbState = DbState() dbState = DbState()
Dice.seed(0) # initial seeding of dice Dice.seed(0) # initial seeding of dice
tc = ThreadCoordinator( tc = ThreadCoordinator(
SteppingThreadPool(dbState, gConfig.num_threads, gConfig.max_steps, 0), ThreadPool(dbState, gConfig.num_threads, gConfig.max_steps, 0),
WorkDispatcher(dbState) WorkDispatcher(dbState),
dbState
) )
tc.run(dbState) tc.run()
dbState.cleanUp() dbState.cleanUp()
logger.info("Finished running thread pool") logger.info("Finished running thread pool")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册