timesead.optim.trainer

Attributes

SUPPORTED_HOOKS

Classes

default_log_fn

Trainer

CheckpointHook

EarlyStoppingHook

Functions

open_file_in_dir(name, out_dir[, mode])

Module Contents

timesead.optim.trainer.SUPPORTED_HOOKS = ['post_validation']
class timesead.optim.trainer.default_log_fn
history
__call__(metric_name: str, metric_value: Any)
Parameters:
  • metric_name (str)

  • metric_value (Any)

class timesead.optim.trainer.Trainer(train_iter: torch.utils.data.DataLoader, val_iter: torch.utils.data.DataLoader, optimizer: Callable = torch.optim.Adam, scheduler: Callable = torch.optim.lr_scheduler.MultiStepLR, device: str | torch.device = 'cpu', checkpoints: bool = False, out_dir: str | None = None, batch_dimension: int = 0)
Parameters:
opt
sched
train_iter
val_iter
device = 'cpu'
batch_dimension = 0
hooks
validate_batch(network: torch.nn.Module, val_metrics: Dict[str, Callable], b_inputs: Tuple[torch.Tensor, Ellipsis], b_targets: Tuple[torch.Tensor, Ellipsis], *args, **kwargs) Dict[str, float]
Parameters:
Return type:

Dict[str, float]

validate_model_once(network: torch.nn.Module, val_metrics: Dict[str, Callable], *args, print_progress: bool = False, **kwargs) Dict[str, Any]
Parameters:
Return type:

Dict[str, Any]

train_batch(network: torch.nn.Module, losses: List[timesead.optim.loss.Loss], optimizers: List[torch.optim.Optimizer], epoch: int, num_epochs: int, b_inputs: Tuple[torch.Tensor, Ellipsis], b_targets: Tuple[torch.Tensor, Ellipsis]) List[float]
Parameters:
Return type:

List[float]

train_epoch(network: torch.nn.Module, losses: List[timesead.optim.loss.Loss], optimizers: List[torch.optim.Optimizer], schedulers: List[torch.optim.lr_scheduler._LRScheduler], epoch: int, num_epochs: int, val_metrics: Dict[str, Callable], log_fn: Callable[[str, Any], None] = default_log_fn) bool
Parameters:
Return type:

bool

train(network: torch.nn.Module, losses: List[timesead.optim.loss.Loss] | timesead.optim.loss.Loss, num_epochs: int, val_metrics: Dict[str, Callable] = None, start_epoch: int = 0, log_fn: Callable[[str, Any], None] = default_log_fn)
Parameters:
add_hook(hook: Callable, type: str = 'post_validation')
Parameters:
  • hook (Callable)

  • type (str)

timesead.optim.trainer.open_file_in_dir(name, out_dir, mode='w+b')
class timesead.optim.trainer.CheckpointHook(out_dir: str | None = None, checkpoint_interval: int = 10, file_write_fn: Callable | None = None, file_read_fn: Callable | None = None)
Parameters:
  • out_dir (Optional[str])

  • checkpoint_interval (int)

  • file_write_fn (Optional[Callable])

  • file_read_fn (Optional[Callable])

checkpoint_interval = 10
file_write_fn = None
file_read_fn = None
save_model(filename: str, network: torch.nn.Module, optimizers: List[torch.optim.Optimizer])
Parameters:
load_model_state(filename: str) Dict
Parameters:

filename (str)

Return type:

Dict

__call__(trainer: Trainer, network: torch.nn.Module, optimizers: List[torch.optim.Optimizer], epoch: int, val_metrics: Dict[str, float]) bool
Parameters:
Return type:

bool

class timesead.optim.trainer.EarlyStoppingHook(metric: str = 'loss_0', invert_metric: bool = True, patience: int = 10, epsilon: float = 0)
Parameters:
metric = 'loss_0'
invert_metric = True
patience = 10
epsilon = 0
best_epoch = 0
decrease_counter = 0
property best_metric: float
Return type:

float

save_best_model(trainer: Trainer, network: torch.nn.Module, optimizers: List[torch.optim.Optimizer])
Parameters:
load_best_model(trainer: Trainer, network: torch.nn.Module, total_epochs: int)
Parameters:
__call__(trainer: Trainer, network: torch.nn.Module, optimizers: List[torch.optim.Optimizer], epoch: int, val_metrics: Dict[str, float]) bool
Parameters:
Return type:

bool

get_best_epoch(_print: bool = False) int
Parameters:

_print (bool)

Return type:

int