Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 47 additions & 17 deletions replicate/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class Prediction(Resource):
version: str
"""An identifier for the version of the model used to create the prediction."""

status: Literal["starting", "processing", "succeeded", "failed", "canceled"]
status: Literal["starting", "processing", "succeeded", "failed", "canceled", "aborted"]
"""The status of the prediction."""

input: Optional[Dict[str, Any]]
Expand Down Expand Up @@ -141,7 +141,7 @@ def wait(self) -> None:
Wait for prediction to finish.
"""

while self.status not in ["succeeded", "failed", "canceled"]:
while self.status not in ["succeeded", "failed", "canceled", "aborted"]:
time.sleep(self._client.poll_interval)
self.reload()

Expand All @@ -150,7 +150,7 @@ async def async_wait(self) -> None:
Wait for prediction to finish asynchronously.
"""

while self.status not in ["succeeded", "failed", "canceled"]:
while self.status not in ["succeeded", "failed", "canceled", "aborted"]:
await asyncio.sleep(self._client.poll_interval)
await self.async_reload()

Expand Down Expand Up @@ -249,20 +249,39 @@ def output_iterator(self) -> Iterator[Any]:
Return an iterator of the prediction output.
"""

# TODO: check output is list
previous_output = self.output or []
while self.status not in ["succeeded", "failed", "canceled"]:
output = self.output or []
def _as_list(value: Any) -> list:
"""Coerce output to a list.

``None`` means the model has not produced any output yet; treat it
as an empty list so the polling loop can start cleanly. Any other
non-list value (e.g. a plain string returned by a non-streaming
model) indicates a model whose output schema is not an array — in
that case we raise a ``ValueError`` rather than silently iterating
over the characters of a string or the keys of a dict.
"""
if value is None:
return []
if isinstance(value, list):
return value
raise ValueError(
f"output_iterator requires an array output type, "
f"but the model returned a {type(value).__name__!r}. "
f"Use prediction.output directly for non-array outputs."
)

previous_output = _as_list(self.output)
while self.status not in ["succeeded", "failed", "canceled", "aborted"]:
output = _as_list(self.output)
new_output = output[len(previous_output) :]
yield from new_output
previous_output = output
time.sleep(self._client.poll_interval) # pylint: disable=no-member
self.reload()

if self.status == "failed":
if self.status in ("failed", "aborted"):
raise ModelError(self)

output = self.output or []
output = _as_list(self.output)
new_output = output[len(previous_output) :]
yield from new_output

Expand All @@ -271,24 +290,35 @@ async def async_output_iterator(self) -> AsyncIterator[Any]:
Return an asynchronous iterator of the prediction output.
"""

# TODO: check output is list
previous_output = self.output or []
while self.status not in ["succeeded", "failed", "canceled"]:
output = self.output or []
def _as_list(value: Any) -> list:
"""Coerce output to a list (see sync variant for rationale)."""
if value is None:
return []
if isinstance(value, list):
return value
raise ValueError(
f"async_output_iterator requires an array output type, "
f"but the model returned a {type(value).__name__!r}. "
f"Use prediction.output directly for non-array outputs."
)

previous_output = _as_list(self.output)
while self.status not in ["succeeded", "failed", "canceled", "aborted"]:
output = _as_list(self.output)
new_output = output[len(previous_output) :]
for item in new_output:
yield item
previous_output = output
await asyncio.sleep(self._client.poll_interval) # pylint: disable=no-member
await self.async_reload()

if self.status == "failed":
if self.status in ("failed", "aborted"):
raise ModelError(self)

output = self.output or []
output = _as_list(self.output)
new_output = output[len(previous_output) :]
for output in new_output:
yield output
for item in new_output:
yield item


class Predictions(Namespace):
Expand Down
88 changes: 88 additions & 0 deletions tests/test_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import respx

import replicate
from replicate.prediction import Prediction


@pytest.mark.vcr("predictions-create.yaml")
Expand Down Expand Up @@ -540,3 +541,90 @@ async def test_predictions_stream(async_flag):
# assert progress.current == 5
# assert progress.total == 5
# assert progress.percentage == 1.0


# ---------------------------------------------------------------------------
# Unit tests: output_iterator / async_output_iterator non-list guard
# ---------------------------------------------------------------------------


def _make_prediction(output, status="succeeded"):
"""Build a minimal Prediction with a mock client (no HTTP calls needed)."""
p = Prediction(
id="p1",
model="owner/model",
version="v1",
urls={
"get": "https://api.replicate.com/v1/predictions/p1",
"cancel": "https://api.replicate.com/v1/predictions/p1/cancel",
},
created_at="2024-01-01T00:00:00.000000Z",
source="api",
status=status,
input={"prompt": "hello"},
output=output,
error=None,
logs="",
)
return p


def test_output_iterator_completed_with_list_output_yields_nothing():
"""output_iterator yields only items arriving *after* the iterator starts.

When called on an already-completed prediction, all output tokens were
present at start-time so no new items are yielded. This documents the
intended "streaming" contract: call output_iterator while the prediction
is still running, not after it has completed.
"""
p = _make_prediction(output=["token1", "token2", "token3"], status="succeeded")
# The full list is the "previous_output" baseline, so nothing is yielded.
assert list(p.output_iterator()) == []


def test_output_iterator_none_output_yields_nothing():
"""output_iterator must handle None output gracefully (empty sequence)."""
p = _make_prediction(output=None)
assert list(p.output_iterator()) == []


def test_output_iterator_string_output_raises():
"""output_iterator must raise ValueError when output is a plain string.

Before the fix, ``self.output or []`` returned the string intact, causing
``yield from "hello world"`` to silently iterate over individual characters
instead of raising a clear error.
"""
p = _make_prediction(output="hello world")
with pytest.raises(ValueError, match="array output type"):
list(p.output_iterator())


def test_output_iterator_dict_output_raises():
"""output_iterator must raise ValueError when output is a dict."""
p = _make_prediction(output={"url": "https://example.com/file.png"})
with pytest.raises(ValueError, match="array output type"):
list(p.output_iterator())


@pytest.mark.asyncio
async def test_async_output_iterator_none_output_yields_nothing():
"""async_output_iterator must handle None output gracefully."""
p = _make_prediction(output=None)
results = []
async for item in p.async_output_iterator():
results.append(item)
assert results == []


@pytest.mark.asyncio
async def test_async_output_iterator_string_output_raises():
"""async_output_iterator must raise ValueError for non-list outputs.

Before the fix, ``self.output or []`` returned the string intact,
causing iteration over individual characters silently.
"""
p = _make_prediction(output="some string")
with pytest.raises(ValueError, match="array output type"):
async for _ in p.async_output_iterator():
pass