from itertools import permutations, chain, combinations
import copy

def check(square, line):
    for i in range(len(line)):
        number = line[i]
        for j in range(len(square)):
            if number == square[j][i]:
                return False
    return True

def all_possible_lines(n, square):
    for perm in permutations(range(1,n+1)):
        if check(square, perm):
            yield perm

def latins(n):
    return _latins(n, [])
    
def _latins(n, square):
    if len(square) == n:
        yield square
    else:
        for line in all_possible_lines(n, square):
            for new_square in _latins(n, square + [line]):
                yield new_square

def print_square(square):
    for line in square:
        print(line)

def permutations_with_constraints(elements, constraints): #constraints is a list of lists, specifying for each position what can't be there
    return _permutations_with_contraints([], len(elements), elements, constraints)

def _permutations_with_contraints(permutation, needed_length, elements, constraints):
    if len(permutation) == needed_length:
        yield permutation
    else:
        for element in elements:
            if element in constraints[len(permutation)]:
                continue
            else:
                new_elements = copy.copy(elements)
                new_elements.remove(element)
                for perm in _permutations_with_contraints(permutation + [element], needed_length, new_elements, constraints):
                    yield perm

def check_from_column_k_to_line_l(k, l, square, line): #l is exclusive
    for i in range(len(line)):
        number = line[i]
        for j in range(l):
            if number == square[j][k + i]:
                return False
    return True
  
def all_possible_symred_lines(n, lines, square):
    elements = [ i for i in range(1, n + 1) if i not in square[lines] ]
    for perm in permutations_with_constraints(elements, square[lines:]):
        #if check_from_column_k_to_line_l(lines, lines, square, perm):
        yield perm

def symred_latins(n):
    square = []
    square.append(list(range(1, n + 1)))
    for i in range(2, n + 1):
        square.append([ i ])
    return _symred_latins(n, 1, square)
    
def _symred_latins(n, lines, square):
    if lines == n:
        yield square
    else:
        for line in all_possible_symred_lines(n, lines, square):
            new_square = copy.deepcopy(square)
            new_square[lines] += line
            l_counter = lines + 1
            for i in range(l_counter, n):
                new_square[i].append(line[i - lines])
            for result in _symred_latins(n, lines + 1, new_square):
                yield result

def sign(perm):
    inversion_count = 0
    for i in range(len(perm)):
        element = perm[i]
        for later_element in perm[i+1:]:
            if later_element < element:
                inversion_count += 1
    if inversion_count % 2 == 0:
        return 1
    else:
        return -1

def transpose(square):
    return [[square[j][i] for j in range(len(square))] for i in range(len(square[0]))]

def parity(square):
    square_t = transpose(square)
    p = 1
    for i in range(len(square)):
        p *= sign(square[i]) * sign(square_t[i])
    return p

def red_latins(n):
    square = []
    square.append(list(range(1, n + 1)))
    for i in range(2, n + 1):
        square.append([ i ])
    constraints = copy.deepcopy(square[1:])
    return _red_latins(n, 1, square, constraints)

def _red_latins(n, lines, square, constraints):
    if lines == n:
        yield square
    else:
        for line in all_possible_red_lines(n, lines, square, constraints):
            new_square = copy.deepcopy(square)
            new_constraints = copy.deepcopy(constraints)
            new_square[lines] += line
            for i in range(len(line)):
                new_constraints[i].append(line[i])
            for result in _red_latins(n, lines + 1, new_square, new_constraints):
                yield result

def all_possible_red_lines(n, lines, square, constraints):
    elements = [ i for i in range(1, n + 1) ]
    elements.remove(lines + 1)
    for perm in permutations_with_constraints(elements, constraints):
        yield perm

def _invert_cycle(cycle): #cycles are tuples, and no 1-length cycles are possible
    return (cycle[0],) + cycle[1:][::-1]

def invert_permutation(permutation): #permutations are given as a list of cycles, the elements are 1,..,n
    return [_invert_cycle(cycle) for cycle in permutation]

def image(element, permutation):
    for cycle in permutation:
        try:
            i = cycle.index(element)
            return cycle[(i+1) % len(cycle)]
        except ValueError:
            continue

def has_cycle_length(n, permutation):
    for cycle in permutation:
        if len(cycle) == n:
            return True
    return False

def _mult_build_cycle(element, permutation1, permutation2):
    cycle = (element,)
    last_element = image(image(element, permutation1), permutation2)
    while last_element != element:
        cycle += last_element,
        last_element = image(image(last_element, permutation1), permutation2)
    return cycle

def mult_permutations(permutation1, permutation2):
    elements = sorted(list(chain.from_iterable(permutation1)))
    p1_times_p2 = []
    while len(elements) != 0:
        cycle = _mult_build_cycle(elements[0], permutation1, permutation2)
        p1_times_p2.append(cycle)
        elements = [element for element in elements if element not in cycle]
    return p1_times_p2

def _build_cycle(element, perm): #perm is just a list of elements, like the line of a square
    cycle = (element,)
    last_element = perm[element - 1]
    while last_element != element:
        cycle += last_element,
        last_element = perm[last_element - 1]
    return cycle

def cycle_notation(perm): #perm is just a list of elements, like the line of a square
    elements = sorted(perm)
    cycles = []
    while len(elements) != 0:
        cycle = _build_cycle(elements[0], perm)
        cycles.append(cycle)
        elements = [element for element in elements if element not in cycle]
    return cycles
        
def square_has_3_cycle(square):
    for line1, line2 in combinations(square[1:], 2):
        p1 = cycle_notation(line1)
        p2 = cycle_notation(line2)
        m = mult_permutations(invert_permutation(p1), p2)
        if has_cycle_length(3, m):
            return True
    return False

def has_correct_cycle(i, j, permutation): #permutation has an odd-length cycle that contains both i and j
    for cycle in permutation:
        if len(cycle) % 2 == 1 and i in cycle and j in cycle:
            return True
    return False

def h_condition(square): # checks the H. condition for all line pairs (but not column pairs)
    for i, j in combinations(range(1, len(square)), 2):
        p1 = cycle_notation(square[i])
        p2 = cycle_notation(square[j])
        m = mult_permutations(invert_permutation(p1), p2)
        if has_correct_cycle(i + 1, j + 1, m):
            return True
    return False

def h_counterexamples_old(n):
    counter = 0
    for square in red_latins(n):
        if parity(square) == -1:
            counter += 1
            if counter % 1000 == 0:
                print("{} odd squares checked".format(counter))
            if not h_condition(square) and not h_condition(transpose(square)):
                print("[{}]".format(counter))
                print_square(square)
                yield square
    print("Total odd squares checked: {}".format(counter))

def red_latins_not_h(n): #red_latins that fail h_condition
    square = []
    square.append(list(range(1, n + 1)))
    for i in range(2, n + 1):
        square.append([ i ])
    constraints = copy.deepcopy(square[1:])
    return _red_latins_not_h(n, 1, square, constraints)

def _red_latins_not_h(n, lines, square, constraints):
    if lines >= 3:
        p2 = cycle_notation(square[lines-1])
        for i in range(1, lines-1):
            p1 = cycle_notation(square[i])
            m = mult_permutations(invert_permutation(p1), p2)
            if has_correct_cycle(i + 1, lines, m):
                return
    if lines == n:
        yield square
    else:
        for line in all_possible_red_lines(n, lines, square, constraints):
            new_square = copy.deepcopy(square)
            new_constraints = copy.deepcopy(constraints)
            new_square[lines] += line
            for i in range(len(line)):
                new_constraints[i].append(line[i])
            for result in _red_latins_not_h(n, lines + 1, new_square, new_constraints):
                yield result

def h_counterexamples(n):
    counter = 0
    for square in red_latins_not_h(n):
        if parity(square) == -1:
            counter += 1
            if counter % 1000 == 0:
                print("{} bad squares checked".format(counter))
            if not h_condition(transpose(square)):
                print("[{}]".format(counter))
                print_square(square)
                yield square
    print("Total bad squares checked: {}".format(counter))


def h_lines(square):
    for i, j in combinations(range(1, len(square)), 2):
        p1 = cycle_notation(square[i])
        p2 = cycle_notation(square[j])
        m = mult_permutations(invert_permutation(p1), p2)
        if has_correct_cycle(i + 1, j + 1, m):
             return i+1, j+1, m


def h_lines_gen(square):
    for i, j in combinations(range(1, len(square)), 2):
        p1 = cycle_notation(square[i])
        p2 = cycle_notation(square[j])
        m = mult_permutations(invert_permutation(p1), p2)
        if has_correct_cycle(i + 1, j + 1, m):
             yield i+1, j+1, m


def h_transform(square):
    h_lines_square = h_lines(square)
    if h_lines_square is not None:
        return _h_transform(square, *h_lines_square)
    else:
        square_t = transpose(square)
        h_lines_square_t = h_lines(square_t)
        if h_lines_square_t is not None:
            return transpose(_h_transform(square_t, *h_lines_square_t))


def _h_transform(square, i, j, permutation):
    fix_cycle = None
    for cycle in permutation:
        if i in cycle and j in cycle:
            fix_cycle = cycle
            break
    new_square = copy.deepcopy(square)
    for index in range(0, len(new_square[0])):
        if new_square[i - 1][index] not in fix_cycle:
            temp = new_square[i - 1][index]
            new_square[i - 1][index] = new_square[j - 1][index]
            new_square[j - 1][index] = temp
    return new_square


def h_images(square):
    yield from h_images_rows(square)
    yield from h_images_cols(square)


def h_images_rows(square):
    for i, j, m in h_lines_gen(square):
        yield _h_transform(square, i, j, m)


def h_images_cols(square):
    square_t = transpose(square)
    for i, j, m in h_lines_gen(square_t):
        yield transpose(_h_transform(square_t, i, j, m))


def tuple_square(square):
    return tuple(tuple(line) for line in square)


def list_square(square):
    return [list(line) for line in square]


def _experiment_1(p): #p in {-1,1}
    images = []
    count = 0
    no_image = []
    for square in red_latins(6):
        if parity(square) == p:
            image_found = False
            for i,j,m in h_lines_gen(square):
                img = _h_transform(square,i,j,m)
                if img not in images:
                    image_found = True
                    images.append(img)
                    break
            if not image_found:
                square_t = transpose(square)
                for i,j,m in h_lines_gen(square_t):
                    img = transpose(_h_transform(square_t,i,j,m))
                    if img not in images:
                        image_found = True
                        images.append(img)
                        break
            if not image_found:
                no_image.append(square)
                count += 1
                print("Image not found!")
                print_square(square)
                print("-----")
    print("Images: {}".format(len(images)))
    print("Not found: {}".format(count))
    return images


def h_images_orbit_rows(square, orbit_size_limit=1000000):
    return _h_images_orbit(square, h_images_rows, orbit_size_limit)


def h_images_orbit_cols(square, orbit_size_limit=1000000):
    return _h_images_orbit(square, h_images_cols, orbit_size_limit)


def h_images_orbit(square, orbit_size_limit=1000000):
    return _h_images_orbit(square, h_images, orbit_size_limit)


def _h_images_orbit(square, h_img_gen, orbit_size_limit=1000000):
    open = set([tuple_square(square)])
    closed = set()
    #counter = 0 ########
    while len(open) != 0:
        if len(open) + len(closed) > orbit_size_limit:
            raise OrbitSizeException
        ########
        #counter += 1
        #if counter % 1000 == 0:
        #    print(len(open), len(closed))
        ########
        current = next(iter(open))
        list_current = list_square(current)
        for img in h_img_gen(list_current):
            tuple_img = tuple_square(img)
            if tuple_img not in open and tuple_img not in closed:
                open.add(tuple_img)
        closed.add(current)
        open.remove(current)
    return closed


class OrbitSizeException(Exception):
    """Thrown when the orbit of a square is too big."""
    pass
