timesead.optim.trainer ====================== .. py:module:: timesead.optim.trainer Attributes ---------- .. autoapisummary:: timesead.optim.trainer.SUPPORTED_HOOKS Classes ------- .. autoapisummary:: timesead.optim.trainer.default_log_fn timesead.optim.trainer.Trainer timesead.optim.trainer.CheckpointHook timesead.optim.trainer.EarlyStoppingHook Functions --------- .. autoapisummary:: timesead.optim.trainer.open_file_in_dir Module Contents --------------- .. py:data:: SUPPORTED_HOOKS :value: ['post_validation'] .. py:class:: default_log_fn .. py:attribute:: history .. py:method:: __call__(metric_name: str, metric_value: Any) .. py:class:: Trainer(train_iter: torch.utils.data.DataLoader, val_iter: torch.utils.data.DataLoader, optimizer: Callable = torch.optim.Adam, scheduler: Callable = torch.optim.lr_scheduler.MultiStepLR, device: Union[str, torch.device] = 'cpu', checkpoints: bool = False, out_dir: Optional[str] = None, batch_dimension: int = 0) .. py:attribute:: opt .. py:attribute:: sched .. py:attribute:: train_iter .. py:attribute:: val_iter .. py:attribute:: device :value: 'cpu' .. py:attribute:: batch_dimension :value: 0 .. py:attribute:: hooks .. py:method:: validate_batch(network: torch.nn.Module, val_metrics: Dict[str, Callable], b_inputs: Tuple[torch.Tensor, Ellipsis], b_targets: Tuple[torch.Tensor, Ellipsis], *args, **kwargs) -> Dict[str, float] .. py:method:: validate_model_once(network: torch.nn.Module, val_metrics: Dict[str, Callable], *args, print_progress: bool = False, **kwargs) -> Dict[str, Any] .. py:method:: train_batch(network: torch.nn.Module, losses: List[timesead.optim.loss.Loss], optimizers: List[torch.optim.Optimizer], epoch: int, num_epochs: int, b_inputs: Tuple[torch.Tensor, Ellipsis], b_targets: Tuple[torch.Tensor, Ellipsis]) -> List[float] .. py:method:: train_epoch(network: torch.nn.Module, losses: List[timesead.optim.loss.Loss], optimizers: List[torch.optim.Optimizer], schedulers: List[torch.optim.lr_scheduler._LRScheduler], epoch: int, num_epochs: int, val_metrics: Dict[str, Callable], log_fn: Callable[[str, Any], None] = default_log_fn) -> bool .. py:method:: train(network: torch.nn.Module, losses: Union[List[timesead.optim.loss.Loss], timesead.optim.loss.Loss], num_epochs: int, val_metrics: Dict[str, Callable] = None, start_epoch: int = 0, log_fn: Callable[[str, Any], None] = default_log_fn) .. py:method:: add_hook(hook: Callable, type: str = 'post_validation') .. py:function:: open_file_in_dir(name, out_dir, mode='w+b') .. py:class:: CheckpointHook(out_dir: Optional[str] = None, checkpoint_interval: int = 10, file_write_fn: Optional[Callable] = None, file_read_fn: Optional[Callable] = None) .. py:attribute:: checkpoint_interval :value: 10 .. py:attribute:: file_write_fn :value: None .. py:attribute:: file_read_fn :value: None .. py:method:: save_model(filename: str, network: torch.nn.Module, optimizers: List[torch.optim.Optimizer]) .. py:method:: load_model_state(filename: str) -> Dict .. py:method:: __call__(trainer: Trainer, network: torch.nn.Module, optimizers: List[torch.optim.Optimizer], epoch: int, val_metrics: Dict[str, float]) -> bool .. py:class:: EarlyStoppingHook(metric: str = 'loss_0', invert_metric: bool = True, patience: int = 10, epsilon: float = 0) .. py:attribute:: metric :value: 'loss_0' .. py:attribute:: invert_metric :value: True .. py:attribute:: patience :value: 10 .. py:attribute:: epsilon :value: 0 .. py:attribute:: best_epoch :value: 0 .. py:attribute:: decrease_counter :value: 0 .. py:property:: best_metric :type: float .. py:method:: save_best_model(trainer: Trainer, network: torch.nn.Module, optimizers: List[torch.optim.Optimizer]) .. py:method:: load_best_model(trainer: Trainer, network: torch.nn.Module, total_epochs: int) .. py:method:: __call__(trainer: Trainer, network: torch.nn.Module, optimizers: List[torch.optim.Optimizer], epoch: int, val_metrics: Dict[str, float]) -> bool .. py:method:: get_best_epoch(_print: bool = False) -> int