diff --git a/qos_server/__init__.py b/qos_server/__init__.py index 9ab78dc..7a36884 100644 --- a/qos_server/__init__.py +++ b/qos_server/__init__.py @@ -15,27 +15,35 @@ # Copyright (c) 2013 Alexandre Dulaunoy - a@foo.be import tornado.escape -from tornado.ioloop import IOLoop import tornado.web import tornado.process +from tornado.ioloop import IOLoop from tornado.concurrent import run_on_executor -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import ThreadPoolExecutor import argparse import sys import signal +from ipaddress import ip_address -from .query import Query +from .query import QueryRecords def handle_signal(sig, frame): IOLoop.instance().add_callback(IOLoop.instance().stop) +def is_ip(q): + try: + ip_address(q) + return True + except: + return False + + class InfoHandler(tornado.web.RequestHandler): def get(self): - response = {'version': 'git', - 'software': 'pdns-qof-server'} + response = {'version': 'git', 'software': 'pdns-qof-server'} self.write(response) @@ -48,13 +56,11 @@ class QueryHandler(tornado.web.RequestHandler): @run_on_executor def run_request(self, q): - to_return = [] - if query.is_ip(q): - for x in query.getAssociatedRecords(q): - to_return.append(query.getRecord(x)) + if is_ip(q): + q = query.getAssociatedRecords(q) else: - to_return.append(query.getRecord(t=q.strip())) - return to_return + q = [q] + return [query.getRecord(x) for x in q] @tornado.gen.coroutine def get(self, q): @@ -69,36 +75,6 @@ class QueryHandler(tornado.web.RequestHandler): self.finish() -class FullQueryHandler(tornado.web.RequestHandler): - # Default value in Python 3.5 - # https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.ThreadPoolExecutor - nb_threads = tornado.process.cpu_count() * 5 - executor = ThreadPoolExecutor(nb_threads) - - @run_on_executor - def run_request(self, q): - to_return = [] - if query.is_ip(q): - for x in query.getAssociatedRecords(q): - to_return.append(query.getRecord(x)) - else: - for x in query.getAssociatedRecords(q): - to_return.append(query.getRecord(t=x.strip())) - return to_return - - @tornado.gen.coroutine - def get(self, q): - print("fquery: " + q) - try: - responses = yield self.run_request(q) - for r in responses: - self.write(r) - except Exception as e: - print('Something went wrong with {}:\n{}'.format(q, e)) - finally: - self.finish() - - def main(): global query signal.signal(signal.SIGINT, handle_signal) @@ -111,6 +87,7 @@ def main(): argParser.add_argument('-rl', default='localhost', help='redis-server listen address (default localhost)') argParser.add_argument('-rd', default=0, help='redis-server database (default 0)') args = argParser.parse_args() + origin = args.o port = args.p listen = args.l @@ -118,13 +95,10 @@ def main(): redis_listen = args.rl redis_db = args.rd - query = Query(redis_listen, redis_port, redis_db, origin) + query = QueryRecords(redis_listen, redis_port, redis_db, origin) - application = tornado.web.Application([ - (r"/query/(.*)", QueryHandler), - (r"/fquery/(.*)", FullQueryHandler), - (r"/info", InfoHandler) - ]) + application = tornado.web.Application([(r"/query/(.*)", QueryHandler), + (r"/info", InfoHandler)]) application.listen(port, address=listen) IOLoop.instance().start() @@ -138,7 +112,7 @@ elif __name__ == "test": qq = ["foo.be", "8.8.8.8"] for q in qq: - if query.is_ip(q): + if is_ip(q): for x in query.getAssociatedRecords(q): print(query.getRecord(x)) else: diff --git a/qos_server/query.py b/qos_server/query.py index 0f671e2..b4d371f 100644 --- a/qos_server/query.py +++ b/qos_server/query.py @@ -3,10 +3,9 @@ import json import redis -from ipaddress import ip_address -class Query(object): +class QueryRecords(object): def __init__(self, redis_listen, redis_port, redis_db, origin): self.rrset = [ @@ -91,32 +90,32 @@ class Query(object): {"Reference": "[RFC4431]", "Type": "DLV", "Value": "32769", "Meaning": "DNSSEC Lookaside Validation", "Template": "", "Registration Date": ""}, {"Reference": "", "Type": "Reserved", "Value": "65535", "Meaning": "", "Template": "", "Registration Date": ""}] self.rrset_supported = ['1', '2', '5', '15', '28', '33'] - self.r = redis.StrictRedis(host=redis_listen, port=redis_port, db=redis_db) + self.r = redis.StrictRedis(host=redis_listen, port=redis_port, db=redis_db, decode_responses=True) self.origin = origin - def getFirstSeen(self, t1=None, t2=None): + def _getFirstSeen(self, t1=None, t2=None): if t1 is None or t2 is None: return False rec = "s:" + t1.lower() + ":" + t2.lower() recget = self.r.get(rec) if recget is not None: - return int(recget.decode(encoding='UTF-8')) + return int(recget) - def getLastSeen(self, t1=None, t2=None): + def _getLastSeen(self, t1=None, t2=None): if t1 is None or t2 is None: return False rec = "l:" + t1.lower() + ":" + t2.lower() recget = self.r.get(rec) if recget is not None: - return int(recget.decode(encoding='UTF-8')) + return int(recget) - def getCount(self, t1=None, t2=None): + def _getCount(self, t1=None, t2=None): if t1 is None or t2 is None: return False rec = "o:" + t1.lower() + ":" + t2.lower() recget = self.r.get(rec) if recget is not None: - return int(recget.decode(encoding='UTF-8')) + return int(recget) def getRecord(self, t=None): if t is None: @@ -129,19 +128,19 @@ class Query(object): if rs: for v in rs: rrval = {} - rdata = v.decode(encoding='UTF-8').strip() - rrval['time_first'] = self.getFirstSeen(t1=t, t2=rdata) - rrval['time_last'] = self.getLastSeen(t1=t, t2=rdata) + rdata = v.strip() + rrval['time_first'] = self._getFirstSeen(t1=t, t2=rdata) + rrval['time_last'] = self._getLastSeen(t1=t, t2=rdata) if rrval['time_first'] is None: break - rrval['count'] = self.getCount(t1=t, t2=rdata) + rrval['count'] = self._getCount(t1=t, t2=rdata) rrval['rrtype'] = rr['Type'] rrval['rrname'] = t rrval['rdata'] = rdata if self.origin: rrval['origin'] = self.origin rrfound.append(rrval) - return self.JsonQOF(rrfound) + return self._JsonQOF(rrfound) def getAssociatedRecords(self, rdata=None): if rdata is None: @@ -150,30 +149,23 @@ class Query(object): records = [] if self.r.smembers(rec): for v in self.r.smembers(rec): - records.append(v.decode(encoding='UTF-8')) + records.append(v) return records - def RemDuplicate(self, d=None): + def _RemDuplicate(self, d=None): if d is None: return False outd = [dict(t) for t in set([tuple(o.items()) for o in d])] return outd - def JsonQOF(self, rrfound=None, RemoveDuplicate=True): + def _JsonQOF(self, rrfound=None, RemoveDuplicate=True): if rrfound is None: return False rrqof = "" if RemoveDuplicate: - rrfound = self.RemDuplicate(d=rrfound) + rrfound = self._RemDuplicate(d=rrfound) for rr in rrfound: rrqof = rrqof + json.dumps(rr) + "\n" return rrqof - - def is_ip(self, q): - try: - ip_address(q) - return True - except: - return False