Sitemap generator
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

132 lines
4.6 KiB

4 years ago
  1. import inspect
  2. from collections import OrderedDict
  3. import json
  4. from aiohttp.http_exceptions import HttpBadRequest
  5. from aiohttp.web_exceptions import HTTPMethodNotAllowed
  6. from aiohttp.web_request import Request
  7. from aiohttp.web_response import Response
  8. from aiohttp.web_routedef import UrlDispatcher
  9. from db import session, get_or_create
  10. from models import Domain, DomainGroup
  11. DEFAULT_METHODS = ('GET', 'POST', 'PUT', 'DELETE')
  12. class RestEndpoint:
  13. def __init__(self):
  14. self.methods = {}
  15. for method_name in DEFAULT_METHODS:
  16. method = getattr(self, method_name.lower(), None)
  17. if method:
  18. self.register_method(method_name, method)
  19. def register_method(self, method_name, method):
  20. self.methods[method_name.upper()] = method
  21. async def dispatch(self, request: Request):
  22. method = self.methods.get(request.method.upper())
  23. if not method:
  24. raise HTTPMethodNotAllowed('', DEFAULT_METHODS)
  25. wanted_args = list(inspect.signature(method).parameters.keys())
  26. available_args = request.match_info.copy()
  27. available_args.update({'request': request})
  28. unsatisfied_args = set(wanted_args) - set(available_args.keys())
  29. if unsatisfied_args:
  30. # Expected match info that doesn't exist
  31. raise HttpBadRequest('')
  32. return await method(**{arg_name: available_args[arg_name] for arg_name in wanted_args})
  33. class DomainEndpoint(RestEndpoint):
  34. def __init__(self, resource):
  35. super().__init__()
  36. self.resource = resource
  37. async def get(self) -> Response:
  38. data = []
  39. domains = session.query(Domain).all()
  40. for instance in self.resource.collection.values():
  41. data.append(self.resource.render(instance))
  42. return Response(status=200, body=self.resource.encode({
  43. 'domains': [
  44. {
  45. 'id': domain.id, 'title': domain.domain,
  46. 'groups': [{'id': group.id, 'name': group.name} for group in domain.groups]
  47. } for domain in session.query(Domain).all()
  48. ]}), content_type='application/json')
  49. async def post(self, request):
  50. data = await request.json()
  51. domain, _created = get_or_create(Domain, domain=data['domain'])
  52. return Response(status=200, body=self.resource.encode(
  53. {'id': domain.id, 'domain': domain.domain, 'created': _created},
  54. ), content_type='application/json')
  55. class DomainGroupEndpoint(RestEndpoint):
  56. def __init__(self, resource):
  57. super().__init__()
  58. self.resource = resource
  59. async def get(self) -> Response:
  60. return Response(status=200, body=self.resource.encode({
  61. 'domain_groups': [
  62. {
  63. 'id': group.id, 'name': group.name,
  64. 'domains': [{'id': domain.id, 'name': domain.domain} for domain in group.domains]
  65. } for group in session.query(DomainGroup).all()
  66. ]}), content_type='application/json')
  67. async def post(self, request):
  68. data = await request.json()
  69. group, _created = get_or_create(session, DomainGroup, name=data['name'])
  70. domains = []
  71. if data.get('domains'):
  72. for domain_el in data.get('domains'):
  73. domain, _domain_created = get_or_create(session, Domain, domain=domain_el)
  74. domains.append({'id': domain.id, 'domain': domain_el, 'created': _domain_created})
  75. return Response(
  76. status=200,
  77. body=self.resource.encode({
  78. 'id': group.id,
  79. 'name': group.name,
  80. 'domains': domains,
  81. 'created': _created
  82. }), content_type='application/json')
  83. class RestResource:
  84. def __init__(self, notes, factory, collection, properties, id_field):
  85. self.notes = notes
  86. self.factory = factory
  87. self.collection = collection
  88. self.properties = properties
  89. self.id_field = id_field
  90. self.domain_endpoint = DomainEndpoint(self)
  91. self.domain_groups_endpoint = DomainGroupEndpoint(self)
  92. def register(self, router: UrlDispatcher):
  93. router.add_route('*', '/{domains}'.format(notes=self.notes), self.domain_endpoint.dispatch)
  94. router.add_route('*', '/{domain_groups}'.format(notes=self.notes), self.domain_groups_endpoint.dispatch)
  95. def render(self, instance):
  96. return OrderedDict((notes, getattr(instance, notes)) for notes in self.properties)
  97. @staticmethod
  98. def encode(data):
  99. return json.dumps(data, indent=4).encode('utf-8')
  100. def render_and_encode(self, instance):
  101. return self.encode(self.render(instance))