More cleanup and remove usused functions

This commit is contained in:
Raphaël Vinot 2016-08-29 23:23:29 +02:00
parent f5b2c32d03
commit 5cb010045f
2 changed files with 39 additions and 73 deletions

View file

@ -15,27 +15,35 @@
# Copyright (c) 2013 Alexandre Dulaunoy - a@foo.be # Copyright (c) 2013 Alexandre Dulaunoy - a@foo.be
import tornado.escape import tornado.escape
from tornado.ioloop import IOLoop
import tornado.web import tornado.web
import tornado.process import tornado.process
from tornado.ioloop import IOLoop
from tornado.concurrent import run_on_executor from tornado.concurrent import run_on_executor
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import ThreadPoolExecutor
import argparse import argparse
import sys import sys
import signal import signal
from ipaddress import ip_address
from .query import Query from .query import QueryRecords
def handle_signal(sig, frame): def handle_signal(sig, frame):
IOLoop.instance().add_callback(IOLoop.instance().stop) IOLoop.instance().add_callback(IOLoop.instance().stop)
def is_ip(q):
try:
ip_address(q)
return True
except:
return False
class InfoHandler(tornado.web.RequestHandler): class InfoHandler(tornado.web.RequestHandler):
def get(self): def get(self):
response = {'version': 'git', response = {'version': 'git', 'software': 'pdns-qof-server'}
'software': 'pdns-qof-server'}
self.write(response) self.write(response)
@ -48,13 +56,11 @@ class QueryHandler(tornado.web.RequestHandler):
@run_on_executor @run_on_executor
def run_request(self, q): def run_request(self, q):
to_return = [] if is_ip(q):
if query.is_ip(q): q = query.getAssociatedRecords(q)
for x in query.getAssociatedRecords(q):
to_return.append(query.getRecord(x))
else: else:
to_return.append(query.getRecord(t=q.strip())) q = [q]
return to_return return [query.getRecord(x) for x in q]
@tornado.gen.coroutine @tornado.gen.coroutine
def get(self, q): def get(self, q):
@ -69,36 +75,6 @@ class QueryHandler(tornado.web.RequestHandler):
self.finish() self.finish()
class FullQueryHandler(tornado.web.RequestHandler):
# Default value in Python 3.5
# https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.ThreadPoolExecutor
nb_threads = tornado.process.cpu_count() * 5
executor = ThreadPoolExecutor(nb_threads)
@run_on_executor
def run_request(self, q):
to_return = []
if query.is_ip(q):
for x in query.getAssociatedRecords(q):
to_return.append(query.getRecord(x))
else:
for x in query.getAssociatedRecords(q):
to_return.append(query.getRecord(t=x.strip()))
return to_return
@tornado.gen.coroutine
def get(self, q):
print("fquery: " + q)
try:
responses = yield self.run_request(q)
for r in responses:
self.write(r)
except Exception as e:
print('Something went wrong with {}:\n{}'.format(q, e))
finally:
self.finish()
def main(): def main():
global query global query
signal.signal(signal.SIGINT, handle_signal) signal.signal(signal.SIGINT, handle_signal)
@ -111,6 +87,7 @@ def main():
argParser.add_argument('-rl', default='localhost', help='redis-server listen address (default localhost)') argParser.add_argument('-rl', default='localhost', help='redis-server listen address (default localhost)')
argParser.add_argument('-rd', default=0, help='redis-server database (default 0)') argParser.add_argument('-rd', default=0, help='redis-server database (default 0)')
args = argParser.parse_args() args = argParser.parse_args()
origin = args.o origin = args.o
port = args.p port = args.p
listen = args.l listen = args.l
@ -118,13 +95,10 @@ def main():
redis_listen = args.rl redis_listen = args.rl
redis_db = args.rd redis_db = args.rd
query = Query(redis_listen, redis_port, redis_db, origin) query = QueryRecords(redis_listen, redis_port, redis_db, origin)
application = tornado.web.Application([ application = tornado.web.Application([(r"/query/(.*)", QueryHandler),
(r"/query/(.*)", QueryHandler), (r"/info", InfoHandler)])
(r"/fquery/(.*)", FullQueryHandler),
(r"/info", InfoHandler)
])
application.listen(port, address=listen) application.listen(port, address=listen)
IOLoop.instance().start() IOLoop.instance().start()
@ -138,7 +112,7 @@ elif __name__ == "test":
qq = ["foo.be", "8.8.8.8"] qq = ["foo.be", "8.8.8.8"]
for q in qq: for q in qq:
if query.is_ip(q): if is_ip(q):
for x in query.getAssociatedRecords(q): for x in query.getAssociatedRecords(q):
print(query.getRecord(x)) print(query.getRecord(x))
else: else:

View file

@ -3,10 +3,9 @@
import json import json
import redis import redis
from ipaddress import ip_address
class Query(object): class QueryRecords(object):
def __init__(self, redis_listen, redis_port, redis_db, origin): def __init__(self, redis_listen, redis_port, redis_db, origin):
self.rrset = [ self.rrset = [
@ -91,32 +90,32 @@ class Query(object):
{"Reference": "[RFC4431]", "Type": "DLV", "Value": "32769", "Meaning": "DNSSEC Lookaside Validation", "Template": "", "Registration Date": ""}, {"Reference": "[RFC4431]", "Type": "DLV", "Value": "32769", "Meaning": "DNSSEC Lookaside Validation", "Template": "", "Registration Date": ""},
{"Reference": "", "Type": "Reserved", "Value": "65535", "Meaning": "", "Template": "", "Registration Date": ""}] {"Reference": "", "Type": "Reserved", "Value": "65535", "Meaning": "", "Template": "", "Registration Date": ""}]
self.rrset_supported = ['1', '2', '5', '15', '28', '33'] self.rrset_supported = ['1', '2', '5', '15', '28', '33']
self.r = redis.StrictRedis(host=redis_listen, port=redis_port, db=redis_db) self.r = redis.StrictRedis(host=redis_listen, port=redis_port, db=redis_db, decode_responses=True)
self.origin = origin self.origin = origin
def getFirstSeen(self, t1=None, t2=None): def _getFirstSeen(self, t1=None, t2=None):
if t1 is None or t2 is None: if t1 is None or t2 is None:
return False return False
rec = "s:" + t1.lower() + ":" + t2.lower() rec = "s:" + t1.lower() + ":" + t2.lower()
recget = self.r.get(rec) recget = self.r.get(rec)
if recget is not None: if recget is not None:
return int(recget.decode(encoding='UTF-8')) return int(recget)
def getLastSeen(self, t1=None, t2=None): def _getLastSeen(self, t1=None, t2=None):
if t1 is None or t2 is None: if t1 is None or t2 is None:
return False return False
rec = "l:" + t1.lower() + ":" + t2.lower() rec = "l:" + t1.lower() + ":" + t2.lower()
recget = self.r.get(rec) recget = self.r.get(rec)
if recget is not None: if recget is not None:
return int(recget.decode(encoding='UTF-8')) return int(recget)
def getCount(self, t1=None, t2=None): def _getCount(self, t1=None, t2=None):
if t1 is None or t2 is None: if t1 is None or t2 is None:
return False return False
rec = "o:" + t1.lower() + ":" + t2.lower() rec = "o:" + t1.lower() + ":" + t2.lower()
recget = self.r.get(rec) recget = self.r.get(rec)
if recget is not None: if recget is not None:
return int(recget.decode(encoding='UTF-8')) return int(recget)
def getRecord(self, t=None): def getRecord(self, t=None):
if t is None: if t is None:
@ -129,19 +128,19 @@ class Query(object):
if rs: if rs:
for v in rs: for v in rs:
rrval = {} rrval = {}
rdata = v.decode(encoding='UTF-8').strip() rdata = v.strip()
rrval['time_first'] = self.getFirstSeen(t1=t, t2=rdata) rrval['time_first'] = self._getFirstSeen(t1=t, t2=rdata)
rrval['time_last'] = self.getLastSeen(t1=t, t2=rdata) rrval['time_last'] = self._getLastSeen(t1=t, t2=rdata)
if rrval['time_first'] is None: if rrval['time_first'] is None:
break break
rrval['count'] = self.getCount(t1=t, t2=rdata) rrval['count'] = self._getCount(t1=t, t2=rdata)
rrval['rrtype'] = rr['Type'] rrval['rrtype'] = rr['Type']
rrval['rrname'] = t rrval['rrname'] = t
rrval['rdata'] = rdata rrval['rdata'] = rdata
if self.origin: if self.origin:
rrval['origin'] = self.origin rrval['origin'] = self.origin
rrfound.append(rrval) rrfound.append(rrval)
return self.JsonQOF(rrfound) return self._JsonQOF(rrfound)
def getAssociatedRecords(self, rdata=None): def getAssociatedRecords(self, rdata=None):
if rdata is None: if rdata is None:
@ -150,30 +149,23 @@ class Query(object):
records = [] records = []
if self.r.smembers(rec): if self.r.smembers(rec):
for v in self.r.smembers(rec): for v in self.r.smembers(rec):
records.append(v.decode(encoding='UTF-8')) records.append(v)
return records return records
def RemDuplicate(self, d=None): def _RemDuplicate(self, d=None):
if d is None: if d is None:
return False return False
outd = [dict(t) for t in set([tuple(o.items()) for o in d])] outd = [dict(t) for t in set([tuple(o.items()) for o in d])]
return outd return outd
def JsonQOF(self, rrfound=None, RemoveDuplicate=True): def _JsonQOF(self, rrfound=None, RemoveDuplicate=True):
if rrfound is None: if rrfound is None:
return False return False
rrqof = "" rrqof = ""
if RemoveDuplicate: if RemoveDuplicate:
rrfound = self.RemDuplicate(d=rrfound) rrfound = self._RemDuplicate(d=rrfound)
for rr in rrfound: for rr in rrfound:
rrqof = rrqof + json.dumps(rr) + "\n" rrqof = rrqof + json.dumps(rr) + "\n"
return rrqof return rrqof
def is_ip(self, q):
try:
ip_address(q)
return True
except:
return False