env.py 1.4 KB
Newer Older
W
Add env  
wanghaoshuang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
import time
import sys
import numpy as np


class Env():
    def __init__(self, stage_len, interval):
        self.stage_len = stage_len
        self.end = self.stage_len - 1
        self.position = 0
        self.interval = interval
        self.step = 0
        self.epoch = -1
        self.render = False

    def reset(self):
        self.end = self.stage_len - 1
        self.position = 0
        self.epoch += 1
        self.step = 0
        if self.render:
            self.draw(True)

    def status(self):
        s = np.zeros([self.stage_len]).astype("float32")
        s[self.position] = 1
        return s

    def move(self, action):
        self.step += 1
        reward = 0.0
        done = False
        if action == 0:
            self.position = max(0, self.position - 1)
        else:
            self.position = min(self.end, self.position + 1)
        if self.render:
            self.draw()
        if self.position == self.end:
            reward = 1.0
            done = True
        return reward, done, self.status()

    def draw(self, new_line=False):
        if new_line:
            print ""
        else:
            print "\r",
        for i in range(self.stage_len):
            if i == self.position:
                sys.stdout.write("O")
            else:
                sys.stdout.write("-")
        sys.stdout.write("    epoch: %d; steps: %d" % (self.epoch, self.step))
        sys.stdout.flush()
        time.sleep(self.interval)