提交 73bdaae1 编写于 作者: S Steven Li

Now using Python barriers and events to sync threads

上级 73b41cb4
...@@ -17,6 +17,7 @@ import getopt ...@@ -17,6 +17,7 @@ import getopt
import threading import threading
import random import random
import logging import logging
import datetime
from util.log import * from util.log import *
from util.dnodes import * from util.dnodes import *
...@@ -32,6 +33,34 @@ def runThread(workerThread): ...@@ -32,6 +33,34 @@ def runThread(workerThread):
logger.info("Running Thread: {}".format(workerThread.tid)) logger.info("Running Thread: {}".format(workerThread.tid))
workerThread.run() workerThread.run()
# Used by one process to block till another is ready
# class Baton:
# def __init__(self):
# self._lock = threading.Lock() # control access to object
# self._baton = threading.Condition() # let thread block
# self._hasGiver = False
# self._hasTaker = False
# def give(self):
# with self._lock:
# if ( self._hasGiver ): # already?
# raise RuntimeError("Cannot double-give a baton")
# self._hasGiver = True
# self._settle() # may block, OUTSIDE self lock
# def take(self):
# with self._lock:
# if ( self._hasTaker):
# raise RuntimeError("Cannot double-take a baton")
# self._hasTaker = True
# self._settle()
# def _settle(self):
class WorkerThread: class WorkerThread:
def __init__(self, pool, tid): # note: main thread context! def __init__(self, pool, tid): # note: main thread context!
self.curStep = -1 self.curStep = -1
...@@ -39,14 +68,14 @@ class WorkerThread: ...@@ -39,14 +68,14 @@ class WorkerThread:
self.tid = tid self.tid = tid
# self.threadIdent = threading.get_ident() # self.threadIdent = threading.get_ident()
self.thread = threading.Thread(target=runThread, args=(self,)) self.thread = threading.Thread(target=runThread, args=(self,))
self.stepGate = threading.Condition() self.stepGate = threading.Event()
def start(self): def start(self):
self.thread.start() # AFTER the thread is recorded self.thread.start() # AFTER the thread is recorded
def run(self): def run(self):
# initialization after thread starts, in the thread context # initialization after thread starts, in the thread context
self.isSleeping = False # self.isSleeping = False
while self.curStep < self.pool.maxSteps: while self.curStep < self.pool.maxSteps:
# stepNo = self.pool.waitForStep() # Step to run # stepNo = self.pool.waitForStep() # Step to run
...@@ -65,23 +94,26 @@ class WorkerThread: ...@@ -65,23 +94,26 @@ class WorkerThread:
if ( not self.thread.is_alive() ): if ( not self.thread.is_alive() ):
raise RuntimeError("Unexpected dead thread") raise RuntimeError("Unexpected dead thread")
def verifyIsSleeping(self, isSleeping): # def verifyIsSleeping(self, isSleeping):
if ( isSleeping != self.isSleeping ): # if ( isSleeping != self.isSleeping ):
raise RuntimeError("Unexpected thread sleep status") # raise RuntimeError("Unexpected thread sleep status")
# A gate is different from a barrier in that a thread needs to be "tapped"
def crossStepGate(self): def crossStepGate(self):
self.verifyThreadAlive() self.verifyThreadAlive()
self.verifyThreadSelf() # only allowed by ourselves self.verifyThreadSelf() # only allowed by ourselves
self.verifyIsSleeping(False) # has to be awake # self.verifyIsSleeping(False) # has to be awake
logger.debug("Worker thread {} going to sleep".format(self.tid)) logger.debug("Worker thread {} about to cross pool barrier".format(self.tid))
self.isSleeping = True # TODO: maybe too early? # self.isSleeping = True # TODO: maybe too early?
self.pool.reportThreadWaiting() # TODO: this triggers the main thread, TOO early self.pool.crossPoolBarrier() # wait for all other threads
# Actually going to sleep # Wait again at the "gate", waiting to be "tapped"
self.stepGate.acquire() # acquire lock immediately logger.debug("Worker thread {} about to cross the step gate".format(self.tid))
self.stepGate.wait() # release and then acquire # self.stepGate.acquire() # acquire lock immediately
self.stepGate.release() # release self.stepGate.wait()
self.stepGate.clear()
# self.stepGate.release() # release
logger.debug("Worker thread {} woke up".format(self.tid)) logger.debug("Worker thread {} woke up".format(self.tid))
# Someone will wake us up here # Someone will wake us up here
...@@ -90,15 +122,15 @@ class WorkerThread: ...@@ -90,15 +122,15 @@ class WorkerThread:
def tapStepGate(self): # give it a tap, release the thread waiting there def tapStepGate(self): # give it a tap, release the thread waiting there
self.verifyThreadAlive() self.verifyThreadAlive()
self.verifyThreadMain() # only allowed for main thread self.verifyThreadMain() # only allowed for main thread
self.verifyIsSleeping(True) # has to be sleeping # self.verifyIsSleeping(True) # has to be sleeping
logger.debug("Tapping worker thread {}".format(self.tid)) logger.debug("Tapping worker thread {}".format(self.tid))
self.stepGate.acquire() # self.stepGate.acquire()
# logger.debug("Tapping worker thread {}, lock acquired".format(self.tid)) # logger.debug("Tapping worker thread {}, lock acquired".format(self.tid))
self.stepGate.notify() # wake up! self.stepGate.set() # wake up!
# logger.debug("Tapping worker thread {}, notified!".format(self.tid)) # logger.debug("Tapping worker thread {}, notified!".format(self.tid))
self.isSleeping = False # No race condition for sure # self.isSleeping = False # No race condition for sure
self.stepGate.release() # this finishes before .wait() can return # self.stepGate.release() # this finishes before .wait() can return
# logger.debug("Tapping worker thread {}, lock released".format(self.tid)) # logger.debug("Tapping worker thread {}, lock released".format(self.tid))
time.sleep(0) # let the released thread run a bit, IMPORTANT, do it after release time.sleep(0) # let the released thread run a bit, IMPORTANT, do it after release
...@@ -109,20 +141,21 @@ class WorkerThread: ...@@ -109,20 +141,21 @@ class WorkerThread:
# We define a class to run a number of threads in locking steps. # We define a class to run a number of threads in locking steps.
class SteppingThreadPool: class SteppingThreadPool:
def __init__(self, numThreads, maxSteps, funcSequencer): def __init__(self, dbState, numThreads, maxSteps, funcSequencer):
self.numThreads = numThreads self.numThreads = numThreads
self.maxSteps = maxSteps self.maxSteps = maxSteps
self.funcSequencer = funcSequencer self.funcSequencer = funcSequencer
# Internal class variables # Internal class variables
self.dispatcher = WorkDispatcher(self) self.dispatcher = WorkDispatcher(self, dbState)
self.curStep = 0 self.curStep = 0
self.threadList = [] self.threadList = []
# self.stepGate = threading.Condition() # Gate to hold/sync all threads # self.stepGate = threading.Condition() # Gate to hold/sync all threads
self.numWaitingThreads = 0 # self.numWaitingThreads = 0
# Thread coordination # Thread coordination
self.lock = threading.Lock() # for critical section execution self.barrier = threading.Barrier(numThreads + 1) # plus main thread
self.mainGate = threading.Condition() # self.lock = threading.Lock() # for critical section execution
# self.mainGate = threading.Condition()
# starting to run all the threads, in locking steps # starting to run all the threads, in locking steps
def run(self): def run(self):
...@@ -136,13 +169,15 @@ class SteppingThreadPool: ...@@ -136,13 +169,15 @@ class SteppingThreadPool:
self.curStep = -1 # not started yet self.curStep = -1 # not started yet
while(self.curStep < self.maxSteps): while(self.curStep < self.maxSteps):
logger.debug("Main thread going to sleep") logger.debug("Main thread going to sleep")
self.mainGate.acquire() # self.mainGate.acquire()
self.mainGate.wait() # start snoozing # self.mainGate.wait() # start snoozing
self.mainGate.release # self.mainGate.release
logger.debug("Main thread woke up") # Now not all threads had time to go to sleep self.crossPoolBarrier()
time.sleep(0.01) # This is like forever self.barrier.reset() # Other worker threads should now be at the "gate"
logger.debug("Main thread waking up, tapping worker threads".format(self.curStep)) # Now not all threads had time to go to sleep
# time.sleep(0.01) # This is like forever
self.curStep += 1 # starts with 0
self.tapAllThreads() self.tapAllThreads()
# The threads will run through many steps # The threads will run through many steps
...@@ -151,21 +186,29 @@ class SteppingThreadPool: ...@@ -151,21 +186,29 @@ class SteppingThreadPool:
logger.info("All threads finished") logger.info("All threads finished")
def reportThreadWaiting(self): def crossPoolBarrier(self):
allThreadWaiting = False if ( self.barrier.n_waiting == self.numThreads ): # everyone else is waiting, inc main thread
with self.lock: logger.info("<-- Step {} finished".format(self.curStep))
self.numWaitingThreads += 1 self.curStep += 1 # we are about to get into next step. TODO: race condition here!
if ( self.numWaitingThreads == self.numThreads ): logger.debug(" ") # line break
allThreadWaiting = True logger.debug("--> Step {} starts with main thread waking up".format(self.curStep)) # Now not all threads had time to go to sleep
if (allThreadWaiting): # aha, pass the baton to the main thread
logger.debug("All threads are now waiting") self.barrier.wait()
self.numWaitingThreads = 0 # do this 1st to avoid race condition # allThreadWaiting = False
# time.sleep(0.001) # thread yield, so main thread can be ready # with self.lock:
self.mainGate.acquire() # self.numWaitingThreads += 1
self.mainGate.notify() # main thread would now start to run # if ( self.numWaitingThreads == self.numThreads ):
self.mainGate.release() # allThreadWaiting = True
time.sleep(0) # yield, maybe main thread can run for just a bit
# if (allThreadWaiting): # aha, pass the baton to the main thread
# logger.debug("All threads are now waiting")
# self.numWaitingThreads = 0 # do this 1st to avoid race condition
# # time.sleep(0.001) # thread yield, so main thread can be ready
# self.mainGate.acquire()
# self.mainGate.notify() # main thread would now start to run
# self.mainGate.release()
# time.sleep(0) # yield, maybe main thread can run for just a bit
# def waitForStep(self): # def waitForStep(self):
# shouldWait = True; # shouldWait = True;
...@@ -201,6 +244,7 @@ class SteppingThreadPool: ...@@ -201,6 +244,7 @@ class SteppingThreadPool:
else: else:
wakeSeq.insert(0, i) wakeSeq.insert(0, i)
logger.info("Waking up threads: {}".format(str(wakeSeq))) logger.info("Waking up threads: {}".format(str(wakeSeq)))
# TODO: set dice seed to a deterministic value
for i in wakeSeq: for i in wakeSeq:
self.threadList[i].tapStepGate() self.threadList[i].tapStepGate()
time.sleep(0) # yield time.sleep(0) # yield
...@@ -208,10 +252,13 @@ class SteppingThreadPool: ...@@ -208,10 +252,13 @@ class SteppingThreadPool:
# A queue of continguous POSITIVE integers # A queue of continguous POSITIVE integers
class LinearQueue(): class LinearQueue():
def __init__(self): def __init__(self):
self.firstIndex = 1 self.firstIndex = 1 # 1st ever element
self.lastIndex = 0 self.lastIndex = 0
self.lock = threading.RLock() # our functions may call each other
self.inUse = set() # the indexes that are in use right now
def push(self): # Push to the tail (largest) def push(self): # Push to the tail (largest)
with self.lock:
if ( self.firstIndex > self.lastIndex ): # impossible, meaning it's empty if ( self.firstIndex > self.lastIndex ): # impossible, meaning it's empty
self.lastIndex = self.firstIndex self.lastIndex = self.firstIndex
return self.firstIndex return self.firstIndex
...@@ -220,18 +267,71 @@ class LinearQueue(): ...@@ -220,18 +267,71 @@ class LinearQueue():
return self.lastIndex return self.lastIndex
def pop(self): def pop(self):
if ( self.firstIndex > self.lastIndex ): # empty with self.lock:
return 0 if ( self.isEmpty() ):
raise RuntimeError("Cannot pop an empty queue")
index = self.firstIndex index = self.firstIndex
self.firstIndex += 1 self.firstIndex += 1
return index return index
def isEmpty(self):
return self.firstIndex > self.lastIndex
def popIfNotEmpty(self):
with self.lock:
if (self.isEmpty()):
return 0
return self.pop()
def use(self, i):
with self.lock:
if ( i in self.inUse ):
raise RuntimeError("Cannot re-use same index in queue: {}".format(i))
self.inUse.add(i)
def unUse(self, i):
with self.lock:
self.inUse.remove(i) # KeyError possible
def size(self):
return self.lastIndex + 1 - self.firstIndex
def allocate(self):
with self.lock:
cnt = 0 # counting the interations
while True:
cnt += 1
if ( cnt > self.size()*10 ): # 10x iteration already
raise RuntimeError("Failed to allocate LinearQueue element")
ret = Dice.throwRange(self.firstIndex, self.lastIndex+1)
if ( not ret in self.inUse ):
return self.use(ret)
# State of the database as we believe it to be # State of the database as we believe it to be
class DbState(): class DbState():
def __init__(self): def __init__(self):
self.tableNumQueue = LinearQueue() self.tableNumQueue = LinearQueue()
self.tick = datetime.datetime(2019, 1, 1) # initial date time tick
self.int = 0 # initial integer
self.openDbServerConnection() self.openDbServerConnection()
self.lock = threading.RLock()
def pickTable(self): # pick any table, and "use" it
return self.tableNumQueue.allocate()
def getNextTick(self):
with self.lock: # prevent duplicate tick
self.tick += datetime.timedelta(0, 1) # add one second to it
return self.tick
def getNextInt(self):
with self.lock:
self.int += 1
return self.int
def unuseTable(self, i): # return the table back, so others can use it
self.tableNumQueue.unUse(i)
def openDbServerConnection(self): def openDbServerConnection(self):
cfgPath = "../../build/test/cfg" # was: tdDnodes.getSimCfgPath() cfgPath = "../../build/test/cfg" # was: tdDnodes.getSimCfgPath()
...@@ -250,12 +350,15 @@ class DbState(): ...@@ -250,12 +350,15 @@ class DbState():
return "table_{}".format(tblNum) return "table_{}".format(tblNum)
def getTableNameToDelete(self): def getTableNameToDelete(self):
tblNum = self.tableNumQueue.pop() if self.tableNumQueue.isEmpty:
if( tblNum==0 ) :
return False return False
tblNum = self.tableNumQueue.pop() # TODO: race condition!
return "table_{}".format(tblNum) return "table_{}".format(tblNum)
class Task(): class Task():
def __init__(self, dbState):
self.dbState = dbState
def execute(self): def execute(self):
raise RuntimeError("Must be overriden by child class") raise RuntimeError("Must be overriden by child class")
...@@ -277,6 +380,9 @@ class DropTableTask(Task): ...@@ -277,6 +380,9 @@ class DropTableTask(Task):
class AddDataTask(Task): class AddDataTask(Task):
def execute(self): def execute(self):
logger.info(" Adding some data...") logger.info(" Adding some data...")
# ds = self.dbState
# tIndex = self.dbState.pickTable()
# tdSql.execute("insert into table_{} values ('{}', {});".format(tIndex, ds.getNextTick(), ds.getNextInt()))
# Deterministic random number generator # Deterministic random number generator
class Dice(): class Dice():
...@@ -312,13 +418,13 @@ class Dice(): ...@@ -312,13 +418,13 @@ class Dice():
# Anyone needing to carry out work should simply come here # Anyone needing to carry out work should simply come here
class WorkDispatcher(): class WorkDispatcher():
def __init__(self, pool): def __init__(self, pool, dbState):
self.pool = pool self.pool = pool
self.totalNumMethods = 2 # self.totalNumMethods = 2
self.tasks = [ self.tasks = [
CreateTableTask(), CreateTableTask(dbState),
DropTableTask(), DropTableTask(dbState),
AddDataTask(), # AddDataTask(dbState),
] ]
def throwDice(self): def throwDice(self):
...@@ -337,7 +443,7 @@ if __name__ == "__main__": ...@@ -337,7 +443,7 @@ if __name__ == "__main__":
Dice.seed(0) # initial seeding of dice Dice.seed(0) # initial seeding of dice
dbState = DbState() dbState = DbState()
threadPool = SteppingThreadPool(3, 5, 0) threadPool = SteppingThreadPool(dbState, 3, 5, 0)
threadPool.run() threadPool.run()
logger.info("Finished running thread pool") logger.info("Finished running thread pool")
dbState.closeDbServerConnection() dbState.closeDbServerConnection()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册