Source code for lir.data.models

import logging
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable, Iterator
from typing import Annotated, Any, Self, TypeVar

import numpy as np
from pydantic import AfterValidator, BaseModel, ConfigDict, model_validator

from lir.util import check_type


LOG = logging.getLogger(__name__)


def _validate_labels(labels: np.ndarray | None) -> np.ndarray | None:
    """
    Check if labels have the correct shape.

    Parameters
    ----------
    labels : np.ndarray | None
        Value passed via ``labels``.

    Returns
    -------
    np.ndarray | None
        Validated label array, or ``None`` when labels are absent.
    """
    if labels is None:
        return labels

    if len(labels.shape) != 1:
        raise ValueError(f'labels must be 1-dimensional; shape: {labels.shape}')

    if np.any((labels != 0) & (labels != 1)):
        raise ValueError(f'labels allowed: 0, 1; found: {np.unique(labels)}')

    return labels


def _validate_source_ids(source_ids: np.ndarray | None) -> np.ndarray | None:
    """
    Check if source_ids have the correct shape.

    Parameters
    ----------
    source_ids : np.ndarray | None
        Value passed via ``source_ids``.

    Returns
    -------
    np.ndarray | None
        Validated source-id array, or ``None`` when source IDs are absent.
    """
    if source_ids is None:
        return source_ids

    if len(source_ids.shape) == 1:
        return source_ids

    # if we have a 2d array with one column, silently reshape it to 1d
    if len(source_ids.shape) == 2 and source_ids.shape[1] == 1:
        return source_ids.reshape(-1)

    if len(source_ids.shape) == 2 and source_ids.shape[1] == 2:
        return source_ids

    raise ValueError(
        f'source_ids must be either 1-dimensional or 2-dimensional with 2 columns; found shape {source_ids.shape}'
    )


[docs] class InstanceData(BaseModel, ABC): """ Base class for data on instances. An InstanceData object may be labeled or unlabeled with ground-truth data. If it is labeled, the label values correspond to the hypotheses and have values 0 or 1. In literature, the labels may have different names for values 1 and 0 respectively, such as: - hypothesis 1 and hypothesis 2 (or H1 and H2) - prosecutor's hypothesis and defense hypothesis (or Hp and Hd) - same-source and different-source (or Hss and Hds) The instances may optionally be associated with sources by means of the `source_ids` attribute. If available, each instance will generally have one source id if the object holds single instances, or two source ids if the object holds pairs of instances. This class imposes no restrictions on the actual instance data. Sub class implementations will specialize in particular data types. Attributes ---------- - `labels`: The hypothesis labels of the instances, as a 1-dimensional array with one value per instance, can be either 0 or 1. - `source_ids`: The ids of all sources that contributed to the instances. Each instance is from a single source, except if it is a pair, in which case it has two sources. The source ids is either a 1-dimensional array or a 2-dimensional array with two columns. """ model_config = ConfigDict(frozen=True, extra='allow', arbitrary_types_allowed=True) labels: Annotated[np.ndarray | None, AfterValidator(_validate_labels)] = None source_ids: Annotated[np.ndarray | None, AfterValidator(_validate_source_ids)] = None @property def require_labels(self) -> np.ndarray: """ Return `labels` and guarantee that it is not None (or raise an error). Returns ------- np.ndarray Label array guaranteed to contain values for both hypotheses. """ if self.labels is None: raise ValueError('labels not set') return self.labels
[docs] @model_validator(mode='after') def check_sourceids_labels_match(self) -> Self: """ Validate the source_ids and labels have matching shapes. Returns ------- Self This instance data object after post-init validation. """ if self.labels is not None and self.source_ids is not None and self.labels.shape[0] != self.source_ids.shape[0]: raise ValueError( f'dimensions of labels and source_ids do not match; "' f'{self.labels.shape[0]} != {self.source_ids.shape[0]}' ) return self
@property def source_ids_1d(self) -> np.ndarray: """ Return source identifiers as a one-dimensional array. Returns ------- np.ndarray One-dimensional source-id array with one source per instance. """ if self.source_ids is None: raise ValueError('source_ids not available') if len(self.source_ids.shape) != 1: raise ValueError(f'expected one source per instance; source_ids has illegal shape {self.source_ids.shape}') return self.source_ids @abstractmethod def __len__(self) -> int: """ Return the number of instances in this dataset. Returns ------- int Number of instances represented by this object. """ raise NotImplementedError def __getitem__(self, indexes: np.ndarray | int) -> Self: """ Get a copy of a subset of instances. All `ndarray` fields are indexed using `indexes`. All other fields are taken as-is. Parameters ---------- indexes : np.ndarray | int Value passed via ``indexes``. Returns ------- Self New instance data object containing only selected rows. """ data = {} for field in self.all_fields: values = getattr(self, field) if isinstance(values, np.ndarray): # If indexes is an int, convert to array. This ensures the result is still an array, even if a single # index is provided. if isinstance(indexes, int): indexes = np.array([indexes]) data[field] = values[indexes] else: data[field] = values return self.replace(**data) def __add__(self, other: 'InstanceData') -> Self: return self.concatenate(other)
[docs] def check_both_labels(self) -> np.ndarray: """ Return labels or raise an error if they are missing or if they do not represent both hypotheses. :raise: ValueError if hypothesis labels are missing or either label is not represented. Returns ------- np.ndarray Label array containing both classes 0 and 1. """ if self.labels is None: raise ValueError('labels not set') if not np.all(np.unique(self.labels) == np.arange(2)): raise ValueError(f'not all classes are represented; labels found: {np.unique(self.labels)}') return self.labels
@classmethod def _concatenate_field(cls, field: str, values: list[Any]) -> Any: if len(values) == 0: raise ValueError('no values to concatenate') if isinstance(values[0], np.ndarray): # we have a numpy array field -> use np.concatenate() return np.concatenate(values) # we have a non-numpy array field -> check if they have the same value all_equal = all(values[0] == other for other in values[1:]) if not all_equal: raise ValueError( f'unable to combine field `{field}` because it is not a numpy array and not all values are equal' ) # return the value, which is the same for all objects return values[0]
[docs] def concatenate(self, *others: 'InstanceData') -> Self: """ Concatenate instances from InstanceData objects. All concatenated objects must have the same types and fields. How fields are concatenated may depend on the subclass. By default, they must have the same values for all non-numpy array fields, or an error is raised. Numpy fields are concatenated using `np.concatenate`. Other fields are copied as-is. Returns a new object with the concatenated instances. Parameters ---------- *others : 'InstanceData' Value passed via ``others``. Returns ------- Self New instance data object with concatenated rows. """ for instances in others: if not self.has_same_type(instances): raise ValueError('instances to concatenate must have the same types and fields') # initialize the dictionary of fields to be updated data: dict[str, np.ndarray | None] = {} for field in self.all_fields: all_values = [getattr(self, field)] all_values.extend([getattr(instances, field) for instances in others]) data[field] = self._concatenate_field(field, all_values) return self.replace(**data)
[docs] def has_same_type(self, other: Any) -> bool: """ Compare these instance data to another class. Returns True iff: - `other` has the same class - `other` has the same fields - all fields have the same type Parameters ---------- other : Any Value passed via ``other``. Returns ------- bool ``True`` when type, fields, and field value types all match. """ if type(self) is not type(other): return False if self.model_extra.keys() != other.model_extra.keys(): # type: ignore return False for field in self.all_fields: if type(getattr(self, field)) is not type(getattr(other, field)): return False return True
[docs] def combine(self, others: 'list[InstanceData] | InstanceData', fn: Callable, *args: Any, **kwargs: Any) -> Self: """ Apply a custom combination function to InstanceData objects. All objects must have the same types and fields, and the same values for all non-numpy array fields, or an error is raised. Numpy fields are concatenated using `fn`. Other fields are copied as-is. Parameters ---------- others : 'list[InstanceData] | InstanceData' Value passed via ``others``. fn : Callable Value passed via ``fn``. *args : Any Additional positional arguments forwarded to the underlying call. **kwargs : Any Additional keyword arguments forwarded to the underlying call. Returns ------- Self New instance data object after applying the combination function. """ if isinstance(others, InstanceData): others = [others] # initialize the dictionary of fields to be updated data: dict[str, np.ndarray | None] = {} for field in self.all_fields: first_value = getattr(self, field) if isinstance(first_value, np.ndarray): # we have a numpy array field -> update required # collect values for all objects involved values = [first_value] for instances in others: if not self.has_same_type(instances): raise ValueError('instances to concatenate must have the same types and fields') values.append(getattr(instances, field)) # apply the function values = fn(values, *args, **kwargs) if field == 'labels' and len(values.shape) != 1: # drop labels if they are in bad shape data[field] = None else: # store the value to be updated later data[field] = values else: # we have a non-numpy array field -> check if they have the same value for instances in others: other_value = getattr(instances, field) if other_value != first_value: raise ValueError( f'unable to combine field `{field}`: value mismatch: {other_value} != {first_value}' ) return self.replace(**data)
[docs] def apply(self, fn: Callable, *args: Any, **kwargs: Any) -> Self: """ Apply a custom function to this InstanceData object. The function `fn` is applied to all Numpy fields. Other fields are copied as-is. Parameters ---------- fn : Callable Value passed via ``fn``. *args : Any Additional positional arguments forwarded to the underlying call. **kwargs : Any Additional keyword arguments forwarded to the underlying call. Returns ------- Self New instance data object after applying the function to numpy fields. """ # initialize the dictionary of fields to be updated data: dict[str, np.ndarray | None] = {} for field in self.all_fields: values = getattr(self, field) if isinstance(values, np.ndarray): # we have a numpy array field -> update required if field == 'labels' and len(values.shape) != 1: # drop labels if they are in bad shape data[field] = None else: # apply the function and store the value to be updated later data[field] = fn(values, *args, **kwargs) return self.replace(**data)
def __eq__(self, other: Any) -> bool: """ Compare these instance data to another class. Returns True iff: - the method `has_same_type()` returns `True` - all numpy fields in `other` have the same shape and the same values - all other fields are compared using the `!=` operator Parameters ---------- other : Any Value passed via ``other``. Returns ------- bool ``True`` when all fields are equal under the class comparison rules. """ if not self.has_same_type(other): return False for field in self.all_fields: value = getattr(self, field) other_value = getattr(other, field) if isinstance(value, np.ndarray) and isinstance(other_value, np.ndarray): if value.shape != other_value.shape or not np.all(value == other_value): return False else: if value != other_value: return False return True @property def all_fields(self) -> list[str]: """ Return all available field names for this data object. Returns ------- list[str] Names of all standard and extra fields available on the instance. """ all_fields = list(type(self).model_fields.keys()) if self.model_extra: all_fields += list(self.model_extra.keys()) return all_fields @property def has_labels(self) -> bool: """ Indicate whether label values are available. Returns ------- bool ``True`` when label information is present. """ return self.labels is not None
[docs] def replace(self, **kwargs: Any) -> Self: """ Return a modified copy with updated values. Parameters ---------- **kwargs : Any Additional keyword arguments forwarded to the underlying call. Returns ------- Self Copy of this object with the provided fields replaced. """ return self.replace_as(type(self), **kwargs)
[docs] def replace_as(self, datatype: type['InstanceDataType'], **kwargs: Any) -> 'InstanceDataType': """ Return a modified copy with updated data type and values. Parameters ---------- datatype : type['InstanceDataType'] Value passed via ``datatype``. **kwargs : Any Additional keyword arguments forwarded to the underlying call. Returns ------- 'InstanceDataType' Instance data object produced by this operation. """ args = self.model_dump() args.update(kwargs) return datatype(**args)
def _validate_features(features: np.ndarray) -> np.ndarray: """ Check if labels have the correct shape. Parameters ---------- features : np.ndarray Value passed via ``features``. Returns ------- np.ndarray Feature array reshaped to at least two dimensions. """ if len(features.shape) < 2: LOG.debug(f'1d features are silently converted to 2d; found shape: {features.shape}') features = np.expand_dims(features, axis=1) return features
[docs] class FeatureData(InstanceData): """ Data class for feature data. Feature data can be any type of numeric data that is associated with the instances, such as measurements on a single instance or similarity scores between a pair of instances. If the object describes single instance data, the `features` attribute is generally 2-dimensional, with one row per instance and one or more feature columns. More than 2 dimensions may be used for paired data, see `PairedFeatureData`. Attributes ---------- - features: an array of instance features, with one row per instance """ features: Annotated[np.ndarray, AfterValidator(_validate_features)] def __len__(self) -> int: return self.features.shape[0]
[docs] @model_validator(mode='after') def check_matching_shapes(self) -> Self: """ Validate the shape of the features and the labels are matching. Returns ------- Self This feature-data object after shape consistency checks. """ if self.labels is not None and self.labels.shape[0] != self.features.shape[0]: raise ValueError( f'dimensions of labels and features do not match; {self.labels.shape[0]} != {self.features.shape[0]}' ) if self.source_ids is not None and self.source_ids.shape[0] != self.features.shape[0]: raise ValueError( f'dimensions of source_ids and features do not match; "' f'{self.source_ids.shape[0]} != {self.features.shape[0]}' ) return self
[docs] @model_validator(mode='after') def check_features(self) -> Self: """ Validate the features. Returns ------- Self This feature-data object after numeric type validation. """ if not np.issubdtype(self.features.dtype, np.number): raise ValueError(f'features should be numeric; found: {self.features.dtype}') return self
[docs] class PairedFeatureData(FeatureData): """ Data class for instance pair data. Each item in this data set represents instances from the "trace" source and from the "reference" source. The number of instances from either source must be at least one. The `features` attribute has at least 3 dimensions: - the pairs are along the first dimension; - the instances are along the second dimension (e.g. in a comparison of 1 trace instance and 1 reference instance, the length of this dimension is 2); - the features are along the third dimension onward. The `source_ids`, if available, must have two values for each item, i.e. 2 columns. Attributes ---------- - n_trace_instances: the number of trace instances in each pair - n_ref_instances: the number of reference instances in each pair - features: the features of all instances in the pair, with pairs along the first dimension, and instances along the second - source_ids: the source ids of the trace and reference instances of each pair, a 2-dimensional array with two columns - features_trace: the features of the trace instances - features_ref: the features of the reference instances - source_ids_trace: the source ids of the trace instances - source_ids_ref: the source ids of the reference instances """ n_trace_instances: int n_ref_instances: int @property def features_trace(self) -> np.ndarray: """ Get the features of the trace instances. Returns ------- np.ndarray Feature tensor slice containing trace-instance features. """ return self.features[:, : self.n_trace_instances] @property def features_ref(self) -> np.ndarray: """ Get the features of the reference instances. Returns ------- np.ndarray Feature tensor slice containing reference-instance features. """ return self.features[:, self.n_trace_instances :] # noqa: E203 @property def source_ids_trace(self) -> np.ndarray | None: """ Get the source ids of the trace instances. Returns ------- np.ndarray | None Trace source IDs when available, otherwise ``None``. """ return self.source_ids[:, 0] if self.source_ids else None @property def source_ids_ref(self) -> np.ndarray | None: """ Get the source ids of the reference instances. Returns ------- np.ndarray | None Reference source IDs when available, otherwise ``None``. """ return self.source_ids[:, 1] if self.source_ids else None
[docs] @model_validator(mode='after') def check_sourceid_shape(self) -> Self: """ Override the `InstanceData` implementation. Returns ------- Self This paired-feature object after source-id shape validation. """ if self.source_ids is not None and (len(self.source_ids.shape) != 2 or self.source_ids.shape[1] != 2): raise ValueError(f'source_ids should be 2-dimensional with 2 columns; found shape {self.source_ids.shape}') return self
[docs] @model_validator(mode='after') def check_features_dimensions(self) -> Self: """ Validate feature dimensions. Returns ------- Self This paired-feature object after feature-dimension validation. """ if len(self.features.shape) < 3: raise ValueError(f'features should have 3 or more dimensions; found shape: {self.features.shape}') if self.features.shape[1] != self.n_trace_instances + self.n_ref_instances: raise ValueError( f'features should have shape (*, {self.n_trace_instances}+{self.n_ref_instances}, *); ' f'found: {self.features.shape[1]}' ) return self
[docs] class LLRData(FeatureData): """ Representation of calculated LLR values. An object of `LLRData` adds a specific interpretation to the `features` attribute. - If the `features` attribute has a single column (i.e. dimensions `(n, 1)`), the values are LLRs. - If the `features` attribute has three columns (i.e. dimensions `(n, 3)`), the values are LLRs and their confidence intervals. The values are also accessible by the attributes `llrs` and `llr_intervals`. Attributes ---------- - llrs: 1-dimensional numpy array of LLR values - has_intervals: indicate whether the LLR's have intervals - llr_intervals: numpy array of LLR values of dimensions (n, 2), or `None` if the LLR's have no intervals - llr_upper_bound: upper bound applied to the LLRs, or `None` if no upper bound was applied - llr_lower_bound: lower bound applied to the LLRs, or `None` if no lower bound was applied """ llr_upper_bound: float | None = None llr_lower_bound: float | None = None @property def llrs(self) -> np.ndarray: """ Return the core LLR values. Returns ------- np.ndarray One-dimensional array containing the central LLR values. """ if len(self.features.shape) == 1: return self.features else: return self.features[:, 0] @property def has_intervals(self) -> bool: """ Indicate whether interval bounds are present for each LLR. Returns ------- bool ``True`` when lower and upper interval bounds are included. """ return len(self.features.shape) == 2 and self.features.shape[1] == 3 @property def llr_intervals(self) -> np.ndarray | None: """ Return interval bounds for each LLR when available. Returns ------- np.ndarray | None Two-column array with lower and upper LLR bounds, if available. """ if self.has_intervals: return self.features[:, 1:] else: return None @property def llr_bounds(self) -> tuple[float | None, float | None]: """ Return global lower and upper bounds applied to LLR values. Returns ------- tuple[float | None, float | None] Tuple containing global lower and upper LLR clipping bounds. """ return self.llr_lower_bound, self.llr_upper_bound
[docs] def feature_for_plot(self, source_key: str) -> np.ndarray | None: """ Return the feature values for a given source key, or None if not available. The return value has to be saved during the LR system execution by using the `save_features_after_step` configuration option. If the feature values for the given source key are not available, this method returns `None`. Use the `require_feature_for_plots` if you want to raise an error instead of returning `None` when the feature values are not available. Parameters ---------- source_key : str Key identifying the source of the feature values to be returned. Returns ------- np.ndarray | None Feature values for the specified source key, or ``None`` if not available. """ return self.model_extra.get(source_key) if self.model_extra is not None else None
[docs] def require_feature_for_plots(self, source_key: str) -> np.ndarray: """ Return the feature values for a given source key, raising an error if not available. If the feature values for the given source key are not available, this method raises a ValueError with an informative error message. Use the `feature_for_plot` method if you want to return `None` instead of raising an error when the feature values for the given source key are not available. Parameters ---------- source_key : str Key identifying the source of the feature values to be returned. Returns ------- np.ndarray Feature values for the specified source key. Raises ------ ValueError If the feature values for the given source key are not available. """ if self.model_extra is None or source_key not in self.model_extra: raise ValueError( f'{source_key} are not available for this instance. ' f'Add the method `save_features` with parameter `save_as: {source_key}` to your pipeline.' f'Currently available sources: {list(self.model_extra.keys()) if self.model_extra else "none"}' ) return self.model_extra[source_key]
[docs] @model_validator(mode='after') def check_features_are_llrs(self) -> Self: """ Validate the feature data. Returns ------- Self This LLR object after validating LLR-specific feature constraints. """ if len(self.features.shape) > 2: raise ValueError(f'features must have 1 or 2 dimensions; shape: {self.features.shape}') if len(self.features.shape) == 2 and self.features.shape[1] != 3 and self.features.shape[1] != 1: raise ValueError( f'features must be 1-dimensional or 2-dimensional with 1 or 3 columns; shape: {self.features.shape}' ) if self.has_intervals and ( np.all(self.features[:, 1] > self.features[:, 0]) or np.all(self.features[:, 2] < self.features[:, 0]) ): raise ValueError('LLRs should not exceed their own intervals') return self
@classmethod def _concatenate_field(cls, field: str, values: list[Any]) -> Any: """ Remove `llr_upper_bound` and `llr_lower_bound` when having different values. The fields `llr_upper_bound` and `llr_lower_bound` may have different values which is not allowed by default. Remove them instead of trying to combine them. Parameters ---------- field : str Value passed via ``field``. values : list[Any] Value passed via ``values``. Returns ------- Any Concatenated field value with LLR-bound handling for mixed values. """ match field: case 'llr_upper_bound' | 'llr_lower_bound': # Check if all values are the same; if so, preserve the value if all(v == values[0] for v in values): return values[0] # Otherwise, return None when values differ return None case _: return super()._concatenate_field(field, values)
[docs] def check_misleading_finite(self) -> None: """Check whether all values are either finite or not misleading.""" values, labels = self.llrs, self.require_labels # give error message if H1's contain zeros and H2's contain ones if np.any(np.isneginf(values[labels == 1])) and np.any(np.isposinf(values[labels == 0])): raise ValueError('invalid input: -inf found for H1 and inf found for H2') # give error message if H1's contain zeros if np.any(np.isneginf(values[labels == 1])): raise ValueError('invalid input: -inf found for H1') # give error message if H2's contain ones if np.any(np.isposinf(values[labels == 0])): raise ValueError('invalid input: inf found for H2')
InstanceDataType = TypeVar('InstanceDataType', bound=InstanceData) FeatureDataType = TypeVar('FeatureDataType', bound=FeatureData)
[docs] def concatenate_instances(first: InstanceDataType, *others: InstanceDataType) -> InstanceDataType: """ Concatenate the results of the InstanceData objects. Alias for `first.concatenate(*others)`. Parameters ---------- first : InstanceDataType Value passed via ``first``. *others : InstanceDataType Value passed via ``others``. Returns ------- InstanceDataType Instance data object produced by this operation. """ return first.concatenate(*others)
[docs] class DataProvider(ABC): """ Base class for data providers. Each data provider should provide access to instance data by implementing the `get_instances()` method. """
[docs] @abstractmethod def get_instances(self) -> InstanceData: """ Return an InstanceData object, containing data for a set of instances. Returns ------- InstanceData Instance data object produced by this operation. """ raise NotImplementedError
[docs] class DataStrategy(ABC): """Base class for data (splitting) strategies."""
[docs] @abstractmethod def apply[DataType: InstanceData](self, instances: DataType) -> Iterable[tuple[DataType, DataType]]: """ Provide iterator to access training and test set. Returns an iterator over tuples of a training set and a test set. Both the training set and the test is represented by an `InstanceData` object. Parameters ---------- instances : DataType Input instances to be processed by this method. Returns ------- Iterable[tuple[DataType, DataType]] Iterable of ``(train_set, test_set)`` splits for the provided data. """ raise NotImplementedError
[docs] def get_instances_by_category[InstanceDataType: InstanceData]( instances: InstanceDataType, category_field: str, category_shape: tuple[int] | None = None ) -> Iterator[tuple[np.ndarray, InstanceDataType]]: """ Return subsets of a set of instances by category. The `instances` object must have a field by the name of `category_field`. That field is a numpy array with one row per instance. Its values are the categories of each instance. The field may have any shape, as long as the number of rows matches the number of instances. If `category_shape` is provided, the shape of the category field is checked against this value. The returned value is an iterator with each item being a tuple of the category and the subset of instances of that category. Parameters ---------- instances : InstanceDataType Input instances to be processed by this method. category_field : str Value passed via ``category_field``. category_shape : tuple[int] | None Value passed via ``category_shape``. """ # extract the category values from the instances if not hasattr(instances, category_field): raise ValueError(f'missing field: {category_field}') category_values = getattr(instances, category_field) # check the category values for sanity check_type(np.ndarray, category_values, 'categories must be a numpy array') if category_values.shape[0] != len(instances): raise ValueError( f'number of categories does not equal number of instances: {category_values.shape[0]} != {len(instances)}' ) # check for shape, if available if category_shape is not None: expected_category_shape = (len(instances),) + category_shape if category_values.shape != expected_category_shape: raise ValueError( f'expected shape of category field {category_field}: {expected_category_shape}; ' f'found: {category_values.shape}' ) # each unique value is a category unique_values = np.unique(category_values, axis=0) # return the subset of instances for each category separately for value in unique_values: current_category_rows = np.all(category_values == value, axis=tuple(range(1, category_values.ndim))) yield value, instances[current_category_rows]