Algorytm random forest





Random Forest to zespół drzew decyzyjnych (ensemble), który zmniejsza przeuczenie i poprawia dokładność.

Wykorzystujemy train_test_split, aby uniknąć testowania modelu na tych samych danych, na których był trenowany.

feature_importances_ pokazuje, które cechy są najbardziej wpływowe w klasyfikacji.

Confusion matrix pomaga ocenić, jak model myli klasy.

Kod ładuje słynny zbiór danych kwiatu irysa (iris flower dataset), przygotowany przez R. A. Fishera w 1936 roku. Dane są wbudowane w scikit-learn, więc nie trzeba pobierać ich z internetu — działają offline.

# random_forest_example.py

from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

# 1. Wczytanie danych
iris = load_iris()
X = iris.data
y = iris.target


# Stwórz DataFrame dla lepszej czytelności
df = pd.DataFrame(iris.data, columns=iris.feature_names)
df['species'] = [iris.target_names[i] for i in iris.target]

# Wyświetl 5 pierwszych wierszy
print(df.head())

feature_names = iris.feature_names
target_names = iris.target_names

# 2. Podział na zbiór treningowy i testowy
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# 3. Trenowanie modelu Random Forest
clf = RandomForestClassifier(n_estimators=100, random_state=42)
clf.fit(X_train, y_train)

# 4. Predykcja i ocena modelu
y_pred = clf.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy:.2f}")

# 5. Macierz pomyłek
cm = confusion_matrix(y_test, y_pred, labels=clf.classes_)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=target_names)
disp.plot()
plt.title("Confusion Matrix")
plt.show()

# 6. Wagi cech (Feature Importance)
importances = clf.feature_importances_
forest_importances = pd.Series(importances, index=feature_names)

# 7. Wizualizacja ważności cech
sns.barplot(x=forest_importances.values, y=forest_importances.index)
plt.title("Feature Importances in Random Forest")
plt.xlabel("Importance")
plt.ylabel("Feature")
plt.tight_layout()
plt.show()

	
:)