import inspect from collections import OrderedDict import json from aiohttp.http_exceptions import HttpBadRequest from aiohttp.web_exceptions import HTTPMethodNotAllowed from aiohttp.web_request import Request from aiohttp.web_response import Response from aiohttp.web_routedef import UrlDispatcher from db import session, get_or_create from models import Domain, DomainGroup DEFAULT_METHODS = ('GET', 'POST', 'PUT', 'DELETE') class RestEndpoint: def __init__(self): self.methods = {} for method_name in DEFAULT_METHODS: method = getattr(self, method_name.lower(), None) if method: self.register_method(method_name, method) def register_method(self, method_name, method): self.methods[method_name.upper()] = method async def dispatch(self, request: Request): method = self.methods.get(request.method.upper()) if not method: raise HTTPMethodNotAllowed('', DEFAULT_METHODS) wanted_args = list(inspect.signature(method).parameters.keys()) available_args = request.match_info.copy() available_args.update({'request': request}) unsatisfied_args = set(wanted_args) - set(available_args.keys()) if unsatisfied_args: # Expected match info that doesn't exist raise HttpBadRequest('') return await method(**{arg_name: available_args[arg_name] for arg_name in wanted_args}) class DomainEndpoint(RestEndpoint): def __init__(self, resource): super().__init__() self.resource = resource async def get(self) -> Response: data = [] domains = session.query(Domain).all() for instance in self.resource.collection.values(): data.append(self.resource.render(instance)) return Response(status=200, body=self.resource.encode({ 'domains': [ { 'id': domain.id, 'title': domain.domain, 'groups': [{'id': group.id, 'name': group.name} for group in domain.groups] } for domain in session.query(Domain).all() ]}), content_type='application/json') async def post(self, request): data = await request.json() domain, _created = get_or_create(Domain, domain=data['domain']) return Response(status=200, body=self.resource.encode( {'id': domain.id, 'domain': domain.domain, 'created': _created}, ), content_type='application/json') class DomainGroupEndpoint(RestEndpoint): def __init__(self, resource): super().__init__() self.resource = resource async def get(self) -> Response: return Response(status=200, body=self.resource.encode({ 'domain_groups': [ { 'id': group.id, 'name': group.name, 'domains': [{'id': domain.id, 'name': domain.domain} for domain in group.domains] } for group in session.query(DomainGroup).all() ]}), content_type='application/json') async def post(self, request): data = await request.json() group, _created = get_or_create(session, DomainGroup, name=data['name']) domains = [] if data.get('domains'): for domain_el in data.get('domains'): domain, _domain_created = get_or_create(session, Domain, domain=domain_el) domains.append({'id': domain.id, 'domain': domain_el, 'created': _domain_created}) return Response( status=200, body=self.resource.encode({ 'id': group.id, 'name': group.name, 'domains': domains, 'created': _created }), content_type='application/json') class RestResource: def __init__(self, notes, factory, collection, properties, id_field): self.notes = notes self.factory = factory self.collection = collection self.properties = properties self.id_field = id_field self.domain_endpoint = DomainEndpoint(self) self.domain_groups_endpoint = DomainGroupEndpoint(self) def register(self, router: UrlDispatcher): router.add_route('*', '/{domains}'.format(notes=self.notes), self.domain_endpoint.dispatch) router.add_route('*', '/{domain_groups}'.format(notes=self.notes), self.domain_groups_endpoint.dispatch) def render(self, instance): return OrderedDict((notes, getattr(instance, notes)) for notes in self.properties) @staticmethod def encode(data): return json.dumps(data, indent=4).encode('utf-8') def render_and_encode(self, instance): return self.encode(self.render(instance))