alan-side-channel/alan/multi-processing.py
2025-07-05 20:29:24 +02:00

83 lines
3.1 KiB
Python

#!/usr/bin/python3
import numpy as np
import matplotlib.pyplot as plt
from multiprocessing import Pool
import aes
D = 6000 # Number of power traces (Number of Samples)
T = 87 # Number of data points per power trace (Points in time)
KEY_GUESSES = np.arange(256, dtype=np.uint8)
def calculate_models(
ciphertext: np.ndarray[np.ndarray[np.uint8]],
) -> np.ndarray[np.ndarray[np.ndarray[np.uint8]]]:
# duplicate each ciphertext 256 times to xor with all possible keys.
models = np.repeat(ciphertext, 256).reshape(16,256)
# create a view of models with the inner axis swaped, so when we xor with
# KEY_GUESSES numpy can use broadcast.
models_view = np.swapaxes(models, 0, 1)
# c ⊕ k_hyp
np.bitwise_xor(models, KEY_GUESSES, out=models)
# apply reverse rbox to all bytes. (rsbox(c ⊕ k_hyp))
models = np.vectorize(lambda x: aes.core.rsbox(x))(models_view)
return models
def read_msgs(file_name: str) -> np.ndarray[np.ndarray[np.ndarray[np.uint8]]]:
msgs = np.empty((D, 3, 16), dtype=np.uint8)
with open(file_name, 'r') as fd:
for idx, (key, plain_text) in enumerate(
(line.strip().split(',') for line in fd)
):
msgs[idx][0] = np.frombuffer(bytes.fromhex(key), dtype=np.uint8) # key
msgs[idx][1] = np.frombuffer(bytes.fromhex(plain_text), dtype=np.uint8) # plain text
msgs[idx][2] = np.array( # ciphertext
aes.aes(int(key, 16), 128).enc_once(int(plain_text, 16)), dtype=np.uint8
)
return msgs
def read_traces(file_name: str) -> np.ndarray[np.ndarray[np.uint8]]:
return np.loadtxt(file_name, delimiter=",", dtype=np.uint8)
if __name__ == "__main__":
msgs = read_msgs("Task-3-example_traces/test_msgs.csv")
traces = read_traces("Task-3-example_traces/test_traces.csv")
with Pool() as pool:
models = pool.map(calculate_models, msgs[:, 2])
models = np.stack(models)
# np.set_printoptions(formatter={"int": hex})
last_round_key = aes.core.key_expansion(msgs[0][0].tolist())[-16:]
for bit in range(128):
# i'th row, and j'th col is the correlation coefficient of key_hyp i and time sample j
model = np.bitwise_and(models[:, :, bit//8], np.array([2**(bit % 8)], dtype=np.uint8))
r = np.corrcoef(
model,
traces,
rowvar=False
)[:256, -87:]
guess = np.argmax(np.max(np.abs(r), axis=1))
# tmp = np.sort(r.flatten())
# confidence = max(abs(tmp[0] - tmp[1]), abs(tmp[-2] - tmp[-1]))
# if confidence > 0.005:
# fig, axs = plt.subplots(1, 1, layout='constrained')
# axs.set_title(f"Bit {bit%8 + 1} of Byte {bit//8 + 1} (Confidence: {confidence:.6f})")
# axs.plot(r.transpose(), alpha=0.3, color='grey')
# axs.plot(r[last_round_key[bit//8]], color="blue")
# axs.plot(r[guess], color="red")
# axs.set_xlabel("Time Samples")
# axs.set_ylabel("Correlation")
# axs.grid(True)
# plt.show()