Skip to content

Dataloader

rerun.experimental.dataloader

PyTorch Datasets for training on data from the Rerun catalog.

class ColumnDecoder

Bases: ABC

Base class for column decoders.

Subclasses convert raw Arrow data into tensors. Stateless decoders (images, scalars) only need to implement decode. Context-aware decoders (compressed video) should also override context_range so the prefetcher fetches surrounding data.

def context_range(index_value)

Extra index-value range needed to decode index_value.

Returns (start, end) inclusive, or None when only the exact index value is required (the default).

def decode(raw, index_value, segment_id) abstractmethod

Decode raw Arrow data into a tensor, or return None to signal data missing.

class DataSource dataclass

An immutable reference to a dataset with an optional segment filter.

PARAMETER DESCRIPTION
dataset

The remote dataset to read from.

TYPE: DatasetEntry

segments

Optional list of segment IDs to restrict to.

TYPE: list[str] | None DEFAULT: None

def filter_segments(segment_ids)

Return a new DataSource narrowed to segment_ids.

class Field dataclass

Declarative spec for one field of a training sample.

Note

This API is provisional and will be improved, expect the surface to change.

PARAMETER DESCRIPTION
path

entity_path:Archetype:component triple identifying the source column (e.g. "/camera:EncodedImage:blob").

TYPE: str

decode

A ColumnDecoder that turns the Arrow column into a tensor.

TYPE: ColumnDecoder

select

Optional jq-like Selector applied client-side to the Arrow column before decode. Used for nested struct/list access. The server-side projection is unaffected.

Field(
    path="/agent:ListOfStructs:animals",
    select=Selector(".[0].dog"),
    decode=NumericDecoder(),
)

TYPE: Selector | None DEFAULT: None

window

Optional (start_offset, end_offset) range, inclusive on both ends and added to the current index value. The field then yields the slice of values across that window instead of a single sample. Offsets are in the index timeline's native unit: integer steps for integer-indexed timelines, nanoseconds for timestamp timelines (use multiples of the FixedRateSampling period to align with the sampling grid). For example, (1, 50) on an integer timeline fetches the next 50 values after the current sample.

TYPE: tuple[int, int] | None DEFAULT: None

class FixedRateSampling dataclass

Sample timestamp timelines at a fixed nominal rate.

Indices are drawn on an algebraic grid seg.index_start + k * ns_per_sample. The server's fill_latest_at absorbs any drift from real-row positions.

class ImageDecoder

Bases: ColumnDecoder

Decode a single encoded-image blob (JPEG/PNG) to a [C, H, W] uint8 tensor.

def context_range(index_value)

Extra index-value range needed to decode index_value.

Returns (start, end) inclusive, or None when only the exact index value is required (the default).

class NumericDecoder

Bases: ColumnDecoder

Decode Arrow numeric / list-of-numeric columns to a tensor.

def context_range(index_value)

Extra index-value range needed to decode index_value.

Returns (start, end) inclusive, or None when only the exact index value is required (the default).

class RerunIterableDataset

Bases: IterableDataset[dict[str, Tensor | None]]

Iterable dataset backed by a catalog server.

Fetches fetch_size samples per server query and yields individual samples, so per-query overhead is amortized across many samples while the DataLoader controls the training batch size independently.

The index list is partitioned across DDP ranks and DataLoader workers internally. With shuffle=True (default), the full list is shuffled once per epoch before partitioning; call set_epoch to re-seed between epochs.

PARAMETER DESCRIPTION
source

The dataset to read from (with optional segment filter).

TYPE: DataSource

index

Timeline to iterate (e.g. "frame_nr").

TYPE: str

fields

Sample fields, keyed by output name.

TYPE: dict[str, Field]

timeline_sampling

Required when index is a timestamp timeline; ignored for integer indices. Pass FixedRateSampling to sample on a fixed grid (e.g. 30 Hz).

TYPE: FixedRateSampling | None DEFAULT: None

fetch_size

Number of samples to fetch per server query. Larger values amortize network overhead but use more memory. Defaults to 128.

TYPE: int DEFAULT: 128

shuffle

Whether to shuffle sample order each epoch. Defaults to True.

TYPE: bool DEFAULT: True

sample_index property

The underlying SampleIndex.

def __iter__()

Yield individual samples as they are decoded.

The arrow fetch for chunk N+1 runs on a background thread while chunk N is being decoded, so samples stream out during decode.

def __len__()

Total number of samples across all segments.

def set_epoch(epoch)

Set the epoch for shuffling (like DistributedSampler.set_epoch).

class RerunMapDataset

Bases: Dataset[dict[str, Tensor | None]]

Map-style dataset backed by a catalog server.

Supports random access by global index, so it works with PyTorch's sampler ecosystem (DistributedSampler, WeightedRandomSampler, SubsetRandomSampler, ...). Shuffling and cross-worker partitioning are driven by the DataLoader's sampler.

For streaming iteration with internal shuffling, use RerunIterableDataset instead.

PARAMETER DESCRIPTION
source

The dataset to read from (with optional segment filter).

TYPE: DataSource

index

Timeline column to use as the sample index (e.g. "frame_nr").

TYPE: str

fields

Sample fields, keyed by output name.

TYPE: dict[str, Field]

timeline_sampling

Required when index is a timestamp timeline; ignored for integer indices. Pass FixedRateSampling to sample on a fixed grid (e.g. 30 Hz).

TYPE: FixedRateSampling | None DEFAULT: None

Examples:

dataset = RerunMapDataset(
    source,
    "frame_nr",
    {"image": Field("/camera:Image:blob", decode=ImageDecoder())},
)
sampler = DistributedSampler(dataset)
loader = DataLoader(dataset, batch_size=8, sampler=sampler, num_workers=4)
for batch in loader:
    ...
sample_index property

The underlying SampleIndex.

def __getitem__(idx)

Fetch a single sample by global index (one server query).

def __getitems__(indices)

Fetch multiple samples by global index in a single server query.

PyTorch's DataLoader calls this automatically when present, so each batch round-trips once.

def __len__()

Total number of samples across all segments.

class SampleIndex

Pre-computed description of the complete sample space.

Maps every segment's positional indices to concrete index values, accounting for the timeline strategy (integer or fixed-rate grid). Small enough to hold in memory for any realistic dataset.

PARAMETER DESCRIPTION
segments

Per-segment metadata (window-adjusted index range + sample count).

TYPE: list[SegmentMetadata]

ns_per_sample

For FixedRateSampling: nanoseconds between grid points. None for integer indices.

TYPE: int | None DEFAULT: None

is_timestamp

True when the index is a timestamp timeline. Exposed so callers can decide whether to interpret returned int values as nanoseconds-since-epoch.

TYPE: bool DEFAULT: False

is_timestamp property

Whether the index is a timestamp timeline.

ns_per_sample property

Nanoseconds between grid points for fixed-rate sampling, or None.

segments property

Per-segment metadata list.

total_samples property

Total number of samples across all segments.

def build(source, index, fields, *, timeline_sampling=None) staticmethod

Build a SampleIndex from lightweight metadata queries.

PARAMETER DESCRIPTION
source

Data source to build from.

TYPE: DataSource

index

Name of the index timeline column.

TYPE: str

fields

Field definitions, used for window-trim calculation.

TYPE: dict[str, Field]

timeline_sampling

Required for timestamp indices; ignored for integer indices. Pass FixedRateSampling for a regular grid.

TYPE: FixedRateSampling | None DEFAULT: None

def global_to_local(idx)

Map a global index [0, total_samples) to (segment, concrete_idx_value).

The returned index value is a plain int for integer timelines and a datetime64[ns] for timestamp timelines.

def indices_in_range(lo, hi)

Enumerate valid index values in [lo, hi].

For fixed-rate timelines the returned values walk down from hi in ns_per_sample steps (so they remain on the grid as long as hi is). For integer timelines, every value in [lo, hi] is returned. Values are plain int (ns-since-epoch for timestamp indices); the caller casts the aggregated set to the right numpy dtype.

def resolve_local_index(seg, pos)

Convert a positional index within seg to a concrete index value.

pos is in [0, seg.num_samples). Returns datetime64[ns] for timestamp timelines, a plain int for integer indices.

class SegmentMetadata dataclass

Per-segment metadata for sampling.

class VideoFrameDecoder

Bases: ColumnDecoder

Compressed video random access via context-aware fetching.

context_range(N) asks the prefetcher to pull the previous keyframe_interval samples (counted directly for integer indices, converted to keyframe_interval / fps_estimate seconds for timestamp indices). decode() runs the codec over that window in order and returns the final frame.

keyframe_interval must be greater than or equal to the actual GOP length, otherwise the window won't contain a keyframe and decode will fail. For timestamp indices fps_estimate must also be close to the true frame rate.

Samples may be raw H.264 AVC1/AVCC (length-prefixed NAL units) or Annex B; the format is detected automatically per sample.

Returns None when the prefetched range contains no keyframe — typically because the target precedes the entity's first frame in a multi-modal segment, or because keyframe_interval under-estimates the true GOP length. Consumers must filter these samples out in their collate function before stacking.

def context_range(index_value)

Need frames from estimated keyframe position to target.

def decode(raw, index_value, segment_id)

Decode the target frame from the context samples in raw, or None if no keyframe is available.

def tracing_scope(name)

Open an OpenTelemetry span for the duration of a with block and propagate trace context into Rerun's Rust SDK.

Context-manager counterpart to with_tracing — use it to scope arbitrary blocks of code without extracting them into a function. Any Rust-side #[instrument] spans triggered from within will be parented under this span in Jaeger.

No-op unless TELEMETRY_ENABLED=true and an OTLP endpoint is configured (OTEL_EXPORTER_OTLP_TRACES_ENDPOINT or OTEL_EXPORTER_OTLP_ENDPOINT).

Examples:

for epoch in range(num_epochs):
    with tracing_scope(f"epoch {epoch}"):
        train_one_epoch(...)

def with_tracing(name)

Wrap a function in an OpenTelemetry span and propagate trace context into Rerun's Rust SDK.

When enabled, creates a span named name, injects the W3C traceparent header into Rerun's shared ContextVar, and runs the wrapped function. Any Rust-side #[instrument] spans triggered from within (e.g. catalog queries) will be parented under this span in Jaeger.

For ad-hoc blocks that don't belong in a dedicated function, use tracing_scope instead.

No-op unless TELEMETRY_ENABLED=true and an OTLP endpoint is configured (OTEL_EXPORTER_OTLP_TRACES_ENDPOINT or OTEL_EXPORTER_OTLP_ENDPOINT).