#!/usr/bin/env python3
import sys
import random
import time
import threading

import pymysql.cursors


def open_conn():
    return pymysql.connect(user='test',
                           unix_socket='/var/run/mysqld/mysqld.sock',
                           db='test',
                           charset='utf8mb4',
                           cursorclass=pymysql.cursors.DictCursor,
                           autocommit=False)


def setup(num_rows):
    start = time.perf_counter()
    conn = open_conn()
    try:
        with conn.cursor() as cur:
            cur.execute('TRUNCATE test')
            sql = 'INSERT INTO test (id, value) VALUES (%s, %s)'
            params = []
            for i in range(num_rows):
                params.append((i, 1))
            cur.executemany(sql, params)
            conn.commit()
    finally:
            conn.close()
    end = time.perf_counter()
    print('setup in {:.3f} sec.'.format(end - start))


def run_test(num_threads, batch_size):
    test_start = time.perf_counter()
    conn = open_conn()
    threads = []
    try:
        conn.begin()
        with conn.cursor() as cur:
            for i in range(num_threads):
                base_id = i * batch_size
                ids = [str(i) for i in range(base_id, base_id + batch_size)]
                sql = 'UPDATE test SET value = 2 WHERE id in (' + ','.join(ids) + ')'
                cur.execute(sql)
                threads.append(run_subconn(base_id))
            conn.commit()
    finally:
        test_end = time.perf_counter()
        conn.close()
        for t, conn2 in threads:
            t.join()
            conn2.close()

    print('test in {:.3f} sec.'.format(test_end - test_start))


def run_subconn(row_id):
    conn = open_conn()
    conn.begin()
    cur = conn.cursor()

    def worker():
        try:
            cur.execute('UPDATE test SET value = 3 WHERE id = %s', (row_id))
        except pymysql.err.InternalError as e:
            print(e, file=sys.stderr)
        cur.close()

    t = threading.Thread(target=worker)
    t.start()
    return t, conn


def main():
    num_threads = 100
    batch_size = 10000
    num_rows = num_threads * batch_size
    setup(num_rows)
    run_test(num_threads, batch_size)


if __name__ == '__main__':
    main()
