timesead.optim.trainer
Attributes
Classes
Functions
|
Module Contents
- timesead.optim.trainer.SUPPORTED_HOOKS = ['post_validation']
- class timesead.optim.trainer.Trainer(train_iter: torch.utils.data.DataLoader, val_iter: torch.utils.data.DataLoader, optimizer: Callable = torch.optim.Adam, scheduler: Callable = torch.optim.lr_scheduler.MultiStepLR, device: str | torch.device = 'cpu', checkpoints: bool = False, out_dir: str | None = None, batch_dimension: int = 0)
- Parameters:
train_iter (torch.utils.data.DataLoader)
val_iter (torch.utils.data.DataLoader)
optimizer (Callable)
scheduler (Callable)
device (Union[str, torch.device])
checkpoints (bool)
out_dir (Optional[str])
batch_dimension (int)
- opt
- sched
- train_iter
- val_iter
- device = 'cpu'
- batch_dimension = 0
- hooks
- validate_batch(network: torch.nn.Module, val_metrics: Dict[str, Callable], b_inputs: Tuple[torch.Tensor, Ellipsis], b_targets: Tuple[torch.Tensor, Ellipsis], *args, **kwargs) Dict[str, float]
- Parameters:
network (torch.nn.Module)
val_metrics (Dict[str, Callable])
b_inputs (Tuple[torch.Tensor, Ellipsis])
b_targets (Tuple[torch.Tensor, Ellipsis])
- Return type:
- validate_model_once(network: torch.nn.Module, val_metrics: Dict[str, Callable], *args, print_progress: bool = False, **kwargs) Dict[str, Any]
- Parameters:
network (torch.nn.Module)
val_metrics (Dict[str, Callable])
print_progress (bool)
- Return type:
Dict[str, Any]
- train_batch(network: torch.nn.Module, losses: List[timesead.optim.loss.Loss], optimizers: List[torch.optim.Optimizer], epoch: int, num_epochs: int, b_inputs: Tuple[torch.Tensor, Ellipsis], b_targets: Tuple[torch.Tensor, Ellipsis]) List[float]
- Parameters:
network (torch.nn.Module)
losses (List[timesead.optim.loss.Loss])
optimizers (List[torch.optim.Optimizer])
epoch (int)
num_epochs (int)
b_inputs (Tuple[torch.Tensor, Ellipsis])
b_targets (Tuple[torch.Tensor, Ellipsis])
- Return type:
List[float]
- train_epoch(network: torch.nn.Module, losses: List[timesead.optim.loss.Loss], optimizers: List[torch.optim.Optimizer], schedulers: List[torch.optim.lr_scheduler._LRScheduler], epoch: int, num_epochs: int, val_metrics: Dict[str, Callable], log_fn: Callable[[str, Any], None] = default_log_fn) bool
- Parameters:
network (torch.nn.Module)
losses (List[timesead.optim.loss.Loss])
optimizers (List[torch.optim.Optimizer])
schedulers (List[torch.optim.lr_scheduler._LRScheduler])
epoch (int)
num_epochs (int)
val_metrics (Dict[str, Callable])
log_fn (Callable[[str, Any], None])
- Return type:
- train(network: torch.nn.Module, losses: List[timesead.optim.loss.Loss] | timesead.optim.loss.Loss, num_epochs: int, val_metrics: Dict[str, Callable] = None, start_epoch: int = 0, log_fn: Callable[[str, Any], None] = default_log_fn)
- Parameters:
network (torch.nn.Module)
losses (Union[List[timesead.optim.loss.Loss], timesead.optim.loss.Loss])
num_epochs (int)
val_metrics (Dict[str, Callable])
start_epoch (int)
log_fn (Callable[[str, Any], None])
- timesead.optim.trainer.open_file_in_dir(name, out_dir, mode='w+b')
- class timesead.optim.trainer.CheckpointHook(out_dir: str | None = None, checkpoint_interval: int = 10, file_write_fn: Callable | None = None, file_read_fn: Callable | None = None)
- Parameters:
- checkpoint_interval = 10
- file_write_fn = None
- file_read_fn = None
- save_model(filename: str, network: torch.nn.Module, optimizers: List[torch.optim.Optimizer])
- Parameters:
filename (str)
network (torch.nn.Module)
optimizers (List[torch.optim.Optimizer])
- __call__(trainer: Trainer, network: torch.nn.Module, optimizers: List[torch.optim.Optimizer], epoch: int, val_metrics: Dict[str, float]) bool
- Parameters:
trainer (Trainer)
network (torch.nn.Module)
optimizers (List[torch.optim.Optimizer])
epoch (int)
- Return type:
- class timesead.optim.trainer.EarlyStoppingHook(metric: str = 'loss_0', invert_metric: bool = True, patience: int = 10, epsilon: float = 0)
-
- metric = 'loss_0'
- invert_metric = True
- patience = 10
- epsilon = 0
- best_epoch = 0
- decrease_counter = 0
- save_best_model(trainer: Trainer, network: torch.nn.Module, optimizers: List[torch.optim.Optimizer])
- Parameters:
trainer (Trainer)
network (torch.nn.Module)
optimizers (List[torch.optim.Optimizer])
- load_best_model(trainer: Trainer, network: torch.nn.Module, total_epochs: int)
- Parameters:
trainer (Trainer)
network (torch.nn.Module)
total_epochs (int)
- __call__(trainer: Trainer, network: torch.nn.Module, optimizers: List[torch.optim.Optimizer], epoch: int, val_metrics: Dict[str, float]) bool
- Parameters:
trainer (Trainer)
network (torch.nn.Module)
optimizers (List[torch.optim.Optimizer])
epoch (int)
- Return type: