-
Notifications
You must be signed in to change notification settings - Fork 11
add eval capes to sdk #460
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
4c6083e
36f6b4a
3caaf8d
13a91b2
cce066e
aced4aa
866ac71
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,8 @@ | ||
| """Nucleus Python SDK. """ | ||
| """Nucleus Python SDK.""" | ||
|
|
||
| __all__ = [ | ||
| "AsyncJob", | ||
| "AllowedLabelMatch", | ||
| "EmbeddingsExportJob", | ||
| "BoxAnnotation", | ||
| "DeduplicationJob", | ||
|
|
@@ -17,6 +18,12 @@ | |
| "DatasetInfo", | ||
| "DatasetItem", | ||
| "DatasetItemRetrievalError", | ||
| "EvaluationV2", | ||
| "EvaluationV2Charts", | ||
| "EvaluationV2ExamplesPage", | ||
| "EvaluationV2FilterArgs", | ||
| "EvaluationV2MatchExample", | ||
| "EvaluationV2Status", | ||
| "Frame", | ||
| "Keypoint", | ||
| "KeypointsAnnotation", | ||
|
|
@@ -129,6 +136,12 @@ | |
| ) | ||
| from .data_transfer_object.dataset_details import DatasetDetails | ||
| from .data_transfer_object.dataset_info import DatasetInfo | ||
| from .data_transfer_object.evaluation_v2 import ( | ||
| EvaluationV2Charts, | ||
| EvaluationV2ExamplesPage, | ||
| EvaluationV2FilterArgs, | ||
| EvaluationV2MatchExample, | ||
| ) | ||
| from .data_transfer_object.job_status import JobInfoRequestPayload | ||
| from .dataset import Dataset | ||
| from .dataset_item import DatasetItem | ||
|
|
@@ -146,6 +159,7 @@ | |
| NotFoundError, | ||
| NucleusAPIError, | ||
| ) | ||
| from .evaluation_v2 import AllowedLabelMatch, EvaluationV2, EvaluationV2Status | ||
| from .job import CustomerJobTypes | ||
| from .model import Model | ||
| from .model_run import ModelRun | ||
|
|
@@ -875,6 +889,61 @@ def commit_model_run( | |
| payload = {} | ||
| return self.make_request(payload, f"modelRun/{model_run_id}/commit") | ||
|
|
||
| def create_evaluation_v2( | ||
| self, | ||
| model_run_id: str, | ||
| *, | ||
| name: Optional[str] = None, | ||
| allowed_label_matches: Optional[List[AllowedLabelMatch]] = None, | ||
| allowed_label_matches_id: Optional[str] = None, | ||
| ) -> EvaluationV2: | ||
| """Create an Evaluation V2 job for a model run. | ||
|
|
||
| Starts a Temporal workflow that fills ``evaluation_match_v2``. Use | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: i don't think infra specific information like |
||
| :meth:`EvaluationV2.wait_for_completion` then :meth:`EvaluationV2.charts` | ||
| or :meth:`EvaluationV2.examples` for results. | ||
|
|
||
| Parameters: | ||
| model_run_id: Nucleus model run id (``run_*``). | ||
| name: Optional human-readable name. | ||
| allowed_label_matches: Optional explicit allowed label pairs; omit to use | ||
| the model run's default configuration. | ||
| allowed_label_matches_id: Optional existing allowed-label-matches config id. | ||
|
|
||
| Returns: | ||
| :class:`EvaluationV2` loaded via ``GET /nucleus/evaluationsV2/:id``. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: is it necessary to specify the get endpoint that the return is using? Ideally we'd want to point users to use the sdk. For this case, probably over information for how the information is retrieved |
||
| """ | ||
| payload: Dict[str, Any] = {} | ||
| if name is not None: | ||
| payload["name"] = name | ||
| if allowed_label_matches is not None: | ||
| payload["allowed_label_matches"] = [ | ||
| m.to_api_dict() for m in allowed_label_matches | ||
| ] | ||
| if allowed_label_matches_id is not None: | ||
| payload["allowed_label_matches_id"] = allowed_label_matches_id | ||
| result = self.make_request( | ||
| payload, f"modelRun/{model_run_id}/evaluationsV2" | ||
| ) | ||
| eval_id = result.get("evaluation_id") | ||
| if not eval_id: | ||
| raise RuntimeError( | ||
| f"Unexpected create evaluation V2 response: {result}" | ||
| ) | ||
| return self.get_evaluation_v2(str(eval_id)) | ||
|
|
||
| def get_evaluation_v2(self, evaluation_id: str) -> EvaluationV2: | ||
| """Fetch a single Evaluation V2 row.""" | ||
| data = self.get(f"evaluationsV2/{evaluation_id}") | ||
| return EvaluationV2.from_json(data, self) | ||
|
|
||
| def list_evaluations_v2(self, model_run_id: str) -> List[EvaluationV2]: | ||
| """List Evaluation V2 rows for a model run (newest first).""" | ||
| rows = self.get(f"modelRun/{model_run_id}/evaluationsV2") | ||
| if not isinstance(rows, list): | ||
| return [] | ||
|
Comment on lines
+943
to
+944
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we throw an error for cases like this instead of empty list? That way callers can distinguish between "no evals" and "broken/wrong response" |
||
| return [EvaluationV2.from_json(r, self) for r in rows] | ||
|
|
||
| @deprecated(msg="Prefer calling Dataset.info() directly.") | ||
| def dataset_info(self, dataset_id: str): | ||
| dataset = self.get_dataset(dataset_id) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,145 @@ | ||
| """Pydantic models for Nucleus Evaluations V2 REST payloads.""" | ||
|
|
||
| from typing import Any, Dict, List, Literal, Optional | ||
|
|
||
| from nucleus.pydantic_base import DictCompatibleModel | ||
|
|
||
|
|
||
| class RangeNum(DictCompatibleModel): | ||
| min: Optional[float] = None | ||
| max: Optional[float] = None | ||
|
|
||
|
|
||
| class MetadataPredicate(DictCompatibleModel): | ||
| key: str | ||
| op: Literal["EQ", "IN", "GT", "LT"] | ||
| value: Optional[Any] = None | ||
|
|
||
|
|
||
| class EvaluationV2FilterArgs(DictCompatibleModel): | ||
| """Filter object for charts/examples calls (mirrors server evaluation_v2 SQL filters).""" | ||
|
|
||
| confidence_range: Optional[RangeNum] = None | ||
| iou_range: Optional[RangeNum] = None | ||
| pred_labels: Optional[List[str]] = None | ||
| gt_labels: Optional[List[str]] = None | ||
| item_metadata: Optional[List[MetadataPredicate]] = None | ||
| prediction_metadata: Optional[List[MetadataPredicate]] = None | ||
| label_equality: Optional[Literal["EQ", "NEQ"]] = None | ||
| has_ground_truth: Optional[bool] = None | ||
| tide_background: Optional[bool] = None | ||
|
|
||
| def to_api_filters(self) -> Dict[str, Any]: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think only top-level keys are converted to camelCase here, which is fine if that's the intention, just wanted to flag. if there are nested fields that are not singlewords in the future, they'll be sent in snake_case i think |
||
| """Serialize to camelCase keys expected by the GraphQL / REST layer.""" | ||
| d = self.dict(exclude_none=True) | ||
| # pydantic v1 uses snake_case fields; server expects camelCase in JSON filters | ||
| out: Dict[str, Any] = {} | ||
| if "confidence_range" in d: | ||
| out["confidenceRange"] = d["confidence_range"] | ||
| if "iou_range" in d: | ||
| out["iouRange"] = d["iou_range"] | ||
| if "pred_labels" in d: | ||
| out["predLabels"] = d["pred_labels"] | ||
| if "gt_labels" in d: | ||
| out["gtLabels"] = d["gt_labels"] | ||
| if "item_metadata" in d: | ||
| out["itemMetadata"] = d["item_metadata"] | ||
| if "prediction_metadata" in d: | ||
| out["predictionMetadata"] = d["prediction_metadata"] | ||
| if "label_equality" in d: | ||
| out["labelEquality"] = d["label_equality"] | ||
| if "has_ground_truth" in d: | ||
| out["hasGroundTruth"] = d["has_ground_truth"] | ||
| if "tide_background" in d: | ||
| out["tideBackground"] = d["tide_background"] | ||
|
Comment on lines
+37
to
+54
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: this is a bit verbose and long, consider using a map |
||
| return out | ||
|
|
||
|
|
||
| class MapSummary(DictCompatibleModel): | ||
| mapAt50: Optional[float] = None | ||
| mapAt75: Optional[float] = None | ||
| mapAt5095: Optional[float] = None | ||
|
|
||
|
|
||
| class PerClassAp(DictCompatibleModel): | ||
| classLabel: str | ||
| ap: float | ||
|
|
||
|
|
||
| class ConfusionEntry(DictCompatibleModel): | ||
| gtLabel: str | ||
| predLabel: str | ||
| count: int | ||
|
|
||
|
|
||
| class ScoreHistogramBucket(DictCompatibleModel): | ||
| bucketMin: float | ||
| bucketMax: float | ||
| count: int | ||
|
|
||
|
|
||
| class TotalCounts(DictCompatibleModel): | ||
| tp: int | ||
| fp: int | ||
| fn: int | ||
| predsWithConfidence: int | ||
|
|
||
|
|
||
| class ApBySize(DictCompatibleModel): | ||
| small: Optional[float] = None | ||
| medium: Optional[float] = None | ||
| large: Optional[float] = None | ||
|
|
||
|
|
||
| class PrCurvePoint(DictCompatibleModel): | ||
| classLabel: str | ||
| recall: float | ||
| precision: float | ||
|
|
||
|
|
||
| class TideAttribution(DictCompatibleModel): | ||
| truePositive: int | ||
| localization: int | ||
| classification: int | ||
| both: int | ||
| duplicate: int | ||
| background: int | ||
| missed: int | ||
|
|
||
|
|
||
| class EvaluationV2Charts(DictCompatibleModel): | ||
| mapSummary: MapSummary | ||
| perClassAp: List[PerClassAp] | ||
| confusionMatrix: List[ConfusionEntry] | ||
| scoreHistogram: List[ScoreHistogramBucket] | ||
| computedIouRanges: List[float] | ||
| totalCounts: TotalCounts | ||
| apBySize: ApBySize | ||
| prCurve: List[PrCurvePoint] | ||
| tideAttribution: TideAttribution | ||
|
|
||
|
|
||
| class EvaluationV2MatchExample(DictCompatibleModel): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: not sure if intentional, but the example keys use snake_case, while the char models above all use camelCase |
||
| id: str | ||
| evaluation_id: str | ||
| dataset_item_id: str | ||
| model_prediction_id: Optional[str] = None | ||
| ground_truth_annotation_id: Optional[str] = None | ||
| pred_canonical_label: Optional[str] = None | ||
| gt_canonical_label: Optional[str] = None | ||
| pred_raw_label: Optional[str] = None | ||
| gt_raw_label: Optional[str] = None | ||
| iou: Optional[float] = None | ||
| confidence: Optional[float] = None | ||
| true_positive: bool | ||
| match_type: str | ||
| gt_area: Optional[float] = None | ||
| item_metadata: Optional[Dict[str, Any]] = None | ||
| prediction_metadata: Optional[Dict[str, Any]] = None | ||
|
luke-e-schaefer marked this conversation as resolved.
|
||
| prediction_row: Optional[Dict[str, Any]] = None | ||
| annotation_row: Optional[Dict[str, Any]] = None | ||
|
|
||
|
|
||
| class EvaluationV2ExamplesPage(DictCompatibleModel): | ||
| rows: List[EvaluationV2MatchExample] | ||
| total: int | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as https://github.com/scaleapi/nucleus-python-client/pull/460/changes#r3311980952