diff --git a/CHANGELOG.md b/CHANGELOG.md index 29f59d10..1a4305fd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,33 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.17.12](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.17.12) - 2026-02-23 + +### Added +- `Dataset.deduplicate()` method to deduplicate images using perceptual hashing. Accepts optional `reference_ids` to deduplicate specific items, or deduplicates the entire dataset when only `threshold` is provided. Required `threshold` parameter (0-64) controls similarity matching (lower = stricter, 0 = exact matches only). +- `Dataset.deduplicate_by_ids()` method for deduplication using internal `dataset_item_ids` directly, avoiding the reference ID to item ID mapping for improved efficiency. +- `DeduplicationResult` and `DeduplicationStats` dataclasses for structured deduplication results. + +Example usage: + +```python +dataset = client.get_dataset("ds_...") + +# Deduplicate entire dataset +result = dataset.deduplicate(threshold=10) + +# Deduplicate specific items by reference IDs +result = dataset.deduplicate(threshold=10, reference_ids=["ref_1", "ref_2", "ref_3"]) + +# Deduplicate by internal item IDs (more efficient if you have them) +result = dataset.deduplicate_by_ids(threshold=10, dataset_item_ids=["item_1", "item_2"]) + +# Access results +print(f"Threshold: {result.stats.threshold}") +print(f"Original: {result.stats.original_count}, Unique: {result.stats.deduplicated_count}") +print(result.unique_reference_ids) +``` + ## [0.17.11](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.17.11) - 2025-11-03 ### Added diff --git a/nucleus/__init__.py b/nucleus/__init__.py index 3f970c2b..df97ddec 100644 --- a/nucleus/__init__.py +++ b/nucleus/__init__.py @@ -4,6 +4,8 @@ "AsyncJob", "EmbeddingsExportJob", "BoxAnnotation", + "DeduplicationResult", + "DeduplicationStats", "BoxPrediction", "CameraParams", "CategoryAnnotation", @@ -128,6 +130,7 @@ from .data_transfer_object.job_status import JobInfoRequestPayload from .dataset import Dataset from .dataset_item import DatasetItem +from .deduplication import DeduplicationResult, DeduplicationStats from .deprecation_warning import deprecated from .errors import ( DatasetItemRetrievalError, diff --git a/nucleus/constants.py b/nucleus/constants.py index 0a2bbf46..ebad94f5 100644 --- a/nucleus/constants.py +++ b/nucleus/constants.py @@ -149,6 +149,7 @@ SLICE_TAGS_KEY = "slice_tags" TAXONOMY_NAME_KEY = "taxonomy_name" TASK_ID_KEY = "task_id" +THRESHOLD_KEY = "threshold" TRACK_REFERENCE_ID_KEY = "track_reference_id" TRACK_REFERENCE_IDS_KEY = "track_reference_ids" TRACKS_KEY = "tracks" diff --git a/nucleus/dataset.py b/nucleus/dataset.py index ea95f840..be1c9242 100644 --- a/nucleus/dataset.py +++ b/nucleus/dataset.py @@ -67,6 +67,7 @@ REQUEST_ID_KEY, SCENE_IDS_KEY, SLICE_ID_KEY, + THRESHOLD_KEY, TRACK_REFERENCE_IDS_KEY, TRACKS_KEY, TRAINED_SLICE_ID_KEY, @@ -83,6 +84,7 @@ check_items_have_dimensions, ) from .dataset_item_uploader import DatasetItemUploader +from .deduplication import DeduplicationResult, DeduplicationStats from .deprecation_warning import deprecated from .errors import NotFoundError, NucleusAPIError from .job import CustomerJobTypes, jobs_status_overview @@ -1006,6 +1008,106 @@ def create_slice_by_ids( ) return Slice(response[SLICE_ID_KEY], self._client) + def deduplicate( + self, + threshold: int, + reference_ids: Optional[List[str]] = None, + ) -> DeduplicationResult: + """Deduplicate images or frames in this dataset. + + Parameters: + threshold: Hamming distance threshold (0-64). Lower = stricter. + 0 = exact matches only. + reference_ids: Optional list of reference IDs to deduplicate. + If not provided (or None), deduplicates the entire dataset. + Cannot be an empty list - use None for entire dataset. + + Returns: + DeduplicationResult with unique_reference_ids, unique_item_ids, and stats. + + Raises: + ValueError: If reference_ids is an empty list (use None for entire dataset). + NucleusAPIError: If threshold is not an integer between 0 and 64 inclusive. + NucleusAPIError: If any reference_id is not found in the dataset. + NucleusAPIError: If any item is missing a perceptual hash (pHash). + Contact Scale support if this occurs. + + Note: + - For scene datasets, this deduplicates the underlying scene frames, + not the scenes themselves. Frame reference IDs or dataset item IDs + should be provided for scene datasets. + - For very large datasets, this operation may take significant time. + """ + # Client-side validation + if reference_ids is not None and len(reference_ids) == 0: + raise ValueError( + "reference_ids cannot be empty. Omit reference_ids parameter to deduplicate entire dataset." + ) + + payload: Dict[str, Any] = {THRESHOLD_KEY: threshold} + if reference_ids is not None: + payload[REFERENCE_IDS_KEY] = reference_ids + + response = self._client.make_request( + payload, f"dataset/{self.id}/deduplicate" + ) + return DeduplicationResult( + unique_item_ids=response["unique_item_ids"], + unique_reference_ids=response["unique_reference_ids"], + stats=DeduplicationStats( + threshold=threshold, + original_count=response["stats"]["original_count"], + deduplicated_count=response["stats"]["deduplicated_count"], + ), + ) + + def deduplicate_by_ids( + self, + threshold: int, + dataset_item_ids: List[str], + ) -> DeduplicationResult: + """Deduplicate images or frames by internal dataset item IDs. + + Parameters: + threshold: Hamming distance threshold (0-64). Lower = stricter. + 0 = exact matches only. + dataset_item_ids: List of internal dataset item IDs to deduplicate. + Must be non-empty. To deduplicate the entire dataset, refer to + the documentation for `deduplicate()` instead. + + Returns: + DeduplicationResult with unique_item_ids, unique_reference_ids, and stats. + + Raises: + ValueError: If dataset_item_ids is empty. + NucleusAPIError: If threshold is not an integer between 0 and 64 inclusive. + NucleusAPIError: If any dataset_item_id is not found in the dataset. + NucleusAPIError: If any item is missing a perceptual hash (pHash). + Contact Scale support if this occurs. + """ + # Client-side validation + if not dataset_item_ids: + raise ValueError( + "dataset_item_ids must be non-empty. Use deduplicate() for entire dataset." + ) + + payload = { + DATASET_ITEM_IDS_KEY: dataset_item_ids, + THRESHOLD_KEY: threshold, + } + response = self._client.make_request( + payload, f"dataset/{self.id}/deduplicate" + ) + return DeduplicationResult( + unique_item_ids=response["unique_item_ids"], + unique_reference_ids=response["unique_reference_ids"], + stats=DeduplicationStats( + threshold=threshold, + original_count=response["stats"]["original_count"], + deduplicated_count=response["stats"]["deduplicated_count"], + ), + ) + def build_slice( self, name: str, diff --git a/nucleus/deduplication.py b/nucleus/deduplication.py new file mode 100644 index 00000000..f427c004 --- /dev/null +++ b/nucleus/deduplication.py @@ -0,0 +1,16 @@ +from dataclasses import dataclass +from typing import List + + +@dataclass +class DeduplicationStats: + threshold: int + original_count: int + deduplicated_count: int + + +@dataclass +class DeduplicationResult: + unique_item_ids: List[str] # Internal dataset item IDs + unique_reference_ids: List[str] # User-defined reference IDs + stats: DeduplicationStats diff --git a/pyproject.toml b/pyproject.toml index 4fe1aaa2..6622dcd6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ ignore = ["E501", "E741", "E731", "F401"] # Easy ignore for getting it running [tool.poetry] name = "scale-nucleus" -version = "0.17.11" +version = "0.17.12" description = "The official Python client library for Nucleus, the Data Platform for AI" license = "MIT" authors = ["Scale AI Nucleus Team "] diff --git a/tests/test_deduplication.py b/tests/test_deduplication.py new file mode 100644 index 00000000..75947296 --- /dev/null +++ b/tests/test_deduplication.py @@ -0,0 +1,313 @@ +import pytest + +from nucleus import Dataset, DatasetItem, NucleusClient, VideoScene +from nucleus.deduplication import DeduplicationResult +from nucleus.errors import NucleusAPIError + +from .helpers import ( + TEST_DATASET_ITEMS, + TEST_DATASET_NAME, + TEST_IMG_URLS, + TEST_VIDEO_DATASET_NAME, + TEST_VIDEO_SCENES, + TEST_VIDEO_URL, +) + + +def test_deduplicate_empty_reference_ids_raises_error(): + fake_dataset = Dataset("fake", NucleusClient("fake")) + with pytest.raises(ValueError, match="reference_ids cannot be empty"): + fake_dataset.deduplicate(threshold=10, reference_ids=[]) + + +def test_deduplicate_by_ids_empty_list_raises_error(): + fake_dataset = Dataset("fake", NucleusClient("fake")) + with pytest.raises(ValueError, match="dataset_item_ids must be non-empty"): + fake_dataset.deduplicate_by_ids(threshold=10, dataset_item_ids=[]) + + +@pytest.fixture(scope="module") +def dataset_image(CLIENT): + """Image dataset with TEST_DATASET_ITEMS (waits for phash calculation).""" + ds = CLIENT.create_dataset(TEST_DATASET_NAME + " dedup", is_scene=False) + ds.append(TEST_DATASET_ITEMS) + yield ds + CLIENT.delete_dataset(ds.id) + + +@pytest.mark.integration +def test_deduplicate_entire_dataset(dataset_image): + result = dataset_image.deduplicate(threshold=10) + assert isinstance(result, DeduplicationResult) + assert len(result.unique_reference_ids) > 0 + assert len(result.unique_item_ids) > 0 + assert result.stats.original_count == len(TEST_DATASET_ITEMS) + + +@pytest.mark.integration +def test_deduplicate_with_reference_ids(dataset_image): + reference_ids = [item.reference_id for item in TEST_DATASET_ITEMS[:2]] + result = dataset_image.deduplicate(threshold=10, reference_ids=reference_ids) + assert isinstance(result, DeduplicationResult) + assert result.stats.original_count == len(reference_ids) + assert len(result.unique_reference_ids) <= len(reference_ids) + assert len(result.unique_item_ids) <= len(reference_ids) + + +@pytest.mark.integration +def test_deduplicate_by_ids(dataset_image): + initial_result = dataset_image.deduplicate(threshold=10) + item_ids = initial_result.unique_item_ids + assert len(item_ids) > 0 + + result = dataset_image.deduplicate_by_ids(threshold=10, dataset_item_ids=item_ids) + assert isinstance(result, DeduplicationResult) + assert result.stats.original_count == len(item_ids) + assert result.unique_item_ids == initial_result.unique_item_ids + + +@pytest.fixture(scope="module") +def dataset_video_scene(CLIENT): + """Scene dataset with scene_1 (frame IDs: video_frame_0, video_frame_1).""" + ds = CLIENT.create_dataset(TEST_VIDEO_DATASET_NAME + " dedup", is_scene=True) + scene_1 = TEST_VIDEO_SCENES["scenes"][0] + scenes = [VideoScene.from_json(scene_1)] + job = ds.append(scenes, asynchronous=True) + job.sleep_until_complete() + yield ds + CLIENT.delete_dataset(ds.id) + + +def _get_scene_frame_ref_ids(): + """Extract frame reference IDs from TEST_VIDEO_SCENES scene_1.""" + return [frame["reference_id"] for frame in TEST_VIDEO_SCENES["scenes"][0]["frames"]] + + +@pytest.mark.integration +def test_deduplicate_video_scene_entire_dataset(dataset_video_scene): + result = dataset_video_scene.deduplicate(threshold=10) + assert isinstance(result, DeduplicationResult) + assert len(result.unique_reference_ids) > 0 + assert len(result.unique_item_ids) > 0 + assert result.stats.original_count == len(_get_scene_frame_ref_ids()) + + +@pytest.mark.integration +def test_deduplicate_video_scene_with_frame_reference_ids(dataset_video_scene): + frame_ref_ids = _get_scene_frame_ref_ids() + result = dataset_video_scene.deduplicate(threshold=10, reference_ids=frame_ref_ids) + assert isinstance(result, DeduplicationResult) + assert result.stats.original_count == len(frame_ref_ids) + assert len(result.unique_reference_ids) <= len(frame_ref_ids) + assert len(result.unique_item_ids) <= len(frame_ref_ids) + + +@pytest.mark.integration +def test_deduplicate_video_scene_by_ids(dataset_video_scene): + initial_result = dataset_video_scene.deduplicate(threshold=10) + item_ids = initial_result.unique_item_ids + assert len(item_ids) > 0 + + result = dataset_video_scene.deduplicate_by_ids( + threshold=10, dataset_item_ids=item_ids + ) + assert isinstance(result, DeduplicationResult) + assert result.stats.original_count == len(item_ids) + assert result.unique_item_ids == initial_result.unique_item_ids + + +@pytest.fixture(scope="module") +def dataset_video_url(CLIENT): + """Scene dataset created from a video URL (not a list of frames).""" + ds = CLIENT.create_dataset(TEST_VIDEO_DATASET_NAME + " video_url dedup", is_scene=True) + scene = VideoScene.from_json({ + "reference_id": "video_url_scene", + "video_url": TEST_VIDEO_URL, + "metadata": {"test": "video_url_dedup"}, + }) + job = ds.append([scene], asynchronous=True) + job.sleep_until_complete() + yield ds + CLIENT.delete_dataset(ds.id) + + +@pytest.mark.integration +def test_deduplicate_video_url_entire_dataset(dataset_video_url): + """Test deduplication on a dataset created from a video URL.""" + result = dataset_video_url.deduplicate(threshold=10) + assert isinstance(result, DeduplicationResult) + assert len(result.unique_reference_ids) > 0 + assert len(result.unique_item_ids) > 0 + assert result.stats.original_count > 0 + + +@pytest.mark.integration +def test_deduplicate_video_url_by_ids(dataset_video_url): + """Test deduplicate_by_ids on a dataset created from a video URL.""" + initial_result = dataset_video_url.deduplicate(threshold=10) + item_ids = initial_result.unique_item_ids + assert len(item_ids) > 0 + + result = dataset_video_url.deduplicate_by_ids( + threshold=10, dataset_item_ids=item_ids + ) + assert isinstance(result, DeduplicationResult) + assert result.stats.original_count == len(item_ids) + assert result.unique_item_ids == initial_result.unique_item_ids + + +# Edge case tests + + +@pytest.mark.integration +def test_deduplicate_threshold_zero(dataset_image): + """Threshold=0 means exact matches only.""" + result = dataset_image.deduplicate(threshold=0) + assert isinstance(result, DeduplicationResult) + assert result.stats.threshold == 0 + + +@pytest.mark.integration +def test_deduplicate_threshold_max(dataset_image): + """Threshold=64 is the maximum allowed value.""" + result = dataset_image.deduplicate(threshold=64) + assert isinstance(result, DeduplicationResult) + assert result.stats.threshold == 64 + + +@pytest.mark.integration +def test_deduplicate_threshold_negative(dataset_image): + """Threshold must be >= 0.""" + with pytest.raises(NucleusAPIError): + dataset_image.deduplicate(threshold=-1) + + +@pytest.mark.integration +def test_deduplicate_threshold_too_high(dataset_image): + """Threshold must be <= 64.""" + with pytest.raises(NucleusAPIError): + dataset_image.deduplicate(threshold=65) + + +@pytest.mark.integration +def test_deduplicate_threshold_non_integer(dataset_image): + """Threshold must be an integer.""" + with pytest.raises(NucleusAPIError): + dataset_image.deduplicate(threshold=10.5) + + +@pytest.mark.integration +def test_deduplicate_nonexistent_reference_id(dataset_image): + with pytest.raises(NucleusAPIError): + dataset_image.deduplicate(threshold=10, reference_ids=["nonexistent_ref_id"]) + + +@pytest.mark.integration +def test_deduplicate_by_ids_nonexistent_id(dataset_image): + with pytest.raises(NucleusAPIError): + dataset_image.deduplicate_by_ids(threshold=10, dataset_item_ids=["di_nonexistent"]) + + +@pytest.mark.integration +def test_deduplicate_idempotency(dataset_image): + result1 = dataset_image.deduplicate(threshold=10) + result2 = dataset_image.deduplicate(threshold=10) + + assert result1.unique_item_ids == result2.unique_item_ids + assert result1.unique_reference_ids == result2.unique_reference_ids + assert result1.stats.original_count == result2.stats.original_count + assert result1.stats.deduplicated_count == result2.stats.deduplicated_count + + +@pytest.mark.integration +def test_deduplicate_response_invariants(dataset_image): + result = dataset_image.deduplicate(threshold=10) + + assert len(result.unique_item_ids) == len(result.unique_reference_ids) + assert result.stats.deduplicated_count == len(result.unique_item_ids) + assert result.stats.deduplicated_count <= result.stats.original_count + assert result.stats.threshold == 10 + + +@pytest.mark.integration +def test_deduplicate_by_ids_threshold_negative(dataset_image): + """deduplicate_by_ids should enforce the same threshold constraints.""" + initial_result = dataset_image.deduplicate(threshold=10) + item_ids = initial_result.unique_item_ids + + with pytest.raises(NucleusAPIError): + dataset_image.deduplicate_by_ids(threshold=-1, dataset_item_ids=item_ids) + + +@pytest.mark.integration +def test_deduplicate_by_ids_threshold_too_high(dataset_image): + """deduplicate_by_ids should enforce the same threshold constraints.""" + initial_result = dataset_image.deduplicate(threshold=10) + item_ids = initial_result.unique_item_ids + + with pytest.raises(NucleusAPIError): + dataset_image.deduplicate_by_ids(threshold=65, dataset_item_ids=item_ids) + + +@pytest.mark.integration +def test_deduplicate_single_item(dataset_image): + """Single item should always be unique.""" + reference_ids = [TEST_DATASET_ITEMS[0].reference_id] + result = dataset_image.deduplicate(threshold=10, reference_ids=reference_ids) + + assert result.stats.original_count == 1 + assert result.stats.deduplicated_count == 1 + assert len(result.unique_reference_ids) == 1 + + +@pytest.fixture() +def dataset_empty(CLIENT): + """Empty dataset with no items.""" + ds = CLIENT.create_dataset(TEST_DATASET_NAME + " empty", is_scene=False) + yield ds + CLIENT.delete_dataset(ds.id) + + +@pytest.mark.integration +def test_deduplicate_empty_dataset(dataset_empty): + """Empty dataset should return zero counts.""" + result = dataset_empty.deduplicate(threshold=10) + + assert result.stats.original_count == 0 + assert result.stats.deduplicated_count == 0 + assert len(result.unique_reference_ids) == 0 + assert len(result.unique_item_ids) == 0 + + +@pytest.fixture() +def dataset_with_duplicates(CLIENT): + """Dataset with duplicate images (same image uploaded twice).""" + ds = CLIENT.create_dataset(TEST_DATASET_NAME + " duplicates", is_scene=False) + items = [ + DatasetItem(TEST_IMG_URLS[0], reference_id="img_original"), + DatasetItem(TEST_IMG_URLS[0], reference_id="img_duplicate"), + DatasetItem(TEST_IMG_URLS[1], reference_id="img_different"), + ] + ds.append(items) + yield ds + CLIENT.delete_dataset(ds.id) + + +@pytest.mark.integration +def test_deduplicate_identifies_duplicates(dataset_with_duplicates): + """Verify deduplication actually identifies duplicate images.""" + result = dataset_with_duplicates.deduplicate(threshold=0) + + assert result.stats.original_count == 3 + # With threshold=0, the two identical images should be deduplicated to one + assert result.stats.deduplicated_count == 2 + assert len(result.unique_reference_ids) == 2 + + +@pytest.mark.integration +def test_deduplicate_distinct_images_all_unique(dataset_image): + """Distinct images should all remain after deduplication.""" + result = dataset_image.deduplicate(threshold=0) + + # With threshold=0 (exact match only), all distinct images should be unique + assert result.stats.deduplicated_count == result.stats.original_count