import sqlite3
from decimal import *
from package.Database.DbInterface import DbInterface


class DbSqlite(DbInterface):
    def __init__(self, config):
        self.db = None
        self.config = config
        self.initDatabase()

    def __del__(self):
        cursor = self.db.cursor()
        cursor.close()

    def initDatabase(self):
        try:
            sqlite3.register_adapter(Decimal, self.adaptDecimal)
            sqlite3.register_converter("Real", self.convertDecimal)
            self.db = sqlite3.connect(self.config["File"], check_same_thread=False, detect_types=sqlite3.PARSE_DECLTYPES)
            self.db.row_factory = self.dictFactory

            # Set journal mode to WAL
            if self.config.getboolean("JournalModeWal"):
                self.db.execute('pragma journal_mode=wal')
        except sqlite3.Error as e:
            print(e)

    def insert(self, table, data):
        separator = ","
        query = "INSERT INTO " + table + " ("

        dataFields = []
        dataPlaceholder = []
        sqlliteData = []
        for i in data:
            dataFields.append(i)
            dataPlaceholder.append("?")
            sqlliteData.append(data[i])

        query += separator.join(dataFields)
        query += ") VALUES ("
        query += separator.join(dataPlaceholder)
        query += ")"

        cursor = self.db.cursor()
        cursor.execute(query, sqlliteData)
        self.db.commit()
        return cursor.lastrowid

    def insertMany(self, table, data):
        columns = list(data[0].keys())
        columnList = ','.join(columns)
        placeholderList = ['?'] * len(columns)
        allData = [sublist[item] for sublist in data for item in sublist]

        placeholderString = "(" + ",".join(placeholderList) + ")"

        query = "INSERT INTO " + table + " (" + columnList + ") VALUES "
        query = query + ",".join(placeholderString for _ in data)

        cursor = self.db.cursor()
        cursor.execute(query, allData)
        self.db.commit()

    def update(self, table, data, where):
        if where is None:
            where = {}

        sqlliteData = []
        dataPlaceholder = []
        whereSeparator = " AND "
        dataSeparator = ","

        query = "UPDATE " + table + " SET "

        # prepare data fields
        dataFields = []
        for i in data:
            tmp = i + " = ?"
            dataFields.append(tmp)
            dataPlaceholder.append("?")
            sqlliteData.append(data[i])

        query += dataSeparator.join(dataFields)

        # prepare where data
        whereFields = []
        for i in where:
            tmp = i + " = ?"
            whereFields.append(tmp)
            sqlliteData.append(where[i])

        if len(whereFields) > 0:
            query += " WHERE " + whereSeparator.join(whereFields)

        # execute query
        cursor = self.db.cursor()
        cursor.execute(query, sqlliteData)
        self.db.commit()

    def delete(self, table, where={}):
        sqlliteData = []
        whereSeparator = " AND "
        query = "DELETE FROM " + table

        whereFields = []
        for i in where:
            tmp = i + " = ?"
            whereFields.append(tmp)
            sqlliteData.append(where[i])

        if len(whereFields) > 0:
            query += " WHERE " + whereSeparator.join(whereFields)

        cursor = self.db.cursor()
        cursor.execute(query, sqlliteData)
        self.db.commit()

    def select(self, table, fields, where={}, sort=None, limit=None):
        sqlliteData = []
        fieldSeperator = ", "
        whereSeperator = " AND "
        query = "SELECT "
        query += fieldSeperator.join(fields)
        query += " FROM " + table

        whereFields = []

        if len(where) > 0:
            query += " WHERE "

            for i in where:
                if isinstance(where[i], list):
                    format_strings = ','.join(['?'] * len(where[i]))
                    tmp = i + " IN (" + format_strings + ")"
                    sqlliteData += tuple(where[i])
                else:
                    if where[i] is None:
                        tmp = i + " IS NULL"
                    else:
                        tmp = i + " = ?"
                        sqlliteData.append(where[i])

                whereFields.append(tmp)

            query += whereSeperator.join(whereFields)

        if sort is not None:
            query += " ORDER BY " + sort[0] + " " + sort[1]

        if limit is not None:
            query += " LIMIT " + str(limit)

        cursor = self.db.cursor()
        cursor.execute(query, sqlliteData)
        return cursor.fetchall()

    def dictFactory(self, cursor, row):
        d = {}
        for idx, col in enumerate(cursor.description):
            d[col[0]] = row[idx]
        return d

    def selectPlain(self, query, where={}):
        query = query.replace("%s", "?")

        cursor = self.db.cursor()
        cursor.execute(query, where)
        return cursor.fetchall()

    def adaptDecimal(self, value):
        return str(value)

    def convertDecimal(self, value):
        return Decimal(value.decode('ascii'))

    def closeConnection(self):
        cursor = self.db.cursor()
        cursor.close()
        self.db.close()
