#!/usr/bin/python3
# connections from python

import psycopg2			# docs: http://initd.org/psycopg/docs/
import MySQLdb
import random
import threading
import time
import sys


# basic test is (25,40)
NUMTHREADS=25
PERTHREAD=40

# NUMTHREADS = 1
# PERTHREAD  = 1000

DB = None

MAXPART = 1000
MAXQUAN = 5
MAXCUST = 2000

SORT=False	# sorting queries reduces deadlocks; default False
REQLOCKS=True	# should we request locks?           default False
NOWAIT=True	# should we wait for locks?          default False

RETRYCOUNT = 10	# number of times we retry asking for locks

ISOLATION='read committed'	# weakest
#ISOLATION='repeatable read'
#ISOLATION='serializable'

globalpartquan = 0	# try to get total number of parts sold
glock = threading.Lock()

def makemconn():
    try:
        conn = MySQLdb.connect(user='pld', passwd='blue', db='transactions')
    except MySQLdb.DatabaseError:
        print('Cannot connect to the database: ')
        # sys.exit(1)
        return None
    return conn

def makepconn():
    try:
        conn = psycopg2.connect("dbname=transactions user=pld")
    except psycopg2.DatabaseError:
        print('Cannot connect to the database: ')
        # sys.exit(1)
        return None
    return conn

def makeconn():
   if DB == 'MYSQL': return makemconn()
   elif DB == 'POSTGRES': return makepconn()
   else: return None

# cur = conn.cursor()

isolation       = 'set transaction isolation level '
pinvoice_insert = 'insert into invoice(ordertime, custid) values (now(), %s) returning invnum;'	# custid
minvoice_insert = 'insert into invoice(ordertime, custid) values (now(), %s);'			# custid
invitem_insert  = 'insert into invitem values(%s, %s, %s);'   # (invnum, partnum, quan)
stock_update    = 'update part set quan_in_stock = quan_in_stock - %s where partnum = %s;'  	# quan, partnum
if NOWAIT:
    plock           = 'select partnum from part where partnum = %s for no key update nowait;'
else:
    plock           = 'select partnum from part where partnum = %s for no key update;'

mlock           = 'select partnum from part where partnum = %s for update;'

# return values, not to be changed:

TNOLOCK  = -2
TFAILURE = -1
TSUCCESS =  0

# returns TSUCCESS for success, TFAILURE for general failure, NOLOCK for lock failure
def doptrans(custid, purchaselist):
    invnum = -1
    try:
        conn = makeconn()
        if conn is None: return TFAILURE
        cur = conn.cursor()
        # cur.execute('begin;')
        cur.execute(isolation + ISOLATION + ';')
        if REQLOCKS:
            for (p,q) in purchaselist: 
                cur.execute(plock, (p,))
    except psycopg2.Error as e:
            #print ('lock request failed: {}'.format(e.pgerror))
            print ('.', end='')
            return TNOLOCK
    try:
        cur.execute(pinvoice_insert, (custid,))
        row = cur.fetchone()
        invnum = row[0]
        for (part,quan) in purchaselist:
            cur.execute(invitem_insert, (invnum, part, quan))
            cur.execute(stock_update, (quan, part))
        conn.commit()
        return TSUCCESS
    except psycopg2.Error as e:
        print ('transaction failed: {}'.format(e.pgerror))
        print ('customer={}, invoice={}, purchaselist={}'.format(custid, invnum, purchaselist))
        conn.rollback()
        return TFAILURE

# mysql transaction
def domtrans(custid, purchaselist):
    invnum = -1
    try:
        conn = makeconn()
        if conn is None: return TFAILURE
        cur = conn.cursor()
        # cur.execute('begin;')
        cur.execute(isolation + ISOLATION + ';')
        if REQLOCKS:
            for (p,q) in purchaselist: 
                cur.execute(mlock, (p,))
    except MySQLdb.Error as e:
            e = str(e)
            print ('transaction failed: {}'.format(e))
            return TNOLOCK
    try:
        cur.execute(minvoice_insert, (custid,))
        cur.execute('select last_insert_id();')
        row = cur.fetchone()
        invnum = row[0]
        for (part,quan) in purchaselist:
            cur.execute(invitem_insert, (invnum, part, quan))
            cur.execute(stock_update, (quan, part))
        conn.commit()
        return TSUCCESS
    except MySQLdb.Error as e:
        e = str(e)
        print ('transaction failed: {}'.format(e))
        print ('customer={}, invoice={}, purchaselist={}'.format(custid, invnum, purchaselist))
        conn.rollback()
        return TFAILURE

# do transaction with retry
# returns pair (TARG, count); TARG in (TSUCCESS, TFAILURE, TNOLOCK)
# TSUCCESS: count = # of retries
# TFAILURE: count = 0
# TNOLOCK:  count = # of retries (presumably RETRYCOUNT)
def rdotrans(tid, custid, purchaselist, maxretries=RETRYCOUNT):
    rtcount = 0
    while rtcount < maxretries:
        rtcount += 1
        if DB=='POSTGRES': 
            rval = doptrans(custid, purchaselist)
        if DB=='MYSQL':
            rval = domtrans(custid, purchaselist)
        if rval == TSUCCESS: return (TSUCCESS, rtcount)
        if rval == TFAILURE: return (TFAILURE, 0)
    print('transaction in thread {} failed after {} tries'.format(tid, rtcount))
    return (TNOLOCK, rtcount)


# runs count transactions; returns (successcount, failurecount, nolockcount, successtries)
# failure count: count of transactions that aborted after the lock phase
# nolockcount: count of transactions that continuously aborted in the lock phase
def runmany(tid, count):
    committedcount = 0
    fcount = 0
    lcount = 0
    tries=0
    totalquan = 0
    for i in range(count):
        plist = makeplist(10)		# (part,quan) pairs
        cust  = makecust()

        totalquan += sum(map(lambda x: x[1], plist))

        (rt,cnt) =  rdotrans(tid, cust, plist)
        if rt == TSUCCESS: 
            committedcount +=1
            tries += cnt
        elif rt == TFAILURE: fcount += 1
        elif rt == TNOLOCK:  lcount += 1
    return (committedcount, fcount, lcount, tries, totalquan)


def makeplist(len):
    plist = []
    for i in range(len):
        part = random.randrange(1,MAXPART+1)
        for (p,_) in plist:		# make sure we have *different* parts
            if part == p: 
                part=None
                continue
        if part is not None:
            quan = random.randrange(1,MAXQUAN+1)
            plist.append((part,quan))
    if SORT: plist.sort(key=lambda x: x[0])
    return plist

def makecust():
    return random.randrange(1, MAXCUST+1)

plist = [ (309,2), (714, 4), (666, 1), (15,10)]

lfcounts = []

class transactionthread (threading.Thread):
    def __init__(self, threadID):
        threading.Thread.__init__(self)
        self.threadID = threadID
        self.name = "thread-" + str(threadID)
        #self.counter = counter
    def run(self):
        global lfcounts, globalpartquan
        print ("Starting " + self.name)
        (comm, fc, nlc, tries, tquan) = runmany(self.threadID, PERTHREAD)
        print ('Exiting {}, {}/{} committed, {} tries, {} failures, {} lockfails'.format(self.name, comm, PERTHREAD, tries, fc, nlc))
        lfcounts[self.threadID] = nlc
        glock.acquire()
        globalpartquan += tquan
        glock.release()
        

def startthreads(count):
    threadlist = []
    for tnum in range(count):
        t = transactionthread(tnum)
        threadlist.append(t)
    starttime = getcurrentmillis()
    for t in threadlist:
        t.start()
    for t in threadlist:
        t.join()
    elapsed = getcurrentmillis() - starttime
    return elapsed

def main():
    global DB, lfcounts
    if len(sys.argv) >=2: 
        DB = sys.argv[1]
        DB = DB.upper()
    else:
        print('supply a database!')
        return
    print('connecting to {}'.format(DB))
    for tnum in range(NUMTHREADS): lfcounts.append(0)
    elapsed = startthreads(NUMTHREADS)
    print ('elapsed time: {} ms'.format(elapsed))
    sum=0
    for c in lfcounts: sum+=c
    print ('total lock failures: {}'.format(sum))
    print('Isolation level: '+ISOLATION)
    print('Total parts quantity = {}'.format(globalpartquan))

def getcurrentmillis():
    return int(round(time.time()*1000))


main()


# while True:
#     row = cur.fetchone()
#     if row==None: break
#     print(row)


# basic transaction: 
# create new invoice_item
# decrement stock

# TIMING INFORMATION
# postgres, repeatable read
# straight run of 1000: 9373 ms, 9782 ms, 9506 ms
# (25,40) WITH NOWAIT:
#  9819 ms, 12 lockfails / RETRYCOUNT=20
# 10192 ms,  6 lockfails
#  9417 ms,  9 lockfails
#  9892 ms, 12 lockfails
# (25,40), postgres, without nowait, read committed:  4380 ms, 4495 ms, 4344 ms
# (25,40), postgres, without nowait, read committed, NO SORTING: 18693 ms, 35503 ms, 27377 ms
# (25,40) without nowait, repeatable read:
# 6596 ms, 0 lockfails
# 6646 ms, 0 lockfails
# 6692
# postgres, nowait, serializable: LOTS of failures
#
# mysql
# straight run of 1000: 7959 ms, 7760 ms, 7641 ms
# (25,40), repeatable read: 3338 ms, 3529 ms 0 lockfails
# read committed: 3320, 3400 ms
# serializable: 3309 ms, 3667 ms
