The Continuous Thought Machine – Tutorial 01: MNIST

Modern deep learning models ignore time as a core computational element. In contrast, the Continuous Thought Machine (CTM) introduces internal recurrence and neural synchronization to model thinking as a temporal process.
Key Ideas
- Internal Ticks: The CTM runs over a self-generated temporal axis (independent of input), which we view as a dimension over which thought can unfold.
- Neuron-Level Models: Each neuron has a private MLP that processes its own history of pre-activations over time.
- Synchronization as Representation: CTMs compute neuron-to-neuron synchronization over time and use these signals for attention and output.
Why It Matters
- Enables interpretable, dynamic reasoning
- Supports adaptive compute (e.g. more ticks for harder tasks)
- Works across tasks: classification, reasoning, memory, RL—without changing the core mechanisms.
MNIST Classification
In this tutorial we walk through a simple example; training a CTM to classify MNIST digits. We cover:
- Defining the model
- Constructing the loss function
- Training
- Building visualization
Imports
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np
import random
from scipy.special import softmax
import math
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
import seaborn as sns
import imageio
import mediapy
import torch._dynamo as dynamo
dynamo.config.suppress_errors = False # Raise errors on fallbackSet the seed for reproducability
def set_seed(seed=42, deterministic=True):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = deterministic
torch.backends.cudnn.benchmark = Falseset_seed(42)We start by defining some helper classes, which we will use in the CTM.
Of note is the SuperLinear class, which implements N unique linear transformations. This SuperLinear class will be used for the neuron-level models.
class Identity(nn.Module):
"""Identity Module."""
def __init__(self):
super().__init__()
def forward(self, x):
return x
class Squeeze(nn.Module):
"""Squeeze Module."""
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
return x.squeeze(self.dim)
class SuperLinear(nn.Module):
"""SuperLinear Layer: Implements Neuron-Level Models (NLMs) for the CTM."""
def __init__(self, in_dims, out_dims, N, dropout=0.0):
super().__init__()
self.in_dims = in_dims
self.dropout = nn.Dropout(dropout) if dropout > 0 else Identity()
self.register_parameter('w1', nn.Parameter(
torch.empty((in_dims, out_dims, N)).uniform_(
-1/math.sqrt(in_dims + out_dims),
1/math.sqrt(in_dims + out_dims)
), requires_grad=True)
)
self.register_parameter('b1', nn.Parameter(torch.zeros((1, N, out_dims)), requires_grad=True))
def forward(self, x):
out = self.dropout(x)
out = torch.einsum('BDM,MHD->BDH', out, self.w1) + self.b1
out = out.squeeze(-1)
return outNext, we define a helper function compute_normalized_entropy. We will use this function inside the CTM to compute the certainty of the model at each internal tick as certainty = 1 - normalized entropy.
def compute_normalized_entropy(logits, reduction='mean'):
"""Computes the normalized entropy for certainty-loss."""
preds = F.softmax(logits, dim=-1)
log_preds = torch.log_softmax(logits, dim=-1)
entropy = -torch.sum(preds * log_preds, dim=-1)
num_classes = preds.shape[-1]
max_entropy = torch.log(torch.tensor(num_classes, dtype=torch.float32))
normalized_entropy = entropy / max_entropy
if len(logits.shape)>2 and reduction == 'mean':
normalized_entropy = normalized_entropy.flatten(1).mean(-1)
return normalized_entropyCTM Architecture Overview
The CTM implementation is initialized with the following core parameters:
iterations: Number of internal ticks (recurrent steps).d_model: Total number of neurons.d_input: Input and attention embedding dimension.memory_length: Length of the sliding activation window used by each neuron.heads: Number of attention heads.n_synch_out: Number of neurons used for output synchronization.n_synch_action: Number of neurons used for computing attention queries.out_dims: Dimensionality of the model’s output.
Key Components
Upon initialization, the CTM builds the following modules:
- Backbone: A CNN feature extractor for the input (e.g. image).
- Synapses: A communication layer allowing neurons to interact.
- Trace Processor: A neuron-level model that operates on each neuron’s temporal activation trace.
- Synchronization Buffers: For tracking decay.
- Learned Initial States: Starting activations and traces for the system.
Forward Pass Mechanics
At each internal tick stepi, the CTM executes the following procedure:
-
Initialize recurrent state:
state_trace: Memory trace per neuron.activated_state: Current post-activations.decay_alpha_out,decay_beta_out: Values for calculating synchronization.
-
Featurize input:
- Use the CNN backbone to extract key-value attention pairs
kv.
- Use the CNN backbone to extract key-value attention pairs
-
Internal Loop (for each tick
stepi):- Compute
synchronisation_actionfromn_synch_actionneurons. - Generate attention query
qfrom this synchronization. - Perform multi-head cross-attention over
kv. - Concatenate attention output with the current neuron activations.
- Update neurons via the synaptic module.
- Append new activation to the trace window.
- Update neuron states using the
trace_processor. - Compute
synchronisation_outfromn_synch_outneurons. - Project to the output space via
output_projector. - Compute prediction certainty from normalized entropy.
- Compute
This inner loop is repeated for the configured number of internal ticks. The CTM emits predictions and certainties at every internal tick.
For detailed mathematical formulation of the synchronization mechanism, please refer to the technical report.
class ContinuousThoughtMachine(nn.Module):
def __init__(self,
iterations,
d_model,
d_input,
memory_length,
heads,
n_synch_out,
n_synch_action,
out_dims,
memory_hidden_dims,
dropout=0.0
):
super(ContinuousThoughtMachine, self).__init__()
# --- Core Parameters ---
self.iterations = iterations
self.d_model = d_model
self.d_input = d_input
self.memory_length = memory_length
self.n_synch_out = n_synch_out
self.n_synch_action = n_synch_action
self.out_dims = out_dims
self.memory_length = memory_length
self.memory_hidden_dims = memory_hidden_dims
# --- Input Processing ---
self.backbone = nn.Sequential(
nn.LazyConv2d(d_input, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(d_input),
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.LazyConv2d(d_input, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(d_input),
nn.ReLU(),
nn.MaxPool2d(2, 2),
)
self.attention = nn.MultiheadAttention(self.d_input, heads, dropout, batch_first=True)
self.kv_proj = nn.Sequential(nn.LazyLinear(self.d_input), nn.LayerNorm(self.d_input))
self.q_proj = nn.LazyLinear(self.d_input)
# --- Core CTM Modules ---
self.synapses = nn.Sequential(
nn.Dropout(dropout),
nn.LazyLinear(d_model * 2),
nn.GLU(),
nn.LayerNorm(d_model)
)
self.trace_processor = nn.Sequential(
SuperLinear(in_dims=memory_length, out_dims=2 * memory_hidden_dims, N=d_model, dropout=dropout),
nn.GLU(),
SuperLinear(in_dims=memory_hidden_dims, out_dims=2, N=d_model, dropout=dropout),
nn.GLU(),
Squeeze(-1)
)
# --- Start States ---
self.register_parameter('start_activated_state', nn.Parameter(
torch.zeros((d_model)).uniform_(-math.sqrt(1/(d_model)), math.sqrt(1/(d_model))),
requires_grad=True
))
self.register_parameter('start_trace', nn.Parameter(
torch.zeros((d_model, memory_length)).uniform_(-math.sqrt(1/(d_model+memory_length)), math.sqrt(1/(d_model+memory_length))),
requires_grad=True
))
# --- Synchronisation ---
self.synch_representation_size_action = (self.n_synch_action * (self.n_synch_action+1))//2
self.synch_representation_size_out = (self.n_synch_out * (self.n_synch_out+1))//2
for synch_type, size in [('action', self.synch_representation_size_action), ('out', self.synch_representation_size_out)]:
print(f"Synch representation size {synch_type}: {size}")
self.set_synchronisation_parameters('out')
self.set_synchronisation_parameters('action')
# --- Output Procesing ---
self.output_projector = nn.Sequential(nn.LazyLinear(self.out_dims))
def set_synchronisation_parameters(self, synch_type: str):
synch_representation_size = self.synch_representation_size_action if synch_type == 'action' else self.synch_representation_size_out
self.register_parameter(f'decay_params_{synch_type}', nn.Parameter(torch.zeros(synch_representation_size), requires_grad=True))
def compute_synchronisation(self, activated_state, decay_alpha, decay_beta, r, synch_type):
if synch_type == 'action':
n_synch = self.n_synch_action
selected_left = selected_right = activated_state[:, -n_synch:]
elif synch_type == 'out':
n_synch = self.n_synch_out
selected_left = selected_right = activated_state[:, :n_synch]
outer = selected_left.unsqueeze(2) * selected_right.unsqueeze(1)
i, j = torch.triu_indices(n_synch, n_synch)
pairwise_product = outer[:, i, j]
if decay_alpha is None or decay_beta is None:
decay_alpha = pairwise_product
decay_beta = torch.ones_like(pairwise_product)
else:
decay_alpha = r * decay_alpha + pairwise_product
decay_beta = r * decay_beta + 1
synchronisation = decay_alpha / (torch.sqrt(decay_beta))
return synchronisation, decay_alpha, decay_beta
def compute_features(self, x):
input_features = self.backbone(x)
kv = self.kv_proj(input_features.flatten(2).transpose(1, 2))
return kv
def compute_certainty(self, current_prediction):
ne = compute_normalized_entropy(current_prediction)
current_certainty = torch.stack((ne, 1-ne), -1)
return current_certainty
def forward(self, x, track=False):
B = x.size(0)
device = x.device
# --- Tracking Initialization ---
pre_activations_tracking = []
post_activations_tracking = []
synch_out_tracking = []
synch_action_tracking = []
attention_tracking = []
# --- Featurise Input Data ---
kv = self.compute_features(x)
# --- Initialise Recurrent State ---
state_trace = self.start_trace.unsqueeze(0).expand(B, -1, -1) # Shape: (B, H, T)
activated_state = self.start_activated_state.unsqueeze(0).expand(B, -1) # Shape: (B, H)
# --- Storage for outputs per iteration
predictions = torch.empty(B, self.out_dims, self.iterations, device=device, dtype=x.dtype)
certainties = torch.empty(B, 2, self.iterations, device=device, dtype=x.dtype)
decay_alpha_action, decay_beta_action = None, None
r_action, r_out = torch.exp(-self.decay_params_action).unsqueeze(0).repeat(B, 1), torch.exp(-self.decay_params_out).unsqueeze(0).repeat(B, 1)
_, decay_alpha_out, decay_beta_out = self.compute_synchronisation(activated_state, None, None, r_out, synch_type='out')
# --- Recurrent Loop ---
for stepi in range(self.iterations):
# --- Calculate Synchronisation for Input Data Interaction ---
synchronisation_action, decay_alpha_action, decay_beta_action = self.compute_synchronisation(activated_state, decay_alpha_action, decay_beta_action, r_action, synch_type='action')
# --- Interact with Data via Attention ---
q = self.q_proj(synchronisation_action).unsqueeze(1)
attn_out, attn_weights = self.attention(q, kv, kv, average_attn_weights=False, need_weights=True)
attn_out = attn_out.squeeze(1)
pre_synapse_input = torch.concatenate((attn_out, activated_state), dim=-1)
# --- Apply Synapses ---
state = self.synapses(pre_synapse_input)
state_trace = torch.cat((state_trace[:, :, 1:], state.unsqueeze(-1)), dim=-1)
# --- Activate ---
activated_state = self.trace_processor(state_trace)
# --- Calculate Synchronisation for Output Predictions ---
synchronisation_out, decay_alpha_out, decay_beta_out = self.compute_synchronisation(activated_state, decay_alpha_out, decay_beta_out, r_out, synch_type='out')
# --- Get Predictions and Certainties ---
current_prediction = self.output_projector(synchronisation_out)
current_certainty = self.compute_certainty(current_prediction)
predictions[..., stepi] = current_prediction
certainties[..., stepi] = current_certainty
# --- Tracking ---
if track:
pre_activations_tracking.append(state_trace[:,:,-1].detach().cpu().numpy())
post_activations_tracking.append(activated_state.detach().cpu().numpy())
attention_tracking.append(attn_weights.detach().cpu().numpy())
synch_out_tracking.append(synchronisation_out.detach().cpu().numpy())
synch_action_tracking.append(synchronisation_action.detach().cpu().numpy())
# --- Return Values ---
if track:
return predictions, certainties, (np.array(synch_out_tracking), np.array(synch_action_tracking)), np.array(pre_activations_tracking), np.array(post_activations_tracking), np.array(attention_tracking)
return predictions, certainties, synchronisation_outCertainty-Based Loss Function
The CTM produces outputs at each internal tick, so the question arises: how do we optimize the model across this internal temporal dimension?
Our answer is a simple but effective certainty-based loss that encourages the model to reason meaningfully across time. Instead of relying on the final tick alone, we aggregate loss from two key internal ticks:
- The tick where the prediction loss is lowest.
- The tick where the certainty (1 - normalized entropy) is highest.
We then take the average of the losses at these two points.
This approach encourages the CTM to both make accurate predictions and express high confidence in them—while supporting adaptive, interpretable computation over time.
def get_loss(predictions, certainties, targets, use_most_certain=True):
"""use_most_certain will select either the most certain point or the final point."""
losses = nn.CrossEntropyLoss(reduction='none')(predictions,
torch.repeat_interleave(targets.unsqueeze(-1), predictions.size(-1), -1))
loss_index_1 = losses.argmin(dim=1)
loss_index_2 = certainties[:,1].argmax(-1)
if not use_most_certain:
loss_index_2[:] = -1
batch_indexer = torch.arange(predictions.size(0), device=predictions.device)
loss_minimum_ce = losses[batch_indexer, loss_index_1].mean()
loss_selected = losses[batch_indexer, loss_index_2].mean()
loss = (loss_minimum_ce + loss_selected)/2
return loss, loss_index_2def calculate_accuracy(predictions, targets, where_most_certain):
"""Calculate the accuracy based on the prediction at the most certain internal tick."""
B = predictions.size(0)
device = predictions.device
predictions_at_most_certain_internal_tick = predictions.argmax(1)[torch.arange(B, device=device), where_most_certain].detach().cpu().numpy()
accuracy = (targets.detach().cpu().numpy() == predictions_at_most_certain_internal_tick).mean()
return accuracydef prepare_data():
transform = transforms.Compose([
transforms.RandomAffine(degrees=15, translate=(0.1, 0.1), scale=(0.9, 1.1)),
transforms.RandomRotation(10),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_data = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_data = datasets.MNIST(root="./data", train=False, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(train_data, batch_size=256, shuffle=True, num_workers=1)
testloader = torch.utils.data.DataLoader(test_data, batch_size=256, shuffle=True, num_workers=1, drop_last=False)
return trainloader, testloaderdef update_training_curve_plot(fig, ax1, ax2, train_losses, test_losses, train_accuracies, test_accuracies, steps):
clear_output(wait=True)
# Plot loss
ax1.clear()
ax1.plot(range(len(train_losses)), train_losses, 'b-', alpha=0.7, label=f'Train Loss: {train_losses[-1]:.3f}')
ax1.plot(steps, test_losses, 'r-', marker='o', label=f'Test Loss: {test_losses[-1]:.3f}')
ax1.set_title('Loss')
ax1.set_xlabel('Step')
ax1.set_ylabel('Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)
# Plot accuracy
ax2.clear()
ax2.plot(range(len(train_accuracies)), train_accuracies, 'b-', alpha=0.7, label=f'Train Accuracy: {train_accuracies[-1]:.3f}')
ax2.plot(steps, test_accuracies, 'r-', marker='o', label=f'Test Accuracy: {test_accuracies[-1]:.3f}')
ax2.set_title('Accuracy')
ax2.set_xlabel('Step')
ax2.set_ylabel('Accuracy')
ax2.legend()
ax2.grid(True, alpha=0.3)
plt.tight_layout()
display(fig)def train(model, trainloader, testloader, iterations, test_every, device):
optimizer = torch.optim.AdamW(params=list(model.parameters()), lr=0.0001, eps=1e-8)
iterator = iter(trainloader)
model.train()
train_losses = []
test_losses = []
train_accuracies = []
test_accuracies = []
steps = []
plt.ion()
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
with tqdm(total=iterations, initial=0, dynamic_ncols=True) as pbar:
test_loss = None
test_accuracy = None
for stepi in range(iterations):
try:
inputs, targets = next(iterator)
except StopIteration:
iterator = iter(trainloader)
inputs, targets = next(iterator)
inputs, targets = inputs.to(device), targets.to(device)
predictions, certainties, _ = model(inputs, track=False)
train_loss, where_most_certain = get_loss(predictions, certainties, targets)
train_accuracy = calculate_accuracy(predictions, targets, where_most_certain)
train_losses.append(train_loss.item())
train_accuracies.append(train_accuracy)
train_loss.backward()
optimizer.step()
optimizer.zero_grad()
if stepi % test_every == 0 or stepi == iterations - 1:
model.eval()
with torch.inference_mode():
all_test_predictions = []
all_test_targets = []
all_test_where_most_certain = []
all_test_losses = []
for inputs, targets in testloader:
inputs, targets = inputs.to(device), targets.to(device)
predictions, certainties, _ = model(inputs, track=False)
test_loss, where_most_certain = get_loss(predictions, certainties, targets)
all_test_losses.append(test_loss.item())
all_test_predictions.append(predictions)
all_test_targets.append(targets)
all_test_where_most_certain.append(where_most_certain)
all_test_predictions = torch.cat(all_test_predictions, dim=0)
all_test_targets = torch.cat(all_test_targets, dim=0)
all_test_where_most_certain = torch.cat(all_test_where_most_certain, dim=0)
test_accuracy = calculate_accuracy(all_test_predictions, all_test_targets, all_test_where_most_certain)
test_loss = sum(all_test_losses) / len(all_test_losses)
test_losses.append(test_loss)
test_accuracies.append(test_accuracy)
steps.append(stepi)
model.train()
update_training_curve_plot(fig, ax1, ax2, train_losses, test_losses, train_accuracies, test_accuracies, steps)
pbar.set_description(f'Train Loss: {train_loss:.3f}, Train Accuracy: {train_accuracy:.3f} Test Loss: {test_loss:.3f}, Test Accuracy: {test_accuracy:.3f}')
pbar.update(1)
plt.ioff()
plt.close(fig)
return modeldevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
trainloader, testloader = prepare_data()
model = ContinuousThoughtMachine(
iterations=15,
d_model=128,
d_input=128,
memory_length=10,
heads=2,
n_synch_out=16,
n_synch_action=16,
memory_hidden_dims=8,
out_dims=10,
dropout=0.0
).to(device)
sample_batch = next(iter(trainloader))
dummy_input = sample_batch[0][:1].to(device)
with torch.no_grad():
_ = model(dummy_input)
# Now compile the model
model = torch.compile(model)
def clamp_decay_params(module, _input):
with torch.no_grad():
module.decay_params_action.data.clamp_(0, 15)
module.decay_params_out.data.clamp_(0, 15)
model.register_forward_pre_hook(clamp_decay_params)
print(f'Model parameters: {sum(p.numel() for p in model.parameters()):,}')
model = train(model=model, trainloader=trainloader, testloader=testloader, iterations=10000, test_every=500, device=device)Visualizing CTM Dynamics
We define a function to create GIFs that show how the CTMs dynamics. These visualizations include:
- Neuron activations at each internal tick
- Attention patterns across the input
- Predictions and certainty at every step
def make_gif(predictions, certainties, targets, pre_activations, post_activations, attention, inputs_to_model, filename):
def reshape_attention_weights(attention, target_size=28):
# The num_positions will not be a perfect square if the input size is not a perfect square. If d_input is not a perfect sqaure, interpolate
T, B, num_heads, _, num_positions = attention.shape
attention = torch.tensor(attention, dtype=torch.float32).mean(dim=2).squeeze(2)
height = int(num_positions**0.5)
while num_positions % height != 0: height -= 1
width = num_positions // height
attention = attention.view(T, B, height, width)
return F.interpolate(attention, size=(target_size, target_size), mode='bilinear', align_corners=False)
batch_index = 0
n_neurons_to_visualise = 16
figscale = 0.28
n_steps = len(pre_activations)
heCTMap_cmap = sns.color_palette("viridis", as_cmap=True)
frames = []
attention = reshape_attention_weights(attention)
these_pre_acts = pre_activations[:, batch_index, :]
these_post_acts = post_activations[:, batch_index, :]
these_inputs = inputs_to_model[batch_index,:, :, :]
these_attention_weights = attention[:, batch_index, :, :]
these_predictions = predictions[batch_index, :, :]
these_certainties = certainties[batch_index, :, :]
this_target = targets[batch_index]
class_labels = [str(i) for i in range(these_predictions.shape[0])]
mosaic = [['img_data', 'img_data', 'attention', 'attention', 'probs', 'probs', 'probs', 'probs'] for _ in range(2)] + \
[['img_data', 'img_data', 'attention', 'attention', 'probs', 'probs', 'probs', 'probs'] for _ in range(2)] + \
[['certainty'] * 8] + \
[[f'trace_{ti}'] * 8 for ti in range(n_neurons_to_visualise)]
for stepi in range(n_steps):
fig_gif, axes_gif = plt.subplot_mosaic(mosaic=mosaic, figsize=(31*figscale*8/4, 76*figscale))
probs = softmax(these_predictions[:, stepi])
colors = [('g' if i == this_target else 'b') for i in range(len(probs))]
axes_gif['probs'].bar(np.arange(len(probs)), probs, color=colors, width=0.9, alpha=0.5)
axes_gif['probs'].set_title('Probabilities')
axes_gif['probs'].set_xticks(np.arange(len(probs)))
axes_gif['probs'].set_xticklabels(class_labels, fontsize=24)
axes_gif['probs'].set_yticks([])
axes_gif['probs'].tick_params(left=False, bottom=False)
axes_gif['probs'].set_ylim([0, 1])
for spine in axes_gif['probs'].spines.values():
spine.set_visible(False)
axes_gif['probs'].tick_params(left=False, bottom=False)
axes_gif['probs'].spines['top'].set_visible(False)
axes_gif['probs'].spines['right'].set_visible(False)
axes_gif['probs'].spines['left'].set_visible(False)
axes_gif['probs'].spines['bottom'].set_visible(False)
# Certainty
axes_gif['certainty'].plot(np.arange(n_steps), these_certainties[1], 'k-', linewidth=2)
axes_gif['certainty'].set_xlim([0, n_steps-1])
axes_gif['certainty'].axvline(x=stepi, color='black', linewidth=1, alpha=0.5)
axes_gif['certainty'].set_xticklabels([])
axes_gif['certainty'].set_yticklabels([])
axes_gif['certainty'].grid(False)
for spine in axes_gif['certainty'].spines.values():
spine.set_visible(False)
# Neuron Traces
for neuroni in range(n_neurons_to_visualise):
ax = axes_gif[f'trace_{neuroni}']
pre_activation = these_pre_acts[:, neuroni]
post_activation = these_post_acts[:, neuroni]
ax_pre = ax.twinx()
ax_pre.plot(np.arange(n_steps), pre_activation, color='grey', linestyle='--', linewidth=1, alpha=0.4)
color = 'blue' if neuroni % 2 else 'red'
ax.plot(np.arange(n_steps), post_activation, color=color, linewidth=2, alpha=1.0)
ax.set_xlim([0, n_steps-1])
ax_pre.set_xlim([0, n_steps-1])
ax.set_ylim([np.min(post_activation), np.max(post_activation)])
ax_pre.set_ylim([np.min(pre_activation), np.max(pre_activation)])
ax.axvline(x=stepi, color='black', linewidth=1, alpha=0.5)
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.grid(False)
ax_pre.set_xticklabels([])
ax_pre.set_yticklabels([])
ax_pre.grid(False)
for spine in ax.spines.values():
spine.set_visible(False)
for spine in ax_pre.spines.values():
spine.set_visible(False)
# Input image
this_image = these_inputs[0]
this_image = (this_image - this_image.min()) / (this_image.max() - this_image.min() + 1e-8)
axes_gif['img_data'].set_title('Input Image')
axes_gif['img_data'].imshow(this_image, cmap='binary', vmin=0, vmax=1)
axes_gif['img_data'].axis('off')
# Attention
this_input_gate = these_attention_weights[stepi]
gate_min, gate_max = np.nanmin(this_input_gate), np.nanmax(this_input_gate)
if not np.isclose(gate_min, gate_max):
normalized_gate = (this_input_gate - gate_min) / (gate_max - gate_min + 1e-8)
else:
normalized_gate = np.zeros_like(this_input_gate)
attention_weights_heCTMap = heCTMap_cmap(normalized_gate)[:,:,:3]
axes_gif['attention'].imshow(attention_weights_heCTMap, vmin=0, vmax=1)
axes_gif['attention'].axis('off')
axes_gif['attention'].set_title('Attention')
fig_gif.tight_layout()
canvas = fig_gif.canvas
canvas.draw()
image_numpy = np.frombuffer(canvas.buffer_rgba(), dtype='uint8')
image_numpy = image_numpy.reshape(*reversed(canvas.get_width_height()), 4)[:, :, :3]
frames.append(image_numpy)
plt.close(fig_gif)
mediapy.show_video(frames, width=400, codec="gif")
imageio.mimsave(filename, frames, fps=5, loop=100)The top row of the gif shows the input image, the attention weights and the models predictions.
The second plot, in black, shows the models certainty over time.
The red and blue lines correspond to the post-activations of different neurons, with the corresponding pre-activation shown in gray.
logdir = f"mnist_logs"
if not os.path.exists(logdir):
os.makedirs(logdir)
model.eval()
with torch.inference_mode():
inputs, targets = next(iter(testloader))
inputs = inputs.to(device)
predictions, certainties, (synch_out_tracking, synch_action_tracking), \
pre_activations_tracking, post_activations_tracking, attention = model(inputs, track=True)
make_gif(
predictions.detach().cpu().numpy(),
certainties.detach().cpu().numpy(),
targets.detach().cpu().numpy(),
pre_activations_tracking,
post_activations_tracking,
attention,
inputs.detach().cpu().numpy(),
f"{logdir}/prediction.gif"
)