Algorytm random forest
Random Forest to zespół drzew decyzyjnych (ensemble), który zmniejsza przeuczenie i poprawia dokładność. Wykorzystujemy train_test_split, aby uniknąć testowania modelu na tych samych danych, na których był trenowany. feature_importances_ pokazuje, które cechy są najbardziej wpływowe w klasyfikacji. Confusion matrix pomaga ocenić, jak model myli klasy. Kod ładuje słynny zbiór danych kwiatu irysa (iris flower dataset), przygotowany przez R. A. Fishera w 1936 roku. Dane są wbudowane w scikit-learn, więc nie trzeba pobierać ich z internetu — działają offline.# random_forest_example.py from sklearn.datasets import load_iris from sklearn.ensemble import RandomForestClassifier from sklearn.model_selection import train_test_split from sklearn.metrics import accuracy_score, confusion_matrix, ConfusionMatrixDisplay import matplotlib.pyplot as plt import pandas as pd import seaborn as sns # 1. Wczytanie danych iris = load_iris() X = iris.data y = iris.target # Stwórz DataFrame dla lepszej czytelności df = pd.DataFrame(iris.data, columns=iris.feature_names) df['species'] = [iris.target_names[i] for i in iris.target] # Wyświetl 5 pierwszych wierszy print(df.head()) feature_names = iris.feature_names target_names = iris.target_names # 2. Podział na zbiór treningowy i testowy X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42) # 3. Trenowanie modelu Random Forest clf = RandomForestClassifier(n_estimators=100, random_state=42) clf.fit(X_train, y_train) # 4. Predykcja i ocena modelu y_pred = clf.predict(X_test) accuracy = accuracy_score(y_test, y_pred) print(f"Accuracy: {accuracy:.2f}") # 5. Macierz pomyłek cm = confusion_matrix(y_test, y_pred, labels=clf.classes_) disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=target_names) disp.plot() plt.title("Confusion Matrix") plt.show() # 6. Wagi cech (Feature Importance) importances = clf.feature_importances_ forest_importances = pd.Series(importances, index=feature_names) # 7. Wizualizacja ważności cech sns.barplot(x=forest_importances.values, y=forest_importances.index) plt.title("Feature Importances in Random Forest") plt.xlabel("Importance") plt.ylabel("Feature") plt.tight_layout() plt.show()