diff --git a/replicate/prediction.py b/replicate/prediction.py index b4ff047..70f57e4 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -55,7 +55,7 @@ class Prediction(Resource): version: str """An identifier for the version of the model used to create the prediction.""" - status: Literal["starting", "processing", "succeeded", "failed", "canceled"] + status: Literal["starting", "processing", "succeeded", "failed", "canceled", "aborted"] """The status of the prediction.""" input: Optional[Dict[str, Any]] @@ -141,7 +141,7 @@ def wait(self) -> None: Wait for prediction to finish. """ - while self.status not in ["succeeded", "failed", "canceled"]: + while self.status not in ["succeeded", "failed", "canceled", "aborted"]: time.sleep(self._client.poll_interval) self.reload() @@ -150,7 +150,7 @@ async def async_wait(self) -> None: Wait for prediction to finish asynchronously. """ - while self.status not in ["succeeded", "failed", "canceled"]: + while self.status not in ["succeeded", "failed", "canceled", "aborted"]: await asyncio.sleep(self._client.poll_interval) await self.async_reload() @@ -249,20 +249,39 @@ def output_iterator(self) -> Iterator[Any]: Return an iterator of the prediction output. """ - # TODO: check output is list - previous_output = self.output or [] - while self.status not in ["succeeded", "failed", "canceled"]: - output = self.output or [] + def _as_list(value: Any) -> list: + """Coerce output to a list. + + ``None`` means the model has not produced any output yet; treat it + as an empty list so the polling loop can start cleanly. Any other + non-list value (e.g. a plain string returned by a non-streaming + model) indicates a model whose output schema is not an array — in + that case we raise a ``ValueError`` rather than silently iterating + over the characters of a string or the keys of a dict. + """ + if value is None: + return [] + if isinstance(value, list): + return value + raise ValueError( + f"output_iterator requires an array output type, " + f"but the model returned a {type(value).__name__!r}. " + f"Use prediction.output directly for non-array outputs." + ) + + previous_output = _as_list(self.output) + while self.status not in ["succeeded", "failed", "canceled", "aborted"]: + output = _as_list(self.output) new_output = output[len(previous_output) :] yield from new_output previous_output = output time.sleep(self._client.poll_interval) # pylint: disable=no-member self.reload() - if self.status == "failed": + if self.status in ("failed", "aborted"): raise ModelError(self) - output = self.output or [] + output = _as_list(self.output) new_output = output[len(previous_output) :] yield from new_output @@ -271,10 +290,21 @@ async def async_output_iterator(self) -> AsyncIterator[Any]: Return an asynchronous iterator of the prediction output. """ - # TODO: check output is list - previous_output = self.output or [] - while self.status not in ["succeeded", "failed", "canceled"]: - output = self.output or [] + def _as_list(value: Any) -> list: + """Coerce output to a list (see sync variant for rationale).""" + if value is None: + return [] + if isinstance(value, list): + return value + raise ValueError( + f"async_output_iterator requires an array output type, " + f"but the model returned a {type(value).__name__!r}. " + f"Use prediction.output directly for non-array outputs." + ) + + previous_output = _as_list(self.output) + while self.status not in ["succeeded", "failed", "canceled", "aborted"]: + output = _as_list(self.output) new_output = output[len(previous_output) :] for item in new_output: yield item @@ -282,13 +312,13 @@ async def async_output_iterator(self) -> AsyncIterator[Any]: await asyncio.sleep(self._client.poll_interval) # pylint: disable=no-member await self.async_reload() - if self.status == "failed": + if self.status in ("failed", "aborted"): raise ModelError(self) - output = self.output or [] + output = _as_list(self.output) new_output = output[len(previous_output) :] - for output in new_output: - yield output + for item in new_output: + yield item class Predictions(Namespace): diff --git a/tests/test_prediction.py b/tests/test_prediction.py index b3c110e..e5205b8 100644 --- a/tests/test_prediction.py +++ b/tests/test_prediction.py @@ -3,6 +3,7 @@ import respx import replicate +from replicate.prediction import Prediction @pytest.mark.vcr("predictions-create.yaml") @@ -540,3 +541,90 @@ async def test_predictions_stream(async_flag): # assert progress.current == 5 # assert progress.total == 5 # assert progress.percentage == 1.0 + + +# --------------------------------------------------------------------------- +# Unit tests: output_iterator / async_output_iterator non-list guard +# --------------------------------------------------------------------------- + + +def _make_prediction(output, status="succeeded"): + """Build a minimal Prediction with a mock client (no HTTP calls needed).""" + p = Prediction( + id="p1", + model="owner/model", + version="v1", + urls={ + "get": "https://api.replicate.com/v1/predictions/p1", + "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", + }, + created_at="2024-01-01T00:00:00.000000Z", + source="api", + status=status, + input={"prompt": "hello"}, + output=output, + error=None, + logs="", + ) + return p + + +def test_output_iterator_completed_with_list_output_yields_nothing(): + """output_iterator yields only items arriving *after* the iterator starts. + + When called on an already-completed prediction, all output tokens were + present at start-time so no new items are yielded. This documents the + intended "streaming" contract: call output_iterator while the prediction + is still running, not after it has completed. + """ + p = _make_prediction(output=["token1", "token2", "token3"], status="succeeded") + # The full list is the "previous_output" baseline, so nothing is yielded. + assert list(p.output_iterator()) == [] + + +def test_output_iterator_none_output_yields_nothing(): + """output_iterator must handle None output gracefully (empty sequence).""" + p = _make_prediction(output=None) + assert list(p.output_iterator()) == [] + + +def test_output_iterator_string_output_raises(): + """output_iterator must raise ValueError when output is a plain string. + + Before the fix, ``self.output or []`` returned the string intact, causing + ``yield from "hello world"`` to silently iterate over individual characters + instead of raising a clear error. + """ + p = _make_prediction(output="hello world") + with pytest.raises(ValueError, match="array output type"): + list(p.output_iterator()) + + +def test_output_iterator_dict_output_raises(): + """output_iterator must raise ValueError when output is a dict.""" + p = _make_prediction(output={"url": "https://example.com/file.png"}) + with pytest.raises(ValueError, match="array output type"): + list(p.output_iterator()) + + +@pytest.mark.asyncio +async def test_async_output_iterator_none_output_yields_nothing(): + """async_output_iterator must handle None output gracefully.""" + p = _make_prediction(output=None) + results = [] + async for item in p.async_output_iterator(): + results.append(item) + assert results == [] + + +@pytest.mark.asyncio +async def test_async_output_iterator_string_output_raises(): + """async_output_iterator must raise ValueError for non-list outputs. + + Before the fix, ``self.output or []`` returned the string intact, + causing iteration over individual characters silently. + """ + p = _make_prediction(output="some string") + with pytest.raises(ValueError, match="array output type"): + async for _ in p.async_output_iterator(): + pass