"""Definitions for monitoring training metrics and determining whether metrics
achieve some condition (most commonly useful for determining when to stop
training).
"""
from typing import Protocol
from .data import TrainStatKey, TrainStats
[docs]class Condition(Protocol):
"""Condition callable that returns ``True`` if a condition is met.
This is the interface used for early-stopping training.
"""
def __call__(self, train_stats: TrainStats, /) -> bool:
"""Method to implement that should return ``True`` for forcing training
within iterations to stop.
"""
[docs]class And:
"""Convenience for joining results from multiple conditions with an ``AND``.
Args:
conditions: Conditions to join results for with an ``AND``.
"""
#: Conditions to join results for with an ``AND``.
conditions: list[Condition]
def __init__(self, conditions: list[Condition], /) -> None:
self.conditions = conditions
def __call__(self, train_stats: TrainStats, /) -> bool:
return all([condition.__call__(train_stats) for condition in self.conditions])
[docs]class HitsLowerBound:
"""Condition that returns ``True`` if the value being monitored hits a
lower bound value.
Args:
key: Key of train stat to monitor.
lower_bound: Minimum threshold for the value of ``key`` to reach before
this condition returns ``True`` when called.
"""
#: Key of train stat to inspect when called.
key: TrainStatKey
#: Minimum threshold for the value of ``key`` to reach before
#: this condition returns ``True`` when called.
lower_bound: float
def __init__(self, key: TrainStatKey, lower_bound: float, /) -> None:
self.key = key
self.lower_bound = lower_bound
def __call__(self, train_stats: TrainStats, /) -> bool:
return train_stats[self.key] <= self.lower_bound
[docs]class HitsUpperBound:
"""Condition that returns ``True`` if the value being monitored hits an
upper bound value.
Args:
key: Key of train stat to monitor.
upper_bound: Maximum threshold for the value of ``key`` to reach before
this condition returns ``True`` when called.
"""
#: Key of train stat to inspect when called.
key: TrainStatKey
#: Maximum threshold for the value of ``key`` to reach before
#: this condition returns ``True`` when called.
upper_bound: float
def __init__(self, key: TrainStatKey, upper_bound: float, /) -> None:
self.key = key
self.upper_bound = upper_bound
def __call__(self, train_stats: TrainStats, /) -> bool:
return train_stats[self.key] >= self.upper_bound
[docs]class Plateaus:
"""Condition that returns ``True`` if the value being monitored plateaus
for ``patience`` number of times.
Args:
key: Key of train stat to monitor.
patience: Threshold for :attr:`Plateaus.losses` to reach for the condition
to return ``True``.
rtol: Relative tolerance when comparing values of :attr:`Plateaus.key`
between calls to determine if the call contributes to
:attr:`Plateaus.losses`.
"""
#: Key of train stat to inspect when called.
key: TrainStatKey
#: Number of times the value of :attr:`Plateaus.key` has been within
#: :attr:`Plateaus.rtol` in a row. If this reaches
#: :attr:`Plateaus.patience`, then the condition is met and
#: this condition returns ``True``.
losses: int
#: Last value of :attr:`Plateaus.key`.
old_value: float
#: Threshold for :attr:`Plateaus.losses` to reach for the condition
#: to return ``True``.
patience: int
#: Relative tolerance when comparing values of :attr:`Plateaus.key`
#: between calls to determine if the call contributes to
#: :attr:`Plateaus.losses`.
rtol: float
def __init__(
self, key: TrainStatKey, /, *, patience: int = 5, rtol: float = 1e-3
) -> None:
self.key = key
self.patience = patience
self.rtol = rtol
self.losses = 0
self.old_value = 0
def __call__(self, train_stats: TrainStats, /) -> bool:
new_value = train_stats[self.key]
if abs(new_value - self.old_value) <= self.rtol * abs(self.old_value):
self.losses += 1
else:
self.losses = 0
self.old_value = new_value
return self.losses >= self.patience
[docs]class StopsDecreasing:
"""Condition that returns ``True`` if the value being monitored keeps the
same minimum for ``patience`` number of times.
Args:
key: Key of train stat to monitor.
patience: Threshold for :attr:`StopsDecreasing.losses` to reach for the condition
to return ``True``.
"""
#: Key of train stat to inspect when called.
key: TrainStatKey
#: Number of times the value of :attr:`StopsDecreasing.key` has not passed
#: :attr:`StopsDecreasing.min_`. If this reaches :attr:`StopsDecreasing.patience`,
#: then the condition is met and this condition returns ``True``.
losses: int
#: Last value of :attr:`StopsDecreasing.key`.
min_: float
#: Threshold for :attr:`StopsDecreasing.losses` to reach for the condition
#: to return ``True``.
patience: int
def __init__(self, key: TrainStatKey, /, *, patience: int = 5) -> None:
self.key = key
self.patience = patience
self.losses = 0
self.min_ = float("inf")
def __call__(self, train_stats: TrainStats, /) -> bool:
new_value = train_stats[self.key]
if new_value >= self.min_:
self.losses += 1
else:
self.losses = 0
self.min_ = new_value
return self.losses >= self.patience
[docs]class StopsIncreasing:
"""Condition that returns ``True`` if the value being monitored keeps the
same maximum for ``patience`` number of times.
Args:
key: Key of train stat to monitor.
patience: Threshold for :attr:`StopsIncreasing.losses` to reach for the condition
to return ``True``.
"""
#: Key of train stat to inspect when called.
key: TrainStatKey
#: Number of times the value of :attr:`StopsIncreasing.key` has not passed
#: :attr:`StopsIncreasing.max_`. If this reaches :attr:`StopsIncreasing.patience`,
#: then the condition is met and this condition returns ``True``.
losses: int
#: Last value of :attr:`StopsIncreasing.key`.
max_: float
#: Threshold for :attr:`StopsIncreasing.losses` to reach for the condition
#: to return ``True``.
patience: int
def __init__(self, key: TrainStatKey, /, *, patience: int = 5) -> None:
self.key = key
self.patience = patience
self.losses = 0
self.max_ = float("-inf")
def __call__(self, train_stats: TrainStats, /) -> bool:
new_value = train_stats[self.key]
if new_value <= self.max_:
self.losses += 1
else:
self.losses = 0
self.max_ = new_value
return self.losses >= self.patience