#!/usr/bin/env python
# 2021 - Advent Of Code - 11

GRID_SIZE = 10


def parse_file(file):
    with open(file, encoding="utf-8") as f:
        columns = []
        for line in f.readlines():
            columns.append([int(c) for c in line.strip()])

        return columns


def print_grid(grid):
    for y in range(GRID_SIZE):
        for x in range(GRID_SIZE):
            print(grid[y][x], end='')
        print()
    print()


def check_flash(grid, f):
    for y in range(GRID_SIZE):
        for x in range(GRID_SIZE):
            if grid[y][x] > 9:
                f = flash(grid, x, y, f)
    return f


def flash(grid, x, y, f):
    # pylint: disable=R0912
    grid[y][x] = 0
    f += 1
    if y > 0:   # upper neighbours
        if x > 0:
            if grid[y-1][x-1] != 0:  # 1
                grid[y-1][x-1] += 1
                if grid[y-1][x-1] > 9:
                    f = flash(grid, x-1, y-1, f)
        if grid[y-1][x] != 0:  # 2
            grid[y-1][x] += 1
            if grid[y-1][x] > 9:
                f = flash(grid, x, y-1, f)

        if x < GRID_SIZE-1:
            if grid[y-1][x+1] != 0:  # 3
                grid[y-1][x+1] += 1
                if grid[y-1][x+1] > 9:
                    f = flash(grid, x+1, y-1, f)

    if x > 0:  # middle neighbours
        if grid[y][x-1] != 0:  # 4
            grid[y][x-1] += 1
            if grid[y][x-1] > 9:
                f = flash(grid, x-1, y, f)
    if x < GRID_SIZE - 1:  # 6
        if grid[y][x+1] != 0:
            grid[y][x+1] += 1
            if grid[y][x+1] > 9:
                f = flash(grid, x+1, y, f)

    if y < GRID_SIZE-1:  # down neighbours
        if x > 0:
            if grid[y+1][x-1] != 0:  # 7
                grid[y+1][x-1] += 1
                if grid[y+1][x-1] > 9:
                    f = flash(grid, x-1, y+1, f)
        if grid[y+1][x] != 0:    # 8
            grid[y+1][x] += 1
            if grid[y + 1][x] > 9:
                f = flash(grid, x, y+1, f)
        if x < GRID_SIZE-1:
            if grid[y+1][x+1] != 0:      # 9
                grid[y+1][x+1] += 1
                if grid[y+1][x+1] > 9:
                    f = flash(grid, x+1, y+1, f)
    return f


octopus_map = parse_file('input.txt')
# octopus_map = parse_file('input_example.txt')
print_grid(octopus_map)

nbflash = 0
for step in range(100):
    # increase all energy levels
    for mapy in range(GRID_SIZE):
        for mapx in range(GRID_SIZE):
            octopus_map[mapy][mapx] += 1

    nbflash = check_flash(octopus_map, nbflash)
    print_grid(octopus_map)

print(f'number of flash: {nbflash}')