"""
It is a very similar class to that of Processor, but it contains the con-
trol voltages, declared as learnable parameters. Therefore, it only has the same
input dimensions as the number of available data input electrodes. It is also a
child of torch.nn.Module, and it allows for representing a layer of DNPUs in a time-
multiplexing fashion (with the same Processor instance). It also enables applying
linear transformations to the inputs before passing them to the processor.
"""
from random import random
import torch
import collections
import numpy as np
from torch import nn, Tensor
from typing import Sequence, Union
from brainspy.processors.processor import Processor
from brainspy.utils.performance import data
from brainspy.utils.transforms import get_linear_transform_constants
import warnings
[docs]
class DNPU(nn.Module):
"""
This class enables to declare the control voltages as learnable parameters in order to train a
surrogate model to find optimal control voltage values for a particular problem. It is a child
of an nn.Module class that contains a processor. Therefore, the solutions found for control
voltages using a surrogate model, can be seamlessly applied to hardware DNPUs. The class also
enables to declare more than one internal DNPU node, that will act as a single layer of the
specified node number, in time-multiplexing of the same model.
"""
def __init__(self,
processor: Processor,
data_input_indices: list,
forward_pass_type: str = 'vec'):
"""
Initialises the super class and makes a separation between those electrodes that are data
input and those that are control.
Parameters
----------
processor : Processor
An instance of a Processor class representing either a simulation processor or a
hardware processor.
data_input_indices: list
Specifies which electrodes are going to be used for inputing data. The reminder of the
activation electrodes will be automatically selected as control electrodes. The list
should have the following shape (dnpu_node_no,data_input_electrode_no). The minimum
dnpu_node_no should be 1, e.g., data_input_indices = [[1,2]]. When specifying more than
one dnpu node in the list, the module will simulate, in time-multiplexing,
as if there was a layer of DNPU devices. Fore example, for an 8 electrode DNPU device
with a single readout electrode and 7 activation electrodes, when
data_input_indices = [[1,2],[1,3],[3,4]], it will be considered that there are 3 DNPU
devices, where the first DNPU device will use the data input electrodes
1 and 2, the second DNPU device will use data input electrodes 1 and 3 and the third
DNPU device will use data input electrodes 3 and 4. Also, the first DNPU device will
have electrodes 0, 3, 4, 5, and 6 defined as control electrodes. The second DNPU device
will have electrodes 0,2,4,5, and 6 defined as control electrodes. The third DNPU device
will have electrodes 0,1,2,5, and 6 defined as control electrodes. More information
about what activation, readout, data input and control electrodes are can be found at
the wiki: https://github.com/BraiNEdarwin/brains-py/wiki/A.-Introduction
forward_pass_type : str
It indicates if the forward pass for more than one DNPU devices on time-multiplexing
will be executed using vectorisation or a for loop. The available options are 'vec' or
'for'. By default it uses the vectorised version.
"""
super(DNPU, self).__init__()
assert isinstance(
processor, Processor
), "Processor should be an instance of brainspy.processors.processor.Processor"
self.processor = processor
self.init_electrode_info(data_input_indices)
self._init_learnable_parameters()
self.set_forward_pass(forward_pass_type)
self.input_transform = False
self.input_clip = False
[docs]
def set_forward_pass(self, forward_pass_type: str):
"""
Sets the type of forward pass that is going to be used.
Parameters
----------
forward_pass_type : str
It indicates if the forward pass for more than one DNPU devices on time-multiplexing
will be executed using vectorisation or a for loop. The available options are 'vec' or
'for'. By default it uses the vectorised version.
"""
assert type(forward_pass_type) is str and (
forward_pass_type == 'vec' or forward_pass_type == 'for'
), "Forward pass type not recognised. It should be either 'for' or 'vec'"
self.forward_pass_type = forward_pass_type
if forward_pass_type == 'vec':
self.forward_pass = self.forward_vec
self.clip_input = self.clip_input_vec
else:
self.forward_pass = self.forward_for
self.clip_input = self.clip_input_for
[docs]
def init_activation_electrode_no(self):
"""
It counts how many control and data input electrodes are
going to be declared using the data_input_indices variable.
Parameters
----------
data_input_indices: Sequence[int]
Specifies which electrodes are going to be used for inputing data. The reminder of the
activation electrodes will be automatically selected as control electrodes. The list
should have the following shape (dnpu_node_no,data_input_electrode_no). The minimum
dnpu_node_no should be 1, e.g., data_input_indices = [[1,2]]. When specifying more than
one dnpu node in the list, the module will simulate, in time-multiplexing,
as if there was a layer of DNPU devices. Fore example, for an 8 electrode DNPU device
with a single readout electrode and 7 activation electrodes, when
data_input_indices = [[1,2],[1,3],[3,4]], it will be considered that there are 3 DNPU
devices, where the first DNPU device will use the data input electrodes
1 and 2, the second DNPU device will use data input electrodes 1 and 3 and the third
DNPU device will use data input electrodes 3 and 4. Also, the first DNPU device will
have electrodes 0, 3, 4, 5, and 6 defined as control electrodes. The second DNPU device
will have electrodes 0,2,4,5, and 6 defined as control electrodes. The third DNPU device
will have electrodes 0,1,2,5, and 6 defined as control electrodes. More information
about what activation, readout, data input and control electrodes are can be found at
the wiki: https://github.com/BraiNEdarwin/brains-py/wiki/A.-Introduction
Returns
-------
input_data_electrode_no - int
Number of data input electrodes.
control_electrode_no - int
Number of control electrodes.
"""
input_data_electrode_no = len(self.data_input_indices[0])
control_electrode_no = self.processor.get_activation_electrode_no(
) - input_data_electrode_no
return input_data_electrode_no, control_electrode_no
[docs]
def get_node_no(self):
"""
It counts how many control and data input electrodes are
going to be declared using the data_input_indices variable.
"""
return self.node_no
[docs]
def init_electrode_info(self, data_input_indices: Sequence[int]):
"""
Set the input data electrode indices and control electrode indices,
as well as the data input and control voltage ranges allowed according
to the assigned processor electrodes. The voltage ranges are defined with
the following shape: (node_no, electrode_no, 2), where the last dimension
correspond to the minimum and maximum of the range.
Method is called by the constructor.
Example
-------
>>> dnpu.get_activation_electrode_no()
7
>>> input_indices = [[0, 2]]
>>> dnpu._init_electrode_info(input_indices)
>>> dnpu.data_input_indices
torch.Tensor([0, 2])
>>> dnpu.control_indices
torch.Tensor([1, 3, 4, 5, 6])
Here we have a DNPU with 7 activation electrodes, where the two with
indices 0 and 2 are set as input electrodes, and the rest become
control electrodes.
Parameters
----------
data_input_indices : Sequence[int]
Indices of the input electrodes.
"""
assert type(
data_input_indices
) is list, "Data input indices should be provided as a list."
temp = torch.tensor(data_input_indices)
assert len(temp.shape) == 2 and temp.shape[0] >= 1 and temp.shape[
1] <= self.processor.get_activation_electrode_no() and (
temp < self.processor.get_activation_electrode_no()
).all().item() and torch.tensor([
(e.unique().shape == e.shape) for e in temp
]).all().item(), (
"Please revise the format in which data input indices has been passed to the DNPU. "
+
"Data input indices should be represented with two dimensions (number of DNPU nodes, "
+
"number of data input electrodes) (e.g., [[1,2]] or [[1,2],[1,3]], data input indices"
+ "CANNOT be represented as just [1,2]. )")
del temp
self.register_buffer(
"data_input_indices",
torch.tensor(data_input_indices, dtype=torch.long))
self.node_no = len(self.data_input_indices)
self.data_input_electrode_no, self.control_electrode_no = self.init_activation_electrode_no(
)
voltage_ranges = self.processor.processor.get_voltage_ranges()
# Define data input voltage ranges
# TODO: Add to documentation. data input ranges are defined as follows: (node_no,
# electrode_no, 2) where last 2 is for min and max
# Define data input electrode indices
self.register_buffer(
"data_input_ranges",
torch.stack([voltage_ranges[i] for i in data_input_indices]))
# Define control voltage ranges
activation_electrode_indices = np.arange(
self.processor.get_activation_electrode_no())
control_list = [
np.delete(activation_electrode_indices, i)
for i in data_input_indices
]
# TODO: Add to documentation. control ranges are defined as follows: (node_no, electrode_no,
# 2) where last 2 is for min and max
# self.control_ranges = XX
self.register_buffer(
"control_ranges",
torch.stack([voltage_ranges[i] for i in control_list]))
self.register_buffer(
"control_indices",
torch.tensor(np.array(control_list), dtype=torch.long))
# # Define control electrode indices
# self.control_indices = TorchUtils.format(
# control_list, data_type=torch.int64
# ) # IndexError: tensors used as indices must be long, byte or bool tensors
[docs]
def get_control_electrode_no(self):
"""
Returns how many control electrodes are used per DNPU node inside the module.
Returns
-------
control_electrode_no - int
Number of control electrodes.
"""
return self.control_electrode_no
[docs]
def constraint_control_voltages(self):
"""
Allows to constraint the control voltages to their corresponding maximum and minimum values.
Can be used after loss.backward() and optimizer.step() to clip out those control voltages
that are outside their ranges.
Example
-------
[...]
loss.backward()
optimizer.step()
model.constraint_weights()
[...]
"""
if torch.__version__ >= '1.11.0':
max_ranges = self.get_control_ranges().permute(
*torch.arange(self.get_control_ranges().ndim - 1, -1, -1))[1]
max_ranges = max_ranges.permute(
*torch.arange(max_ranges.ndim - 1, -1, -1))
min_ranges = self.get_control_ranges().permute(
*torch.arange(self.get_control_ranges().ndim - 1, -1, -1))[0]
min_ranges = min_ranges.permute(
*torch.arange(min_ranges.ndim - 1, -1, -1))
self.control_voltages.data = torch.max(
torch.min(self.control_voltages, max_ranges), min_ranges)
else:
self.control_voltages.data = torch.max(
torch.min(self.control_voltages,
self.get_control_ranges().T[1].T),
self.get_control_ranges().T[0].T)
# Returns a random single value of a control voltage within a specified range.
# Control voltage range = [min,max]
[docs]
def sample_controls(self):
"""
Returns a random single value of a control voltage, in the shape of
(node_no, control_electrode_no), within a specified control voltage range.
The control voltage ranges are automatically taken from the electrode position.
Each electrode might have a different voltage range, and this is defined during
the data gathering process of the surrogate model generator.
Example
-------
[...]
loss.backward()
optimizer.step()
model.constraint_weights()
[...]
"""
if torch.__version__ >= '1.11.0':
random_voltages = torch.rand(self.control_indices.shape,
device=self.control_ranges.device)
random_voltages = random_voltages.permute(
*torch.arange(random_voltages.ndim - 1, -1, -1))
range_base = self.control_ranges.permute(
*torch.arange(self.control_ranges.ndim - 1, -1, -1))[0]
range_size = (self.control_ranges.permute(
*torch.arange(self.control_ranges.ndim - 1, -1, -1))[1] -
range_base)
result = ((range_size * random_voltages) + range_base)
result = result.permute(*torch.arange(result.ndim - 1, -1, -1))
else:
random_voltages = torch.rand(self.control_indices.shape,
device=self.control_ranges.device).T
range_base = self.control_ranges.T[0]
range_size = (self.control_ranges.T[1] - range_base)
result = ((range_size * random_voltages) + range_base).T
return result
def _init_learnable_parameters(self):
"""
Freezes the parameters of the neural network model of the surrogate model and
initialises the control voltages as learnable parameters of the model.
"""
for params in self.parameters():
params.requires_grad = False
self._init_control_voltages()
def _init_control_voltages(self):
"""
Sets the intial random values for the control voltages of the DNPU, and declares them
as learnable pytorch parameters. The initial random values are specified within the range
that the processor has defined by default for a specific electrode.
Method is called by constructor (indirectly) and reset method.
Raise
-----
AssertionError
If negative voltages are detected.
"""
self.control_voltages = nn.Parameter(self.sample_controls())
# node_index: For time multiplexing only.
# Is the index of the processor to which the data will be sent.
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Run a forward pass through the processor after merging the input data
voltages with the control voltages.
Parameters
----------
x : torch.Tensor
Input data.
Returns
-------
torch.Tensor
Output data.
"""
return self.forward_pass(x)
[docs]
def forward_single(self, x: torch.Tensor, control_voltages: torch.Tensor,
data_input_indices: torch.Tensor,
control_indices: torch.Tensor):
"""
Run a forward pass through the processor for a single DNPU node.
Parameters
----------
x : torch.Tensor
Input data in the shape of (batch_size, data_input_electrode_no), where
data_input_electrode_no is the number of data input electrodes for a
single DNPU inside the time-multiplexing layer.
control_voltages : torch.Tensor
Control voltage values that will be sent to the control electrodes.
data_input_indices : torch.Tensor [int64]
Indices of the data input electrodes.
control_indices : torch.Tensor [int64]
Indices of the control electrodes.
Returns
-------
torch.Tensor
Output data in the shape of (batch_size, readout_electrode_no)
"""
merged_data = merge_electrode_data(
x,
control_voltages.expand(x.size()[0], -1),
data_input_indices,
control_indices,
)
return self.processor(merged_data)
[docs]
def forward_for(self, x):
"""
Run a forward pass through all of the DNPU nodes using a for loop.
Parameters
----------
x : torch.Tensor
Input data in the shape of (batch_size, total_data_input_electrode_no), where
total_data_input_electrode_no is the number of all data input electrodes from
all the nodes used in the time-multiplexing layer.
Returns
-------
torch.Tensor
Output data in the shape of (batch_size, readout_electrode_no)
"""
# Cut input values to force linear transform or force being between a voltage range.
x = self.clip_input(x)
# Apply a linear transformation from raw data to the voltage ranges of the dnpu.
if self.input_transform:
x = (self.scale.flatten() * x) + self.offset.flatten()
assert (x.shape[-1] == self.data_input_indices.numel()), (
f"size mismatch: data is {x.shape}," +
f"DNPU_Layer expecting {self.data_input_indices.numel()}")
outputs = [
self.forward_single(
node_x,
self.control_voltages[node_index],
self.data_input_indices[node_index],
self.control_indices[node_index],
) for node_index, node_x in enumerate(self.get_node_input_data(x))
]
return torch.cat(outputs, dim=1)
[docs]
def forward_vec(self, x):
"""
Run a forward pass through all of the DNPU nodes using by vectorising the input into
(batch_size, dnpu_node_no, activation_electrode_no).
Parameters
----------
x : torch.Tensor
Input data in the shape of (batch_size, total_data_input_electrode_no), where
total_data_input_electrode_no is the number of all data input electrodes from
all the nodes used in the time-multiplexing layer.
Returns
-------
torch.Tensor
Output data in the shape of (batch_size, readout_electrode_no)
"""
batch_size = x.size(0)
data_input_shape = list(self.data_input_indices.shape)
data_input_shape.insert(0, batch_size)
control_voltages_shape = list(self.control_voltages.shape)
control_voltages_shape.insert(0, batch_size)
# Reshape input and expand controls
x = x.reshape(data_input_shape)
last_dim = len(x.size()) - 1
controls = self.control_voltages.expand(control_voltages_shape)
# Cut input values to force linear transform or force being between a voltage range.
x = self.clip_input(x)
# Apply a linear transformation from raw data to the voltage ranges of the dnpu.
if self.input_transform:
x = (self.scale.to(x.device) * x) + self.offset.to(x.device)
# Expand indices according to batch size
input_indices = self.data_input_indices.expand(data_input_shape)
control_indices = self.control_indices.expand(control_voltages_shape)
# Create input data and order it according to the indices
indices = torch.argsort(torch.cat((input_indices, control_indices),
dim=last_dim),
dim=last_dim)
data = torch.cat((x, controls.to(x.device)), dim=last_dim)
data = torch.gather(data, last_dim, indices.to(x.device))
# pass data through the processor
if self.processor.is_hardware():
original_data_shape = data.shape
data = data.reshape(
original_data_shape[0] * original_data_shape[1], -1)
result = self.processor(data)
if not self.processor.average_plateaus:
result = result.reshape(
int(self.processor.waveform_mgr.plateau_length *
original_data_shape[0]), original_data_shape[1])
return result
else:
return self.processor(data).squeeze(
-1) # * self.node.amplification
# def refresh_after_processor_swap():
# pass
[docs]
def regularizer(self) -> torch.Tensor:
"""
Return a penalty term if the value of the control_voltages is outside of the
interval for the control voltages.
Example
-------
>>> dnpu.control_low
torch.Tensor([1.0, 5.0])
>>> dnpu.control_high
torch.Tensor([3.0, 7.0])
>>> dnpu.control_voltages
torch.Tensor([2.5, 8.0])
>>> dnpu.regularizer()
torch.Tensor([1.0])
In this example we have two control electrodes, the first with voltage
range 1 to 3 and the second 5 to 7. The control_voltages of the network is 2.5 for
the first electrode and 8 for the second. The first is within the
boundaries so no penalty is generated from it but the second is
outside, which means a penalty will be generated, which is equal to
how far outside the interval the value is. In this case 8 - 7 = 1, so
the penalty will be 0 + 1 = 1.
Returns
-------
torch.Tensor
Penalty term >=0.
"""
if torch.__version__ >= '1.11.0':
control_voltages = self.get_control_voltages().permute(
*torch.arange(self.get_control_voltages().ndim - 1, -1, -1))
control_ranges = self.get_control_ranges().permute(
*torch.arange(self.get_control_ranges().ndim - 1, -1, -1))
else:
control_voltages = self.get_control_voltages().T
control_ranges = self.get_control_ranges().T
return torch.sum(
torch.relu(control_ranges[0] - control_voltages) +
torch.relu(control_voltages - control_ranges[1]))
# @TODO: Update documentation
[docs]
def hw_eval(
self,
configs: dict,
data_input_indices: list = None,
):
"""
It helps setting the DNPU to evaluation mode. While training happens in simulation,
evaluation happens in hardware. This function sets the nn.Module to evaluation mode
(meaning no training) and swaps the processor (typically to hardware, although it also
supports to do it for hardware_debug or simulation). Checks if the voltage ranges of the
new processor are the same as the ones of the old.
Parameters
----------
arg : Processor or dict
Either a processor or configs dictionary to be applied to the
existing one.
Raises
------
AssertionError
If control voltages of the new processor are different than the
ones used for training the DNPU.
"""
assert type(
configs
) is dict, "Configs should be a dictionary. Check Processor for the information that the dictionary should have."
assert data_input_indices is None or type(
data_input_indices
) is list, "Data input indices should be None or a list with shape (DNPU node number, electrode_no)"
self.eval()
old_ranges = self.processor.get_voltage_ranges().cpu().half().clone()
self.processor.swap(configs, self.get_info_dict())
if data_input_indices is not None:
self.init_electrode_info(data_input_indices)
new_ranges = self.processor.get_voltage_ranges().cpu().half().clone()
if not torch.equal(old_ranges, new_ranges):
warnings.warn(
"Voltage ranges for the new processor are different "
"than the control voltage ranges for which the DNPU was trained: \n\n"
f"* old voltage ranges: \n {old_ranges.cpu().numpy()} \n\n"
f"* new voltage ranges: \n {new_ranges.cpu().numpy()} \n")
del old_ranges
del new_ranges
# @TODO: Update documentation
[docs]
def sw_train(
self,
configs: dict,
info: dict,
model_state_dict: collections.OrderedDict = None,
):
"""
It helps swap the DNPU to training mode. While evaluation happens in hardware, training
happens in software. This function sets the nn.Module to training mode and swaps
the processor (typically to software).
Checks if the voltage ranges of the new processor are the same as the ones of the old.
Parameters
----------
arg : Processor or dict
Either a processor or configs dictionary to be applied to the
existing one.
Raises
------
AssertionError
If control voltages of the new processor are different than the
ones used for training the DNPU.
"""
self.train()
old_ranges = self.processor.get_voltage_ranges().clone().cpu().half()
self.processor.swap(configs, info, model_state_dict)
new_ranges = self.processor.get_voltage_ranges().clone().cpu().half()
assert torch.equal(
old_ranges,
new_ranges), "Voltage ranges for the new processor are different "
"than the control voltage ranges for which the DNPU was trained."
del old_ranges
del new_ranges
[docs]
def set_control_voltages(self, control_voltages: torch.Tensor):
"""
Change the control_voltages of the network.
Example
-------
>>> dnpu.set_control_voltages(torch.tensor([1.0, 2.0, 3.0, 4.0]))
Parameters
----------
control_voltages : torch.Tensor
New value of the control_voltages.
One dimensional tensor.
"""
with torch.no_grad():
assert (
self.control_voltages.shape == control_voltages.shape
), "Control voltages could not be set due to a shape missmatch "
"with regard to the ones already in the model."
assert (
self.control_voltages.dtype == control_voltages.dtype
), "Control voltages could not be set due to a shape missmatch "
"with regard to the ones already in the model."
self.control_voltages = torch.nn.Parameter(
control_voltages.type_as(self.control_voltages).to(
self.control_voltages.device))
[docs]
def get_control_voltages(self) -> torch.Tensor:
"""
Get the (next) control_voltages of the network, detach it from
the computational graph.
Returns
-------
torch.Tensor
Value of the control_voltages.
"""
return self.control_voltages.detach()
[docs]
def get_control_ranges(self) -> torch.Tensor:
"""
Get the voltage ranges of the control electrodes.
It has a shape of (dnpu_no, electrode_no, 2), where the last dimension is
0 for the minimum value of the range and 1 for the maximum value of the range.
Returns
-------
torch.Tensor
Control voltage ranges.
"""
return self.control_ranges
[docs]
def get_clipping_value(self) -> torch.Tensor:
"""
Get the output clipping/clipping value.
Returns
-------
torch.Tensor
The output clipping of the processor.
"""
return self.processor.get_clipping_value()
[docs]
def reset(self):
"""
Reset the control_voltages of the processor.
"""
del self.control_voltages
self._init_control_voltages()
# TODO: Document the need to override the closing of the processor on
# custom models.
[docs]
def close(self):
"""
Close the processor. For simulation models, it does nothing. For hardware models it closes
the drivers.
"""
self.processor.close()
[docs]
def is_hardware(self) -> bool:
"""
Check if processor is a hardware processor or a simulation processor.
Returns
-------
bool
True if hardware, False if simulation.
"""
return self.processor.is_hardware()
# TODO: Document the need to override the closing of the return of the
# info dictionary.
[docs]
def get_info_dict(self) -> dict:
"""
Get the info dictionary of the processor.
Returns
-------
dict
The info dictionary.
"""
return self.processor.info
[docs]
def merge_electrode_data(
input_data,
control_data,
input_data_indices: torch.Tensor,
control_indices,
) -> Union[np.array, Tensor]:
"""
Merge data from two electrodes with the specified indices for each.
Need to indicate whether numpy or torch is used. The result will
have the same type as the input.
Example
-------
>>> inputs = np.array([[1.0, 3.0], [2.0, 4.0]])
>>> control_voltages = np.array([[5.0, 7.0], [6.0, 8.0]])
>>> input_indices = [0, 2]
>>> control_voltage_indices = [3, 1]
>>> electrodes.merge_electrode_data(
... inputs=inputs,
... control_voltages=control_voltages,
... input_indices=input_indices,
... control_voltage_indices=control_voltage_indices,
... use_torch=False,
... )
np.array([[1.0, 7.0, 3.0, 5.0], [2.0, 8.0, 4.0, 6.0]])
Merging two arrays of size 2x2, resulting in an array of size 2x4.
Parameters
----------
inputs: np.array or torch.tensor
Data for the input electrodes.
control_voltages: np.array or torch.tensor
Data for the control electrodes.
input_indices: iterable of int
Indices of the input electrodes.
control_voltage_indices: iterable of int
Indices of the control electrodes.
Returns
-------
result: np.array or torch.tensor
Array or tensor with merged data.
"""
assert (
input_data.dtype == control_data.dtype
and input_data.device == control_data.device
), ("Input data voltages and control voltages have a different data type "
+ "or are in a different device (CUDA or CPU). ")
result = torch.empty(
(input_data.shape[0], len(input_data_indices) + len(control_indices)),
device=input_data.device,
dtype=input_data.dtype)
result[:, input_data_indices] = input_data
result[:, control_indices] = control_data
return result