#Импортируем все необходимые библиотеки import pandas as pd from catboost import CatBoostClassifier from sklearn.model_selection import train_test_split from sklearn.metrics import accuracy_score import numpy as np import matplotlib.pyplot as plt import seaborn as sns import json
# 🔕 Отключаем предупреждения, чтобы не загромождали вывод import warnings warnings.filterwarnings('ignore')
### Установим красивые дефолтные настройки ### Может быть лень постоянно прописывать ### У графиков параметры цвета, размера, шрифта ### Можно положить их в словарь дефолтных настроек import matplotlib as mlp # Сетка (grid) mlp.rcParams['axes.grid'] = True mlp.rcParams['grid.color'] = '#D3D3D3' mlp.rcParams['grid.linestyle'] = '--' mlp.rcParams['grid.linewidth'] = 1 # Цвет фона mlp.rcParams['axes.facecolor'] = '#F9F9F9' # светло-серый фон внутри графика mlp.rcParams['figure.facecolor'] = '#FFFFFF' # фон всего холста # Легенда mlp.rcParams['legend.fontsize'] = 14 mlp.rcParams['legend.frameon'] = True mlp.rcParams['legend.framealpha'] = 0.9 mlp.rcParams['legend.edgecolor'] = '#333333' # Размер фигуры по умолчанию mlp.rcParams['figure.figsize'] = (10, 6) # Шрифты mlp.rcParams['font.family'] = 'DejaVu Sans' # можешь заменить на 'Arial', 'Roboto' и т.д. mlp.rcParams['font.size'] = 16 # Цвет осей (спинки) mlp.rcParams['axes.edgecolor'] = '#333333' mlp.rcParams['axes.linewidth'] = 2 # Цвет основного текста mlp.rcParams['text.color'] = '#222222'
# Отдельно скачиваю train... train_df = pd.read_csv('../data/train.csv')
# ... и отдельно test test_df = pd.read_csv('../data/test.csv')
# Посмотрим первые 10 строк train'a train_df.head(10)
|
|
PassengerId |
Survived |
Pclass |
Name |
Sex |
Age |
SibSp |
Parch |
Ticket |
Fare |
Cabin |
Embarked |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
0 |
1 |
0 |
3 |
Braund, Mr. Owen Harris |
male |
22.0 |
1 |
0 |
A/5 21171 |
7.2500 |
NaN |
S |
|
1 |
2 |
1 |
1 |
Cumings, Mrs. John Bradley (Florence Briggs Th… |
female |
38.0 |
1 |
0 |
PC 17599 |
71.2833 |
C85 |
C |
|
2 |
3 |
1 |
3 |
Heikkinen, Miss. Laina |
female |
26.0 |
0 |
0 |
STON/O2. 3101282 |
7.9250 |
NaN |
S |
|
3 |
4 |
1 |
1 |
Futrelle, Mrs. Jacques Heath (Lily May Peel) |
female |
35.0 |
1 |
0 |
113803 |
53.1000 |
C123 |
S |
|
4 |
5 |
0 |
3 |
Allen, Mr. William Henry |
male |
35.0 |
0 |
0 |
373450 |
8.0500 |
NaN |
S |
|
5 |
6 |
0 |
3 |
Moran, Mr. James |
male |
NaN |
0 |
0 |
330877 |
8.4583 |
NaN |
Q |
|
6 |
7 |
0 |
1 |
McCarthy, Mr. Timothy J |
male |
54.0 |
0 |
0 |
17463 |
51.8625 |
E46 |
S |
|
7 |
8 |
0 |
3 |
Palsson, Master. Gosta Leonard |
male |
2.0 |
3 |
1 |
349909 |
21.0750 |
NaN |
S |
|
8 |
9 |
1 |
3 |
Johnson, Mrs. Oscar W (Elisabeth Vilhelmina Berg) |
female |
27.0 |
0 |
2 |
347742 |
11.1333 |
NaN |
S |
|
9 |
10 |
1 |
2 |
Nasser, Mrs. Nicholas (Adele Achem) |
female |
14.0 |
1 |
0 |
237736 |
30.0708 |
NaN |
C |
# Посмотрим информацию по train'у train_df.info()
RangeIndex: 891 entries, 0 to 890 Data columns (total 12 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 PassengerId 891 non-null int64 1 Survived 891 non-null int64 2 Pclass 891 non-null int64 3 Name 891 non-null object 4 Sex 891 non-null object 5 Age 714 non-null float64 6 SibSp 891 non-null int64 7 Parch 891 non-null int64 8 Ticket 891 non-null object 9 Fare 891 non-null float64 10 Cabin 204 non-null object 11 Embarked 889 non-null object dtypes: float64(2), int64(5), object(5) memory usage: 83.7+ KB
Видно, что есть пропуски в признаке возраст. Очень много пропусков в признаке кабина
# Получаем базовую статистику по числовым признакам (среднее, медиана, std и т.д.) train_df.describe()
|
|
PassengerId |
Survived |
Pclass |
Age |
SibSp |
Parch |
Fare |
|---|---|---|---|---|---|---|---|
|
count |
891.000000 |
891.000000 |
891.000000 |
714.000000 |
891.000000 |
891.000000 |
891.000000 |
|
mean |
446.000000 |
0.383838 |
2.308642 |
29.699118 |
0.523008 |
0.381594 |
32.204208 |
|
std |
257.353842 |
0.486592 |
0.836071 |
14.526497 |
1.102743 |
0.806057 |
49.693429 |
|
min |
1.000000 |
0.000000 |
1.000000 |
0.420000 |
0.000000 |
0.000000 |
0.000000 |
|
25% |
223.500000 |
0.000000 |
2.000000 |
20.125000 |
0.000000 |
0.000000 |
7.910400 |
|
50% |
446.000000 |
0.000000 |
3.000000 |
28.000000 |
0.000000 |
0.000000 |
14.454200 |
|
75% |
668.500000 |
1.000000 |
3.000000 |
38.000000 |
1.000000 |
0.000000 |
31.000000 |
|
max |
891.000000 |
1.000000 |
3.000000 |
80.000000 |
8.000000 |
6.000000 |
512.329200 |
# И статистику по объектным признакам train_df.describe(include='object')
|
|
Name |
Sex |
Ticket |
Cabin |
Embarked |
|---|---|---|---|---|---|
|
count |
891 |
891 |
891 |
204 |
889 |
|
unique |
891 |
2 |
681 |
147 |
3 |
|
top |
Braund, Mr. Owen Harris |
male |
347082 |
G6 |
S |
|
freq |
1 |
577 |
7 |
4 |
644 |
# Посмотрим на распределение таргета sns.countplot(x='Survived', data=train_df) plt.title('Распределение выживших') plt.show()

Дисбаланса классов не неаблюдаем
# Посмотрим на распределение таргета по признакам # Пол plt.figure(figsize=(5,4)) sns.countplot(x='Sex', hue='Survived', data=train_df) plt.title('Пол и выживание') plt.show() # Класс каюты plt.figure(figsize=(5,4)) sns.countplot(x='Pclass', hue='Survived', data=train_df) plt.title('Класс каюты и выживание') plt.show() # Порт посадки plt.figure(figsize=(5,4)) sns.countplot(x='Embarked', hue='Survived', data=train_df) plt.title('Порт посадки и выживание') plt.show()

Видно, что все признаки являются важными для таргета. В противном случае графики для разных признаков были бы одинаковыми.
# Для возраста и платы за проезд посмотрим на ящики с усами # Возраст plt.figure(figsize=(6,5)) sns.boxplot(x='Survived', y='Age', data=train_df) plt.title('Возраст и выживание (boxplot)') plt.show() # Fare plt.figure(figsize=(6,5)) sns.boxplot(x='Survived', y='Fare', data=train_df) plt.title('Стоимость билета и выживание (boxplot)') plt.show()
Тоже видны различия в зависимости от таргета
# Сделаем списки: категориальные и числовые признаки categorical_cols = ['Sex', 'Pclass', 'Embarked', 'Cabin'] numeric_cols = ['Age', 'Fare', 'SibSp', 'Parch']
# Посмотрим тепловую карту корреляции между числовыми признаками plt.figure(figsize=(10,8)) sns.heatmap(train_df.corr(numeric_only=True), annot=True, cmap='coolwarm', fmt=".2f") plt.title('Корреляция числовых признаков') plt.show()
Визуально видно, что мультиколлинеарных признаков нет, но проверим с помощью функции
### Секретная функция со Stackovervlow для фильтрации признаков def get_redundant_pairs(df): pairs_to_drop = set() cols = df.columns for i in range(0, df.shape[1]): for j in range(0, i+1): pairs_to_drop.add((cols[i], cols[j])) return pairs_to_drop def get_top_abs_correlations(df, n=5): au_corr = df.corr().abs().unstack() labels_to_drop = get_redundant_pairs(df) au_corr = au_corr.drop(labels=labels_to_drop).sort_values(ascending=False) return au_corr[0:n] print("Top Absolute Correlations") print(get_top_abs_correlations(train_df[numeric_cols], 10))
Top Absolute Correlations SibSp Parch 0.414838 Age SibSp 0.308247 Fare Parch 0.216225 Age Parch 0.189119 Fare SibSp 0.159651 Age Fare 0.096067 dtype: float64
Мультиколлинеарности нет — подтверждаем
# Смотрим пропуски train_df.isnull().sum()
PassengerId 0 Survived 0 Pclass 0 Name 0 Sex 0 Age 177 SibSp 0 Parch 0 Ticket 0 Fare 0 Cabin 687 Embarked 2 dtype: int64
#Удаляем пропуски в колонке "Embarked" #Так как их всего два #Сейчас до объединения #Тренировочных и тестовых данных #Можно это делать train_df = train_df.dropna(subset=['Embarked']).copy()
#Смотрим снова train_df.isnull().sum()
PassengerId 0 Survived 0 Pclass 0 Name 0 Sex 0 Age 177 SibSp 0 Parch 0 Ticket 0 Fare 0 Cabin 687 Embarked 0 dtype: int64
После удаления можно объединять строки. Удаляли до объединения, чтобы не повредить тестовые данные.
Теперь можно готовить данные к объединению и обработке общего датафрейма.
Обязательно нужно всё сделать правильно, чтобы не испортить данные
#Добавим колонку-метку, чтобы потом правильно разделить данные обратно train_df['is_train'] = 1 test_df['is_train'] = 0
#Добавим фиктивную колонку `Survived` в тест (чтобы структура была одинаковая) test_df['Survived'] = np.nan
#Сохраняем PassengerId из теста для submission passenger_ids = test_df['PassengerId'].copy()
(колонка не важна для обучения, но требуется в итоговом файле решения)
#Удаляем колонку PassengerId перед объединением — она не нужна для модели train_df = train_df.drop(columns=['PassengerId']) test_df = test_df.drop(columns=['PassengerId'])
#Объединяем тренировочные и тестовые данные для одинаковой обработки данных full_df = pd.concat([train_df, test_df], axis=0).reset_index(drop=True)
#Пропущенные значения к признаке Age заменяем медианным значением по всем пассажирам #Но которые были только в тренировочных данных #Медианное значение менее чувствительно к выбросам в данных full_df['Age'] = full_df['Age'].fillna(train_df['Age'].median())
# Снова смотрим пропуски, но уже в объединённом датафрейме full_df.isnull().sum()
Survived 418 Pclass 0 Name 0 Sex 0 Age 0 SibSp 0 Parch 0 Ticket 0 Fare 1 Cabin 1014 Embarked 0 is_train 0 dtype: int64
# Посмотрим ещё раз на данные, чтобы принять решение, что делать с признаком плата за проезд full_df.head(20)
|
|
Survived |
Pclass |
Name |
Sex |
Age |
SibSp |
Parch |
Ticket |
Fare |
Cabin |
Embarked |
is_train |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
0 |
0.0 |
3 |
Braund, Mr. Owen Harris |
male |
22.0 |
1 |
0 |
A/5 21171 |
7.2500 |
NaN |
S |
1 |
|
1 |
1.0 |
1 |
Cumings, Mrs. John Bradley (Florence Briggs Th… |
female |
38.0 |
1 |
0 |
PC 17599 |
71.2833 |
C85 |
C |
1 |
|
2 |
1.0 |
3 |
Heikkinen, Miss. Laina |
female |
26.0 |
0 |
0 |
STON/O2. 3101282 |
7.9250 |
NaN |
S |
1 |
|
3 |
1.0 |
1 |
Futrelle, Mrs. Jacques Heath (Lily May Peel) |
female |
35.0 |
1 |
0 |
113803 |
53.1000 |
C123 |
S |
1 |
|
4 |
0.0 |
3 |
Allen, Mr. William Henry |
male |
35.0 |
0 |
0 |
373450 |
8.0500 |
NaN |
S |
1 |
|
5 |
0.0 |
3 |
Moran, Mr. James |
male |
28.0 |
0 |
0 |
330877 |
8.4583 |
NaN |
Q |
1 |
|
6 |
0.0 |
1 |
McCarthy, Mr. Timothy J |
male |
54.0 |
0 |
0 |
17463 |
51.8625 |
E46 |
S |
1 |
|
7 |
0.0 |
3 |
Palsson, Master. Gosta Leonard |
male |
2.0 |
3 |
1 |
349909 |
21.0750 |
NaN |
S |
1 |
|
8 |
1.0 |
3 |
Johnson, Mrs. Oscar W (Elisabeth Vilhelmina Berg) |
female |
27.0 |
0 |
2 |
347742 |
11.1333 |
NaN |
S |
1 |
|
9 |
1.0 |
2 |
Nasser, Mrs. Nicholas (Adele Achem) |
female |
14.0 |
1 |
0 |
237736 |
30.0708 |
NaN |
C |
1 |
|
10 |
1.0 |
3 |
Sandstrom, Miss. Marguerite Rut |
female |
4.0 |
1 |
1 |
PP 9549 |
16.7000 |
G6 |
S |
1 |
|
11 |
1.0 |
1 |
Bonnell, Miss. Elizabeth |
female |
58.0 |
0 |
0 |
113783 |
26.5500 |
C103 |
S |
1 |
|
12 |
0.0 |
3 |
Saundercock, Mr. William Henry |
male |
20.0 |
0 |
0 |
A/5. 2151 |
8.0500 |
NaN |
S |
1 |
|
13 |
0.0 |
3 |
Andersson, Mr. Anders Johan |
male |
39.0 |
1 |
5 |
347082 |
31.2750 |
NaN |
S |
1 |
|
14 |
0.0 |
3 |
Vestrom, Miss. Hulda Amanda Adolfina |
female |
14.0 |
0 |
0 |
350406 |
7.8542 |
NaN |
S |
1 |
|
15 |
1.0 |
2 |
Hewlett, Mrs. (Mary D Kingcome) |
female |
55.0 |
0 |
0 |
248706 |
16.0000 |
NaN |
S |
1 |
|
16 |
0.0 |
3 |
Rice, Master. Eugene |
male |
2.0 |
4 |
1 |
382652 |
29.1250 |
NaN |
Q |
1 |
|
17 |
1.0 |
2 |
Williams, Mr. Charles Eugene |
male |
28.0 |
0 |
0 |
244373 |
13.0000 |
NaN |
S |
1 |
|
18 |
0.0 |
3 |
Vander Planke, Mrs. Julius (Emelia Maria Vande… |
female |
31.0 |
1 |
0 |
345763 |
18.0000 |
NaN |
S |
1 |
|
19 |
1.0 |
3 |
Masselmani, Mrs. Fatima |
female |
28.0 |
0 |
0 |
2649 |
7.2250 |
NaN |
C |
1 |
#Также закодируем и цену за проезд #Удалять после объединения нельзя - можно удалить строку из тестовых данных full_df['Fare'] = full_df['Fare'].fillna(train_df['Fare'].median())
# Проверяем full_df.isnull().sum()
Survived 418 Pclass 0 Name 0 Sex 0 Age 0 SibSp 0 Parch 0 Ticket 0 Fare 0 Cabin 1014 Embarked 0 is_train 0 dtype: int64
Три четверти данных по колонке Cabin в тестовых данных являются NaN. Скорее всего, это пассажиры второго или третьего класса, у которых просто не было собственной кабины. Совсем избавляться от этой колонки, наверное, не стоит — и не обязательно. Вместо этого мы закодируем её как наличие или отсутствие палубы.
# Создаём новый бинарный признак: была ли указана каюта full_df['Has_Cabin'] = full_df['Cabin'].notnull().astype(int)
# Удаляем оригинальную колонку Cabin, чтобы она не мешала full_df = full_df.drop(columns='Cabin')
# Посмотрим на наш изменённый датафрейм full_df
|
|
Survived |
Pclass |
Name |
Sex |
Age |
SibSp |
Parch |
Ticket |
Fare |
Embarked |
is_train |
Has_Cabin |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
0 |
0.0 |
3 |
Braund, Mr. Owen Harris |
male |
22.0 |
1 |
0 |
A/5 21171 |
7.2500 |
S |
1 |
0 |
|
1 |
1.0 |
1 |
Cumings, Mrs. John Bradley (Florence Briggs Th… |
female |
38.0 |
1 |
0 |
PC 17599 |
71.2833 |
C |
1 |
1 |
|
2 |
1.0 |
3 |
Heikkinen, Miss. Laina |
female |
26.0 |
0 |
0 |
STON/O2. 3101282 |
7.9250 |
S |
1 |
0 |
|
3 |
1.0 |
1 |
Futrelle, Mrs. Jacques Heath (Lily May Peel) |
female |
35.0 |
1 |
0 |
113803 |
53.1000 |
S |
1 |
1 |
|
4 |
0.0 |
3 |
Allen, Mr. William Henry |
male |
35.0 |
0 |
0 |
373450 |
8.0500 |
S |
1 |
0 |
|
… |
… |
… |
… |
… |
… |
… |
… |
… |
… |
… |
… |
… |
|
1302 |
NaN |
3 |
Spector, Mr. Woolf |
male |
28.0 |
0 |
0 |
A.5. 3236 |
8.0500 |
S |
0 |
0 |
|
1303 |
NaN |
1 |
Oliva y Ocana, Dona. Fermina |
female |
39.0 |
0 |
0 |
PC 17758 |
108.9000 |
C |
0 |
1 |
|
1304 |
NaN |
3 |
Saether, Mr. Simon Sivertsen |
male |
38.5 |
0 |
0 |
SOTON/O.Q. 3101262 |
7.2500 |
S |
0 |
0 |
|
1305 |
NaN |
3 |
Ware, Mr. Frederick |
male |
28.0 |
0 |
0 |
359309 |
8.0500 |
S |
0 |
0 |
|
1306 |
NaN |
3 |
Peter, Master. Michael J |
male |
28.0 |
1 |
1 |
2668 |
22.3583 |
C |
0 |
0 |
1307 rows × 12 columns
Начинаем избавляться от ненужных, не несущих полезной информации для обучения модели, признаков
# Имя и номер билета удаляем full_df = full_df.drop(columns=['Name','Ticket'])
# Смотрим результат full_df
|
|
Survived |
Pclass |
Sex |
Age |
SibSp |
Parch |
Fare |
Embarked |
is_train |
Has_Cabin |
|---|---|---|---|---|---|---|---|---|---|---|
|
0 |
0.0 |
3 |
male |
22.0 |
1 |
0 |
7.2500 |
S |
1 |
0 |
|
1 |
1.0 |
1 |
female |
38.0 |
1 |
0 |
71.2833 |
C |
1 |
1 |
|
2 |
1.0 |
3 |
female |
26.0 |
0 |
0 |
7.9250 |
S |
1 |
0 |
|
3 |
1.0 |
1 |
female |
35.0 |
1 |
0 |
53.1000 |
S |
1 |
1 |
|
4 |
0.0 |
3 |
male |
35.0 |
0 |
0 |
8.0500 |
S |
1 |
0 |
|
… |
… |
… |
… |
… |
… |
… |
… |
… |
… |
… |
|
1302 |
NaN |
3 |
male |
28.0 |
0 |
0 |
8.0500 |
S |
0 |
0 |
|
1303 |
NaN |
1 |
female |
39.0 |
0 |
0 |
108.9000 |
C |
0 |
1 |
|
1304 |
NaN |
3 |
male |
38.5 |
0 |
0 |
7.2500 |
S |
0 |
0 |
|
1305 |
NaN |
3 |
male |
28.0 |
0 |
0 |
8.0500 |
S |
0 |
0 |
|
1306 |
NaN |
3 |
male |
28.0 |
1 |
1 |
22.3583 |
C |
0 |
0 |
1307 rows × 10 columns
В целом не стесняемся часто смотреть и проверять результат. В процессе постоянного отсмотра данных может придти идея, которая улучшит качество модели
Теперь закодируем колонки, которые являются объектами (object), как категории (category). Это особенно важно, если мы собираемся использовать модель CatBoost — она умеет напрямую работать с категориальными признаками и не требует их one-hot-кодирования.
CatBoost сам обработает эти признаки, если они будут иметь тип category, поэтому просто приведём нужные колонки к этому типу.
# Приводим колонки пол и порт посадки к категориальному виду full_df['Sex'] = full_df['Sex'].astype('category') full_df['Embarked'] = full_df['Embarked'].astype('category')
# Проверяем full_df.describe(include='all')
|
|
Survived |
Pclass |
Sex |
Age |
SibSp |
Parch |
Fare |
Embarked |
is_train |
Has_Cabin |
|---|---|---|---|---|---|---|---|---|---|---|
|
count |
889.000000 |
1307.000000 |
1307 |
1307.000000 |
1307.000000 |
1307.000000 |
1307.000000 |
1307 |
1307.000000 |
1307.000000 |
|
unique |
NaN |
NaN |
2 |
NaN |
NaN |
NaN |
NaN |
3 |
NaN |
NaN |
|
top |
NaN |
NaN |
male |
NaN |
NaN |
NaN |
NaN |
S |
NaN |
NaN |
|
freq |
NaN |
NaN |
843 |
NaN |
NaN |
NaN |
NaN |
914 |
NaN |
NaN |
|
mean |
0.382452 |
2.296863 |
NaN |
29.471821 |
0.499617 |
0.385616 |
33.209595 |
NaN |
0.680184 |
0.224178 |
|
std |
0.486260 |
0.836942 |
NaN |
12.881592 |
1.042273 |
0.866092 |
51.748768 |
NaN |
0.466584 |
0.417199 |
|
min |
0.000000 |
1.000000 |
NaN |
0.170000 |
0.000000 |
0.000000 |
0.000000 |
NaN |
0.000000 |
0.000000 |
|
25% |
0.000000 |
2.000000 |
NaN |
22.000000 |
0.000000 |
0.000000 |
7.895800 |
NaN |
0.000000 |
0.000000 |
|
50% |
0.000000 |
3.000000 |
NaN |
28.000000 |
0.000000 |
0.000000 |
14.454200 |
NaN |
1.000000 |
0.000000 |
|
75% |
1.000000 |
3.000000 |
NaN |
35.000000 |
1.000000 |
0.000000 |
31.275000 |
NaN |
1.000000 |
0.000000 |
|
max |
1.000000 |
3.000000 |
NaN |
80.000000 |
8.000000 |
9.000000 |
512.329200 |
NaN |
1.000000 |
1.000000 |
Кроме того, колонка Pclass изначально имеет тип int, но на самом деле это категориальный признак (класс обслуживания: 1, 2 или 3). Если оставить её как числовую, модель может ошибочно посчитать, что класс 3 «больше» и важнее, чем класс 2, а тот — важнее, чем класс 1. Чтобы избежать этого, мы также приведём Pclass к категориальному типу.
# Приводим колонку класс к категориальному виду full_df['Pclass'] = full_df['Pclass'].astype('category')
Обработка данных завершена и теперь разделяем данные обратно:
# Разделим обратно: X_train = full_df[full_df['is_train'] == 1].drop(['is_train', 'Survived'], axis=1) y_train = full_df[full_df['is_train'] == 1]['Survived'] X_test = full_df[full_df['is_train'] == 0].drop(['is_train', 'Survived'], axis=1)
# Проверяем размеры print(X_train.shape, y_train.shape, X_test.shape)
(889, 8) (889,) (418, 8)
Начинаем обучение модели
# Начинаем обучение # Сначали сплиттим выборку X_train_split, X_valid, y_train_split, y_valid = train_test_split( X_train, y_train, test_size=0.2, random_state=42, stratify=y_train )
# Положим в список категориальных признаков для CatBoost наши приведённые к типу Category колонки cat_features = X_train.select_dtypes(include='category').columns.tolist()
# Проверяем cat_features
['Pclass', 'Sex', 'Embarked']
# Обучаем модель с достаточно средними параметрами # Пока не используем перебор гиперпараметров model = CatBoostClassifier( iterations=1000, learning_rate=0.05, depth=6, eval_metric='Accuracy', random_seed=42, early_stopping_rounds=50, verbose=100 ) model.fit( X_train_split, y_train_split, eval_set=(X_valid, y_valid), cat_features=cat_features )
0:learn: 0.8227848test: 0.7977528best: 0.7977528 (0)total: 160msremaining: 2m 39s Stopped by overfitting detector (50 iterations wait) bestTest = 0.8314606742 bestIteration = 32 Shrink model to first 33 iterations.
# Оценим качество y_pred = model.predict(X_valid) acc = accuracy_score(y_valid, y_pred) print(f"Validation Accuracy: {acc:.4f}")
Validation Accuracy: 0.8315
Доля правильных ответов 83,15%
# Предсказание на тесте test_preds = model.predict(X_test)
# Посмотрим на предсказания модели test_preds
array([0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 1., 0., 1., 1., 0., 0., 0., 1., 0., 0., 1., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 1., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 1., 1., 1., 0., 0., 1., 1., 0., 0., 0., 1., 0., 0., 1., 0., 1., 1., 0., 0., 0., 0., 0., 1., 0., 1., 1., 0., 0., 1., 0., 0., 0., 1., 0., 1., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 1., 0., 1., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 1., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 1., 0., 1., 1., 0., 0., 1., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 1., 0., 0., 1., 1., 0., 1., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 1., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 1., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 1., 0., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 1., 0., 0., 0., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 1., 0., 1., 1., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 1., 0., 1., 0., 0., 0., 1., 0., 0., 1., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 1., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 0., 1., 0., 0., 0.])
# Создание submission.csv submission = pd.DataFrame({ 'PassengerId': passenger_ids, 'Survived': test_preds.astype(int) }) submission.to_csv('../submissions/submission.csv', index=False) print("✅ Submission файл сохранён как submission.csv")
✅ Submission файл сохранён как submission.csv
# Посмотрим файл submission
|
|
PassengerId |
Survived |
|---|---|---|
|
0 |
892 |
0 |
|
1 |
893 |
0 |
|
2 |
894 |
0 |
|
3 |
895 |
0 |
|
4 |
896 |
0 |
|
… |
… |
… |
|
413 |
1305 |
0 |
|
414 |
1306 |
1 |
|
415 |
1307 |
0 |
|
416 |
1308 |
0 |
|
417 |
1309 |
0 |
418 rows × 2 columns
Результат Топ 2’400
Теперь попробуем улучшить качество модели с помощью подбора гиперпараметров
Буду использовать случайный подбор параметров с помощью RandomizedSearchCV
# Импортируем модуль from sklearn.model_selection import RandomizedSearchCV
#Сетка гиперпараметров param_grid = { 'depth': [4, 6, 8, 10], # Максимальная глубина дерева (чем глубже — тем сложнее модель) 'learning_rate': [0.01, 0.05, 0.1], # Скорость обучения (маленькое значение = медленнее обучение, но может быть точнее) 'iterations': [300, 500, 1000], # Количество деревьев (итераций бустинга) 'l2_leaf_reg': [1, 3, 5, 7, 9], # L2-регуляризация — предотвращает переобучение 'border_count': [32, 64, 128] # Количество бинов для дискретизации числовых признаков } #Randomized Search с кросс-валидацией random_search = RandomizedSearchCV( estimator=model, param_distributions=param_grid, n_iter=45, # Сколько случайных комбинаций попробовать scoring='accuracy', # Метрика качества, которую нужно максимизировать cv=10, # Количество фолдов (разбиений) для кросс-валидации verbose=2, # Показывать процесс обучения в терминале n_jobs=-1 # Использовать все доступные ядра CPU для ускорения )
# Создаём экземпляр модели model = CatBoostClassifier(silent=True, random_state=42) # random state фиксированный
# Фиксируем некоторые параметры для модели fit_params = { "eval_set": [(X_valid, y_valid)], # Набор валидационных данных (для контроля переобучения и использования early stopping) "early_stopping_rounds": 100, # Если метрика не улучшается в течение 100 итераций — обучение остановится "cat_features": cat_features, # Указываем, какие признаки являются категориальными (CatBoost работает с ними нативно) "verbose": 1 # Показывать прогресс обучения во время тренировки }
#Запуск подбора random_search.fit(X_train_split, y_train_split, **fit_params)
Fitting 10 folds for each of 45 candidates, totalling 450 fits 0:learn: 0.7988748test: 0.7752809best: 0.7752809 (0)total: 18.8msremaining: 5.63s 1:learn: 0.8016878test: 0.7808989best: 0.7808989 (1)total: 39.8msremaining: 5.93s 2:learn: 0.8101266test: 0.7921348best: 0.7921348 (2)total: 56.4msremaining: 5.58s 3:learn: 0.8045007test: 0.7865169best: 0.7921348 (2)total: 76.9msremaining: 5.69s 4:learn: 0.8030942test: 0.7865169best: 0.7921348 (2)total: 97.1msremaining: 5.73s 5:learn: 0.8087201test: 0.7977528best: 0.7977528 (5)total: 118msremaining: 5.78s 6:learn: 0.8087201test: 0.7977528best: 0.7977528 (5)total: 139msremaining: 5.82s 7:learn: 0.8101266test: 0.8033708best: 0.8033708 (7)total: 160msremaining: 5.85s 8:learn: 0.8101266test: 0.7977528best: 0.8033708 (7)total: 181msremaining: 5.86s 9:learn: 0.8101266test: 0.7977528best: 0.8033708 (7)total: 201msremaining: 5.82s 10:learn: 0.8101266test: 0.7977528best: 0.8033708 (7)total: 220msremaining: 5.77s 11:learn: 0.8115331test: 0.7977528best: 0.8033708 (7)total: 241msremaining: 5.78s 12:learn: 0.8171589test: 0.7977528best: 0.8033708 (7)total: 262msremaining: 5.78s 13:learn: 0.8185654test: 0.7977528best: 0.8033708 (7)total: 293msremaining: 5.98s 14:learn: 0.8185654test: 0.8033708best: 0.8033708 (7)total: 322msremaining: 6.12s 15:learn: 0.8185654test: 0.8033708best: 0.8033708 (7)total: 348msremaining: 6.17s 16:learn: 0.8185654test: 0.8033708best: 0.8033708 (7)total: 369msremaining: 6.14s 17:learn: 0.8185654test: 0.8033708best: 0.8033708 (7)total: 385msremaining: 6.03s 18:learn: 0.8199719test: 0.8033708best: 0.8033708 (7)total: 407msremaining: 6.03s 19:learn: 0.8227848test: 0.8033708best: 0.8033708 (7)total: 430msremaining: 6.02s 20:learn: 0.8227848test: 0.8033708best: 0.8033708 (7)total: 452msremaining: 6s 21:learn: 0.8227848test: 0.8033708best: 0.8033708 (7)total: 473msremaining: 5.98s 22:learn: 0.8255977test: 0.8033708best: 0.8033708 (7)total: 495msremaining: 5.96s 23:learn: 0.8270042test: 0.8033708best: 0.8033708 (7)total: 501msremaining: 5.76s 24:learn: 0.8270042test: 0.8033708best: 0.8033708 (7)total: 521msremaining: 5.73s 25:learn: 0.8255977test: 0.8033708best: 0.8033708 (7)total: 536msremaining: 5.65s 26:learn: 0.8255977test: 0.8033708best: 0.8033708 (7)total: 558msremaining: 5.64s 27:learn: 0.8241913test: 0.7977528best: 0.8033708 (7)total: 576msremaining: 5.6s 28:learn: 0.8255977test: 0.7977528best: 0.8033708 (7)total: 596msremaining: 5.57s 29:learn: 0.8255977test: 0.8033708best: 0.8033708 (7)total: 615msremaining: 5.54s 30:learn: 0.8255977test: 0.8033708best: 0.8033708 (7)total: 635msremaining: 5.51s 31:learn: 0.8255977test: 0.8089888best: 0.8089888 (31)total: 655msremaining: 5.48s 32:learn: 0.8284107test: 0.8089888best: 0.8089888 (31)total: 674msremaining: 5.45s 33:learn: 0.8298172test: 0.8146067best: 0.8146067 (33)total: 694msremaining: 5.43s 34:learn: 0.8340366test: 0.8146067best: 0.8146067 (33)total: 706msremaining: 5.35s 35:learn: 0.8354430test: 0.8146067best: 0.8146067 (33)total: 728msremaining: 5.34s 36:learn: 0.8354430test: 0.8146067best: 0.8146067 (33)total: 749msremaining: 5.32s 37:learn: 0.8368495test: 0.8089888best: 0.8146067 (33)total: 764msremaining: 5.27s 38:learn: 0.8382560test: 0.8089888best: 0.8146067 (33)total: 780msremaining: 5.22s 39:learn: 0.8368495test: 0.8089888best: 0.8146067 (33)total: 799msremaining: 5.19s 40:learn: 0.8368495test: 0.8089888best: 0.8146067 (33)total: 820msremaining: 5.18s 41:learn: 0.8382560test: 0.8089888best: 0.8146067 (33)total: 840msremaining: 5.16s 42:learn: 0.8382560test: 0.8202247best: 0.8202247 (42)total: 862msremaining: 5.15s 43:learn: 0.8410689test: 0.8146067best: 0.8202247 (42)total: 884msremaining: 5.14s 44:learn: 0.8396624test: 0.8146067best: 0.8202247 (42)total: 907msremaining: 5.14s 45:learn: 0.8438819test: 0.8258427best: 0.8258427 (45)total: 930msremaining: 5.14s 46:learn: 0.8466948test: 0.8258427best: 0.8258427 (45)total: 953msremaining: 5.13s 47:learn: 0.8466948test: 0.8258427best: 0.8258427 (45)total: 976msremaining: 5.12s 48:learn: 0.8481013test: 0.8258427best: 0.8258427 (45)total: 999msremaining: 5.12s 49:learn: 0.8452883test: 0.8314607best: 0.8314607 (49)total: 1.02sremaining: 5.11s 50:learn: 0.8438819test: 0.8314607best: 0.8314607 (49)total: 1.04sremaining: 5.09s 51:learn: 0.8438819test: 0.8314607best: 0.8314607 (49)total: 1.06sremaining: 5.07s 52:learn: 0.8452883test: 0.8370787best: 0.8370787 (52)total: 1.08sremaining: 5.05s 53:learn: 0.8424754test: 0.8370787best: 0.8370787 (52)total: 1.1sremaining: 5.04s 54:learn: 0.8396624test: 0.8370787best: 0.8370787 (52)total: 1.13sremaining: 5.01s 55:learn: 0.8396624test: 0.8314607best: 0.8370787 (52)total: 1.15sremaining: 4.99s 56:learn: 0.8396624test: 0.8314607best: 0.8370787 (52)total: 1.17sremaining: 4.97s 57:learn: 0.8382560test: 0.8314607best: 0.8370787 (52)total: 1.18sremaining: 4.94s 58:learn: 0.8382560test: 0.8314607best: 0.8370787 (52)total: 1.2sremaining: 4.89s 59:learn: 0.8396624test: 0.8314607best: 0.8370787 (52)total: 1.22sremaining: 4.87s 60:learn: 0.8396624test: 0.8314607best: 0.8370787 (52)total: 1.24sremaining: 4.85s 61:learn: 0.8396624test: 0.8314607best: 0.8370787 (52)total: 1.27sremaining: 4.89s 62:learn: 0.8396624test: 0.8314607best: 0.8370787 (52)total: 1.29sremaining: 4.87s 63:learn: 0.8396624test: 0.8314607best: 0.8370787 (52)total: 1.32sremaining: 4.86s 64:learn: 0.8396624test: 0.8314607best: 0.8370787 (52)total: 1.34sremaining: 4.85s 65:learn: 0.8396624test: 0.8314607best: 0.8370787 (52)total: 1.36sremaining: 4.84s 66:learn: 0.8410689test: 0.8314607best: 0.8370787 (52)total: 1.39sremaining: 4.82s 67:learn: 0.8410689test: 0.8314607best: 0.8370787 (52)total: 1.41sremaining: 4.8s 68:learn: 0.8410689test: 0.8314607best: 0.8370787 (52)total: 1.42sremaining: 4.75s 69:learn: 0.8424754test: 0.8314607best: 0.8370787 (52)total: 1.44sremaining: 4.74s 70:learn: 0.8424754test: 0.8314607best: 0.8370787 (52)total: 1.46sremaining: 4.72s 71:learn: 0.8424754test: 0.8314607best: 0.8370787 (52)total: 1.49sremaining: 4.7s 72:learn: 0.8438819test: 0.8314607best: 0.8370787 (52)total: 1.51sremaining: 4.69s 73:learn: 0.8396624test: 0.8314607best: 0.8370787 (52)total: 1.52sremaining: 4.66s 74:learn: 0.8382560test: 0.8258427best: 0.8370787 (52)total: 1.55sremaining: 4.64s 75:learn: 0.8410689test: 0.8258427best: 0.8370787 (52)total: 1.57sremaining: 4.63s 76:learn: 0.8424754test: 0.8202247best: 0.8370787 (52)total: 1.59sremaining: 4.6s 77:learn: 0.8424754test: 0.8202247best: 0.8370787 (52)total: 1.61sremaining: 4.58s 78:learn: 0.8438819test: 0.8202247best: 0.8370787 (52)total: 1.63sremaining: 4.56s 79:learn: 0.8452883test: 0.8202247best: 0.8370787 (52)total: 1.65sremaining: 4.54s 80:learn: 0.8452883test: 0.8202247best: 0.8370787 (52)total: 1.67sremaining: 4.52s 81:learn: 0.8452883test: 0.8202247best: 0.8370787 (52)total: 1.69sremaining: 4.49s 82:learn: 0.8452883test: 0.8202247best: 0.8370787 (52)total: 1.71sremaining: 4.47s 83:learn: 0.8452883test: 0.8202247best: 0.8370787 (52)total: 1.73sremaining: 4.44s 84:learn: 0.8438819test: 0.8202247best: 0.8370787 (52)total: 1.75sremaining: 4.42s 85:learn: 0.8452883test: 0.8202247best: 0.8370787 (52)total: 1.77sremaining: 4.39s 86:learn: 0.8452883test: 0.8202247best: 0.8370787 (52)total: 1.79sremaining: 4.37s 87:learn: 0.8438819test: 0.8202247best: 0.8370787 (52)total: 1.8sremaining: 4.35s 88:learn: 0.8452883test: 0.8146067best: 0.8370787 (52)total: 1.83sremaining: 4.33s 89:learn: 0.8438819test: 0.8146067best: 0.8370787 (52)total: 1.85sremaining: 4.31s 90:learn: 0.8438819test: 0.8146067best: 0.8370787 (52)total: 1.87sremaining: 4.29s 91:learn: 0.8438819test: 0.8146067best: 0.8370787 (52)total: 1.89sremaining: 4.26s 92:learn: 0.8438819test: 0.8146067best: 0.8370787 (52)total: 1.91sremaining: 4.24s 93:learn: 0.8438819test: 0.8258427best: 0.8370787 (52)total: 1.92sremaining: 4.22s 94:learn: 0.8438819test: 0.8258427best: 0.8370787 (52)total: 1.93sremaining: 4.16s 95:learn: 0.8438819test: 0.8258427best: 0.8370787 (52)total: 1.95sremaining: 4.14s 96:learn: 0.8438819test: 0.8202247best: 0.8370787 (52)total: 1.97sremaining: 4.12s 97:learn: 0.8438819test: 0.8258427best: 0.8370787 (52)total: 1.99sremaining: 4.1s 98:learn: 0.8438819test: 0.8258427best: 0.8370787 (52)total: 2.01sremaining: 4.08s 99:learn: 0.8438819test: 0.8258427best: 0.8370787 (52)total: 2.02sremaining: 4.05s 100:learn: 0.8438819test: 0.8258427best: 0.8370787 (52)total: 2.04sremaining: 4.03s 101:learn: 0.8438819test: 0.8258427best: 0.8370787 (52)total: 2.06sremaining: 4.01s 102:learn: 0.8438819test: 0.8258427best: 0.8370787 (52)total: 2.09sremaining: 3.99s 103:learn: 0.8438819test: 0.8258427best: 0.8370787 (52)total: 2.11sremaining: 3.97s 104:learn: 0.8438819test: 0.8258427best: 0.8370787 (52)total: 2.13sremaining: 3.95s 105:learn: 0.8452883test: 0.8202247best: 0.8370787 (52)total: 2.15sremaining: 3.93s 106:learn: 0.8452883test: 0.8202247best: 0.8370787 (52)total: 2.2sremaining: 3.97s 107:learn: 0.8452883test: 0.8202247best: 0.8370787 (52)total: 2.24sremaining: 3.99s 108:learn: 0.8452883test: 0.8146067best: 0.8370787 (52)total: 2.28sremaining: 3.99s 109:learn: 0.8466948test: 0.8146067best: 0.8370787 (52)total: 2.3sremaining: 3.98s 110:learn: 0.8466948test: 0.8146067best: 0.8370787 (52)total: 2.32sremaining: 3.96s 111:learn: 0.8466948test: 0.8146067best: 0.8370787 (52)total: 2.35sremaining: 3.94s 112:learn: 0.8466948test: 0.8146067best: 0.8370787 (52)total: 2.37sremaining: 3.92s 113:learn: 0.8466948test: 0.8146067best: 0.8370787 (52)total: 2.39sremaining: 3.9s 114:learn: 0.8466948test: 0.8146067best: 0.8370787 (52)total: 2.41sremaining: 3.88s 115:learn: 0.8466948test: 0.8146067best: 0.8370787 (52)total: 2.43sremaining: 3.86s 116:learn: 0.8466948test: 0.8146067best: 0.8370787 (52)total: 2.45sremaining: 3.84s 117:learn: 0.8452883test: 0.8202247best: 0.8370787 (52)total: 2.48sremaining: 3.82s 118:learn: 0.8452883test: 0.8146067best: 0.8370787 (52)total: 2.5sremaining: 3.8s 119:learn: 0.8452883test: 0.8146067best: 0.8370787 (52)total: 2.52sremaining: 3.78s 120:learn: 0.8452883test: 0.8146067best: 0.8370787 (52)total: 2.54sremaining: 3.76s 121:learn: 0.8452883test: 0.8146067best: 0.8370787 (52)total: 2.56sremaining: 3.74s 122:learn: 0.8452883test: 0.8146067best: 0.8370787 (52)total: 2.59sremaining: 3.72s 123:learn: 0.8452883test: 0.8146067best: 0.8370787 (52)total: 2.61sremaining: 3.71s 124:learn: 0.8452883test: 0.8146067best: 0.8370787 (52)total: 2.63sremaining: 3.69s 125:learn: 0.8452883test: 0.8146067best: 0.8370787 (52)total: 2.65sremaining: 3.67s 126:learn: 0.8452883test: 0.8146067best: 0.8370787 (52)total: 2.68sremaining: 3.65s 127:learn: 0.8466948test: 0.8146067best: 0.8370787 (52)total: 2.7sremaining: 3.63s 128:learn: 0.8466948test: 0.8146067best: 0.8370787 (52)total: 2.72sremaining: 3.61s 129:learn: 0.8466948test: 0.8146067best: 0.8370787 (52)total: 2.73sremaining: 3.57s 130:learn: 0.8466948test: 0.8146067best: 0.8370787 (52)total: 2.75sremaining: 3.55s 131:learn: 0.8466948test: 0.8146067best: 0.8370787 (52)total: 2.77sremaining: 3.53s 132:learn: 0.8466948test: 0.8146067best: 0.8370787 (52)total: 2.79sremaining: 3.51s 133:learn: 0.8466948test: 0.8146067best: 0.8370787 (52)total: 2.8sremaining: 3.47s 134:learn: 0.8466948test: 0.8146067best: 0.8370787 (52)total: 2.83sremaining: 3.46s 135:learn: 0.8466948test: 0.8146067best: 0.8370787 (52)total: 2.85sremaining: 3.43s 136:learn: 0.8466948test: 0.8146067best: 0.8370787 (52)total: 2.87sremaining: 3.42s 137:learn: 0.8466948test: 0.8146067best: 0.8370787 (52)total: 2.89sremaining: 3.4s 138:learn: 0.8481013test: 0.8146067best: 0.8370787 (52)total: 2.91sremaining: 3.37s 139:learn: 0.8481013test: 0.8146067best: 0.8370787 (52)total: 2.94sremaining: 3.35s 140:learn: 0.8481013test: 0.8146067best: 0.8370787 (52)total: 2.96sremaining: 3.33s 141:learn: 0.8481013test: 0.8146067best: 0.8370787 (52)total: 2.98sremaining: 3.31s 142:learn: 0.8481013test: 0.8146067best: 0.8370787 (52)total: 3sremaining: 3.29s 143:learn: 0.8495077test: 0.8146067best: 0.8370787 (52)total: 3.02sremaining: 3.27s 144:learn: 0.8495077test: 0.8146067best: 0.8370787 (52)total: 3.04sremaining: 3.25s 145:learn: 0.8495077test: 0.8146067best: 0.8370787 (52)total: 3.06sremaining: 3.23s 146:learn: 0.8509142test: 0.8146067best: 0.8370787 (52)total: 3.08sremaining: 3.21s 147:learn: 0.8509142test: 0.8146067best: 0.8370787 (52)total: 3.1sremaining: 3.19s 148:learn: 0.8523207test: 0.8146067best: 0.8370787 (52)total: 3.12sremaining: 3.16s 149:learn: 0.8523207test: 0.8146067best: 0.8370787 (52)total: 3.14sremaining: 3.14s 150:learn: 0.8523207test: 0.8146067best: 0.8370787 (52)total: 3.16sremaining: 3.12s 151:learn: 0.8523207test: 0.8146067best: 0.8370787 (52)total: 3.18sremaining: 3.1s 152:learn: 0.8523207test: 0.8146067best: 0.8370787 (52)total: 3.2sremaining: 3.08s Stopped by overfitting detector (100 iterations wait) bestTest = 0.8370786517 bestIteration = 52 Shrink model to first 53 iterations.
RandomizedSearchCV(cv=10, estimator=<catboost.core.CatBoostClassifier object at 0x000002E16600D400>, n_iter=45, n_jobs=-1, param_distributions={'border_count': [32, 64, 128], 'depth': [4, 6, 8, 10], 'iterations': [300, 500, 1000], 'l2_leaf_reg': [1, 3, 5, 7, 9], 'learning_rate': [0.01, 0.05, 0.1]}, scoring='accuracy', verbose=2)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
RandomizedSearchCV
?Documentation for RandomizedSearchCViFitted
Parameters
<tr class="user-set"> <td><i class="copy-paste-icon"></i></td> <td class="param">estimator </td> <td class="value"><catboost.cor...002E16600D400></td> </tr> <tr class="user-set"> <td><i class="copy-paste-icon"></i></td> <td class="param">param_distributions </td> <td class="value">{'border_count': [32, 64, ...], 'depth': [4, 6, ...], 'iterations': [300, 500, ...], 'l2_leaf_reg': [1, 3, ...], ...}</td> </tr> <tr class="user-set"> <td><i class="copy-paste-icon"></i></td> <td class="param">n_iter </td> <td class="value">45</td> </tr> <tr class="user-set"> <td><i class="copy-paste-icon"></i></td> <td class="param">scoring </td> <td class="value">'accuracy'</td> </tr> <tr class="user-set"> <td><i class="copy-paste-icon"></i></td> <td class="param">n_jobs </td> <td class="value">-1</td> </tr> <tr class="default"> <td><i class="copy-paste-icon"></i></td> <td class="param">refit </td> <td class="value">True</td> </tr> <tr class="user-set"> <td><i class="copy-paste-icon"></i></td> <td class="param">cv </td> <td class="value">10</td> </tr> <tr class="user-set"> <td><i class="copy-paste-icon"></i></td> <td class="param">verbose </td> <td class="value">2</td> </tr> <tr class="default"> <td><i class="copy-paste-icon"></i></td> <td class="param">pre_dispatch </td> <td class="value">'2*n_jobs'</td> </tr> <tr class="default"> <td><i class="copy-paste-icon"></i></td> <td class="param">random_state </td> <td class="value">None</td> </tr> <tr class="default"> <td><i class="copy-paste-icon"></i></td> <td class="param">error_score </td> <td class="value">nan</td> </tr> <tr class="default"> <td><i class="copy-paste-icon"></i></td> <td class="param">return_train_score </td> <td class="value">False</td> </tr> </tbody> </table> </details> </div> </div></div></div><div class="sk-parallel"><div class="sk-parallel-item"><div class="sk-item"><div class="sk-label-container"><div class="sk-label fitted sk-toggleable"><input type="checkbox" id="sk-estimator-id-2" class="sk-toggleable__control sk-hidden--visually"><label class="sk-toggleable__label fitted sk-toggleable__label-arrow" for="sk-estimator-id-2"><div><div>best_estimator_: CatBoostClassifier</div></div></label><div data-param-prefix="best_estimator___" class="sk-toggleable__content fitted"><pre><catboost.core.CatBoostClassifier object at 0x000002E169667020></pre></div></div></div><div class="sk-serial"><div class="sk-item"><div class="sk-estimator fitted sk-toggleable"><input type="checkbox" id="sk-estimator-id-3" class="sk-toggleable__control sk-hidden--visually"><label class="sk-toggleable__label fitted sk-toggleable__label-arrow" for="sk-estimator-id-3"><div><div>CatBoostClassifier</div></div></label><div data-param-prefix="best_estimator___" class="sk-toggleable__content fitted"><pre><catboost.core.CatBoostClassifier object at 0x000002E169667020></pre></div></div></div></div></div></div></div></div></div></div>
# Выведем лучшие параметры random_search.best_params_
{'learning_rate': 0.05, 'l2_leaf_reg': 3, 'iterations': 300, 'depth': 4, 'border_count': 32}
# Выведем лучший скор random_search.best_score_
np.float64(0.82981220657277)
#Сохраняем best_params в .txt файл, чтобы не потерять with open("best_params.txt", "a") as f: json.dump(random_search.best_params_, f, indent=4)
Оступление
Дважды при малом early_stopping_rounds, равном 30, при n_iter, равном 15 и cs, равном 3
Модель показывала лучший accuracy, но при валидации на Kaggle, показывала результаты хуже
Добавил early_stopping_rounds, n_iter и cs и тогда
Получилось улучшить итоговый результат
# Посмотрим лучшую модель best_model = random_search.best_estimator_
# Обучаем модель с лучшими параметрами best_model.fit(X_train_split, y_train_split, **fit_params)
0:learn: 0.7988748test: 0.7752809best: 0.7752809 (0)total: 21.9msremaining: 6.54s 1:learn: 0.8016878test: 0.7808989best: 0.7808989 (1)total: 48.9msremaining: 7.29s 2:learn: 0.8101266test: 0.7921348best: 0.7921348 (2)total: 74.5msremaining: 7.37s 3:learn: 0.8045007test: 0.7865169best: 0.7921348 (2)total: 105msremaining: 7.76s 4:learn: 0.8030942test: 0.7865169best: 0.7921348 (2)total: 133msremaining: 7.85s 5:learn: 0.8087201test: 0.7977528best: 0.7977528 (5)total: 159msremaining: 7.77s 6:learn: 0.8087201test: 0.7977528best: 0.7977528 (5)total: 184msremaining: 7.7s 7:learn: 0.8101266test: 0.8033708best: 0.8033708 (7)total: 209msremaining: 7.63s 8:learn: 0.8101266test: 0.7977528best: 0.8033708 (7)total: 234msremaining: 7.57s 9:learn: 0.8101266test: 0.7977528best: 0.8033708 (7)total: 258msremaining: 7.49s 10:learn: 0.8101266test: 0.7977528best: 0.8033708 (7)total: 286msremaining: 7.52s 11:learn: 0.8115331test: 0.7977528best: 0.8033708 (7)total: 314msremaining: 7.54s 12:learn: 0.8171589test: 0.7977528best: 0.8033708 (7)total: 341msremaining: 7.54s 13:learn: 0.8185654test: 0.7977528best: 0.8033708 (7)total: 369msremaining: 7.53s 14:learn: 0.8185654test: 0.8033708best: 0.8033708 (7)total: 392msremaining: 7.45s 15:learn: 0.8185654test: 0.8033708best: 0.8033708 (7)total: 416msremaining: 7.38s 16:learn: 0.8185654test: 0.8033708best: 0.8033708 (7)total: 438msremaining: 7.29s 17:learn: 0.8185654test: 0.8033708best: 0.8033708 (7)total: 456msremaining: 7.14s 18:learn: 0.8199719test: 0.8033708best: 0.8033708 (7)total: 479msremaining: 7.08s 19:learn: 0.8227848test: 0.8033708best: 0.8033708 (7)total: 501msremaining: 7.02s 20:learn: 0.8227848test: 0.8033708best: 0.8033708 (7)total: 524msremaining: 6.96s 21:learn: 0.8227848test: 0.8033708best: 0.8033708 (7)total: 546msremaining: 6.9s 22:learn: 0.8255977test: 0.8033708best: 0.8033708 (7)total: 576msremaining: 6.93s 23:learn: 0.8270042test: 0.8033708best: 0.8033708 (7)total: 584msremaining: 6.72s 24:learn: 0.8270042test: 0.8033708best: 0.8033708 (7)total: 607msremaining: 6.68s 25:learn: 0.8255977test: 0.8033708best: 0.8033708 (7)total: 624msremaining: 6.57s 26:learn: 0.8255977test: 0.8033708best: 0.8033708 (7)total: 646msremaining: 6.53s 27:learn: 0.8241913test: 0.7977528best: 0.8033708 (7)total: 670msremaining: 6.51s 28:learn: 0.8255977test: 0.7977528best: 0.8033708 (7)total: 692msremaining: 6.47s 29:learn: 0.8255977test: 0.8033708best: 0.8033708 (7)total: 716msremaining: 6.44s 30:learn: 0.8255977test: 0.8033708best: 0.8033708 (7)total: 739msremaining: 6.41s 31:learn: 0.8255977test: 0.8089888best: 0.8089888 (31)total: 760msremaining: 6.37s 32:learn: 0.8284107test: 0.8089888best: 0.8089888 (31)total: 784msremaining: 6.34s 33:learn: 0.8298172test: 0.8146067best: 0.8146067 (33)total: 807msremaining: 6.31s 34:learn: 0.8340366test: 0.8146067best: 0.8146067 (33)total: 820msremaining: 6.21s 35:learn: 0.8354430test: 0.8146067best: 0.8146067 (33)total: 845msremaining: 6.2s 36:learn: 0.8354430test: 0.8146067best: 0.8146067 (33)total: 868msremaining: 6.17s 37:learn: 0.8368495test: 0.8089888best: 0.8146067 (33)total: 886msremaining: 6.11s 38:learn: 0.8382560test: 0.8089888best: 0.8146067 (33)total: 908msremaining: 6.08s 39:learn: 0.8368495test: 0.8089888best: 0.8146067 (33)total: 936msremaining: 6.08s 40:learn: 0.8368495test: 0.8089888best: 0.8146067 (33)total: 959msremaining: 6.05s 41:learn: 0.8382560test: 0.8089888best: 0.8146067 (33)total: 979msremaining: 6.01s 42:learn: 0.8382560test: 0.8202247best: 0.8202247 (42)total: 1sremaining: 5.98s 43:learn: 0.8410689test: 0.8146067best: 0.8202247 (42)total: 1.02sremaining: 5.94s 44:learn: 0.8396624test: 0.8146067best: 0.8202247 (42)total: 1.04sremaining: 5.92s 45:learn: 0.8438819test: 0.8258427best: 0.8258427 (45)total: 1.07sremaining: 5.89s 46:learn: 0.8466948test: 0.8258427best: 0.8258427 (45)total: 1.09sremaining: 5.85s 47:learn: 0.8466948test: 0.8258427best: 0.8258427 (45)total: 1.11sremaining: 5.82s 48:learn: 0.8481013test: 0.8258427best: 0.8258427 (45)total: 1.13sremaining: 5.78s 49:learn: 0.8452883test: 0.8314607best: 0.8314607 (49)total: 1.15sremaining: 5.75s 50:learn: 0.8438819test: 0.8314607best: 0.8314607 (49)total: 1.17sremaining: 5.71s 51:learn: 0.8438819test: 0.8314607best: 0.8314607 (49)total: 1.19sremaining: 5.67s 52:learn: 0.8452883test: 0.8370787best: 0.8370787 (52)total: 1.21sremaining: 5.63s 53:learn: 0.8424754test: 0.8370787best: 0.8370787 (52)total: 1.23sremaining: 5.59s 54:learn: 0.8396624test: 0.8370787best: 0.8370787 (52)total: 1.25sremaining: 5.56s 55:learn: 0.8396624test: 0.8314607best: 0.8370787 (52)total: 1.27sremaining: 5.53s 56:learn: 0.8396624test: 0.8314607best: 0.8370787 (52)total: 1.29sremaining: 5.5s 57:learn: 0.8382560test: 0.8314607best: 0.8370787 (52)total: 1.31sremaining: 5.46s 58:learn: 0.8382560test: 0.8314607best: 0.8370787 (52)total: 1.32sremaining: 5.41s 59:learn: 0.8396624test: 0.8314607best: 0.8370787 (52)total: 1.34sremaining: 5.38s 60:learn: 0.8396624test: 0.8314607best: 0.8370787 (52)total: 1.36sremaining: 5.35s 61:learn: 0.8396624test: 0.8314607best: 0.8370787 (52)total: 1.38sremaining: 5.31s 62:learn: 0.8396624test: 0.8314607best: 0.8370787 (52)total: 1.4sremaining: 5.27s 63:learn: 0.8396624test: 0.8314607best: 0.8370787 (52)total: 1.42sremaining: 5.24s 64:learn: 0.8396624test: 0.8314607best: 0.8370787 (52)total: 1.44sremaining: 5.21s 65:learn: 0.8396624test: 0.8314607best: 0.8370787 (52)total: 1.46sremaining: 5.18s 66:learn: 0.8410689test: 0.8314607best: 0.8370787 (52)total: 1.48sremaining: 5.15s 67:learn: 0.8410689test: 0.8314607best: 0.8370787 (52)total: 1.5sremaining: 5.13s 68:learn: 0.8410689test: 0.8314607best: 0.8370787 (52)total: 1.51sremaining: 5.07s 69:learn: 0.8424754test: 0.8314607best: 0.8370787 (52)total: 1.53sremaining: 5.05s 70:learn: 0.8424754test: 0.8314607best: 0.8370787 (52)total: 1.56sremaining: 5.03s 71:learn: 0.8424754test: 0.8314607best: 0.8370787 (52)total: 1.58sremaining: 5s 72:learn: 0.8438819test: 0.8314607best: 0.8370787 (52)total: 1.6sremaining: 4.98s 73:learn: 0.8396624test: 0.8314607best: 0.8370787 (52)total: 1.61sremaining: 4.93s 74:learn: 0.8382560test: 0.8258427best: 0.8370787 (52)total: 1.64sremaining: 4.91s 75:learn: 0.8410689test: 0.8258427best: 0.8370787 (52)total: 1.66sremaining: 4.88s 76:learn: 0.8424754test: 0.8202247best: 0.8370787 (52)total: 1.68sremaining: 4.86s 77:learn: 0.8424754test: 0.8202247best: 0.8370787 (52)total: 1.7sremaining: 4.83s 78:learn: 0.8438819test: 0.8202247best: 0.8370787 (52)total: 1.72sremaining: 4.81s 79:learn: 0.8452883test: 0.8202247best: 0.8370787 (52)total: 1.74sremaining: 4.78s 80:learn: 0.8452883test: 0.8202247best: 0.8370787 (52)total: 1.76sremaining: 4.76s 81:learn: 0.8452883test: 0.8202247best: 0.8370787 (52)total: 1.78sremaining: 4.73s 82:learn: 0.8452883test: 0.8202247best: 0.8370787 (52)total: 1.8sremaining: 4.71s 83:learn: 0.8452883test: 0.8202247best: 0.8370787 (52)total: 1.82sremaining: 4.68s 84:learn: 0.8438819test: 0.8202247best: 0.8370787 (52)total: 1.84sremaining: 4.67s 85:learn: 0.8452883test: 0.8202247best: 0.8370787 (52)total: 1.87sremaining: 4.65s 86:learn: 0.8452883test: 0.8202247best: 0.8370787 (52)total: 1.89sremaining: 4.64s 87:learn: 0.8438819test: 0.8202247best: 0.8370787 (52)total: 1.92sremaining: 4.62s 88:learn: 0.8452883test: 0.8146067best: 0.8370787 (52)total: 1.94sremaining: 4.59s 89:learn: 0.8438819test: 0.8146067best: 0.8370787 (52)total: 1.96sremaining: 4.58s 90:learn: 0.8438819test: 0.8146067best: 0.8370787 (52)total: 1.98sremaining: 4.55s 91:learn: 0.8438819test: 0.8146067best: 0.8370787 (52)total: 2sremaining: 4.53s 92:learn: 0.8438819test: 0.8146067best: 0.8370787 (52)total: 2.02sremaining: 4.51s 93:learn: 0.8438819test: 0.8258427best: 0.8370787 (52)total: 2.05sremaining: 4.49s 94:learn: 0.8438819test: 0.8258427best: 0.8370787 (52)total: 2.05sremaining: 4.43s 95:learn: 0.8438819test: 0.8258427best: 0.8370787 (52)total: 2.08sremaining: 4.41s 96:learn: 0.8438819test: 0.8202247best: 0.8370787 (52)total: 2.1sremaining: 4.39s 97:learn: 0.8438819test: 0.8258427best: 0.8370787 (52)total: 2.12sremaining: 4.36s 98:learn: 0.8438819test: 0.8258427best: 0.8370787 (52)total: 2.14sremaining: 4.34s 99:learn: 0.8438819test: 0.8258427best: 0.8370787 (52)total: 2.16sremaining: 4.32s 100:learn: 0.8438819test: 0.8258427best: 0.8370787 (52)total: 2.19sremaining: 4.31s 101:learn: 0.8438819test: 0.8258427best: 0.8370787 (52)total: 2.21sremaining: 4.29s 102:learn: 0.8438819test: 0.8258427best: 0.8370787 (52)total: 2.23sremaining: 4.27s 103:learn: 0.8438819test: 0.8258427best: 0.8370787 (52)total: 2.25sremaining: 4.25s 104:learn: 0.8438819test: 0.8258427best: 0.8370787 (52)total: 2.28sremaining: 4.23s 105:learn: 0.8452883test: 0.8202247best: 0.8370787 (52)total: 2.3sremaining: 4.21s 106:learn: 0.8452883test: 0.8202247best: 0.8370787 (52)total: 2.32sremaining: 4.19s 107:learn: 0.8452883test: 0.8202247best: 0.8370787 (52)total: 2.35sremaining: 4.17s 108:learn: 0.8452883test: 0.8146067best: 0.8370787 (52)total: 2.37sremaining: 4.15s 109:learn: 0.8466948test: 0.8146067best: 0.8370787 (52)total: 2.39sremaining: 4.13s 110:learn: 0.8466948test: 0.8146067best: 0.8370787 (52)total: 2.41sremaining: 4.11s 111:learn: 0.8466948test: 0.8146067best: 0.8370787 (52)total: 2.44sremaining: 4.09s 112:learn: 0.8466948test: 0.8146067best: 0.8370787 (52)total: 2.46sremaining: 4.07s 113:learn: 0.8466948test: 0.8146067best: 0.8370787 (52)total: 2.48sremaining: 4.05s 114:learn: 0.8466948test: 0.8146067best: 0.8370787 (52)total: 2.5sremaining: 4.02s 115:learn: 0.8466948test: 0.8146067best: 0.8370787 (52)total: 2.52sremaining: 4s 116:learn: 0.8466948test: 0.8146067best: 0.8370787 (52)total: 2.54sremaining: 3.97s 117:learn: 0.8452883test: 0.8202247best: 0.8370787 (52)total: 2.56sremaining: 3.95s 118:learn: 0.8452883test: 0.8146067best: 0.8370787 (52)total: 2.58sremaining: 3.92s 119:learn: 0.8452883test: 0.8146067best: 0.8370787 (52)total: 2.6sremaining: 3.9s 120:learn: 0.8452883test: 0.8146067best: 0.8370787 (52)total: 2.62sremaining: 3.87s 121:learn: 0.8452883test: 0.8146067best: 0.8370787 (52)total: 2.64sremaining: 3.85s 122:learn: 0.8452883test: 0.8146067best: 0.8370787 (52)total: 2.66sremaining: 3.83s 123:learn: 0.8452883test: 0.8146067best: 0.8370787 (52)total: 2.68sremaining: 3.8s 124:learn: 0.8452883test: 0.8146067best: 0.8370787 (52)total: 2.7sremaining: 3.78s 125:learn: 0.8452883test: 0.8146067best: 0.8370787 (52)total: 2.72sremaining: 3.75s 126:learn: 0.8452883test: 0.8146067best: 0.8370787 (52)total: 2.74sremaining: 3.73s 127:learn: 0.8466948test: 0.8146067best: 0.8370787 (52)total: 2.76sremaining: 3.71s 128:learn: 0.8466948test: 0.8146067best: 0.8370787 (52)total: 2.78sremaining: 3.68s 129:learn: 0.8466948test: 0.8146067best: 0.8370787 (52)total: 2.79sremaining: 3.65s 130:learn: 0.8466948test: 0.8146067best: 0.8370787 (52)total: 2.81sremaining: 3.62s 131:learn: 0.8466948test: 0.8146067best: 0.8370787 (52)total: 2.82sremaining: 3.59s 132:learn: 0.8466948test: 0.8146067best: 0.8370787 (52)total: 2.84sremaining: 3.57s 133:learn: 0.8466948test: 0.8146067best: 0.8370787 (52)total: 2.85sremaining: 3.53s 134:learn: 0.8466948test: 0.8146067best: 0.8370787 (52)total: 2.87sremaining: 3.51s 135:learn: 0.8466948test: 0.8146067best: 0.8370787 (52)total: 2.89sremaining: 3.49s 136:learn: 0.8466948test: 0.8146067best: 0.8370787 (52)total: 2.91sremaining: 3.47s 137:learn: 0.8466948test: 0.8146067best: 0.8370787 (52)total: 2.94sremaining: 3.44s 138:learn: 0.8481013test: 0.8146067best: 0.8370787 (52)total: 2.96sremaining: 3.42s 139:learn: 0.8481013test: 0.8146067best: 0.8370787 (52)total: 2.97sremaining: 3.4s 140:learn: 0.8481013test: 0.8146067best: 0.8370787 (52)total: 3sremaining: 3.38s 141:learn: 0.8481013test: 0.8146067best: 0.8370787 (52)total: 3.02sremaining: 3.35s 142:learn: 0.8481013test: 0.8146067best: 0.8370787 (52)total: 3.03sremaining: 3.33s 143:learn: 0.8495077test: 0.8146067best: 0.8370787 (52)total: 3.05sremaining: 3.31s 144:learn: 0.8495077test: 0.8146067best: 0.8370787 (52)total: 3.07sremaining: 3.29s 145:learn: 0.8495077test: 0.8146067best: 0.8370787 (52)total: 3.1sremaining: 3.27s 146:learn: 0.8509142test: 0.8146067best: 0.8370787 (52)total: 3.12sremaining: 3.24s 147:learn: 0.8509142test: 0.8146067best: 0.8370787 (52)total: 3.14sremaining: 3.22s 148:learn: 0.8523207test: 0.8146067best: 0.8370787 (52)total: 3.16sremaining: 3.2s 149:learn: 0.8523207test: 0.8146067best: 0.8370787 (52)total: 3.18sremaining: 3.18s 150:learn: 0.8523207test: 0.8146067best: 0.8370787 (52)total: 3.2sremaining: 3.15s 151:learn: 0.8523207test: 0.8146067best: 0.8370787 (52)total: 3.22sremaining: 3.13s 152:learn: 0.8523207test: 0.8146067best: 0.8370787 (52)total: 3.24sremaining: 3.11s Stopped by overfitting detector (100 iterations wait) bestTest = 0.8370786517 bestIteration = 52 Shrink model to first 53 iterations.
#Оценим качество acc = accuracy_score(y_valid, y_pred) print(f"Validation Accuracy: {acc:.4f}")
Validation Accuracy: 0.8315
#Предсказание на тесте best_test_preds = best_model.predict(X_test)
# Создание submission_V2.csv submission_V2 = pd.DataFrame({ 'PassengerId': passenger_ids, 'Survived': best_test_preds.astype(int) }) submission_V2.to_csv('submission_V2.csv', index=False) print("✅ Submission файл сохранён как submission_V2.csv")
✅ Submission файл сохранён как submission_V2.csv
# Смотрим оценку качества accuracy_score(y_valid, best_model.predict(X_valid))
0.8370786516853933
Смогли улучшить качество модели с помощью подбора гиперпараметров и отвоевать больше 500 мест в итоговом рейтинге
ссылка на оригинал статьи https://habr.com/ru/articles/935540/
Добавить комментарий