"""
File containing the gradient descent algorithm methods adapted for DNPU classes and
custom torch.nn.Module children custom classes that contain DNPU classes or DNPU based
modules from brainspy.processors.modules.
"""
import os
import torch
import numpy as np
from tqdm import trange
from brainspy.utils.pytorch import TorchUtils
from torch.utils.data import DataLoader
[docs]
def train(
model: torch.nn.Module,
dataloaders: list,
criterion,
optimizer: torch.optim.Optimizer,
configs: dict,
logger=None,
save_dir: str = None,
return_best_model: bool = True,
):
"""
Main training loop for off-chip gradient descent training with early stopping using PyTorch.
It is a default training loop used for simple training tasks, but its code can be taken as a
reference on how to implement a training loop for more specific or complext tasks.
Parameters
----------
model : torch.nn.Module
The model to be trained. It should be an instance of a torch.nn.Module. It can be a
Processor, representing a hardware DNPU or a DNPU model, but it also can be a model that
contains different more complex architectures using several processors.
Note that the model can be a custom model (child of torch.nn.Module) containing multiple
DNPU instances, but the model cannot be an instance of SurrogateModel or
HardwareProcessor. If the model is a custom model, it should have the following methods
implemented:
format_targets : The hardware processor uses a waveform to represent points
(see 5.1 in Introduction of the Wiki). Each point is represented with some
slope and some plateau points. When passing through the hardware, there will
be a difference between the output from the device and the input (in points).
This function is used for the targets to have the same length in shape as the
outputs. It simply repeats each point in the input as many times as there are
points in the plateau. In this way, targets can then be compared against hardware
outputs in the loss function. This function should have the following input
(x : torch.Tensor), that represents the rgets of the supervised learning problem,
which will be extended to have the same length shape as the outputs from the processor.
regularizer : When the constraint_control_voltages parameter is set to "regul", the
result from the custom method regularizer will be added to the loss function. It is used
to add a penalisation to the loss function when found control voltages are outside the
control electrode ranges. The developer should decide how this value will be computed.
Each DNPU class contains a regularizer method that returns how much the current control
voltages of the DNPU are outside from the control electrode ranges. In a custom model, the
custom regularizer function can be composed by calling the regularizer function of instantiated
DNPUs. The custom regularizer method of a custom model only needs to be implemented if
constraint_control_voltages = "regul" in the configs. An example can be
found at: brainspy.processors.dnpu, inside the class DNPU.
constraint_weights : When the constraint_control_voltages parameter is set to "clip", the
trainer will call this function to clip the current control voltages, if they are outside from
the control electrode ranges to which they correspond. Each DNPU class contains a clip method
(constraint_control_voltages) that clips current control voltage electrodes in this way.
This method only needs to be implemented in a custom model if
constraint_control_voltages = "clip" in the configs.
dataloaders : list
A list containing one or two Pytorch dataloaders. The first dataloader corresponds to the
training dataset. The second dataloader is optional, and it corresponds to the validation
dataset. If no validation dataset is given, the training loop will train the model and
return the trained model only after reaching to the latest epoch. If a second dataloader is
given, it will be used as a validation dataset. When a validation dataset is present, only
models with solutions that achieve the lowest validation score will be saved. It is
recommended to have an additional test dataset on the side, to check the model against,
after training it with an additional validation datasetz
More information about dataloaders can be found at:
https://pytorch.org/tutorials/beginner/basics/data_tutorial.html
criterion : Object <method>
Loss function criterion that will be used to optimise the model. More information on
several loss functions supported can be found at:
https://pytorch.org/docs/stable/nn.html#loss-functions
optimizer : torch.optim.Optimizer
Optimisation algorithm to be used during the training process. More on Pytorch's optimizer
package can be found at:
https://pytorch.org/docs/stable/optim.html
configs : dict
Dictionary containing the following extra configuration keys:
epochs : int
Number of passes through the entire training dataset.
constraint_control_voltages : str
When training models, typically it is desired for the control voltages to stay
within the ranges in which they where trained, in order to avoid extrapolating, or
reaching the clipping values. This str key can have the following values:
1. 'regul' : It applies a penalty to the loss function when control voltages go
outside the ranges in which they were trained. This method allows a
bit of flexibility, enabling to find solutions that are, in some
cases, slightly outside of the control voltage ranges. In order to be
used, it also requires that the model has a method called
'regularizer' which controls that penalty. An example can be found at:
brainspy.processors.dnpu, inside the class DNPU, method regularizer.
2. 'clip' : It applies clipping after the backward pass and optimiser step. It
enforces that the control voltage ranges will not be outside the
ranges in which the model was trained. In order to use it, the model
should have a method called 'constraint_weights'. An example can be
found at: brainspy.processors.dnpu, inside the class DNPU, method
constraint_weights.
logger: logging (optional)
It provides a way for applications to configure different log handlers.
By default None.
The logger should be an already initialised class that contains a method called
'log_output', where the input is a single numpy array variable. It can be any class,
and the data can be treated in the way the user wants.You can get more information about
loggers at https://pytorch.org/docs/stable/tensorboard.html
Logger directory info:
1. log_train_step: to log each step in the training process
2. log_val_step: to log each step in the validation process
save_dir : Optional[str]
Folder where the trained model is going to be saved.
When None, the model will not be saved.
By default None.
return_best_model : bool, optional
to return the trained model instead of saving
it to a directory, by default True
Returns
-------
model : torch.nn.Module
Trained model with best results according to the criterion fitness function.
training_data: dict
Dictionary returning relevant data produced while training the model.
configs['return_best_model']: boolean
It also adds to the configs dictionary whether the algorithm was returning the best model or not
at configs['return_best_model'].
Notes
-----
A) After the end of the last epoch, the algorithm saves two main files:
model_raw.pt: An exact copy of the model after the end of the training process. It can be loaded directly as
an instance of the model using: my_model_instance_at_best_val_results = torch.load('best_model_raw.pt').
training_data.pickle: A pytorch picle which contains the following keys:
epochs: int
Number of epochs used for training the model
algorithm:
Algorithm type that was being used. Either 'genetic' or 'gradient'.
optimizer_state_dict: OrderedDict
State of the optimizer at the end of last epoch. It can be used to resume model training at that
exact point.
model_state_dict: OrderedDict
It contains the value of the learnable parameters (weights, or in this case, control voltages) at
the point where all the training was finised.
train_losses: list
A list of the loss performance over all epochs
val_losses: list
A list of the loss performance over all epochs
B) If there is a validation dataset present, and return_best_model is set to true. The algorithm will
save, each time that the validation loss is better than the previous, the following files:
best_model_raw.pt: An exact copy of the model when it got the best validation results. It can be
loaded directly as an instance of the model using:
my_model_instance_at_best_val_results = torch.load('best_model_raw.pt').
best_training_data.pickle: A pytorch picle which contains the following keys:
epoch: int
Epoch at which the model with best validation loss was found.
algorithm: str
Algorithm type that was being used. Either 'genetic' or 'gradient'.
optimizer_state_dict: OrderedDict
State of the optimizer at the moment when the best validation loss was achieved. It can be used
to resume model training at that exact point.
model_state_dict: OrderedDict
It contains the value of the learnable parameters (weights, or in this case, control voltages) at
the point where the best validation was achieved.
train_loss: float
Training loss at the point where the best validation was achieved.
validation_loss: float
Best validation loss achieved.
"""
train_checks(model, dataloaders, criterion, optimizer, configs, save_dir, return_best_model)
start_epoch = 0
train_losses, val_losses = [], []
min_val_loss = np.inf
looper = trange(configs["epochs"], desc=" Initialising")
looper.update(start_epoch)
configs['return_best_model'] = return_best_model
model.to(device=TorchUtils.get_device())
for epoch in looper:
model, running_loss = default_train_step(
model,
epoch,
dataloaders[0],
criterion,
optimizer,
logger=logger,
constraint_control_voltages=configs['constraint_control_voltages'])
train_losses.append(running_loss)
description = "Training Loss: {:.6f}.. ".format(train_losses[-1])
if len(dataloaders) > 1 and dataloaders[1] is not None and len(
dataloaders[1]) > 0:
val_loss = default_val_step(
epoch,
model,
dataloaders[1],
criterion,
logger=logger,
)
val_losses.append(val_loss)
description += "Validation Loss: {:.6f}.. ".format(val_losses[-1])
# Save only when peak val performance is reached
if save_dir is not None and (val_losses[-1] < min_val_loss
or epoch == 0):
min_val_loss = val_losses[-1]
description += " Saving model ..."
torch.save(model, os.path.join(save_dir, "best_model_raw.pt"))
torch.save(
{
"epochs": epoch,
"algorithm": 'gradient',
"optimizer_state_dict": optimizer.state_dict(),
"model_state_dict": model.state_dict(),
"train_loss": train_losses[-1],
"val_loss": val_losses[-1],
},
os.path.join(save_dir, "best_training_data.pickle"),
)
looper.set_description(description)
if logger is not None and "log_performance" in dir(logger):
logger.log_performance(train_losses, val_losses, epoch)
# TODO: Add a save instruction and a stopping criteria
# if stopping_criteria(train_losses, val_losses):
# break
if save_dir is not None:
torch.save(
model,
os.path.join(
save_dir, # type: ignore[arg-type]
"model_raw.pt"))
torch.save(
{
"epoch": epoch + 1,
"algorithm": 'gradient',
"optimizer_state_dict": optimizer.state_dict(),
"model_state_dict": model.state_dict(),
"train_losses": train_losses,
"val_losses": val_losses,
"min_val_loss": min_val_loss,
},
os.path.join(
save_dir, # type: ignore[arg-type]
"training_data.pickle"),
)
if logger is not None:
logger.close()
if (save_dir is not None and return_best_model
and (len(dataloaders) == 1 or
(dataloaders[1] is not None and len(dataloaders[1]))) > 0):
if os.path.exists(os.path.join(save_dir, "best_model_raw.pt")):
model = torch.load(os.path.join(save_dir, "best_model_raw.pt"))
return model, {
"performance_history":
[torch.tensor(train_losses),
torch.tensor(val_losses)]
}
[docs]
def train_checks(model, dataloaders, criterion, optimizer, configs, save_dir, return_best_model):
"""
Main training loop for off-chip gradient descent training with early stopping using PyTorch.
It is a default training loop used for simple training tasks, but its code can be taken as a
reference on how to implement a training loop for more specific or complext tasks.
Parameters
----------
model : torch.nn.Module
The model to be trained. It should be an instance of a torch.nn.Module. It can be a
Processor, representing a hardware DNPU or a DNPU model, but it also can be a model that
contains different more complex architectures using several processors.
Note that the model can be a custom model (child of torch.nn.Module) containing multiple
DNPU instances, but the model cannot be an instance of SurrogateModel or
HardwareProcessor. If the model is a custom model, it should have the following methods
implemented:
format_targets : The hardware processor uses a waveform to represent points
(see 5.1 in Introduction of the Wiki). Each point is represented with some
slope and some plateau points. When passing through the hardware, there will
be a difference between the output from the device and the input (in points).
This function is used for the targets to have the same length in shape as the
outputs. It simply repeats each point in the input as many times as there are
points in the plateau. In this way, targets can then be compared against hardware
outputs in the loss function. This function should have the following input
(x : torch.Tensor), that represents the rgets of the supervised learning problem,
which will be extended to have the same length shape as the outputs from the processor.
regularizer : When the constraint_control_voltages parameter is set to "regul", the
result from the custom method regularizer will be added to the loss function. It is used
to add a penalisation to the loss function when found control voltages are outside the
control electrode ranges. The developer should decide how this value will be computed.
Each DNPU class contains a regularizer method that returns how much the current control
voltages of the DNPU are outside from the control electrode ranges. In a custom model, the
custom regularizer function can be composed by calling the regularizer function of instantiated
DNPUs. The custom regularizer method of a custom model only needs to be implemented if
constraint_control_voltages = "regul" in the configs. An example can be
found at: brainspy.processors.dnpu, inside the class DNPU.
constraint_weights : When the constraint_control_voltages parameter is set to "clip", the
trainer will call this function to clip the current control voltages, if they are outside from
the control electrode ranges to which they correspond. Each DNPU class contains a clip method
(constraint_control_voltages) that clips current control voltage electrodes in this way.
This method only needs to be implemented in a custom model if
constraint_control_voltages = "clip" in the configs.
dataloaders : list
A list containing one or two Pytorch dataloaders. The first dataloader corresponds to the
training dataset. The second dataloader is optional, and it corresponds to the validation
dataset. If no validation dataset is given, the training loop will train the model and
return the trained model only after reaching to the latest epoch. If a second dataloader is
given, it will be used as a validation dataset. When a validation dataset is present, only
models with solutions that achieve the lowest validation score will be saved. It is
recommended to have an additional test dataset on the side, to check the model against,
after training it with an additional validation datasetz
More information about dataloaders can be found at:
https://pytorch.org/tutorials/beginner/basics/data_tutorial.html
criterion : Object <method>
Loss function criterion that will be used to optimise the model. More information on
several loss functions supported can be found at:
https://pytorch.org/docs/stable/nn.html#loss-functions
optimizer : torch.optim.Optimizer
Optimisation algorithm to be used during the training process. More on Pytorch's optimizer
package can be found at:
https://pytorch.org/docs/stable/optim.html
configs : dict
Dictionary containing the following extra configuration keys:
epochs : int
Number of passes through the entire training dataset.
constraint_control_voltages : str
When training models, typically it is desired for the control voltages to stay
within the ranges in which they where trained, in order to avoid extrapolating, or
reaching the clipping values. This str key can have the following values:
1. 'regul' : It applies a penalty to the loss function when control voltages go
outside the ranges in which they were trained. This method allows a
bit of flexibility, enabling to find solutions that are, in some
cases, slightly outside of the control voltage ranges. In order to be
used, it also requires that the model has a method called
'regularizer' which controls that penalty. An example can be found at:
brainspy.processors.dnpu, inside the class DNPU, method regularizer.
2. 'clip' : It applies clipping after the backward pass and optimiser step. It
enforces that the control voltage ranges will not be outside the
ranges in which the model was trained. In order to use it, the model
should have a method called 'constraint_weights'. An example can be
found at: brainspy.processors.dnpu, inside the class DNPU, method
constraint_weights.
logger: logging (optional)
It provides a way for applications to configure different log handlers.
By default None.
The logger should be an already initialised class that contains a method called
'log_output', where the input is a single numpy array variable. It can be any class,
and the data can be treated in the way the user wants.You can get more information about
loggers at https://pytorch.org/docs/stable/tensorboard.html
Logger directory info:
1. log_train_step: to log each step in the training process
2. log_val_step: to log each step in the validation process
save_dir : Optional[str]
Folder where the trained model is going to be saved.
When None, the model will not be saved.
By default None.
return_best_model : bool, optional
to return the trained model instead of saving
it to a directory, by default True
Returns
-------
model : torch.nn.Module
Trained model with best results according to the criterion fitness function.
training_data: dict
Dictionary returning relevant data produced while training the model.
configs['return_best_model']: boolean
It also adds to the configs dictionary whether the algorithm was returning the best model or not
at configs['return_best_model'].
Notes
-----
A) After the end of the last epoch, the algorithm saves two main files:
model_raw.pt: An exact copy of the model after the end of the training process. It can be loaded directly as
an instance of the model using: my_model_instance_at_best_val_results = torch.load('best_model_raw.pt').
training_data.pickle: A pytorch picle which contains the following keys:
epochs: int
Number of epochs used for training the model
algorithm:
Algorithm type that was being used. Either 'genetic' or 'gradient'.
optimizer_state_dict: OrderedDict
State of the optimizer at the end of last epoch. It can be used to resume model training at that
exact point.
model_state_dict: OrderedDict
It contains the value of the learnable parameters (weights, or in this case, control voltages) at
the point where all the training was finised.
train_losses: list
A list of the loss performance over all epochs
val_losses: list
A list of the loss performance over all epochs
B) If there is a validation dataset present, and return_best_model is set to true. The algorithm will
save, each time that the validation loss is better than the previous, the following files:
best_model_raw.pt: An exact copy of the model when it got the best validation results. It can be
loaded directly as an instance of the model using:
my_model_instance_at_best_val_results = torch.load('best_model_raw.pt').
best_training_data.pickle: A pytorch picle which contains the following keys:
epoch: int
Epoch at which the model with best validation loss was found.
algorithm: str
Algorithm type that was being used. Either 'genetic' or 'gradient'.
optimizer_state_dict: OrderedDict
State of the optimizer at the moment when the best validation loss was achieved. It can be used
to resume model training at that exact point.
model_state_dict: OrderedDict
It contains the value of the learnable parameters (weights, or in this case, control voltages) at
the point where the best validation was achieved.
train_loss: float
Training loss at the point where the best validation was achieved.
validation_loss: float
Best validation loss achieved.
"""
assert isinstance(
model,
torch.nn.Module), "The model should be an instance of torch.nn.Module"
assert "format_targets" in dir(
model), "The format_targets function should be implemeted in the model"
assert type(
dataloaders) == list, "The dataloaders should be of type - list"
for dataloader in dataloaders:
assert isinstance(
dataloader, DataLoader
), "The dataloader should be an instance of torch.utils.data.DataLoader"
assert callable(criterion), "The criterion should be a callable method"
assert isinstance(
optimizer, torch.optim.Optimizer
), "The optimizer object should be an instance of torch.optim.Optimizer"
assert type(configs) == dict, "The extra configs should be of type - dict"
if configs["epochs"]:
assert type(
configs["epochs"]) == int, "The epochs key should be of type - int"
assert type(
configs["constraint_control_voltages"]
) == str, "The constraint_control_voltages key should be of type str"
assert configs["constraint_control_voltages"] == "clip" or configs[
"constraint_control_voltages"] == "regul", "The constraint_control_voltages should be either clip or regul"
if configs["constraint_control_voltages"] == "regul":
assert "regularizer" in dir(
model
), "The model should implement the regularizer function for this option"
else:
assert "constraint_weights" in dir(
model
), "The model should implement the constraint_weights function for this option"
assert save_dir is None or type(
save_dir
) == str, "The name/path of the save_dir should be of type - str"
assert type(return_best_model) == bool, "Return best model should be boolean."
[docs]
def default_train_step(model,
epoch,
dataloader,
criterion,
optimizer,
logger=None,
constraint_control_voltages=None):
"""
Deafult training step for training a torch model in Gradiet descent. The method calulates the
training loss in each training step. The training loss indicates how well the model is fitting
the training data.
More information about training loss can be found at
https://www.baeldung.com/cs/learning-curve-ml
The method returns the trained model and the running loss, which is used to calculate the
training loss, in that step.
Parameters
----------
model : torch.nn.Module
The model to be trained. It should be an instance of a torch.nn.Module. It can be a
Processor, representing a hardware DNPU or a DNPU model, but it also can be a model that
contains different more complex architectures using several processors. Refer to the
documentation of the train function above for more inforamtion about defining a model.
epoch : int
Number of passes through the entire training dataset.
dataloader : torch.utils.data.Dataloader
A Pytorch dataloaders that corresponds to the training dataset.
More information about dataloaders can be found at:
https://pytorch.org/tutorials/beginner/basics/data_tutorial.html
criterion : Object <method>
Loss function criterion that will be used to optimise the model. More information on
several loss functions supported can be found at:
https://pytorch.org/docs/stable/nn.html#loss-functions
optimizer : torch.optim.Optimizer
Optimisation algorithm to be used during the training process. More on Pytorch's optimizer
package can be found at:
https://pytorch.org/docs/stable/optim.html
logger: logging (optional)
It provides a way for applications to configure different log handlers.
by default None.
The logger should be an already initialised class that contains a method called
'log_output', where the input is a single numpy array variable. It can be any class,
and the data can be treated in the way the user wants.You can get more information about
loggers at https://pytorch.org/docs/stable/tensorboard.html
Logger directory info :
log_train_step: to log each step in the training process
constraint_control_voltages : str
When training models, typically it is desired for the control voltages to stay
within the ranges in which they where trained, in order to avoid extrapolating, or
reaching the clipping values. This str key can have the following values:
1. 'regul' : It applies a penalty to the loss function when control voltages go
outside the ranges in which they were trained. This method allows a
bit of flexibility, enabling to find solutions that are, in some
cases, slightly outside of the control voltage ranges. In order to be
used, it also requires that the model has a method called
'regularizer' which controls that penalty. An example can be found at:
brainspy.processors.dnpu, inside the class DNPU, method regularizer.
2. 'clip' : It applies clipping after the backward pass and optimiser step. It
enforces that the control voltage ranges will not be outside the
ranges in which the model was trained. In order to use it, the model
should have a method called 'constraint_weights'. An example can be
found at: brainspy.processors.dnpu, inside the class DNPU, method
constraint_weights.
Returns
-------
model : torch.nn.Module
Trained model with best results according to the criterion fitness function.
running loss : int
To assess the training loss: how far the predictions of the model are from the actual
targets.
"""
assert isinstance(
model,
torch.nn.Module), "The model should be an instance of torch.nn.Module"
assert "format_targets" in dir(
model), "The format_targets function should be implemeted in the model"
assert type(epoch) == int, "The epoch param should be of type - int"
assert isinstance(
dataloader, DataLoader
), "The dataloader should be an instance of torch.utils.data.DataLoader"
assert callable(criterion), "The criterion should be a callable method"
assert isinstance(
optimizer, torch.optim.Optimizer
), "The optimizer should be an instance of torch.optim.Optimizer"
assert (
constraint_control_voltages is None
or constraint_control_voltages == "clip"
or constraint_control_voltages == "regul"
), "The constraint_control_voltages should be None or 'clip' or 'regul'"
if constraint_control_voltages == "regul":
assert "regularizer" in dir(
model
), "The model should implement the regularizer function for this option"
else:
assert "constraint_weights" in dir(
model
), "The model should implement the constraint_weights function for this option"
assert (
constraint_control_voltages is None
or constraint_control_voltages == 'clip'
or constraint_control_voltages == 'regul'
), "Variable constraint_control_voltages should be 'regul', 'clip' or None."
running_loss = 0
model.train()
for inputs, targets in dataloader:
inputs, targets = TorchUtils.format(inputs), model.format_targets(
TorchUtils.format(targets))
optimizer.zero_grad()
predictions = model(inputs)
if constraint_control_voltages is None or constraint_control_voltages == 'clip':
loss = criterion(predictions, targets)
elif constraint_control_voltages == 'regul':
loss = criterion(predictions, targets) + model.regularizer()
loss.backward()
optimizer.step()
if constraint_control_voltages is not None and constraint_control_voltages == 'clip':
# with torch.no_grad():
model.constraint_weights()
running_loss += loss.item() * inputs.shape[0]
if logger is not None and "log_train_step" in dir(logger):
logger.log_train_step(epoch, inputs, targets, predictions, model,
loss, running_loss)
running_loss /= len(dataloader.dataset)
return model, running_loss
[docs]
def default_val_step(epoch, model, dataloader, criterion, logger=None):
"""
To calulate the validation loss in each training step of the Gradient descent.
Validation loss indicates how well the model fits unseen data.
More information about validation loss and training loss can be
found at https://www.baeldung.com/cs/learning-curve-ml
Parameters
----------
epoch : int
Number of passes through the entire training dataset.
model : torch.nn.Module
The model to be trained. It should be an instance of a torch.nn.Module. It can be a
Processor, representing a hardware DNPU or a DNPU model, but it also can be a model that
contains different more complex architectures using several processors.Refer to the
documentation of the train function above for more inforamtion about defining a model.
dataloader : torch.utils.data.Dataloader
A Pytorch dataloaders that corresponds to the validation dataset.
More information about dataloaders can be found at:
https://pytorch.org/tutorials/beginner/basics/data_tutorial.html
criterion : Object <method>
Loss function criterion that will be used to optimise the model. More information on
several loss functions supported can be found at:
https://pytorch.org/docs/stable/nn.html#loss-functions
logger: logging (optional)
It provides a way for applications to configure different log handlers.
by default None.
The logger should be an already initialised class that contains a method called
'log_output', where the input is a single numpy array variable. It can be any class,
and the data can be treated in the way the user wants.You can get more information about
loggers at https://pytorch.org/docs/stable/tensorboard.html
Logger directory info :
log_val_step: to log each step in the validation process
Returns
-------
val_loss : int
To assess how well the model fits new data.
It is the sum of errors made for each example in training or validation sets.
"""
assert isinstance(
model,
torch.nn.Module), "The model should be an instance of torch.nn.Module"
assert "format_targets" in dir(
model), "The format_targets function should be implemeted in the model"
assert type(epoch) == int, "The epoch param should be of type - int"
assert isinstance(
dataloader, DataLoader
), "The dataloader should be an instance of torch.utils.data.DataLoader"
assert callable(criterion), "The criterion should be a callable method"
with torch.no_grad():
val_loss = 0
model.eval()
for inputs, targets in dataloader:
inputs, targets = TorchUtils.format(inputs), model.format_targets(
TorchUtils.format(targets))
predictions = model(inputs)
loss = criterion(predictions, targets).item()
val_loss += loss * inputs.shape[0]
if logger is not None and "log_val_step" in dir(logger):
logger.log_val_step(epoch, inputs, targets, predictions, model,
loss, val_loss)
val_loss /= len(dataloader.dataset)
return val_loss