Skip to content

Commit e26406b

Browse files
committed
Final Ruff-compliant ARIMA implementation
1 parent 533758f commit e26406b

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

machine_learning/arima.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,16 @@
1212
array([10.99999999, 12.00000001])
1313
"""
1414

15-
1615
import numpy as np
1716

1817

1918
class ARIMAModel:
20-
def __init__(self, ar_order: int = 1, diff_order: int = 0, ma_order: int = 0) -> None:
19+
def __init__(
20+
self,
21+
ar_order: int = 1,
22+
diff_order: int = 0,
23+
ma_order: int = 0,
24+
) -> None:
2125
"""Initialize ARIMA model.
2226
Args:
2327
ar_order: Autoregressive order (p)
@@ -50,14 +54,18 @@ def fit(self, time_series: np.ndarray) -> "ARIMAModel":
5054
"""
5155
y = np.asarray(time_series)
5256
y_diff = self.difference(y, self.diff_order)
57+
5358
# Build lagged feature matrix
54-
feature_matrix = np.column_stack([np.roll(y_diff, i) for i in range(1, self.ar_order + 1)])
59+
feature_matrix = np.column_stack(
60+
[np.roll(y_diff, i) for i in range(1, self.ar_order + 1)]
61+
)
5562
feature_matrix = feature_matrix[self.ar_order:]
5663
target = y_diff[self.ar_order:]
64+
5765
# Add intercept
58-
feature_matrix = np.hstack(
59-
[np.ones((feature_matrix.shape[0], 1)), feature_matrix]
60-
)
66+
intercept = np.ones((feature_matrix.shape[0], 1))
67+
feature_matrix = np.hstack([intercept, feature_matrix])
68+
6169
# Solve least squares for AR coefficients
6270
self.coef_ = np.linalg.lstsq(feature_matrix, target, rcond=None)[0]
6371
self.resid_ = target - feature_matrix @ self.coef_
@@ -82,7 +90,7 @@ def predict(self, time_series: np.ndarray, n_periods: int = 1) -> np.ndarray:
8290
y_pred = list(y[-self.ar_order:])
8391
for _ in range(n_periods):
8492
# Build feature vector for prediction
85-
features = [1] + y_pred[-self.ar_order:][::-1]
93+
features = [1, *y_pred[-self.ar_order:][::-1]]
8694
next_val = np.dot(features, self.coef_)
8795
y_pred.append(next_val)
8896
return np.array(y_pred[self.ar_order:])

0 commit comments

Comments
 (0)