diff --git a/ameliabot/owncast.py b/ameliabot/owncast.py index e632131..5449649 100644 --- a/ameliabot/owncast.py +++ b/ameliabot/owncast.py @@ -60,8 +60,8 @@ def get_quote(num): pass except ValueError: pass - q = quote.get(num) - return html.escape(q) + return quote.get(num) + def process_chat(data): diff --git a/ameliabot/quote.py b/ameliabot/quote.py index a103199..812f661 100644 --- a/ameliabot/quote.py +++ b/ameliabot/quote.py @@ -1,4 +1,4 @@ -from datetime import datetime +from datetime import datetime, timezone import random import sqlite3 from ameliabot.logger import logging @@ -6,72 +6,62 @@ from ameliabot.logger import logging class Quote: def __init__(self): - # TODO: Generalise the db connection so that other parts can use it - self.conn = "" self.__init_table() self.num_quotes = self._get_num_quotes() logging.info("Quote subsystem online") def insert(self, owner, submitter, text): + conn = self.__connect() text = text.replace("'", "''") - self.conn.execute(''' + conn.execute(''' INSERT INTO quotes (submitter, text, timestamp) - VALUES ('{}', E'{}', {})'''.format( + VALUES ('{}', '{}', {})'''.format( submitter, text, datetime.now().replace(tzinfo=timezone.utc))) self.num_quotes += 1 logging.debug("Quote number %s inserted" % self.num_quotes) def get(self, arg=None): + conn = self.__connect() if arg: - cur = self.conn.cursor() - ret = "SELECT number, text, timestamp FROM quotes WHERE " + ret = "SELECT id, text, timestamp FROM quotes WHERE " try: - arg = int(arg) - if arg > self.num_quotes: + if int(arg) > self.num_quotes: return "No quote matching that number." - ret += "number = {}".format(arg) + ret += "id = %s" % arg except ValueError: ret += "text like '%{}%'".format(arg.lower()) - cur.execute(ret) - quotes = cur.fetchall() - cur.close() - - return self._format(quotes) + quote = conn.execute(ret) + return self._format(list(quote.fetchone())) else: return self.get(random.randint(0, self.num_quotes)) def _get_num_quotes(self): - cur = self.conn.cursor() + conn = self.__connect() try: - cur.execute("SELECT COUNT(id) FROM quotes") + num = conn.execute("SELECT COUNT(id) FROM quotes") except sqlite3.OperationalError: return 0 - return cur.fetchone()[0] + return int(num.fetchone()[0]) - def _format(self, quotes): - if len(quotes) >= 1: - data = quotes[random.randint(0, len(quotes)-1)] - else: - try: - data = quotes[0] - except IndexError: - return "No quote." + def _format(self, quote): + num, text, timestamp = quote + timestamp = datetime.strptime(timestamp, '%Y-%m-%d %H:%M:%S.%f%z') return "{}. {}, {}".format( - data[0], data[1], datetime.strftime(data[2], '%Y')) + num, text, datetime.strftime(timestamp, '%Y')) - def __connect(self): - self.conn = sqlite3.connect("data/quote.db") + def connect(self): + return sqlite3.connect("data/quote.db") def __init_table(self): try: - self.__connect() + conn = self.__connect() except sqlite3.OperationalError: import os os.makedirs("data") - self.__connect() + self.__init_table() - self.conn.execute(''' + conn.execute(''' CREATE TABLE IF NOT EXISTS quotes ( id INTEGER PRIMARY KEY, submitter TEXT, text TEXT, timestamp TEXT diff --git a/bot.py b/bot.py index a894253..dc90b0c 100644 --- a/bot.py +++ b/bot.py @@ -17,3 +17,4 @@ app = Flask(__name__) def respond(): owncast.parse_webhook() return Response(status=200) +