Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions machine_learning/cnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""
Convolutional Neural Network (CNN) implementation for image classification.

Reference: https://en.wikipedia.org/wiki/Convolutional_neural_network

>>> import numpy as np
>>> model = SimpleCNN(input_shape=(1, 28, 28), num_classes=10)
>>> dummy_input = np.random.rand(1, 28, 28)
>>> output = model.forward(dummy_input)
>>> output.shape
(10,)
"""

import numpy as np


class SimpleCNN:
def __init__(self, input_shape: tuple[int, int, int], num_classes: int) -> None:
"""
Initialize a simple CNN model.

Args:
input_shape: Tuple of (channels, height, width)
num_classes: Number of output classes
"""
self.input_shape = input_shape
self.num_classes = num_classes
rng = np.random.default_rng()
self.filters = rng.normal(0, 0.1, size=(8, input_shape[0], 3, 3)) # 8 filters
self.fc_weights = rng.normal(0, 0.1, size=(8 * 26 * 26, num_classes))

def relu(self, feature_map: np.ndarray) -> np.ndarray:
"""Apply ReLU activation to the feature map."""
return np.maximum(0, feature_map)

def convolve(self, input_tensor: np.ndarray, filters: np.ndarray) -> np.ndarray:
"""Apply convolution operation to the input tensor."""
_, height, width = input_tensor.shape
num_filters, _, fh, fw = filters.shape
output = np.zeros((num_filters, height - fh + 1, width - fw + 1))

for f in range(num_filters):
for i in range(height - fh + 1):
for j in range(width - fw + 1):
region = input_tensor[:, i : i + fh, j : j + fw]
output[f, i, j] = np.sum(region * filters[f])
return output

def flatten(self, feature_map: np.ndarray) -> np.ndarray:
"""Flatten the feature map into a 1D array."""
return feature_map.reshape(-1)

def forward(self, input_tensor: np.ndarray) -> np.ndarray:
"""
Forward pass through the CNN.

Args:
input_tensor: Input image of shape (channels, height, width)

Returns:
Output logits of shape (num_classes,)
"""
conv_out = self.convolve(input_tensor, self.filters)
activated = self.relu(conv_out)
flattened = self.flatten(activated)
logits = flattened @ self.fc_weights
return logits