import mysql.connector as mysql
from .DbInterface import DbInterface


class DbMysql(DbInterface):
    def __init__(self, host, user, password, database, port):
        self.db = None
        self.host = host
        self.user = user
        self.password = password
        self.port = port
        self.databaseName = database
        self.initDatabase(host, user, password, database, port)

    def initDatabase(self, host, user, password, database, port):
        self.db = mysql.connect(
            host=host,
            user=user,
            passwd=password,
            database=database,
            port=port
        )

    def insert(self, table, data):
        separator = ","
        query = "INSERT INTO " + table + " SET "

        dataFields = []
        for i in data:
            tmp = i + " = %(" + i + ")s"
            dataFields.append(tmp)

        query += separator.join(dataFields)

        cursor = self.getCursor()
        cursor.execute(query, data)
        self.db.commit()
        return cursor.lastrowid

    def insertMany(self, table, data):
        columns = list(data[0].keys())
        columnList = ','.join(columns)
        placeholderList = ['%s'] * len(columns)
        allData = [sublist[item] for sublist in data for item in sublist]

        placeholderString = "(" + ",".join(placeholderList) + ")"

        query = "INSERT INTO " + table + " (" + columnList + ") VALUES "
        query = query + ",".join(placeholderString for _ in data)

        cursor = self.getCursor()
        cursor.execute(query, allData)
        self.db.commit()

    def update(self, table, data, where):
        if where is None:
            where = {}

        whereSeparator = " AND "
        dataSeparator = ","
        allParameters = {}

        query = "UPDATE " + table + " SET "

        # prepare data fields
        dataFields = []
        for i in data:
            tmp = i + " = %(" + i + ")s"
            dataFields.append(tmp)
            allParameters[i] = data[i]

        query += dataSeparator.join(dataFields)

        # prepare where data
        whereFields = []
        for i in where:
            tmp = i + " = %(" + i + ")s"
            whereFields.append(tmp)
            allParameters[i] = where[i]

        if len(whereFields) > 0:
            query += " WHERE " + whereSeparator.join(whereFields)

        # execute query
        cursor = self.getCursor()
        cursor.execute(query, allParameters)
        self.db.commit()

    def delete(self, table, where=None):
        if where is None:
            where = {}

        whereSeparator = " AND "
        query = "DELETE FROM " + table

        whereFields = []
        for i in where:
            tmp = i + " = %(" + i + ")s"
            whereFields.append(tmp)

        if len(whereFields) > 0:
            query += " WHERE " + whereSeparator.join(whereFields)

        cursor = self.getCursor()
        cursor.execute(query, where)
        self.db.commit()

    def select(self, table, fields, where, sort=None, limit=None):
        fieldSeperator = ", "
        whereSeperator = " AND "
        query = "SELECT "
        query += fieldSeperator.join(fields)
        query += " FROM " + table

        whereValues = []
        whereFields = []

        if len(where) > 0:
            query += " WHERE "

            for i in where:
                if isinstance(where[i], list):
                    format_strings = ','.join(['%s'] * len(where[i]))
                    tmp = i + " IN (" + format_strings + ")"
                    whereValues += tuple(where[i])
                else:
                    if where[i] is None:
                        tmp = i + " IS NULL"
                    else:
                        tmp = i + " = %s"
                        whereValues.append(where[i])

                whereFields.append(tmp)

            query += whereSeperator.join(whereFields)

        if sort is not None:
            query += " ORDER BY " + sort[0] + " " + sort[1]

        if limit is not None:
            query += " LIMIT " + str(limit)

        cursor = self.getCursor(dictionary=True)
        cursor.execute(query, whereValues)
        result = cursor.fetchall()
        self.db.commit()
        return result

    def selectPlain(self, query, where={}):
        query = query.replace("?", "%s")

        cursor = self.getCursor(dictionary=True)
        cursor.execute(query, where)
        result = cursor.fetchall()
        self.db.commit()
        return result

    def getCursor(self, dictionary=False):
        try:
            self.db.ping(reconnect=False, attempts=3, delay=5)
        except mysql.Error as err:
            self.initDatabase(self.host, self.user, self.password, self.databaseName, self.port)

        return self.db.cursor(dictionary=dictionary)

    def closeConnection(self):
        self.db.close()
