diff --git a/documentcloud/addons/tests/test_views.py b/documentcloud/addons/tests/test_views.py index 21a285f4..da6cf4a9 100644 --- a/documentcloud/addons/tests/test_views.py +++ b/documentcloud/addons/tests/test_views.py @@ -309,6 +309,44 @@ def test_filter_site_absent_is_noop(self, client): assert response.status_code == status.HTTP_200_OK assert len(response.json()["results"]) == 3 + def test_filter_domain(self, client): + """Filter runs by the host of the event's parameters.site""" + user = UserFactory() + matching_event = AddOnEventFactory( + user=user, + parameters={ + "site": "https://www.nifc.gov/fire-information/statistics/wildfires", + "selector": "*", + }, + ) + other_event = AddOnEventFactory( + user=user, parameters={"site": "https://www.other.com/path"} + ) + no_site_event = AddOnEventFactory(user=user, parameters={"selector": "*"}) + matching_run = AddOnRunFactory(user=user, event=matching_event) + AddOnRunFactory(user=user, event=other_event) + AddOnRunFactory(user=user, event=no_site_event) + AddOnRunFactory(user=user, event=None) + client.force_authenticate(user=user) + # bare host and full origin both match, regardless of path + for domain in ("www.nifc.gov", "https://www.nifc.gov"): + response = client.get("/api/addon_runs/", {"domain": domain}) + assert response.status_code == status.HTTP_200_OK + uuids = [r["uuid"] for r in response.json()["results"]] + assert uuids == [str(matching_run.uuid)], domain + + def test_filter_domain_no_partial_host_match(self, client): + """The domain filter matches whole hosts, not substrings""" + user = UserFactory() + event = AddOnEventFactory( + user=user, parameters={"site": "https://www.nifc.gov.evil.com/path"} + ) + AddOnRunFactory(user=user, event=event) + client.force_authenticate(user=user) + response = client.get("/api/addon_runs/", {"domain": "www.nifc.gov"}) + assert response.status_code == status.HTTP_200_OK + assert response.json()["results"] == [] + @pytest.mark.django_db() class TestAddOnEventAPI: @@ -351,6 +389,25 @@ def test_filter_site_no_match(self, client): assert response.status_code == status.HTTP_200_OK assert response.json()["results"] == [] + def test_filter_domain(self, client): + """Filter events by the host of their parameters.site""" + user = UserFactory() + matching = AddOnEventFactory( + user=user, + parameters={ + "site": "https://www.nifc.gov/fire-information/statistics/wildfires", + "selector": "*", + }, + ) + AddOnEventFactory(user=user, parameters={"site": "https://www.other.com/path"}) + AddOnEventFactory(user=user, parameters={"selector": "*"}) + client.force_authenticate(user=user) + for domain in ("www.nifc.gov", "https://www.nifc.gov"): + response = client.get("/api/addon_events/", {"domain": domain}) + assert response.status_code == status.HTTP_200_OK + ids = [r["id"] for r in response.json()["results"]] + assert ids == [matching.pk], domain + def test_filter_message(self, client): """Filter runs by message""" user = UserFactory() diff --git a/documentcloud/addons/views.py b/documentcloud/addons/views.py index acd03fd0..3e643b4f 100644 --- a/documentcloud/addons/views.py +++ b/documentcloud/addons/views.py @@ -30,6 +30,7 @@ import hmac import json import logging +import re from collections import defaultdict from datetime import timedelta from functools import lru_cache @@ -65,6 +66,21 @@ logger = logging.getLogger(__name__) +def domain_site_regex(value): + """Build a regex matching `site` URLs whose host equals the given domain. + + Accepts either a bare host (`www.nifc.gov`) or a full URL + (`https://www.nifc.gov/path`) and matches stored site values regardless of + scheme, port, path or query string. Returns None if no host can be parsed. + """ + value = value.strip() + # furl only parses the host when an authority is present, so ensure one + host = furl(value if "//" in value else "//" + value).host + if not host: + return None + return r"^(https?://)?" + re.escape(host) + r"(:\d+)?($|[/?#])" + + class AddOnViewSet(viewsets.ModelViewSet): serializer_class = AddOnSerializer queryset = AddOn.objects.none() @@ -746,6 +762,14 @@ class Filter(django_filters.FilterSet): label="Site", help_text="Filter runs by the `site` value in the event's parameters.", ) + domain = django_filters.CharFilter( + method="domain_filter", + label="Domain", + help_text=( + "Filter runs by the host of the event's `site` parameter, e.g. " + "`www.nifc.gov` or `https://www.nifc.gov`." + ), + ) message = django_filters.CharFilter( field_name="message", lookup_expr="exact", @@ -753,6 +777,13 @@ class Filter(django_filters.FilterSet): help_text="Filter runs by their progress message.", ) + def domain_filter(self, queryset, name, value): + # pylint: disable=unused-argument + pattern = domain_site_regex(value) + if pattern is None: + return queryset.none() + return queryset.filter(event__parameters__site__iregex=pattern) + class Meta: model = AddOnRun fields = { @@ -988,6 +1019,21 @@ class Filter(django_filters.FilterSet): label="Site", help_text="Filter events by the `site` value in their parameters.", ) + domain = django_filters.CharFilter( + method="domain_filter", + label="Domain", + help_text=( + "Filter events by the host of their `site` parameter, e.g. " + "`www.nifc.gov` or `https://www.nifc.gov`." + ), + ) + + def domain_filter(self, queryset, name, value): + # pylint: disable=unused-argument + pattern = domain_site_regex(value) + if pattern is None: + return queryset.none() + return queryset.filter(parameters__site__iregex=pattern) class Meta: model = AddOnEvent