Algorytm random forest
Random Forest to zespół drzew decyzyjnych (ensemble), który zmniejsza przeuczenie i poprawia dokładność.
Wykorzystujemy train_test_split, aby uniknąć testowania modelu na tych samych danych, na których był trenowany.
feature_importances_ pokazuje, które cechy są najbardziej wpływowe w klasyfikacji.
Confusion matrix pomaga ocenić, jak model myli klasy.
Kod ładuje słynny zbiór danych kwiatu irysa (iris flower dataset), przygotowany przez R. A. Fishera w 1936 roku. Dane są wbudowane w scikit-learn, więc nie trzeba pobierać ich z internetu — działają offline.
# random_forest_example.py
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
# 1. Wczytanie danych
iris = load_iris()
X = iris.data
y = iris.target
# Stwórz DataFrame dla lepszej czytelności
df = pd.DataFrame(iris.data, columns=iris.feature_names)
df['species'] = [iris.target_names[i] for i in iris.target]
# Wyświetl 5 pierwszych wierszy
print(df.head())
feature_names = iris.feature_names
target_names = iris.target_names
# 2. Podział na zbiór treningowy i testowy
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
# 3. Trenowanie modelu Random Forest
clf = RandomForestClassifier(n_estimators=100, random_state=42)
clf.fit(X_train, y_train)
# 4. Predykcja i ocena modelu
y_pred = clf.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy:.2f}")
# 5. Macierz pomyłek
cm = confusion_matrix(y_test, y_pred, labels=clf.classes_)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=target_names)
disp.plot()
plt.title("Confusion Matrix")
plt.show()
# 6. Wagi cech (Feature Importance)
importances = clf.feature_importances_
forest_importances = pd.Series(importances, index=feature_names)
# 7. Wizualizacja ważności cech
sns.barplot(x=forest_importances.values, y=forest_importances.index)
plt.title("Feature Importances in Random Forest")
plt.xlabel("Importance")
plt.ylabel("Feature")
plt.tight_layout()
plt.show()