import gym
import numpy as np

class GameControl:

    def __init__(self, game, agent, renderMode='human'):
        self.game = game
        self.env = gym.make(self.game)
        self.agent = agent
        self.dynInfo = DynamicInfo()
        self.score = (0, 0)
        self.renderMode = renderMode
        self.play(self.env)

    def play(self, env):
        (obs, reward) = (self.env.reset(), 0)
        stepCount = 0
        done = False
        while not done:
            stepCount += 1
            self.dynInfo.extractInfo(obs)
            if stepCount >  45 and self.dynInfo.puckDirection == self.dynInfo.east:
                print(f'{self.dynInfo.str()}. >> ', end='')
            # The system uses the first 25 steps to find the upper and lower bounds for the field
            # and then to wait for the first serve.
            inp = 2 if stepCount <= 10 else \
                  3 if stepCount <= 24 else \
                  2 if stepCount <= 29 else \
                  0 if stepCount <= 45 else \
                  self.agent.makeMove(stepCount, self.dynInfo, reward)
            # The only valid actions are 0 .. 6.
            # They mean: nop, nop, up, down, up down
            action = int(inp) if str(inp) in '012345' and inp != '' else 0
            (obs, reward, done, debug) = env.step(action)
            if stepCount > 45 and self.dynInfo.puckDirection == self.dynInfo.east:
                print(f'{inp}')
            if reward != 0:
                (computer, player) = self.score
                pointWinner = {-1: 'Computer', 1:'Agent'}[reward]
                (computer, player) = self.score = \
                    {'Computer': (computer+1, player), 'Agent': (computer, player+1)}[pointWinner]
                print(f'\n\t\t{pointWinner} won a point. Computer: {computer}, Agent: {player}\n')
            self.env.render(self.renderMode)
        print(input("Game over. Take a picture?"))
        env.close()  # https://github.com/openai/gym/issues/893


class DynamicInfo():

    def __init__(self):
        self.blocks = None

        self.yTop = None
        self.yBottom = None

        self.puckHalfHeight = 1
        self.paddleHalfHeight = 7

        self.puck = None
        self.puckSlopeList = []
        self.puckSlope = None
        self.puckDirection = None
        self.puckIntercept = None

        self.paddle = None
        self.paddleVelocity = None
        self.paddleTarget = None

        self.east = '-->'
        self.west = '<--'
        self.NoneStr = '**'

    def allInfoDefined(self):
        return self.puckIntercept is not None and \
               self.paddleTarget is not None and \
               self.puckDirection == self.east

    def extractInfo(self, obs):
        self.blocks = self.getBlocks(obs)

        # Find the topmost and bottommost rows, i.e., the rows that frame the field.
        # Since this system counts rows downward, the topmost/bottommost row is the 
        # row with lowest/highest number.
        blocksRows = [row for (row, _) in self.blocks]
        blockYTop    = min([rowTop    for (rowTop, _) in blocksRows]) - 1
        blockYBottom = max([rowBottom for (_, rowBottom) in blocksRows]) + 1
        self.yTop =    blockYTop    if self.yTop    is None else min(self.yTop,    blockYTop)
        self.yBottom = blockYBottom if self.yBottom is None else max(self.yBottom, blockYBottom)

        if (len(self.blocks)) == 3:
            self.computePuckIntercept()
            self.computePaddleTarget()
        else:
            self.puckDirection = None
            self.puckIntercept = None
            self.puck = None
            self.puckSlopeList = []
        return self

    def computePaddleTarget(self):
        paddle = self.blocks[-1]
        # Care only about the right-most paddle col. Increases chance of hitting puck with top
        # or bottom of paddle. Those create high-angle shots, which often win points.
        # Also, creates margin of error for intercept on defense.
        paddleRightCol = paddle[1][1]
        paddleMeanRow = self.mean(paddle[0])
        # self.paddle will be None initially.
        self.paddleVelocity = None if self.paddle is None else int(paddleMeanRow - self.paddle[0])
        self.paddle = (paddleMeanRow, paddleRightCol)
        self.paddleTarget = None if self.paddleVelocity is None else paddleMeanRow + self.paddleVelocity

    def computePuckIntercept(self):
        newPuckBlock = self.blocks[1]
        # We only care about the left puck column. (Same reasoning as rightmost passdle col.)
        newPuckLeftCol = newPuckBlock[1][0]
        newPuckMeanRow = self.mean(newPuckBlock[0])
        newPuck = (newPuckMeanRow, newPuckLeftCol)
        self.puckSlope = None
        if self.puck is not None:
            self.puckDirection = self.east if newPuck[1] > self.puck[1] else self.west
            divisor = newPuck[1] - self.puck[1]
            slope = 0 if divisor == 0 else (newPuck[0] - self.puck[0]) / divisor
            self.puckSlopeList.append(int(round(slope, 0)))
            # Keep the two most recent slopes. (The '-' in front of min is negative indexing.)
            self.puckSlopeList = self.puckSlopeList[-min(2, len(self.puckSlopeList)):]
            # If the two most recent slopes are the same, take that as the slope.
            if len(self.puckSlopeList) > 1 and self.puckSlopeList[-1] == self.puckSlopeList[-2]:
                self.puckSlope = self.puckSlopeList[-1]
        self.puck = newPuck
        # Compute the proected intercept of the puck with the paddle column.
        if self.puckSlope is not None and self.puckDirection == self.east:
            distance = self.paddle[1] - self.puck[1]
            self.puckIntercept = self.puck[0] + distance * self.puckSlope
            # If the intercept is above or below the frame of the board, fold it over.
            # Need three comparisons to allow for hitting both top and bottom.
            # This system counts rows from top down. yTop is smaller than yBottom.
            if self.puckIntercept < self.yTop:
                self.puckIntercept += 2 * (self.yTop - self.puckIntercept)
            if self.puckIntercept > self.yBottom:
                self.puckIntercept -= 2 * (self.puckIntercept - self.yBottom)
            if self.puckIntercept < self.yTop:
                self.puckIntercept += 2 * (self.yTop - self.puckIntercept)

    def getBlocks(self, obs):
        """
        Returns a list of blocks. Each block is ((row_top, row_bottom), (col_left, col_right)).
        :param obs: An observation
        :return: [((int, int), (int, int))]
        """
        # Drop the top 30 rows and the leftmost 15 columns
        obs = np.delete(obs[30:], list(range(15)), 1)
        (rows, cols, _) = np.shape(obs)
        # Build a dictionary with the columns as keys. 
        # The values are the max/min rows for cols that include non-background cells.
        # Use obs[row, 0] as a background cell to compare against.
        cDict = {col: self.minAndMax([row for row in range(rows) if np.any((obs[row, col] != obs[row, 0]))])
                 for col in range(cols)}
        # Build a set of the row max/min ranges. This will include None for those cols with no
        # non-background cells. These are ignored below. None is treated as boolean False.
        rMinsAndMaxesSet = set(cDict.values())
        # For each row max/min value associate its col max/min range.
        # These are the blocks.
        blocks = [(rMinMax, self.minAndMax([c for c in cDict if cDict[c] == rMinMax]))
                  for rMinMax in rMinsAndMaxesSet if rMinMax]
        # Return the blocks sorted by column, i.e., from left to right on the screen.
        return sorted(blocks, key=(lambda rc: rc[1]))

    @staticmethod
    def mean(elts: tuple) -> float:
        return sum(elts) / len(elts)

    @staticmethod
    def minAndMax(lst):
        return (min(lst), max(lst)) if lst else None

    def paddleTargetBottom(self):
        return None if self.paddleTarget is None else self.paddleTarget + self.paddleHalfHeight

    def paddleTargetTop(self):
        return None if self.paddleTarget is None else self.paddleTarget - self.paddleHalfHeight

    def puckInterceptBottom(self):
        return None if self.puckIntercept is None else self.puckIntercept + self.puckHalfHeight

    def puckInterceptTop(self):
        return None if self.puckIntercept is None else self.puckIntercept - self.puckHalfHeight

    def separation(self):
        return None if self.paddle is None or self.puck is None else self.paddle[1]-self.puck[1]

    def str(self):
        """
        A string representation of the important features of the world.
        :return: 
        """
        st = f'puck: [{self.toStr(self.puck[0])}] {self.toStr(self.puckSlope)} {self.puckDirection} ' + \
             f'{self.toStr(self.puckInterceptTop())} - {self.toStr(self.puckInterceptBottom())};   '  + \
             f'paddle: [{self.paddle[0]}] {self.paddleVelocity} --> '     + \
             f'{self.paddleTargetTop()} - {self.paddleTargetBottom()};  ' + \
             f'separation: {self.separation()}'
        return st

    def toStr(self, x):
        """
        :param x: Any value
        :return: the NoneStr ('**') if x is None; otherwise str(x)
        """
        return self.NoneStr if x is None else str(x)


class Bot:
    """
    The Bot player
    """
    def __init__(self):
        self.puckInterceptTop = None
        self.puckInterceptBottom = None
        self.paddleTargetTop = None
        self.paddleTargetBottom = None
        self.paddle = None
        self.moveFrequency = 4

    def makeMove(self, stepCount, dynInfo, reward):
        # Cache these values in case we lose track of the puck or the paddle.
        if dynInfo.allInfoDefined():
            self.puckInterceptTop = dynInfo.puckInterceptTop()
            self.puckInterceptBottom = dynInfo.puckInterceptBottom()
            self.paddleTargetTop = dynInfo.paddleTargetTop()
            self.paddleTargetBottom = dynInfo.paddleTargetBottom()
            self.paddle = dynInfo.paddle[0]

        # The first two rows are intended to prevent to paddle from going too far
        # out of the field. It doesn't seem to work.
        # Notice the third row down. We take an action only every third opportunity.
        # This gives the paddle a change to settle down.
        return 3 if self.paddle < dynInfo.yTop else \
               2 if self.paddle > dynInfo.yBottom else \
               0 if stepCount % (2 if reward == -1 else self.moveFrequency) != 0 else \
               3 if self.puckInterceptTop > self.paddleTargetBottom else \
               2 if self.puckInterceptBottom < self.paddleTargetTop else \
               0


class Human:
    """
    The Human player
    """
    @staticmethod
    def makeMove(stepCount, dynInfo, reward):
        return input(f'{stepCount}. >> ')


if __name__ == '__main__':
    # To have the bot play, use Bot() as second argument.
    # To have a human play, use Human() as second argument.
    GameControl('Pong-v0', Bot())