#!usr/bin/env python3
#
# Simple multi-threading benchmark, rolling dice.

import sys
import time
import random
import concurrent.futures
from collections import Counter

NUM_THREADS = 10
TOTAL_TURNS = 10_000_000
WORKER_TURNS = 10_000
assert TOTAL_TURNS % WORKER_TURNS == 0


def get_results(turns=WORKER_TURNS):
    rnd = random.Random()
    totals = []
    for _ in range(turns):
        # Roll 3 pairs of dice
        doubles = 0
        total = 0
        while True:
            d1 = rnd.randint(1, 7)
            d2 = rnd.randint(1, 7)
            total += d1 + d2
            # Stop when you fail to roll a double
            if d1 != d2:
                break
            doubles += 1
            # Three doubles and you go to gaol
            if doubles == 3:
                break
        totals.append(total)
    return totals


def main_threads():
    results = Counter()
    with concurrent.futures.ThreadPoolExecutor() as exc:
        futures = [
            exc.submit(get_results)
            for _ in range(TOTAL_TURNS // WORKER_TURNS)
        ]
        for future in concurrent.futures.as_completed(futures):
            for total in future.result():
                results[total] += 1
    print(results)


def main():
    results = Counter()
    for _ in range(TOTAL_TURNS // WORKER_TURNS):
        for total in get_results():
            results[total] += 1
    print(results)


if __name__ == '__main__':
    start = time.perf_counter()
    if '-t' in sys.argv:
        main_threads()
    else:
        main()
    print('time', time.perf_counter() - start)
