diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index f96d996..790e404 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -23,8 +23,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install pylint black ratelimit requests urllib3 python-squarelet python_dateutil pytest - + pip install -e ".[dev]" - name: Run Pylint on muckrock directory run: | pylint --disable=missing-function-docstring,too-many-instance-attributes,too-many-arguments,too-many-positional-arguments src/muckrock @@ -46,7 +45,6 @@ jobs: steps: - name: Check out code uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v4 with: @@ -54,7 +52,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -r requirements.txt + pip install -e ".[dev]" - name: Run tests run: | pytest src/muckrock/tests.py diff --git a/docs/changelog.rst b/docs/changelog.rst index c8220c5..b93a522 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -1,5 +1,8 @@ Changelog --------- +2.3.0 +~~~~~ +* Adds sane burst rate limits to endpoints. 2.2.0 ~~~~~ diff --git a/docs/conf.py b/docs/conf.py index 94201aa..2698b06 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -55,9 +55,9 @@ # built documents. # # The short X.Y version. -version = "2.2" +version = "2.3" # The full version, including alpha/beta/rc tags. -release = "2.2.0" +release = "2.3.0" # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/pyproject.toml b/pyproject.toml index 4ad49cb..bbd4ce4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,14 +3,14 @@ requires = ["hatchling"] build-backend = "hatchling.build" [project] -name = "python-muckrock" -version = "2.2.0" +name = "python-muckrock" +version = "2.3.0" authors = [ { name="duckduckgrayduck", email="sanjin@muckrock.com" }, ] description = "A simple Python wrapper for the MuckRock API v2" readme = "README.md" -requires-python = ">=3.7,<=3.12" +requires-python = ">=3.7,<3.13" classifiers = [ "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", @@ -20,13 +20,22 @@ dependencies = [ "requests", "ratelimit", "urllib3", - "python-squarelet" + "python-squarelet", + "token-bucket", + "python-dateutil", +] + +[project.optional-dependencies] +dev = [ + "pylint", + "black", + "pytest", ] [project.urls] "Homepage" = "https://github.com/MuckRock/python-muckrock" "Bug Tracker" = "https://github.com/MuckRock/python-muckrock/issues" -"Documentation" ="https://python-muckrock.readthedocs.io/en/latest/" +"Documentation" = "https://python-muckrock.readthedocs.io/en/latest/" [tool.hatch.build.targets.wheel] -packages = ["src/muckrock"] +packages = ["src/muckrock"] \ No newline at end of file diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index e1456c8..0000000 --- a/requirements.txt +++ /dev/null @@ -1,4 +0,0 @@ -pytest==7.4.2 -python_dateutil==2.9.0.post0 -urllib3==1.26.5 -python-squarelet \ No newline at end of file diff --git a/src/muckrock/client.py b/src/muckrock/client.py index 62f62a6..be502af 100644 --- a/src/muckrock/client.py +++ b/src/muckrock/client.py @@ -1,9 +1,13 @@ -""" Provides the client wrapper with Squarelet """ +""" +Provides the client wrapper with Squarelet +""" # Standard Library import logging +import time # Third Party +import token_bucket from squarelet import SquareletClient from .agencies import AgencyClient @@ -12,19 +16,44 @@ from .jurisdictions import JurisdictionClient from .organizations import OrganizationClient from .projects import ProjectClient + # Local Imports from .requests import RequestClient from .users import UserClient logger = logging.getLogger("muckrock") +# Per-endpoint rate limits. +# Format: (url_pattern, rate_per_second, capacity) +# +# Endpoint Rate Burst Notes +# -------- ---- ----- ----- +# requests/ 15/min 100 +# communications/ 15/min 100 +# agencies/ 15/min 100 +# files/ 15/min 100 +# jurisdictions/ 15/min 100 +# projects/ 15/min 100 +# organizations/ 5/min 5 Heavy rate limit, minimal burst +# users/ 5/min 5 Heavy rate limit, minimal burst +ENDPOINT_RATE_LIMITS = [ + ("organizations", 5 / 60, 5), + ("users", 5 / 60, 5), + ("requests", 15 / 60, 100), + ("communications", 15 / 60, 100), + ("agencies", 15 / 60, 100), + ("files", 15 / 60, 100), + ("jurisdictions", 15 / 60, 100), + ("projects", 15 / 60, 100), +] + -class MuckRock(SquareletClient): +class MuckRock(SquareletClient): # pylint:disable=too-many-instance-attributes """ The public interface for the MuckRock API, now integrated with SquareletClient """ - # pylint:disable=too-many-positional-arguments + # pylint:disable=too-many-positional-arguments, too-many-arguments def __init__( self, username=None, @@ -56,6 +85,17 @@ def __init__( else: logger.addHandler(logging.NullHandler()) + # Build per-endpoint token bucket rate limiters + storage = token_bucket.MemoryStorage() + self._endpoint_limiters = [ + ( + pattern, + token_bucket.Limiter(rate=rate, capacity=capacity, storage=storage), + pattern, + ) + for pattern, rate, capacity in ENDPOINT_RATE_LIMITS + ] + self.requests = RequestClient(self) self.jurisdictions = JurisdictionClient(self) self.agencies = AgencyClient(self) @@ -64,3 +104,13 @@ def __init__( self.organizations = OrganizationClient(self) self.users = UserClient(self) self.projects = ProjectClient(self) + + def request(self, method, url, raise_error=True, **kwargs): + for pattern, limiter, bucket_key in self._endpoint_limiters: + if pattern in url: + if not limiter.consume(bucket_key): + logger.warning("Rate limit reached for %s, throttling...", pattern) + while not limiter.consume(bucket_key): + time.sleep(0.1) + return super().request(method, url, raise_error=raise_error, **kwargs) + return super().request(method, url, raise_error=raise_error, **kwargs) diff --git a/src/muckrock/tests.py b/src/muckrock/tests.py index 6658149..5f18189 100644 --- a/src/muckrock/tests.py +++ b/src/muckrock/tests.py @@ -1,186 +1,163 @@ -""" Tests the functions for this Python wrapper of the v2 of MuckRock API """ +"""Tests the functions for this Python wrapper of the v2 of MuckRock API""" +# Standard Library import os +import time +# Third Party import pytest -from squarelet.exceptions import DoesNotExistError from muckrock import MuckRock - # pylint:disable=redefined-outer-name @pytest.fixture -def muckrock_client(): +def client(): """Fixture to create a MuckRock client instance.""" - mr_user = os.environ.get("MR_USER") - mr_password = os.environ.get("MR_PASSWORD") - if not mr_user or not mr_password: - pytest.skip("MR_USER and MR_PASSWORD environment variables are required.") - return MuckRock( - username=mr_user, - password=mr_password, - ) - - -@pytest.fixture -def regular_user_client(): - """Fixture to create a MuckRock client with base permissions.""" - reg_user = os.environ.get("REG_USER") - reg_password = os.environ.get("REG_PASSWORD") - if not reg_user or not reg_password: + username = os.environ.get("REG_USER") + password = os.environ.get("REG_PASSWORD") + if not username or not password: pytest.skip("REG_USER and REG_PASSWORD environment variables are required.") - return MuckRock( - username=reg_user, - password=reg_password, - ) + return MuckRock(username=username, password=password) -def test_list_agencies(muckrock_client): - agencies = muckrock_client.agencies.list() +def test_list_agencies(client): + """ Test that listing agencies returns a non empty list """ + agencies = client.agencies.list() assert agencies, "Expected a non-empty list of agencies." - print(agencies) -def test_retrieve_agencies(muckrock_client): - agency_id = 1 - agency = muckrock_client.agencies.retrieve(agency_id) - assert agency.id == agency_id, f"Expected agency ID to be {agency_id}." - print(agency) +def test_retrieve_agency(client): + """ Test that retrieving agency with ID 1 works""" + agency = client.agencies.retrieve(1) + assert agency.id == 1 -def test_list_communications(muckrock_client): - communications = muckrock_client.communications.list() +def test_list_communications(client): + communications = client.communications.list() assert communications, "Expected a non-empty list of communications." - print(communications) -def test_retrieve_communications(muckrock_client): - communication_id = 1 - communication = muckrock_client.communications.retrieve(communication_id) - assert ( - communication.id == communication_id - ), f"Expected communication ID to be {communication_id}." - print(communication) +def test_retrieve_communication(client): + communication = client.communications.retrieve(1) + assert communication.id == 1 -def test_list_files(muckrock_client): - files = muckrock_client.files.list() +def test_list_files(client): + files = client.files.list() assert files, "Expected a non-empty list of files." - print(files) -def test_retrieve_files(muckrock_client): - file_id = 1 - file = muckrock_client.files.retrieve(file_id) - assert file.id == file_id, f"Expected file ID to be {file_id}." - print(file) +def test_retrieve_file(client): + file = client.files.retrieve(1) + assert file.id == 1 -def test_list_jurisdictions(muckrock_client): - jurisdictions = muckrock_client.jurisdictions.list() +def test_list_jurisdictions(client): + jurisdictions = client.jurisdictions.list() assert jurisdictions, "Expected a non-empty list of jurisdictions." - print(jurisdictions) - - -def test_retrieve_jurisdictions(muckrock_client): - jurisdiction_id = 1 - jurisdiction = muckrock_client.jurisdictions.retrieve(jurisdiction_id) - assert ( - jurisdiction.id == jurisdiction_id - ), f"Expected jurisdiction ID to be {jurisdiction_id}." - - -def test_list_organizations(muckrock_client): - organizations = muckrock_client.organizations.list() - orgs_list = organizations.results - assert len(orgs_list) > 5 -def test_list_organizations_nonstaff(regular_user_client): - organizations = regular_user_client.organizations.list() - orgs_list = organizations.results - assert len(orgs_list) == 1 # This test user is only part of one org +def test_retrieve_jurisdiction(client): + jurisdiction = client.jurisdictions.retrieve(1) + assert jurisdiction.id == 1 -def test_retrieve_organizations(muckrock_client): - organization_id = 1 - organization = muckrock_client.organizations.retrieve(organization_id) - assert ( - organization.id == organization_id - ), f"Expected organization ID to be {organization_id}." - print(organization) +def test_list_organizations(client): + organizations = client.organizations.list() + assert organizations.results, "Expected a non-empty list of organizations." -def test_retrieve_organizations_nonstaff(regular_user_client): - organization_id = 1 - with pytest.raises(DoesNotExistError): - regular_user_client.organizations.retrieve(organization_id) +def test_retrieve_organization(client): + organizations = client.organizations.list() + org_id = organizations.results[0].id + organization = client.organizations.retrieve(org_id) + assert organization.id == org_id -def test_list_requests(muckrock_client): - requests = muckrock_client.requests.list() +def test_list_requests(client): + requests = client.requests.list() assert requests, "Expected a non-empty list of requests." - print(requests) -def test_retrieve_requests(muckrock_client): - request_id = 17 - request = muckrock_client.requests.retrieve(request_id) - assert request.id == request_id, f"Expected request ID to be {request_id}." - print(request) +def test_retrieve_request(client): + request = client.requests.retrieve(17) + assert request.id == 17 -def test_retrieve_requests_nonstaff(regular_user_client): - request_id = 86429 - with pytest.raises(DoesNotExistError): - regular_user_client.requests.retrieve(request_id) +def test_list_users(client): + users = client.users.list() + assert users.results, "Expected a non-empty list of users." -def test_create_requests(muckrock_client): - new_request_data = { - "title": "Test FOIA Request", - "requested_docs": "This is a test FOIA request.", - "organization": 1, - "agencies": [248], # This is the ID of a test agency - } - new_request = muckrock_client.requests.create(**new_request_data) - assert "test-foia-request" in new_request - - -def test_list_users(muckrock_client): - users = muckrock_client.users.list() - user_list = users.results - assert ( - len(user_list) > 1 - ) # Expect a list of users greater than 1 as it is staff perms - - -def test_list_users_non_staff(regular_user_client): - users = regular_user_client.users.list() - user_list = users.results - assert len(user_list) == 1 +def test_retrieve_user(client): + users = client.users.list() + user_id = users.results[0].id + user = client.users.retrieve(user_id) + assert user.id == user_id -def test_retrieve_users(muckrock_client): - user_id = 1 - user = muckrock_client.users.retrieve(user_id) - assert user.id == user_id, f"Expected user ID to be {user_id}." +def test_list_projects(client): + projects = client.projects.list() + assert projects, "Expected a non-empty list of projects." -def test_retrieve_users_nonstaff(regular_user_client): - user_id = 1 - with pytest.raises(DoesNotExistError): - regular_user_client.users.retrieve(user_id) +def test_retrieve_project(client): + project = client.projects.retrieve(10) + assert project.id == 10 -def test_list_projects(muckrock_client): - projects = muckrock_client.projects.list() - assert projects, "Expected a non-empty list of communications." - print(projects) +def test_create_request(client): + new_request = client.requests.create( + title="Test FOIA Request", + requested_docs="This is a test FOIA request.", + organization=1, + agencies=[248], + ) + assert "test-foia-request" in new_request -def test_retrieve_projects(muckrock_client): - project_id = 10 - project = muckrock_client.projects.retrieve(project_id) - assert project.id == project_id, f"Expected request ID to be {project_id}." - print(project) +def test_rate_limit_tokens_consumed(client): + """Tokens should be consumed when making real API calls""" + # pylint:disable=protected-access + agencies_limiter = next( + lim for p, lim, _ in client._endpoint_limiters if p == "agencies" + ) + # consume all but one token + for _ in range(99): + agencies_limiter.consume("agencies") + # make a real API call - should consume the last token + client.agencies.list() + # bucket should now be empty + assert not agencies_limiter.consume("agencies") + + +def test_rate_limit_throttles_after_burst(client): + """Client should throttle after burst capacity is exhausted""" + # pylint:disable=protected-access + agencies_limiter = next( + lim for p, lim, _ in client._endpoint_limiters if p == "agencies" + ) + # exhaust the bucket + for _ in range(100): + agencies_limiter.consume("agencies") + + start = time.time() + client.agencies.list() # this should throttle + elapsed = time.time() - start + # should have waited for at least one token to refill (1/rate = 4 seconds) + assert elapsed >= 4 + + +def test_rate_limit_tokens_consumed_users(client): + """Tokens should be consumed when making real API calls to users endpoint""" + # pylint:disable=protected-access + users_limiter = next( + lim for p, lim, _ in client._endpoint_limiters if p == "users" + ) + # consume all but one token + for _ in range(4): + users_limiter.consume("users") + # make a real API call - should consume the last token + client.users.list() + # bucket should now be empty + assert not users_limiter.consume("users")