from microbit import *

HT16K33_ADDRESS = 0x70
HT16K33_BLINK_CMD = 0x80
HT16K33_BLINK_DISPLAYON = 0x01
HT16K33_CMD_BRIGHTNESS = 0xE0
HT16K33_OSCILATOR_ON = 0x21

class HT16K33(object):
    def __init__(self, address=HT16K33_ADDRESS):
        self.address = address
        self.temp = bytearray(1)
        self.buffer = bytearray(17)
        self.buffer[0] = 0x00
        self.fill(0)
        self.write_cmd(HT16K33_OSCILATOR_ON)
        self.blink_rate(0)
        self.brightness(1)

    def write_cmd(self, byte):
        self.temp[0] = byte
        i2c.write(self.address, self.temp)

    def blink_rate(self, rate=None):
        if rate is None:
            return self.blink_rate
        rate = rate & 0x02
        self.blink_rate = rate
        self.write_cmd(HT16K33_BLINK_CMD |
                       HT16K33_BLINK_DISPLAYON | rate << 1)

    def brightness(self, brightness):
        if brightness is None:
            return self.brightness
        brightness = brightness & 0x0F
        self.brightness = brightness
        self.write_cmd(HT16K33_CMD_BRIGHTNESS | brightness)

    def show(self):
        i2c.write(self.address, self.buffer)

    def fill(self, color):
        fill = 0xff if color else 0x00
        for i in range(16):
            self.buffer[i + 1] = fill

    def pixel(self, x, y, color=None):
        mask = 1 << x
        if color is None:
            return bool((self.buffer[y + 1] | self.buffer[y + 2] << 8) & mask)
        if color:
            self.buffer[y * 2 + 1] |= mask & 0xff
            self.buffer[y * 2 + 2] |= mask >> 8
        else:
            self.buffer[y * 2 + 1] &= ~(mask & 0xff)
            self.buffer[y * 2 + 2] &= ~(mask >> 8)


class Matrix8x8(HT16K33):
    def pixel(self, x, y, color=1):
        if not 0 <= x <= 7:
            return
        if not 0 <= y <= 7:
            return
        x = (x - 1) % 8
        return super().pixel(x, y, color)

    def clear(self):
        for y in range(8):
            for x in range(8):
                super().pixel(x, y, 0)
                super().show()

    def show_image(self,image):
        for y, line in enumerate(bytearray(image)):
            for x in range(8):
                if line & (1 << x):
                    if x==7:
                        x=0
                        super().pixel(x, y, 1)
                        super().show()
                        continue
                    if x==6:
                        x+=1
                        super().pixel(x, y, 1)
                        super().show()
                        continue
                    if x==5:
                        x-=4
                        super().pixel(x, y, 1)
                        super().show()
                        continue
                    if x==4:
                        x-=2
                        super().pixel(x, y, 1)
                        super().show()
                        continue
                    if x==3:
                        super().pixel(x, y, 1)
                        super().show()
                        continue
                    if x==2:
                        x+=2
                        super().pixel(x, y, 1)
                        super().show()
                        continue
                    if x==1:
                        x+=4
                        super().pixel(x, y, 1)
                        super().show()
                        continue
                    if x==0:
                        x+=6
                        super().pixel(x, y, 1)
                        super().show()
                        continue