File size: 18,832 Bytes
b89907e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 |
import tensorflow as tf
import time as tm
import sys
# import numpy as np
# import cudnn as cd
# from tensorflow.keras import datasets, layers, models
# import matplotlib.pyplot as plt
# from tensorflow.python.client import device_lib
class Integrator_layer(tf.keras.layers.Layer):
def __init__(self, n_steps=100, integration_window=100, time_constant=1.0, leakyness=0.0,
V_m_threshold = 1.0, refractory_period=0, amplitude=1.0, V_m_min=0, V_cm=2.5, device='cuda', name='I&F'):
super(Integrator_layer, self).__init__(name=name)
# self.threshold = nn.Threshold(V_m_threshold, 0)
# self.zero = torch.tensor(0, dtype=torch.float, device=device)
self.Vm_threshold = V_m_threshold
self.integration_window = integration_window
self.refractory_period = refractory_period
self.time_constant = time_constant
# self.epsilon = 0.001
self.epsilon = tf.keras.backend.epsilon
self.amplitude = amplitude
# self.threshold = nn.Threshold(V_m_threshold - self.epsilon, 0) ### Thresholding function
# self.threshold = tf.nn.relu(V_m_threshold - self.epsilon, 0) ### Thresholding function
self.V_m_min = V_m_min
self.device = device
@tf.function
def chunk_sizes(self, length, chunk_size):
chunks = [chunk_size for x in range(length//chunk_size)]
if length % chunk_size != 0:
chunks.append(length % chunk_size)
return chunks
def build(self, input_shape):
self.batch_size = input_shape[0]
self.timesteps = input_shape[1]
self.image_shape = input_shape[2:]
self.image_rank = len(input_shape)
# self.chunk_sizes = self.chunk_sizes(self.timesteps, self.integration_window)
self.chunk_sizes = self.chunk_sizes(input_shape[1], self.integration_window)
self.tensor_invariance = [None for i in range(self.image_rank)]
# self.list_of_indices = [[x, 0] for x in tf.range(input_shape[0])]
# self.list_of_indices = tf.range(input_shape[0])
@tf.function
def call(self, inputs):
### List of indices - list of indices to replace very first timestep with zero after the roll operation
# list_of_indices = tf.pad(tf.expand_dims(tf.range(tf.shape(inputs)[0]), axis=1),
# paddings=[[0, 0], [0, 1]],
# mode="CONSTANT")
list_of_indices = tf.pad(tf.expand_dims(tf.range(tf.shape(inputs)[0]), axis=1),
paddings=[[0, 0], [0, 1]],
mode="CONSTANT")
roll_padding = tf.zeros([self.image_rank, 2], dtype=tf.int32)
roll_padding = tf.tensor_scatter_nd_update(roll_padding, indices=[[1, 0]], updates= [1])
images_chunks = tf.split(inputs, self.chunk_sizes, axis=1) ### Fragment current sample into multiple chunks with length equal to the integration window
first_chunk = True
# zero = torch.tensor(0, dtype=torch.float, device=self.device)
for chunk, n_timesteps in zip(images_chunks, self.chunk_sizes):
### n_timesteps - the number of timesteps for current chunk of integration window
Spikes_out = tf.zeros([tf.shape(chunk)[0], *self.image_shape, n_timesteps + 1])
### V_m_out - array for storing membrane potential
V_m_out = tf.zeros_like(chunk)
# V_m_temp = tf.zeros_like(chunk)
V_m_temp = tf.ones_like(chunk)
# V_m_temp = tf.tensor_scatter_nd_update(V_m_temp, indices=list_of_indices,
# updates=tf.ones([1, *self.image_shape]))
while tf.math.count_nonzero(V_m_temp) != 0:
tf.autograph.experimental.set_loop_options(shape_invariants=[(V_m_temp, tf.TensorShape(self.tensor_invariance))])
### V_m_chunk - cumulative summation (integration) along time dimension
V_m_chunk = tf.math.cumsum(tf.math.multiply(chunk, self.time_constant), axis=1)
### Thresholding chunks, all values bellow threshold value are zeroed
V_m_temp = tf.nn.relu(V_m_chunk - self.Vm_threshold)
# V_m_temp = tf.print(V_m_temp, [V_m_temp], 'breaking')
if tf.math.count_nonzero(V_m_temp) == 0: ### if Vm did not cross threshold, break the cycle
# V_m_out = V_m_out + V_m_chunk
# V_m_out = tf.print(V_m_out, [V_m_out], 'breaking')
break
### Cumsum of the thresholded cumsum - to avoid any future threshold crossings (additional zeroes) that can occur after threshold is hit:
V_m_temp = tf.math.cumsum(V_m_temp, axis=1)
### V_m_temp == 0 The amount of zero values before function crosses the threshold. Used to calculated how many timesteps it took for an integrator to fire an output spike
Spikes_out = Spikes_out + tf.one_hot(tf.reduce_sum(tf.cast((V_m_temp == 0), tf.int32), axis=1), depth=n_timesteps + 1)### One hotted zero counts
### TF roll operation is used to shift the vector values by 1, other timestep which crossed threshold is not included:
V_m_temp = tf.pad(V_m_temp, paddings=roll_padding, mode="CONSTANT")
V_m_temp, _ = tf.split(V_m_temp, [n_timesteps, 1], axis=1)
# V_m_temp = tf.roll(V_m_temp, shift=1, axis=1)
###Since roll operation will shift the last value to the first place, the first value should be 0'ed for a proper counting of 0 in the next code fragments.
# __, V_m_temp = tf.split(V_m_temp, [1, n_timesteps - 1], axis=1)
# V_m_temp = tf.concat((V_m_temp, tf.zeros_like(__)), axis=1)
# V_m_temp = tf.tensor_scatter_nd_update(V_m_temp, indices=list_of_indices,
# updates=tf.zeros([tf.shape(chunk)[0], *self.image_shape]))
# V_m_out = tf.where(V_m_temp == 0, V_m_out + V_m_chunk, 0) ### Resets V_m to 0 after firing
if self.refractory_period!=0: ### Resets (=0) number of timesteps after output spike is fired
V_m_temp = tf.roll(chunk, shift=self.refractory_period, axis=1)
V_m_temp[:, 0:(self.refractory_period-1), :, :, :] = 0
chunk = tf.where(V_m_temp == 0.0, 0.0, chunk) ### Removes spikes before firing. So new V_m can be calculated for a next spike.
# Spikes_out = torch.narrow(Spikes_out, dim=-1, start=0, length= n_timesteps)
Spikes_out, _ = tf.split(Spikes_out, [n_timesteps, 1], axis=-1) ### Onehot operation adds back time dimension to the last place, so it must be popped out
if first_chunk:
# V_m_final = V_m_out
Spikes_out_final = Spikes_out
first_chunk = False
else:
V_m_final = tf.concat((V_m_final, V_m_out), axis=1)
Spikes_out_final = tf.concat((Spikes_out_final, Spikes_out), axis=-1)
### Onehotting puts time as the last tensor dimension. 'movedim' moves time dimension to the 2nd place, after the batch number, as it was before.
Spikes_out_final = tf.experimental.numpy.moveaxis(Spikes_out_final, source=-1, destination=1)
# return V_m_final, Spikes_out_final
# print('LIF forward end:')
# print(f'{datetime.now().time().replace(microsecond=0)} --- ')
# print(Spikes_out.type())
if self.amplitude !=1.0:
Spikes_out_final = Spikes_out_final*self.amplitude
return Spikes_out_final
def sparse_data_generator_non_spiking(input_images, input_labels, batch_size=32, nb_steps=100, shuffle=True, flatten= False):
""" This generator takes datasets in analog format and generates network input as constant currents.
If repeat=True, encoding is rate-based, otherwise it is a latency encoding
Args:
X: The data ( sample x event x 2 ) the last dim holds (time,neuron) tuples
y: The labels
"""
# data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=n_workers)
data_loader_original = tf.data.Dataset.from_tensor_slices((tf.cast(input_images, tf.float32), input_labels))
if shuffle:
data_loader_original = data_loader_original.shuffle(buffer_size=100)
data_loader = data_loader_original.batch(batch_size=batch_size, drop_remainder=False)
number_of_batches = input_labels.__len__() // batch_size
counter = 0
time = tm.time()
for X, y in data_loader:
if flatten:
X = X.reshape(X.shape[0], -1)
# sample_dims = np.array(X.shape[1:], dtype=int)
sample_dims = X.shape[1:]
# X = torch.unsqueeze(X, dim=1)
X = tf.expand_dims(X, axis=1)
X = tf.repeat(X, repeats=nb_steps, axis=1)
time_taken = tm.time() - time
time = tm.time()
ETA = time_taken * (number_of_batches - counter)
sys.stdout.write(
"\rBatch: {0}/{1}, Progress: {2:0.2f}%, Time to process last batch: {3:0.2f} seconds, Estimated time to finish epoch: {4:0.2f} seconds | {5}:{6} minutes".format(
counter, number_of_batches, (counter / number_of_batches) * 100, time_taken, ETA, int(ETA // 60),
int(ETA % 60)))
sys.stdout.flush()
# X_batch = torch.tensor(X, device=device, dtype=torch.float)
# yield X.expand(-1, nb_steps, *sample_dims).to(device), y.to(device) ### Returns this values after each batch
counter += 1
yield X, y ### Returns this values after each batch
# return argument_free_generator()
class Reduce_sum(tf.keras.layers.Layer):
def __init__(self, name=None):
super(Reduce_sum, self).__init__(name=name)
def call(self, inputs):
return tf.math.reduce_sum(inputs, axis=1, keepdims=False)
# """
# class Integrator_layer(tf.keras.layers.Layer):
# def __init__(self, n_steps=100, integration_window=100, time_constant=1.0, leakyness=0.0,
# V_m_threshold = 2.0, refractory_period=0, amplitude=1.0, V_m_min=0, V_cm=2.5, device='cuda', name='I&F'):
# super(Integrator_layer, self).__init__(name=name)
# # self.threshold = nn.Threshold(V_m_threshold, 0)
# # self.zero = torch.tensor(0, dtype=torch.float, device=device)
# self.Vm_threshold = V_m_threshold
# self.integration_window = integration_window
# self.refractory_period = refractory_period
# self.time_constant = time_constant
# # self.epsilon = 0.001
# self.epsilon = tf.keras.backend.epsilon
# self.amplitude = amplitude
# # self.threshold = nn.Threshold(V_m_threshold - self.epsilon, 0) ### Thresholding function
# # self.threshold = tf.nn.relu(V_m_threshold - self.epsilon, 0) ### Thresholding function
# self.V_m_min = V_m_min
# self.device = device
#
# @tf.function
# def chunk_sizes(self, length, chunk_size):
# chunks = [chunk_size for x in range(length//chunk_size)]
# if length % chunk_size != 0:
# chunks.append(length % chunk_size)
# return chunks
#
# def build(self, input_shape):
# self.batch_size = input_shape[0]
# self.timesteps = input_shape[1]
# self.image_shape = input_shape[2:]
# # self.chunk_sizes = self.chunk_sizes(self.timesteps, self.integration_window)
# self.chunk_sizes = self.chunk_sizes(input_shape[1], self.integration_window)
# ###
# ###
# ###
# # self.list_of_indices = [[x, 0] for x in tf.range(input_shape[0])]
# # self.list_of_indices = tf.range(input_shape[0])
#
# @tf.function
# def call(self, inputs):
# ### List of indices - list of indices to replace very first timestep with zero after the roll operation
# list_of_indices = tf.pad(tf.expand_dims(tf.range(tf.shape(inputs)[0]), axis=1),
# paddings=[[0, 0], [0, 1]],
# mode="CONSTANT")
# images_chunks = tf.split(inputs, self.chunk_sizes, axis=1) ### Fragment current sample into multiple chunks with length equal to the integration window
# first_chunk = True
# # zero = torch.tensor(0, dtype=torch.float, device=self.device)
# for chunk, n_timesteps in zip(images_chunks, self.chunk_sizes):
# ### n_timesteps - the number of timesteps for current chunk of integration window
# Spikes_out = tf.zeros([tf.shape(chunk)[0], *self.image_shape, n_timesteps + 1])
# ### V_m_out - array for storing membrane potential
# V_m_out = tf.zeros_like(chunk)
# V_m_temp = tf.zeros_like(chunk)
# while tf.math.count_nonzero(V_m_temp) != 0:
# ### V_m_chunk - cumulative summation (integration) along time dimension
# V_m_chunk = tf.math.cumsum(tf.math.multiply(chunk, self.time_constant), axis=1)
# ### Thresholding chunks, all values bellow threshold value are zeroed
# V_m_temp = tf.nn.relu(V_m_chunk - self.Vm_threshold)
# if tf.math.count_nonzero(V_m_temp) == 0: ### if Vm did not cross threshold, break the cycle
# V_m_out = V_m_out + V_m_chunk
# break
# ### Cumsum of the thresholded cumsum - to avoid any future threshold crossings (additional zeroes) that can occur after threshold is hit:
# V_m_temp = tf.math.cumsum(V_m_temp, axis=1)
# ### V_m_temp == 0 The amount of zero values before function crosses the threshold. Used to calculated how many timesteps it took for an integrator to fire an output spike
# Spikes_out = Spikes_out + tf.one_hot(tf.reduce_sum(tf.cast((V_m_temp == 0), tf.int32), axis=1), depth=n_timesteps + 1)### One hotted zero counts
# ### TF roll operation is used to shift the vector values by 1, other timestep which crossed threshold is not included:
# V_m_temp = tf.roll(V_m_temp, shift=1, axis=1)
# ###Since roll operation will shift the last value to the first place, the first value should be 0'ed for a proper counting of 0 in the next code fragments.
# V_m_temp = tf.tensor_scatter_nd_update(V_m_temp, indices=list_of_indices,
# updates=tf.zeros([1, *self.image_shape]))
# V_m_out = tf.where(V_m_temp == 0, V_m_out + V_m_chunk, 0) ### Resets V_m to 0 after firing
# if self.refractory_period!=0: ### Resets (=0) number of timesteps after output spike is fired
# V_m_temp = tf.roll(chunk, shift=self.refractory_period, axis=1)
# V_m_temp[:, 0:(self.refractory_period-1), :, :, :] = 0
# chunk = tf.where(V_m_temp == 0.0, 0.0, chunk) ### Removes spikes before firing. So new V_m can be calculated for a next spike.
# # Spikes_out = torch.narrow(Spikes_out, dim=-1, start=0, length= n_timesteps)
# Spikes_out, _ = tf.split(Spikes_out, [n_timesteps, 1], axis=-1) ### Onehot operation adds back time dimension to the last place, so it must be popped out
# if first_chunk:
# V_m_final = V_m_out
# Spikes_out_final = Spikes_out
# first_chunk = False
# else:
# V_m_final = tf.concat((V_m_final, V_m_out), axis=1)
# Spikes_out_final = tf.concat((Spikes_out_final, Spikes_out), axis=-1)
#
# # Spikes_out_final = torch.movedim(Spikes_out_final, source=-1, destination=1)
# Spikes_out_final = tf.experimental.numpy.swapaxes(Spikes_out_final, axis1=-1,
# axis2=1) ### Onehotting puts time as the last tensor dimension. 'movedim' moves time dimension to the 2nd place, after the batch number, as it was before.
#
# # return V_m_final, Spikes_out_final
# # print('LIF forward end:')
# # print(f'{datetime.now().time().replace(microsecond=0)} --- ')
# # print(Spikes_out.type())
# if self.amplitude !=1.0:
# Spikes_out_final = Spikes_out_final*self.amplitude
# return Spikes_out_final
# """
#
# def sparse_data_generator_non_spiking(input_images, input_labels, batch_size=32, nb_steps=100, shuffle=True, flatten= False):
# """ This generator takes datasets in analog format and generates network input as constant currents.
# If repeat=True, encoding is rate-based, otherwise it is a latency encoding
# Args:
# X: The data ( sample x event x 2 ) the last dim holds (time,neuron) tuples
# y: The labels
# """
#
# # def argument_free_generator():
# # data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=n_workers)
# data_loader_original = tf.data.Dataset.from_tensor_slices((tf.cast(input_images, tf.float32), input_labels))
# if shuffle:
# data_loader_original = data_loader_original.shuffle(buffer_size=100)
# data_loader = data_loader_original.batch(batch_size=batch_size, drop_remainder=False)
#
# number_of_batches = input_labels.__len__() // batch_size
# counter = 0
# time = tm.time()
#
# for X, y in data_loader:
# if flatten:
# X = X.reshape(X.shape[0], -1)
# # sample_dims = np.array(X.shape[1:], dtype=int)
# sample_dims = X.shape[1:]
# # X = torch.unsqueeze(X, dim=1)
# X = tf.expand_dims(X, axis=1)
# X = tf.repeat(X, repeats=nb_steps, axis=1)
# time_taken = tm.time() - time
# time = tm.time()
# ETA = time_taken * (number_of_batches - counter)
# sys.stdout.write(
# "\rBatch: {0}/{1}, Progress: {2:0.2f}%, Time to process last batch: {3:0.2f} seconds, Estimated time to finish epoch: {4:0.2f} seconds | {5}:{6} minutes".format(
# counter, number_of_batches, (counter / number_of_batches) * 100, time_taken, ETA, int(ETA // 60),
# int(ETA % 60)))
# sys.stdout.flush()
# # X_batch = torch.tensor(X, device=device, dtype=torch.float)
# # yield X.expand(-1, nb_steps, *sample_dims).to(device), y.to(device) ### Returns this values after each batch
# counter += 1
# yield X, y ### Returns this values after each batch
#
# # return argument_free_generator()
#
#
# class Reduce_sum(tf.keras.layers.Layer):
# def __init__(self, name=None):
# super(Reduce_sum, self).__init__(name=name)
#
# def call(self, inputs):
# return tf.math.reduce_sum(inputs, axis=1, keepdims=False)
#
#
|