"""
File for applying a batch normalisation layer after a DNPU class.
"""
import torch.nn as nn
from brainspy.processors.dnpu import DNPU
from brainspy.processors.processor import Processor
[docs]
class DNPUBatchNorm(DNPU):
"""
A child of brainspy.processors.dnpu.DNPU class that adds a batch normalisation layer after the
output. for adding a batch normalisation layer after the output of a DNPU.
More information about batch normalisation can be found in:
https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html
"""
def __init__(self,
processor: Processor,
data_input_indices: list,
forward_pass_type: str = 'vec',
affine=False,
track_running_stats=True,
momentum=0.1,
eps=1e-5,
custom_bn=nn.BatchNorm1d):
"""
Initialises the super class and the batch normalisation module, according to the batch norm
parameters given (affine, track_running_stats, momentum, eps, custom_bn).
More information about batch normalisation can be found in:
https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html
Parameters
----------
processor : brainspy.processors.processor.Processor
An instance of a Processor, which can hold a DNPU model or a driver connection to the
DNPU hardware.
data_input_indices: list
Specifies which electrodes are going to be used for inputing data. The reminder of the
activation electrodes will be automatically selected as control electrodes. The list
should have the following shape (dnpu_node_no,data_input_electrode_no). The minimum
dnpu_node_no should be 1, e.g., data_input_indices = [[1,2]]. When specifying more than
one dnpu node in the list, the module will simulate, in time-multiplexing,
as if there was a layer of DNPU devices. Fore example, for an 8 electrode DNPU device
with a single readout electrode and 7 activation electrodes, when
data_input_indices = [[1,2],[1,3],[3,4]], it will be considered that there are 3 DNPU
devices, where the first DNPU device will use the data input electrodes
1 and 2, the second DNPU device will use data input electrodes 1 and 3 and the third
DNPU device will use data input electrodes 3 and 4. Also, the first DNPU device will
have electrodes 0, 3, 4, 5, and 6 defined as control electrodes. The second DNPU device
will have electrodes 0,2,4,5, and 6 defined as control electrodes. The third DNPU device
will have electrodes 0,1,2,5, and 6 defined as control electrodes. More information
about what activation, readout, data input and control electrodes are can be found at
the wiki: https://github.com/BraiNEdarwin/brains-py/wiki/A.-Introduction
forward_pass_type : str
It indicates if the forward pass for more than one DNPU devices on time-multiplexing
will be executed using vectorisation or a for loop. The available options are 'vec' or
'for'. By default it uses the vectorised version.
affine : A boolean value that when set to True, this module has learnable affine parameters.
By default is set to False, in order to save using extra parameters.
track_running_stats : bool
A boolean value that when set to True, this module tracks the running mean and variance,
and when set to False, this module does not track such statistics, and initializes
statistics buffers running_mean and running_var as None. When these buffers are None,
this module always uses batch statistics in both training and eval modes. Default: True
momentum : float
The value used for the running_mean and running_var computation. Can be set to None for
cumulative moving average (i.e. simple average). Default: 0.1
eps : float
A value added to the denominator for numerical stability. Default: 1e-5
custom_bn : torch.nn.Module
A batch normalisation module that is an instance of a torch.nn.Module. By default
torch.nn.BatchNorm1d
"""
super(DNPUBatchNorm,
self).__init__(processor,
data_input_indices,
forward_pass_type=forward_pass_type)
self.bn = custom_bn(self.get_node_no(),
affine=affine,
track_running_stats=track_running_stats,
momentum=momentum,
eps=eps)
[docs]
def forward(self, x):
"""
Run a forward pass through the processor, including any time-multiplexing modules that
are declared to be measured in the same layer. After getting the output from the processor
the output is passed through the batch normalisation layer.
Parameters
----------
x : torch.Tensor
Input data.
Returns
-------
torch.Tensor
Output data.
"""
self.dnpu_output = self.forward_pass(x)
self.batch_norm_output = self.bn(self.dnpu_output)
return self.batch_norm_output
[docs]
def get_logged_variables(self):
"""
Get the otuput results from each layer from the last forward pass.
Returns
-------
dict
Dictionary containing the output from the last forward pass as a dictionary.
1. c_dnpu_output: Output of the dnpu / dnpu layer
2. d_batch_norm_output: Output of the batch norm layer
"""
return {
"c_dnpu_output": self.dnpu_output.clone().detach(),
"d_batch_norm_output": self.batch_norm_output.clone().detach(),
}