提交 cc93971b 编写于 作者: S Steven Li

Fixed a barrier problem, now main/worker thread sync correctly

上级 4e3ad933
...@@ -74,8 +74,7 @@ class WorkerThread: ...@@ -74,8 +74,7 @@ class WorkerThread:
# Let us have a DB connection of our own # Let us have a DB connection of our own
if ( gConfig.per_thread_db_connection ): if ( gConfig.per_thread_db_connection ):
self._dbConn = DbConn() self._dbConn = DbConn()
def start(self): def start(self):
self._thread.start() # AFTER the thread is recorded self._thread.start() # AFTER the thread is recorded
...@@ -120,18 +119,16 @@ class WorkerThread: ...@@ -120,18 +119,16 @@ class WorkerThread:
self.verifyThreadSelf() # only allowed by ourselves self.verifyThreadSelf() # only allowed by ourselves
# self.verifyIsSleeping(False) # has to be awake # self.verifyIsSleeping(False) # has to be awake
logger.debug("Worker thread {} about to cross pool barrier".format(self._tid)) # logger.debug("Worker thread {} about to cross pool barrier".format(self._tid))
# self.isSleeping = True # TODO: maybe too early? # self.isSleeping = True # TODO: maybe too early?
self._pool.crossPoolBarrier() # wait for all other threads self._pool.crossPoolBarrier() # wait for all other threads
# Wait again at the "gate", waiting to be "tapped" # Wait again at the "gate", waiting to be "tapped"
logger.debug("Worker thread {} about to cross the step gate".format(self._tid)) # logger.debug("Worker thread {} about to cross the step gate".format(self._tid))
# self.stepGate.acquire() # acquire lock immediately
self._stepGate.wait() self._stepGate.wait()
self._stepGate.clear() self._stepGate.clear()
# self.stepGate.release() # release
logger.debug("Worker thread {} woke up".format(self._tid)) # logger.debug("Worker thread {} woke up".format(self._tid))
# Someone will wake us up here # Someone will wake us up here
self._curStep += 1 # off to a new step... self._curStep += 1 # off to a new step...
...@@ -151,9 +148,15 @@ class WorkerThread: ...@@ -151,9 +148,15 @@ class WorkerThread:
time.sleep(0) # let the released thread run a bit, IMPORTANT, do it after release time.sleep(0) # let the released thread run a bit, IMPORTANT, do it after release
def doWork(self): def doWork(self):
logger.info(" Step {}, thread {}: ".format(self._curStep, self._tid)) self.logInfo("Thread starting an execution")
self._pool.dispatcher.doWork(self) self._pool.dispatcher.doWork(self)
def logInfo(self, msg):
logger.info(" T[{}.{}]: ".format(self._curStep, self._tid) + msg)
def logDebug(self, msg):
logger.debug(" T[{}.{}]: ".format(self._curStep, self._tid) + msg)
def execSql(self, sql): def execSql(self, sql):
if ( gConfig.per_thread_db_connection ): if ( gConfig.per_thread_db_connection ):
return self._dbConn.execSql(sql) return self._dbConn.execSql(sql)
...@@ -175,9 +178,8 @@ class SteppingThreadPool: ...@@ -175,9 +178,8 @@ class SteppingThreadPool:
# self.numWaitingThreads = 0 # self.numWaitingThreads = 0
# Thread coordination # Thread coordination
self.barrier = threading.Barrier(numThreads + 1) # plus main thread self._lock = threading.RLock() # lock to control access (e.g. even reading it is dangerous)
# self.lock = threading.Lock() # for critical section execution self._poolBarrier = threading.Barrier(numThreads + 1) # do nothing before crossing this, except main thread
# self.mainGate = threading.Condition()
# starting to run all the threads, in locking steps # starting to run all the threads, in locking steps
def run(self): def run(self):
...@@ -189,11 +191,20 @@ class SteppingThreadPool: ...@@ -189,11 +191,20 @@ class SteppingThreadPool:
# Coordinate all threads step by step # Coordinate all threads step by step
self.curStep = -1 # not started yet self.curStep = -1 # not started yet
while(self.curStep < self.maxSteps): while(self.curStep < self.maxSteps):
print(".", end="", flush=True)
logger.debug("Main thread going to sleep") logger.debug("Main thread going to sleep")
self.crossPoolBarrier()
self.barrier.reset() # Other worker threads should now be at the "gate"
logger.debug("Main thread waking up, tapping worker threads".format(self.curStep)) # Now not all threads had time to go to sleep # Now ready to enter a step
self.crossPoolBarrier() # let other threads go past the pool barrier, but wait at the thread gate
self._poolBarrier.reset() # Other worker threads should now be at the "gate"
# Rare chance, when all threads should be blocked at the "step gate" for each thread
logger.info("<-- Step {} finished".format(self.curStep))
self.curStep += 1 # we are about to get into next step. TODO: race condition here!
logger.debug(" ") # line break
logger.debug("--> Step {} starts with main thread waking up".format(self.curStep)) # Now not all threads had time to go to sleep
logger.debug("Main thread waking up at step {}, tapping worker threads".format(self.curStep)) # Now not all threads had time to go to sleep
self.tapAllThreads() self.tapAllThreads()
# The threads will run through many steps # The threads will run through many steps
...@@ -201,15 +212,11 @@ class SteppingThreadPool: ...@@ -201,15 +212,11 @@ class SteppingThreadPool:
workerThread._thread.join() # slight hack, accessing members workerThread._thread.join() # slight hack, accessing members
logger.info("All threads finished") logger.info("All threads finished")
print("")
print("Finished")
def crossPoolBarrier(self): def crossPoolBarrier(self):
if ( self.barrier.n_waiting == self.numThreads ): # everyone else is waiting, inc main thread self._poolBarrier.wait()
logger.info("<-- Step {} finished".format(self.curStep))
self.curStep += 1 # we are about to get into next step. TODO: race condition here!
logger.debug(" ") # line break
logger.debug("--> Step {} starts with main thread waking up".format(self.curStep)) # Now not all threads had time to go to sleep
self.barrier.wait()
def tapAllThreads(self): # in a deterministic manner def tapAllThreads(self): # in a deterministic manner
wakeSeq = [] wakeSeq = []
...@@ -397,46 +404,51 @@ class DbState(): ...@@ -397,46 +404,51 @@ class DbState():
def cleanUp(self): def cleanUp(self):
self._dbConn.close() self._dbConn.close()
# A task is a long-living entity, carrying out short-lived "executions" for threads
class Task(): class Task():
def __init__(self, dbState): def __init__(self, dbState):
self.dbState = dbState self.dbState = dbState
def _executeInternal(self, wt):
raise RuntimeError("To be implemeted by child classes")
def execute(self, workerThread): def execute(self, workerThread):
raise RuntimeError("Must be overriden by child class") self._executeInternal(workerThread) # TODO: no return value?
workerThread.logDebug("[X] task execution completed")
def execSql(self, sql): def execSql(self, sql):
return self.dbState.execute(sql) return self.dbState.execute(sql)
class CreateTableTask(Task): class CreateTableTask(Task):
def execute(self, wt): def _executeInternal(self, wt):
tIndex = dbState.addTable() tIndex = dbState.addTable()
logger.debug(" Creating a table {} ...".format(tIndex)) wt.logDebug("Creating a table {} ...".format(tIndex))
wt.execSql("create table db.table_{} (ts timestamp, speed int)".format(tIndex)) wt.execSql("create table db.table_{} (ts timestamp, speed int)".format(tIndex))
logger.debug(" Table {} created.".format(tIndex)) wt.logDebug("Table {} created.".format(tIndex))
dbState.releaseTable(tIndex) dbState.releaseTable(tIndex)
class DropTableTask(Task): class DropTableTask(Task):
def execute(self, wt): def _executeInternal(self, wt):
tableName = dbState.getTableNameToDelete() tableName = dbState.getTableNameToDelete()
if ( not tableName ): # May be "False" if ( not tableName ): # May be "False"
logger.info(" Cannot generate a table to delete, skipping...") wt.logInfo("Cannot generate a table to delete, skipping...")
return return
logger.info(" Dropping a table db.{} ...".format(tableName)) wt.logInfo("Dropping a table db.{} ...".format(tableName))
wt.execSql("drop table db.{}".format(tableName)) wt.execSql("drop table db.{}".format(tableName))
class AddDataTask(Task): class AddDataTask(Task):
def execute(self, wt): def _executeInternal(self, wt):
ds = self.dbState ds = self.dbState
logger.info(" Adding some data... numQueue={}".format(ds.tableNumQueue.toText())) wt.logInfo("Adding some data... numQueue={}".format(ds.tableNumQueue.toText()))
tIndex = ds.pickAndAllocateTable() tIndex = ds.pickAndAllocateTable()
if ( tIndex == None ): if ( tIndex == None ):
logger.info(" No table found to add data, skipping...") wt.logInfo("No table found to add data, skipping...")
return return
sql = "insert into db.table_{} values ('{}', {});".format(tIndex, ds.getNextTick(), ds.getNextInt()) sql = "insert into db.table_{} values ('{}', {});".format(tIndex, ds.getNextTick(), ds.getNextInt())
logger.debug(" Executing SQL: {}".format(sql)) wt.logDebug("Executing SQL: {}".format(sql))
wt.execSql(sql) wt.execSql(sql)
ds.releaseTable(tIndex) ds.releaseTable(tIndex)
logger.debug(" Finished adding data") wt.logDebug("Finished adding data")
# Deterministic random number generator # Deterministic random number generator
class Dice(): class Dice():
...@@ -510,7 +522,7 @@ if __name__ == "__main__": ...@@ -510,7 +522,7 @@ if __name__ == "__main__":
Dice.seed(0) # initial seeding of dice Dice.seed(0) # initial seeding of dice
dbState = DbState() dbState = DbState()
threadPool = SteppingThreadPool(dbState, 5, 10, 0) threadPool = SteppingThreadPool(dbState, 5, 500, 0)
threadPool.run() threadPool.run()
logger.info("Finished running thread pool") logger.info("Finished running thread pool")
dbState.cleanUp() dbState.cleanUp()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册