import warnings
warnings.filterwarnings("ignore")
from statsmodels.tools.sm_exceptions import ConvergenceWarning
warnings.simplefilter("ignore", ConvergenceWarning)
# Libraries to help with reading and manipulating data
import pandas as pd
import numpy as np
# Library to split data
from sklearn.model_selection import train_test_split
# libaries to help with data visualization
import matplotlib.pyplot as plt
import seaborn as sns
# Removes the limit for the number of displayed columns
pd.set_option("display.max_columns", None)
# Sets the limit for the number of displayed rows
pd.set_option("display.max_rows", 200)
# To build model for prediction
import statsmodels.stats.api as sms
from statsmodels.stats.outliers_influence import variance_inflation_factor
import statsmodels.api as sm
from statsmodels.tools.tools import add_constant
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree
# To tune different models
from sklearn.model_selection import GridSearchCV
# To get diferent metric scores
from sklearn.metrics import (
f1_score,
accuracy_score,
recall_score,
precision_score,
confusion_matrix,
roc_auc_score,
plot_confusion_matrix,
precision_recall_curve,
roc_curve,
make_scorer,
)
# To define maximum number of columns to be displayed in a dataframe
pd.set_option("display.max_columns", None)
# To supress scientific notations for a dataframe
pd.set_option("display.float_format", lambda x: "%.3f" % x)
from sklearn.model_selection import GridSearchCV
# To supress warnings
import warnings
warnings.filterwarnings("ignore")
# Removes the limit from the number of displayed columns and rows.
# This is so I can see the entire dataframe when I print it
pd.set_option("display.max_columns", None)
# pd.set_option('display.max_rows', None)
pd.set_option("display.max_rows", 200)
df = pd.read_csv('heart.csv')
# Observing the first five rows of the dataset.
df.head()
Age | Sex | ChestPainType | RestingBP | Cholesterol | FastingBS | RestingECG | MaxHR | ExerciseAngina | Oldpeak | ST_Slope | HeartDisease | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 40 | M | ATA | 140 | 289 | 0 | Normal | 172 | N | 0.000 | Up | 0 |
1 | 49 | F | NAP | 160 | 180 | 0 | Normal | 156 | N | 1.000 | Flat | 1 |
2 | 37 | M | ATA | 130 | 283 | 0 | ST | 98 | N | 0.000 | Up | 0 |
3 | 48 | F | ASY | 138 | 214 | 0 | Normal | 108 | Y | 1.500 | Flat | 1 |
4 | 54 | M | NAP | 150 | 195 | 0 | Normal | 122 | N | 0.000 | Up | 0 |
# Observing the last five rows of the dataset.
df.tail()
Age | Sex | ChestPainType | RestingBP | Cholesterol | FastingBS | RestingECG | MaxHR | ExerciseAngina | Oldpeak | ST_Slope | HeartDisease | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
913 | 45 | M | TA | 110 | 264 | 0 | Normal | 132 | N | 1.200 | Flat | 1 |
914 | 68 | M | ASY | 144 | 193 | 1 | Normal | 141 | N | 3.400 | Flat | 1 |
915 | 57 | M | ASY | 130 | 131 | 0 | Normal | 115 | Y | 1.200 | Flat | 1 |
916 | 57 | F | ATA | 130 | 236 | 0 | LVH | 174 | N | 0.000 | Flat | 1 |
917 | 38 | M | NAP | 138 | 175 | 0 | Normal | 173 | N | 0.000 | Up | 0 |
# Making a copy of the data set to preserve integrity
hd = df.copy()
print(f"There are {hd.shape[0]} rows and {hd.shape[1]} columns.") # f-string
# Viewing random rows
np.random.seed(1)
hd.sample(n=10)
There are 918 rows and 12 columns.
Age | Sex | ChestPainType | RestingBP | Cholesterol | FastingBS | RestingECG | MaxHR | ExerciseAngina | Oldpeak | ST_Slope | HeartDisease | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
900 | 58 | M | ASY | 114 | 318 | 0 | ST | 140 | N | 4.400 | Down | 1 |
570 | 56 | M | ASY | 128 | 223 | 0 | ST | 119 | Y | 2.000 | Down | 1 |
791 | 51 | M | ASY | 140 | 298 | 0 | Normal | 122 | Y | 4.200 | Flat | 1 |
189 | 53 | M | ASY | 180 | 285 | 0 | ST | 120 | Y | 1.500 | Flat | 1 |
372 | 63 | M | ASY | 185 | 0 | 0 | Normal | 98 | Y | 0.000 | Up | 1 |
191 | 50 | M | ATA | 170 | 209 | 0 | ST | 116 | N | 0.000 | Up | 0 |
643 | 58 | M | NAP | 112 | 230 | 0 | LVH | 165 | N | 2.500 | Flat | 1 |
474 | 62 | M | ATA | 131 | 0 | 0 | Normal | 130 | N | 0.100 | Up | 0 |
65 | 37 | F | ATA | 120 | 260 | 0 | Normal | 130 | N | 0.000 | Up | 0 |
890 | 64 | M | TA | 170 | 227 | 0 | LVH | 155 | N | 0.600 | Flat | 0 |
# Looking at rows and columns
hd.shape
(918, 12)
# Checking for missing values
hd.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 918 entries, 0 to 917 Data columns (total 12 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 Age 918 non-null int64 1 Sex 918 non-null object 2 ChestPainType 918 non-null object 3 RestingBP 918 non-null int64 4 Cholesterol 918 non-null int64 5 FastingBS 918 non-null int64 6 RestingECG 918 non-null object 7 MaxHR 918 non-null int64 8 ExerciseAngina 918 non-null object 9 Oldpeak 918 non-null float64 10 ST_Slope 918 non-null object 11 HeartDisease 918 non-null int64 dtypes: float64(1), int64(6), object(5) memory usage: 86.2+ KB
## Converting the data type of categorical features to 'category'
cat_cols = ['Sex', 'ChestPainType', 'FastingBS', 'ExerciseAngina', 'ST_Slope', 'HeartDisease','RestingECG']
hd[cat_cols] = hd[cat_cols].astype("category")
hd.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 918 entries, 0 to 917 Data columns (total 12 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 Age 918 non-null int64 1 Sex 918 non-null category 2 ChestPainType 918 non-null category 3 RestingBP 918 non-null int64 4 Cholesterol 918 non-null int64 5 FastingBS 918 non-null category 6 RestingECG 918 non-null category 7 MaxHR 918 non-null int64 8 ExerciseAngina 918 non-null category 9 Oldpeak 918 non-null float64 10 ST_Slope 918 non-null category 11 HeartDisease 918 non-null category dtypes: category(7), float64(1), int64(4) memory usage: 43.2 KB
hd.describe().T
count | mean | std | min | 25% | 50% | 75% | max | |
---|---|---|---|---|---|---|---|---|
Age | 918.000 | 53.511 | 9.433 | 28.000 | 47.000 | 54.000 | 60.000 | 77.000 |
RestingBP | 918.000 | 132.397 | 18.514 | 0.000 | 120.000 | 130.000 | 140.000 | 200.000 |
Cholesterol | 918.000 | 198.800 | 109.384 | 0.000 | 173.250 | 223.000 | 267.000 | 603.000 |
MaxHR | 918.000 | 136.809 | 25.460 | 60.000 | 120.000 | 138.000 | 156.000 | 202.000 |
Oldpeak | 918.000 | 0.887 | 1.067 | -2.600 | 0.000 | 0.600 | 1.500 | 6.200 |
hd.describe(include=["category"]).T
count | unique | top | freq | |
---|---|---|---|---|
Sex | 918 | 2 | M | 725 |
ChestPainType | 918 | 4 | ASY | 496 |
FastingBS | 918 | 2 | 0 | 704 |
RestingECG | 918 | 3 | Normal | 552 |
ExerciseAngina | 918 | 2 | N | 547 |
ST_Slope | 918 | 3 | Flat | 460 |
HeartDisease | 918 | 2 | 1 | 508 |
# Looking at unique values for categorical variables.
for i in cat_cols:
print('Unique values in', i, 'are:')
print(hd[i].value_counts())
print("*" * 50)
Unique values in Sex are: M 725 F 193 Name: Sex, dtype: int64 ************************************************** Unique values in ChestPainType are: ASY 496 NAP 203 ATA 173 TA 46 Name: ChestPainType, dtype: int64 ************************************************** Unique values in FastingBS are: 0 704 1 214 Name: FastingBS, dtype: int64 ************************************************** Unique values in ExerciseAngina are: N 547 Y 371 Name: ExerciseAngina, dtype: int64 ************************************************** Unique values in ST_Slope are: Flat 460 Up 395 Down 63 Name: ST_Slope, dtype: int64 ************************************************** Unique values in HeartDisease are: 1 508 0 410 Name: HeartDisease, dtype: int64 ************************************************** Unique values in RestingECG are: Normal 552 LVH 188 ST 178 Name: RestingECG, dtype: int64 **************************************************
def histogram_boxplot(data, feature, figsize=(12, 7), kde=False, bins=None):
"""
Boxplot and histogram combined
data: dataframe
feature: dataframe column
figsize: size of figure (default (12,7))
kde: whether to show the density curve (default False)
bins: number of bins for histogram (default None)
"""
f2, (ax_box2, ax_hist2) = plt.subplots(
nrows=2, # Number of rows of the subplot grid= 2
sharex=True, # x-axis will be shared among all subplots
gridspec_kw={"height_ratios": (0.25, 0.75)},
figsize=figsize,
) # creating the 2 subplots
sns.boxplot(
data=data, x=feature, ax=ax_box2, showmeans=True, color="violet"
) # boxplot will be created and a star will indicate the mean value of the column
sns.histplot(
data=data, x=feature, kde=kde, ax=ax_hist2, bins=bins, palette="winter"
) if bins else sns.histplot(
data=data, x=feature, kde=kde, ax=ax_hist2
) # For histogram
ax_hist2.axvline(
data[feature].mean(), color="green", linestyle="--"
) # Add mean to the histogram
ax_hist2.axvline(
data[feature].median(), color="black", linestyle="-"
) # Add median to the histogram
# Observation on Age
histogram_boxplot(hd, 'Age')
# Observation on RestingBP
histogram_boxplot(hd, 'RestingBP')
# Observation on Cholesterol
histogram_boxplot(hd, 'Cholesterol')
hd['Cholesterol'].describe()
count 918.000 mean 198.800 std 109.384 min 0.000 25% 173.250 50% 223.000 75% 267.000 max 603.000 Name: Cholesterol, dtype: float64
hd.Cholesterol.replace({0: hd.Cholesterol.mean()}, inplace=True)
hd['Cholesterol'].describe()
count 918.000 mean 236.047 std 56.241 min 85.000 25% 198.800 50% 223.000 75% 267.000 max 603.000 Name: Cholesterol, dtype: float64
hd.describe().T
count | mean | std | min | 25% | 50% | 75% | max | |
---|---|---|---|---|---|---|---|---|
Age | 918.000 | 53.511 | 9.433 | 28.000 | 47.000 | 54.000 | 60.000 | 77.000 |
RestingBP | 918.000 | 132.397 | 18.514 | 0.000 | 120.000 | 130.000 | 140.000 | 200.000 |
Cholesterol | 918.000 | 236.047 | 56.241 | 85.000 | 198.800 | 223.000 | 267.000 | 603.000 |
MaxHR | 918.000 | 136.809 | 25.460 | 60.000 | 120.000 | 138.000 | 156.000 | 202.000 |
Oldpeak | 918.000 | 0.887 | 1.067 | -2.600 | 0.000 | 0.600 | 1.500 | 6.200 |
# Observation on MaxHR
histogram_boxplot(hd, 'MaxHR')
# Observation on Oldpeak
histogram_boxplot(hd, 'Oldpeak')
# function to create labeled barplots
def labeled_barplot(data, feature, perc=False, n=None):
"""
Barplot with percentage at the top
data: dataframe
feature: dataframe column
perc: whether to display percentages instead of count (default is False)
n: displays the top n category levels (default is None, i.e., display all levels)
"""
total = len(data[feature]) # length of the column
count = data[feature].nunique()
if n is None:
plt.figure(figsize=(count + 1, 5))
else:
plt.figure(figsize=(n + 1, 5))
plt.xticks(rotation=90, fontsize=15)
ax = sns.countplot(
data=data,
x=feature,
palette="Paired",
order=data[feature].value_counts().index[:n].sort_values(),
)
for p in ax.patches:
if perc == True:
label = "{:.1f}%".format(
100 * p.get_height() / total
) # percentage of each class of the category
else:
label = p.get_height() # count of each level of the category
x = p.get_x() + p.get_width() / 2 # width of the plot
y = p.get_height() # height of the plot
ax.annotate(
label,
(x, y),
ha="center",
va="center",
size=12,
xytext=(0, 5),
textcoords="offset points",
) # annotate the percentage
plt.show() # show the plot
labeled_barplot(hd, 'Sex', perc=True)
labeled_barplot(hd, 'ChestPainType', perc=True)
labeled_barplot(hd, 'FastingBS', perc=True)
labeled_barplot(hd, 'RestingECG', perc=True)
labeled_barplot(hd, 'ExerciseAngina', perc=True)
labeled_barplot(hd, 'ST_Slope', perc=True)
# Changing HeartDisease to a integer for the heat map
hd['HeartDisease'] = hd.HeartDisease.astype(int)
plt.figure(figsize=(15, 7))
sns.heatmap(hd.corr(), annot=True, vmin=-1, vmax=1, fmt=".2f", cmap="Spectral")
plt.show()
# Changing HeartDisease back to a categorical variable
hd['HeartDisease'] = hd.HeartDisease.astype('category')
def stacked_barplot(data, predictor, target):
"""
Print the category counts and plot a stacked bar chart
data: dataframe
predictor: independent variable
target: target variable
"""
count = data[predictor].nunique()
sorter = data[target].value_counts().index[-1]
tab1 = pd.crosstab(data[predictor], data[target], margins=True).sort_values(
by=sorter, ascending=False
)
print(tab1)
print("-" * 120)
tab = pd.crosstab(data[predictor], data[target], normalize="index").sort_values(
by=sorter, ascending=False
)
tab.plot(kind="bar", stacked=True, figsize=(count + 5, 5))
plt.legend(
loc="lower left", frameon=False,
)
plt.legend(loc="upper left", bbox_to_anchor=(1, 1))
plt.show()
stacked_barplot(hd, 'Sex', 'HeartDisease')
HeartDisease 0 1 All Sex All 410 508 918 M 267 458 725 F 143 50 193 ------------------------------------------------------------------------------------------------------------------------
stacked_barplot(hd, 'ChestPainType', 'HeartDisease')
HeartDisease 0 1 All ChestPainType All 410 508 918 ATA 149 24 173 NAP 131 72 203 ASY 104 392 496 TA 26 20 46 ------------------------------------------------------------------------------------------------------------------------
stacked_barplot(hd, 'FastingBS', 'HeartDisease')
HeartDisease 0 1 All FastingBS All 410 508 918 0 366 338 704 1 44 170 214 ------------------------------------------------------------------------------------------------------------------------
stacked_barplot(hd, 'RestingECG', 'HeartDisease')
HeartDisease 0 1 All RestingECG All 410 508 918 Normal 267 285 552 LVH 82 106 188 ST 61 117 178 ------------------------------------------------------------------------------------------------------------------------
stacked_barplot(hd, 'ST_Slope', 'HeartDisease')
HeartDisease 0 1 All ST_Slope All 410 508 918 Up 317 78 395 Flat 79 381 460 Down 14 49 63 ------------------------------------------------------------------------------------------------------------------------
# Looking at averages of postive and negative cases between men and women.
hd.groupby(["Sex","HeartDisease"], as_index=False)["Age"].mean()
Sex | HeartDisease | Age | |
---|---|---|---|
0 | F | 0 | 51.203 |
1 | F | 1 | 56.180 |
2 | M | 0 | 50.202 |
3 | M | 1 | 55.869 |
# Looking at averages of postive and negative cases between men and women.
hd.groupby(["Sex","HeartDisease"], as_index=False)["MaxHR"].mean()
Sex | HeartDisease | MaxHR | |
---|---|---|---|
0 | F | 0 | 149.049 |
1 | F | 1 | 137.820 |
2 | M | 0 | 147.670 |
3 | M | 1 | 126.546 |
# Looking at averages of postive and negative cases between men and women.
hd.groupby(["Sex","HeartDisease"], as_index=False)["Cholesterol"].mean()
Sex | HeartDisease | Cholesterol | |
---|---|---|---|
0 | F | 0 | 248.831 |
1 | F | 1 | 263.100 |
2 | M | 0 | 230.386 |
3 | M | 1 | 232.403 |
### function to plot distributions wrt target
def distribution_plot_wrt_target(data, predictor, target):
fig, axs = plt.subplots(2, 2, figsize=(12, 10))
target_uniq = data[target].unique()
axs[0, 0].set_title("Distribution of target for target=" + str(target_uniq[0]))
sns.histplot(
data=data[data[target] == target_uniq[0]],
x=predictor,
kde=True,
ax=axs[0, 0],
color="teal",
stat="density",
)
axs[0, 1].set_title("Distribution of target for target=" + str(target_uniq[1]))
sns.histplot(
data=data[data[target] == target_uniq[1]],
x=predictor,
kde=True,
ax=axs[0, 1],
color="orange",
stat="density",
)
axs[1, 0].set_title("Boxplot w.r.t target")
sns.boxplot(data=data, x=target, y=predictor, ax=axs[1, 0], palette="gist_rainbow")
axs[1, 1].set_title("Boxplot (without outliers) w.r.t target")
sns.boxplot(
data=data,
x=target,
y=predictor,
ax=axs[1, 1],
showfliers=False,
palette="gist_rainbow",
)
plt.tight_layout()
plt.show()
distribution_plot_wrt_target(hd, 'Age', 'HeartDisease')
distribution_plot_wrt_target(hd, 'RestingBP', 'HeartDisease')
distribution_plot_wrt_target(hd, 'Cholesterol', 'HeartDisease')
distribution_plot_wrt_target(hd, 'MaxHR', 'HeartDisease')
# let's plot the boxplots of all columns to check for outliers
numeric_columns = ['Age', 'RestingBP', 'Cholesterol', 'MaxHR', 'Oldpeak']
plt.figure(figsize=(20, 30))
for i, variable in enumerate(numeric_columns):
plt.subplot(5, 4, i + 1)
plt.boxplot(df[variable], whis=1.5)
plt.tight_layout()
plt.title(variable)
plt.show()
# Let's treat outliers by flooring and capping
def treat_outliers(hd, col):
"""
treats outliers in a variable
col: str, name of the numerical variable
df: dataframe
col: name of the column
"""
Q1 = hd[col].quantile(0.25) # 25th quantile
Q3 = hd[col].quantile(0.75) # 75th quantile
IQR = Q3 - Q1
Lower_Whisker = Q1 - 1.5 * IQR
Upper_Whisker = Q3 + 1.5 * IQR
# all the values smaller than Lower_Whisker will be assigned the value of Lower_Whisker
# all the values greater than Upper_Whisker will be assigned the value of Upper_Whisker
hd[col] = np.clip(hd[col], Lower_Whisker, Upper_Whisker)
return df
def treat_outliers_all(hd, col_list):
"""
treat outlier in all numerical variables
col_list: list of numerical variables
df: data frame
"""
for c in col_list:
hd = treat_outliers(df, c)
return hd
numerical_col = hd.select_dtypes(include=np.number).columns.tolist()
hd = treat_outliers_all(hd, numerical_col)
# let's look at box plot to see if outliers have been treated or not
plt.figure(figsize=(20, 30))
for i, variable in enumerate(numeric_columns):
plt.subplot(5, 4, i + 1)
plt.boxplot(df[variable], whis=1.5)
plt.tight_layout()
plt.title(variable)
plt.show()
hd.head(10)
Age | Sex | ChestPainType | RestingBP | Cholesterol | FastingBS | RestingECG | MaxHR | ExerciseAngina | Oldpeak | ST_Slope | HeartDisease | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 40 | M | ATA | 140 | 289.000 | 0 | Normal | 172 | N | 0.000 | Up | 0 |
1 | 49 | F | NAP | 160 | 180.000 | 0 | Normal | 156 | N | 1.000 | Flat | 1 |
2 | 37 | M | ATA | 130 | 283.000 | 0 | ST | 98 | N | 0.000 | Up | 0 |
3 | 48 | F | ASY | 138 | 214.000 | 0 | Normal | 108 | Y | 1.500 | Flat | 1 |
4 | 54 | M | NAP | 150 | 195.000 | 0 | Normal | 122 | N | 0.000 | Up | 0 |
5 | 39 | M | NAP | 120 | 339.000 | 0 | Normal | 170 | N | 0.000 | Up | 0 |
6 | 45 | F | ATA | 130 | 237.000 | 0 | Normal | 170 | N | 0.000 | Up | 0 |
7 | 54 | M | ATA | 110 | 208.000 | 0 | Normal | 142 | N | 0.000 | Up | 0 |
8 | 37 | M | ASY | 140 | 207.000 | 0 | Normal | 130 | Y | 1.500 | Flat | 1 |
9 | 48 | F | ATA | 120 | 284.000 | 0 | Normal | 120 | N | 0.000 | Up | 0 |
# defining a function to compute different metrics to check performance of a classification model built using statsmodels
def model_performance_classification_statsmodels(
model, predictors, target, threshold=0.5
):
"""
Function to compute different metrics to check classification model performance
model: classifier
predictors: independent variables
target: dependent variable
threshold: threshold for classifying the observation as class 1
"""
# checking which probabilities are greater than threshold
pred_temp = model.predict(predictors) > threshold
# rounding off the above values to get classes
pred = np.round(pred_temp)
acc = accuracy_score(target, pred) # to compute Accuracy
recall = recall_score(target, pred) # to compute Recall
precision = precision_score(target, pred) # to compute Precision
f1 = f1_score(target, pred) # to compute F1-score
# creating a dataframe of metrics
df_perf = pd.DataFrame(
{"Accuracy": acc, "Recall": recall, "Precision": precision, "F1": f1,},
index=[0],
)
return df_perf
# defining a function to plot the confusion_matrix of a classification model
def confusion_matrix_statsmodels(model, predictors, target, threshold=0.5):
"""
To plot the confusion_matrix with percentages
model: classifier
predictors: independent variables
target: dependent variable
threshold: threshold for classifying the observation as class 1
"""
y_pred = model.predict(predictors) > threshold
cm = confusion_matrix(target, y_pred)
labels = np.asarray(
[
["{0:0.0f}".format(item) + "\n{0:.2%}".format(item / cm.flatten().sum())]
for item in cm.flatten()
]
).reshape(2, 2)
plt.figure(figsize=(6, 4))
sns.heatmap(cm, annot=labels, fmt="")
plt.ylabel("True label")
plt.xlabel("Predicted label")
X = hd.drop("HeartDisease", axis=1)
Y = hd["HeartDisease"]
# creating dummy variables
X = pd.get_dummies(X, drop_first=True)
# adding constant
X = sm.add_constant(X)
# splitting in training and test set
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.3, random_state=1)
print("Shape of Training set : ", X_train.shape)
print("Shape of test set : ", X_test.shape)
print("Percentage of classes in training set:")
print(y_train.value_counts(normalize=True))
print("Percentage of classes in test set:")
print(y_test.value_counts(normalize=True))
Shape of Training set : (642, 16) Shape of test set : (276, 16) Percentage of classes in training set: 1 0.531 0 0.469 Name: HeartDisease, dtype: float64 Percentage of classes in test set: 1 0.605 0 0.395 Name: HeartDisease, dtype: float64
logit = sm.Logit(y_train, X_train.astype(float))
lg = logit.fit(
disp=False
) # setting disp=False will remove the information on number of iterations
print(lg.summary())
Logit Regression Results ============================================================================== Dep. Variable: HeartDisease No. Observations: 642 Model: Logit Df Residuals: 626 Method: MLE Df Model: 15 Date: Tue, 08 Mar 2022 Pseudo R-squ.: 0.5284 Time: 22:02:12 Log-Likelihood: -209.28 converged: True LL-Null: -443.75 Covariance Type: nonrobust LLR p-value: 2.058e-90 ===================================================================================== coef std err z P>|z| [0.025 0.975] ------------------------------------------------------------------------------------- const -2.1176 1.721 -1.230 0.219 -5.491 1.255 Age 0.0185 0.016 1.145 0.252 -0.013 0.050 RestingBP 0.0094 0.008 1.150 0.250 -0.007 0.025 Cholesterol -0.0055 0.002 -3.609 0.000 -0.008 -0.003 FastingBS 1.0620 0.328 3.237 0.001 0.419 1.705 MaxHR -0.0023 0.006 -0.362 0.717 -0.014 0.010 Oldpeak 0.4359 0.141 3.096 0.002 0.160 0.712 Sex_M 1.5323 0.343 4.466 0.000 0.860 2.205 ChestPainType_ATA -2.0184 0.416 -4.849 0.000 -2.834 -1.203 ChestPainType_NAP -1.9138 0.317 -6.043 0.000 -2.534 -1.293 ChestPainType_TA -1.6108 0.528 -3.053 0.002 -2.645 -0.577 RestingECG_Normal -0.0792 0.326 -0.243 0.808 -0.719 0.560 RestingECG_ST -0.3289 0.414 -0.795 0.427 -1.140 0.482 ExerciseAngina_Y 0.8733 0.284 3.072 0.002 0.316 1.431 ST_Slope_Flat 1.4135 0.527 2.680 0.007 0.380 2.447 ST_Slope_Up -0.9200 0.556 -1.655 0.098 -2.010 0.170 =====================================================================================
print("Training performance:")
model_performance_classification_statsmodels(lg, X_train, y_train)
Training performance:
Accuracy | Recall | Precision | F1 | |
---|---|---|---|---|
0 | 0.871 | 0.903 | 0.860 | 0.881 |
Observations
Negative values of the coefficient shows that probability of patient getting Heart Disease decreases with the increase of corresponding attribute value.
Positive values of the coefficient show that that probability of a patient getting Heart Disease increases with the increase of corresponding attribute value.
p-value of a variable indicates if the variable is significant or not. If we consider the significance level to be 0.05 (5%), then any variable with a p-value less than 0.05 would be considered significant.
But these variables might contain multicollinearity, which will affect the p-values.
We will have to remove multicollinearity from the data to get reliable coefficients and p-values.
There are different ways of detecting (or testing) multi-collinearity, one such way is the Variation Inflation Factor.
Variance Inflation factor: Variance inflation factors measure the inflation in the variances of the regression coefficients estimates due to collinearity that exist among the predictors. It is a measure of how much the variance of the estimated regression coefficient βk is "inflated" by the existence of correlation among the predictor variables in the model.
General Rule of thumb: If VIF is 1 then there is no correlation among the kth predictor and the remaining predictor variables, and hence the variance of β̂k is not inflated at all. Whereas if VIF exceeds 5, we say there is moderate VIF and if it is 10 or exceeding 10, it shows signs of high multi-collinearity. But the purpose of the analysis should dictate which threshold to use.
vif_series = pd.Series(
[variance_inflation_factor(X_train.values, i) for i in range(X_train.shape[1])],
index=X_train.columns,
dtype=float,
)
print("Series before feature selection: \n\n{}\n".format(vif_series))
Series before feature selection: const 200.343 Age 1.386 RestingBP 1.130 Cholesterol 1.309 FastingBS 1.170 MaxHR 1.577 Oldpeak 1.460 Sex_M 1.113 ChestPainType_ATA 1.484 ChestPainType_NAP 1.231 ChestPainType_TA 1.103 RestingECG_Normal 1.773 RestingECG_ST 1.774 ExerciseAngina_Y 1.552 ST_Slope_Flat 5.237 ST_Slope_Up 6.179 dtype: float64
X_train1 = X_train.drop("ST_Slope_Flat", axis=1)
vif_series2 = pd.Series(
[variance_inflation_factor(X_train1.values, i) for i in range(X_train1.shape[1])],
index=X_train1.columns,
)
print("Series before feature selection: \n\n{}\n".format(vif_series2))
Series before feature selection: const 186.049 Age 1.386 RestingBP 1.124 Cholesterol 1.297 FastingBS 1.170 MaxHR 1.576 Oldpeak 1.410 Sex_M 1.113 ChestPainType_ATA 1.483 ChestPainType_NAP 1.230 ChestPainType_TA 1.103 RestingECG_Normal 1.771 RestingECG_ST 1.770 ExerciseAngina_Y 1.552 ST_Slope_Up 1.669 dtype: float64
logit1 = sm.Logit(y_train, X_train1.astype(float))
lg1 = logit1.fit(disp=False)
print("Training performance:")
model_performance_classification_statsmodels(lg1, X_train1, y_train)
Training performance:
Accuracy | Recall | Precision | F1 | |
---|---|---|---|---|
0 | 0.864 | 0.897 | 0.855 | 0.876 |
X_train2 = X_train.drop("ST_Slope_Up", axis=1)
vif_series3 = pd.Series(
[variance_inflation_factor(X_train2.values, i) for i in range(X_train2.shape[1])],
index=X_train2.columns,
)
print("Series before feature selection: \n\n{}\n".format(vif_series3))
Series before feature selection: const 187.160 Age 1.384 RestingBP 1.127 Cholesterol 1.302 FastingBS 1.167 MaxHR 1.567 Oldpeak 1.331 Sex_M 1.113 ChestPainType_ATA 1.475 ChestPainType_NAP 1.231 ChestPainType_TA 1.103 RestingECG_Normal 1.771 RestingECG_ST 1.769 ExerciseAngina_Y 1.521 ST_Slope_Flat 1.415 dtype: float64
logit2 = sm.Logit(y_train, X_train2.astype(float))
lg2 = logit2.fit(disp=False)
print("Training performance:")
model_performance_classification_statsmodels(lg2, X_train2, y_train)
Training performance:
Accuracy | Recall | Precision | F1 | |
---|---|---|---|---|
0 | 0.864 | 0.894 | 0.857 | 0.875 |
Observations
# Let's Look at summary of lg2 model.
print(lg2.summary())
Logit Regression Results ============================================================================== Dep. Variable: HeartDisease No. Observations: 642 Model: Logit Df Residuals: 627 Method: MLE Df Model: 14 Date: Tue, 08 Mar 2022 Pseudo R-squ.: 0.5252 Time: 22:02:13 Log-Likelihood: -210.69 converged: True LL-Null: -443.75 Covariance Type: nonrobust LLR p-value: 1.388e-90 ===================================================================================== coef std err z P>|z| [0.025 0.975] ------------------------------------------------------------------------------------- const -2.5868 1.670 -1.549 0.121 -5.860 0.686 Age 0.0180 0.016 1.119 0.263 -0.014 0.049 RestingBP 0.0086 0.008 1.066 0.286 -0.007 0.024 Cholesterol -0.0056 0.002 -3.682 0.000 -0.009 -0.003 FastingBS 1.0958 0.327 3.355 0.001 0.456 1.736 MaxHR -0.0037 0.006 -0.599 0.549 -0.016 0.008 Oldpeak 0.5148 0.132 3.885 0.000 0.255 0.774 Sex_M 1.5003 0.341 4.399 0.000 0.832 2.169 ChestPainType_ATA -2.0392 0.414 -4.922 0.000 -2.851 -1.227 ChestPainType_NAP -1.8823 0.314 -5.992 0.000 -2.498 -1.267 ChestPainType_TA -1.5810 0.527 -3.002 0.003 -2.613 -0.549 RestingECG_Normal -0.1117 0.324 -0.345 0.730 -0.747 0.524 RestingECG_ST -0.3484 0.411 -0.849 0.396 -1.153 0.456 ExerciseAngina_Y 0.9375 0.281 3.342 0.001 0.388 1.487 ST_Slope_Flat 2.1507 0.273 7.864 0.000 1.615 2.687 =====================================================================================
# Dropping Age
X_train3 = X_train2.drop(
['Age'],
axis=1,
)
logit3 = sm.Logit(y_train, X_train3.astype(float))
lg3 = logit3.fit(disp=False)
print(lg3.summary())
Logit Regression Results ============================================================================== Dep. Variable: HeartDisease No. Observations: 642 Model: Logit Df Residuals: 628 Method: MLE Df Model: 13 Date: Tue, 08 Mar 2022 Pseudo R-squ.: 0.5238 Time: 22:02:13 Log-Likelihood: -211.32 converged: True LL-Null: -443.75 Covariance Type: nonrobust LLR p-value: 4.193e-91 ===================================================================================== coef std err z P>|z| [0.025 0.975] ------------------------------------------------------------------------------------- const -1.5608 1.386 -1.126 0.260 -4.278 1.156 RestingBP 0.0106 0.008 1.346 0.178 -0.005 0.026 Cholesterol -0.0057 0.002 -3.793 0.000 -0.009 -0.003 FastingBS 1.1181 0.326 3.434 0.001 0.480 1.756 MaxHR -0.0057 0.006 -0.971 0.331 -0.017 0.006 Oldpeak 0.5467 0.130 4.199 0.000 0.292 0.802 Sex_M 1.4986 0.341 4.389 0.000 0.829 2.168 ChestPainType_ATA -2.0367 0.413 -4.933 0.000 -2.846 -1.227 ChestPainType_NAP -1.8694 0.313 -5.976 0.000 -2.483 -1.256 ChestPainType_TA -1.5265 0.521 -2.930 0.003 -2.548 -0.505 RestingECG_Normal -0.1924 0.315 -0.610 0.542 -0.810 0.425 RestingECG_ST -0.3915 0.408 -0.959 0.338 -1.192 0.409 ExerciseAngina_Y 0.9365 0.280 3.340 0.001 0.387 1.486 ST_Slope_Flat 2.1626 0.273 7.925 0.000 1.628 2.697 =====================================================================================
# Dropping RestingBP
X_train4 = X_train3.drop(["RestingBP"], axis=1)
logit4 = sm.Logit(y_train, X_train4.astype(float))
lg4 = logit4.fit(disp=False)
print(lg4.summary())
Logit Regression Results ============================================================================== Dep. Variable: HeartDisease No. Observations: 642 Model: Logit Df Residuals: 629 Method: MLE Df Model: 12 Date: Tue, 08 Mar 2022 Pseudo R-squ.: 0.5217 Time: 22:02:13 Log-Likelihood: -212.24 converged: True LL-Null: -443.75 Covariance Type: nonrobust LLR p-value: 1.609e-91 ===================================================================================== coef std err z P>|z| [0.025 0.975] ------------------------------------------------------------------------------------- const -0.2343 0.971 -0.241 0.809 -2.137 1.669 Cholesterol -0.0053 0.001 -3.637 0.000 -0.008 -0.002 FastingBS 1.1464 0.325 3.528 0.000 0.510 1.783 MaxHR -0.0061 0.006 -1.050 0.294 -0.017 0.005 Oldpeak 0.5633 0.129 4.371 0.000 0.311 0.816 Sex_M 1.4695 0.338 4.342 0.000 0.806 2.133 ChestPainType_ATA -1.9931 0.410 -4.861 0.000 -2.797 -1.189 ChestPainType_NAP -1.8467 0.312 -5.915 0.000 -2.459 -1.235 ChestPainType_TA -1.4814 0.521 -2.844 0.004 -2.502 -0.461 RestingECG_Normal -0.1838 0.314 -0.586 0.558 -0.798 0.431 RestingECG_ST -0.3352 0.404 -0.830 0.407 -1.127 0.456 ExerciseAngina_Y 0.9585 0.280 3.426 0.001 0.410 1.507 ST_Slope_Flat 2.1747 0.272 8.000 0.000 1.642 2.707 =====================================================================================
# Dropping MaxHR
X_train5 = X_train4.drop(["MaxHR"], axis=1)
logit5 = sm.Logit(y_train, X_train5.astype(float))
lg5 = logit5.fit(disp=False)
print(lg5.summary())
Logit Regression Results ============================================================================== Dep. Variable: HeartDisease No. Observations: 642 Model: Logit Df Residuals: 630 Method: MLE Df Model: 11 Date: Tue, 08 Mar 2022 Pseudo R-squ.: 0.5205 Time: 22:02:14 Log-Likelihood: -212.79 converged: True LL-Null: -443.75 Covariance Type: nonrobust LLR p-value: 4.155e-92 ===================================================================================== coef std err z P>|z| [0.025 0.975] ------------------------------------------------------------------------------------- const -1.0800 0.546 -1.979 0.048 -2.150 -0.010 Cholesterol -0.0057 0.001 -4.034 0.000 -0.008 -0.003 FastingBS 1.1256 0.325 3.467 0.001 0.489 1.762 Oldpeak 0.5593 0.128 4.357 0.000 0.308 0.811 Sex_M 1.5028 0.338 4.453 0.000 0.841 2.164 ChestPainType_ATA -2.0306 0.409 -4.967 0.000 -2.832 -1.229 ChestPainType_NAP -1.8881 0.309 -6.103 0.000 -2.494 -1.282 ChestPainType_TA -1.5145 0.527 -2.876 0.004 -2.547 -0.482 RestingECG_Normal -0.1571 0.311 -0.505 0.614 -0.767 0.453 RestingECG_ST -0.2662 0.398 -0.668 0.504 -1.047 0.514 ExerciseAngina_Y 1.0246 0.273 3.757 0.000 0.490 1.559 ST_Slope_Flat 2.2346 0.266 8.390 0.000 1.713 2.757 =====================================================================================
# Dropping RestingECG_Normal
X_train6 = X_train5.drop(["RestingECG_Normal"], axis=1)
logit6 = sm.Logit(y_train, X_train6.astype(float))
lg6 = logit6.fit(disp=False)
print(lg6.summary())
Logit Regression Results ============================================================================== Dep. Variable: HeartDisease No. Observations: 642 Model: Logit Df Residuals: 631 Method: MLE Df Model: 10 Date: Tue, 08 Mar 2022 Pseudo R-squ.: 0.5202 Time: 22:02:14 Log-Likelihood: -212.92 converged: True LL-Null: -443.75 Covariance Type: nonrobust LLR p-value: 6.745e-93 ===================================================================================== coef std err z P>|z| [0.025 0.975] ------------------------------------------------------------------------------------- const -1.2053 0.487 -2.474 0.013 -2.160 -0.250 Cholesterol -0.0056 0.001 -4.006 0.000 -0.008 -0.003 FastingBS 1.1283 0.325 3.475 0.001 0.492 1.765 Oldpeak 0.5648 0.128 4.419 0.000 0.314 0.815 Sex_M 1.4960 0.337 4.436 0.000 0.835 2.157 ChestPainType_ATA -2.0493 0.407 -5.035 0.000 -2.847 -1.251 ChestPainType_NAP -1.8868 0.309 -6.101 0.000 -2.493 -1.281 ChestPainType_TA -1.5026 0.524 -2.865 0.004 -2.531 -0.475 RestingECG_ST -0.1488 0.322 -0.461 0.644 -0.781 0.483 ExerciseAngina_Y 1.0147 0.272 3.731 0.000 0.482 1.548 ST_Slope_Flat 2.2274 0.266 8.383 0.000 1.707 2.748 =====================================================================================
# Dropping RestingECG_ST
X_train7 = X_train6.drop(["RestingECG_ST"], axis=1)
logit7 = sm.Logit(y_train, X_train7.astype(float))
lg7 = logit7.fit(disp=False)
print(lg7.summary())
Logit Regression Results ============================================================================== Dep. Variable: HeartDisease No. Observations: 642 Model: Logit Df Residuals: 632 Method: MLE Df Model: 9 Date: Tue, 08 Mar 2022 Pseudo R-squ.: 0.5200 Time: 22:02:15 Log-Likelihood: -213.02 converged: True LL-Null: -443.75 Covariance Type: nonrobust LLR p-value: 1.015e-93 ===================================================================================== coef std err z P>|z| [0.025 0.975] ------------------------------------------------------------------------------------- const -1.2368 0.482 -2.567 0.010 -2.181 -0.292 Cholesterol -0.0055 0.001 -3.987 0.000 -0.008 -0.003 FastingBS 1.1108 0.322 3.446 0.001 0.479 1.742 Oldpeak 0.5666 0.128 4.437 0.000 0.316 0.817 Sex_M 1.4837 0.336 4.422 0.000 0.826 2.141 ChestPainType_ATA -2.0534 0.408 -5.028 0.000 -2.854 -1.253 ChestPainType_NAP -1.8793 0.309 -6.087 0.000 -2.484 -1.274 ChestPainType_TA -1.5076 0.523 -2.884 0.004 -2.532 -0.483 ExerciseAngina_Y 1.0023 0.270 3.707 0.000 0.472 1.532 ST_Slope_Flat 2.2259 0.266 8.384 0.000 1.706 2.746 =====================================================================================
Checking model performance on the training set
# creating confusion matrix
confusion_matrix_statsmodels(lg7, X_train7, y_train)
log_reg_model_train_perf = model_performance_classification_statsmodels(
lg7, X_train7, y_train
)
print("Training performance:")
log_reg_model_train_perf
Training performance:
Accuracy | Recall | Precision | F1 | |
---|---|---|---|---|
0 | 0.858 | 0.886 | 0.853 | 0.869 |
logit_roc_auc_train = roc_auc_score(y_train, lg7.predict(X_train7))
fpr, tpr, thresholds = roc_curve(y_train, lg7.predict(X_train7))
plt.figure(figsize=(7, 5))
plt.plot(fpr, tpr, label="Logistic Regression (area = %0.2f)" % logit_roc_auc_train)
plt.plot([0, 1], [0, 1], "r--")
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Receiver operating characteristic")
plt.legend(loc="lower right")
plt.show()
# Optimal threshold as per AUC-ROC curve
# The optimal cut off would be where tpr is high and fpr is low
fpr, tpr, thresholds = roc_curve(y_train, lg7.predict(X_train7))
optimal_idx = np.argmax(tpr - fpr)
optimal_threshold_auc_roc = thresholds[optimal_idx]
print(optimal_threshold_auc_roc)
0.4116979571949174
# creating confusion matrix
confusion_matrix_statsmodels(
lg5, X_train5, y_train, threshold=optimal_threshold_auc_roc
)
# checking model performance for this model
log_reg_model_train_perf_threshold_auc_roc = model_performance_classification_statsmodels(
lg7, X_train7, y_train, threshold=optimal_threshold_auc_roc
)
print("Training performance:")
log_reg_model_train_perf_threshold_auc_roc
Training performance:
Accuracy | Recall | Precision | F1 | |
---|---|---|---|---|
0 | 0.868 | 0.921 | 0.844 | 0.881 |
y_scores = lg7.predict(X_train7)
prec, rec, tre = precision_recall_curve(y_train, y_scores,)
def plot_prec_recall_vs_tresh(precisions, recalls, thresholds):
plt.plot(thresholds, precisions[:-1], "b--", label="precision")
plt.plot(thresholds, recalls[:-1], "g--", label="recall")
plt.xlabel("Threshold")
plt.legend(loc="upper left")
plt.ylim([0, 1])
plt.figure(figsize=(10, 7))
plot_prec_recall_vs_tresh(prec, rec, tre)
plt.show()
# setting the threshold
optimal_threshold_curve = 0.58
# creating confusion matrix
confusion_matrix_statsmodels(lg7, X_train7, y_train, threshold=optimal_threshold_curve)
log_reg_model_train_perf_threshold_curve = model_performance_classification_statsmodels(
lg7, X_train7, y_train, threshold=optimal_threshold_curve
)
print("Training performance:")
log_reg_model_train_perf_threshold_curve
Training performance:
Accuracy | Recall | Precision | F1 | |
---|---|---|---|---|
0 | 0.857 | 0.859 | 0.869 | 0.864 |
X_test7 = X_test[list(X_train7.columns)]
# creating confusion matrix
confusion_matrix_statsmodels(lg7, X_test7, y_test)
log_reg_model_test_perf = model_performance_classification_statsmodels(
lg7, X_test7, y_test
)
print("Test performance:")
log_reg_model_test_perf
Test performance:
Accuracy | Recall | Precision | F1 | |
---|---|---|---|---|
0 | 0.888 | 0.880 | 0.930 | 0.905 |
# ROC curve on test set
logit_roc_auc_train = roc_auc_score(y_test, lg7.predict(X_test7))
fpr, tpr, thresholds = roc_curve(y_test, lg7.predict(X_test7))
plt.figure(figsize=(7, 5))
plt.plot(fpr, tpr, label="Logistic Regression (area = %0.2f)" % logit_roc_auc_train)
plt.plot([0, 1], [0, 1], "r--")
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Receiver operating characteristic")
plt.legend(loc="lower right")
plt.show()
# Using model with threshold=0.41
# creating confusion matrix
confusion_matrix_statsmodels(lg7, X_test7, y_test, threshold=optimal_threshold_auc_roc)
# checking model performance for this model
log_reg_model_test_perf_threshold_auc_roc = model_performance_classification_statsmodels(
lg7, X_test7, y_test, threshold=optimal_threshold_auc_roc
)
print("Test performance:")
log_reg_model_test_perf_threshold_auc_roc
Test performance:
Accuracy | Recall | Precision | F1 | |
---|---|---|---|---|
0 | 0.880 | 0.898 | 0.904 | 0.901 |
#Using model with threshold = 0.58
# creating confusion matrix
confusion_matrix_statsmodels(lg7, X_test7, y_test, threshold=optimal_threshold_curve)
log_reg_model_test_perf_threshold_curve = model_performance_classification_statsmodels(
lg7, X_test7, y_test, threshold=optimal_threshold_curve
)
print("Test performance:")
log_reg_model_test_perf_threshold_curve
Test performance:
Accuracy | Recall | Precision | F1 | |
---|---|---|---|---|
0 | 0.877 | 0.838 | 0.952 | 0.892 |
# training performance comparison
models_train_comp_df = pd.concat(
[
log_reg_model_train_perf.T,
log_reg_model_train_perf_threshold_auc_roc.T,
log_reg_model_train_perf_threshold_curve.T,
],
axis=1,
)
models_train_comp_df.columns = [
"Logistic Regression sklearn",
"Logistic Regression-0.41 Threshold",
"Logistic Regression-0.58 Threshold",
]
print("Training performance comparison:")
models_train_comp_df
Training performance comparison:
Logistic Regression sklearn | Logistic Regression-0.41 Threshold | Logistic Regression-0.58 Threshold | |
---|---|---|---|
Accuracy | 0.858 | 0.868 | 0.857 |
Recall | 0.886 | 0.921 | 0.859 |
Precision | 0.853 | 0.844 | 0.869 |
F1 | 0.869 | 0.881 | 0.864 |
# testing performance comparison
models_test_comp_df = pd.concat(
[
log_reg_model_test_perf.T,
log_reg_model_test_perf_threshold_auc_roc.T,
log_reg_model_test_perf_threshold_curve.T,
],
axis=1,
)
models_test_comp_df.columns = [
"Logistic Regression sklearn",
"Logistic Regression-0.41 Threshold",
"Logistic Regression-0.58 Threshold",
]
print("Test set performance comparison:")
models_test_comp_df
Test set performance comparison:
Logistic Regression sklearn | Logistic Regression-0.41 Threshold | Logistic Regression-0.58 Threshold | |
---|---|---|---|
Accuracy | 0.888 | 0.880 | 0.877 |
Recall | 0.880 | 0.898 | 0.838 |
Precision | 0.930 | 0.904 | 0.952 |
F1 | 0.905 | 0.901 | 0.892 |
Conclusion
We have been able to build a predictive model that can be used by the medical provider to find the potential patients that may get Heart Disese with recall of 0.88 on the training set and formulate marketing policies accordingly.
The logistic regression models are giving a generalized performance on training and test set.
Coefficients of FastingBS, Oldpeak, ExerciseAngina_Y, and ST_Slop_Flat are positive an increase in these will lead to an increase in chances of a patient getting Heart Disease.
Coefficients of Cholesterol, ChestPainType_ATA, ChestPainType_NAP, ChestPainType_TA are negative an increase in these will decrease the chance of a customer getting heart disease.
# Splitting the data again
X = hd.drop(["HeartDisease"], axis=1)
y = hd["HeartDisease"]
X = pd.get_dummies(X, columns=["Sex", "ChestPainType", "FastingBS", "RestingECG", "ExerciseAngina", "ST_Slope"], drop_first=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=1)
First, let's create functions to calculate different metrics and confusion matrix so that we don't have to use the same code repeatedly for each model.
# defining a function to compute different metrics to check performance of a classification model built using sklearn
def model_performance_classification_sklearn(model, predictors, target):
"""
Function to compute different metrics to check classification model performance
model: classifier
predictors: independent variables
target: dependent variable
"""
# predicting using the independent variables
pred = model.predict(predictors)
acc = accuracy_score(target, pred) # to compute Accuracy
recall = recall_score(target, pred) # to compute Recall
precision = precision_score(target, pred) # to compute Precision
f1 = f1_score(target, pred) # to compute F1-score
# creating a dataframe of metrics
df_perf = pd.DataFrame(
{"Accuracy": acc, "Recall": recall, "Precision": precision, "F1": f1,},
index=[0],
)
return df_perf
def confusion_matrix_sklearn(model, predictors, target):
"""
To plot the confusion_matrix with percentages
model: classifier
predictors: independent variables
target: dependent variable
"""
y_pred = model.predict(predictors)
cm = confusion_matrix(target, y_pred)
labels = np.asarray(
[
["{0:0.0f}".format(item) + "\n{0:.2%}".format(item / cm.flatten().sum())]
for item in cm.flatten()
]
).reshape(2, 2)
plt.figure(figsize=(6, 4))
sns.heatmap(cm, annot=labels, fmt="")
plt.ylabel("True label")
plt.xlabel("Predicted label")
model = DecisionTreeClassifier(criterion="gini", random_state=1)
model.fit(X_train, y_train)
DecisionTreeClassifier(random_state=1)
confusion_matrix_statsmodels(model, X_train, y_train)
decision_tree_perf_train = model_performance_classification_sklearn(
model, X_train, y_train
)
decision_tree_perf_train
Accuracy | Recall | Precision | F1 | |
---|---|---|---|---|
0 | 1.000 | 1.000 | 1.000 | 1.000 |
feature_names = list(X_train.columns)
print(feature_names)
['Age', 'RestingBP', 'Cholesterol', 'MaxHR', 'Oldpeak', 'Sex_M', 'ChestPainType_ATA', 'ChestPainType_NAP', 'ChestPainType_TA', 'FastingBS_1', 'RestingECG_Normal', 'RestingECG_ST', 'ExerciseAngina_Y', 'ST_Slope_Flat', 'ST_Slope_Up']
plt.figure(figsize=(20, 30))
out = tree.plot_tree(
model,
feature_names=feature_names,
filled=True,
fontsize=9,
node_ids=False,
class_names=None,
)
# below code will add arrows to the decision tree split if they are missing
for o in out:
arrow = o.arrow_patch
if arrow is not None:
arrow.set_edgecolor("black")
arrow.set_linewidth(1)
plt.show()
# Text report showing the rules of a decision tree -
print(tree.export_text(model, feature_names=feature_names, show_weights=True))
|--- ST_Slope_Up <= 0.50 | |--- MaxHR <= 150.50 | | |--- Sex_M <= 0.50 | | | |--- ExerciseAngina_Y <= 0.50 | | | | |--- RestingBP <= 135.00 | | | | | |--- Cholesterol <= 276.00 | | | | | | |--- weights: [7.00, 0.00] class: 0 | | | | | |--- Cholesterol > 276.00 | | | | | | |--- RestingBP <= 125.00 | | | | | | | |--- weights: [0.00, 1.00] class: 1 | | | | | | |--- RestingBP > 125.00 | | | | | | | |--- weights: [1.00, 0.00] class: 0 | | | | |--- RestingBP > 135.00 | | | | | |--- Age <= 64.00 | | | | | | |--- weights: [0.00, 4.00] class: 1 | | | | | |--- Age > 64.00 | | | | | | |--- weights: [2.00, 0.00] class: 0 | | | |--- ExerciseAngina_Y > 0.50 | | | | |--- Age <= 53.50 | | | | | |--- Age <= 47.50 | | | | | | |--- weights: [0.00, 2.00] class: 1 | | | | | |--- Age > 47.50 | | | | | | |--- Cholesterol <= 257.50 | | | | | | | |--- weights: [3.00, 0.00] class: 0 | | | | | | |--- Cholesterol > 257.50 | | | | | | | |--- MaxHR <= 125.00 | | | | | | | | |--- weights: [1.00, 0.00] class: 0 | | | | | | | |--- MaxHR > 125.00 | | | | | | | | |--- weights: [0.00, 2.00] class: 1 | | | | |--- Age > 53.50 | | | | | |--- Cholesterol <= 337.00 | | | | | | |--- weights: [0.00, 12.00] class: 1 | | | | | |--- Cholesterol > 337.00 | | | | | | |--- MaxHR <= 118.00 | | | | | | | |--- weights: [0.00, 1.00] class: 1 | | | | | | |--- MaxHR > 118.00 | | | | | | | |--- weights: [1.00, 0.00] class: 0 | | |--- Sex_M > 0.50 | | | |--- MaxHR <= 130.50 | | | | |--- Age <= 71.50 | | | | | |--- RestingBP <= 130.50 | | | | | | |--- RestingECG_ST <= 0.50 | | | | | | | |--- ChestPainType_NAP <= 0.50 | | | | | | | | |--- weights: [0.00, 56.00] class: 1 | | | | | | | |--- ChestPainType_NAP > 0.50 | | | | | | | | |--- Cholesterol <= 122.81 | | | | | | | | | |--- MaxHR <= 97.50 | | | | | | | | | | |--- Age <= 58.50 | | | | | | | | | | | |--- weights: [1.00, 0.00] class: 0 | | | | | | | | | | |--- Age > 58.50 | | | | | | | | | | | |--- weights: [0.00, 1.00] class: 1 | | | | | | | | | |--- MaxHR > 97.50 | | | | | | | | | | |--- weights: [0.00, 4.00] class: 1 | | | | | | | | |--- Cholesterol > 122.81 | | | | | | | | | |--- weights: [2.00, 0.00] class: 0 | | | | | | |--- RestingECG_ST > 0.50 | | | | | | | |--- Cholesterol <= 196.50 | | | | | | | | |--- ST_Slope_Flat <= 0.50 | | | | | | | | | |--- ExerciseAngina_Y <= 0.50 | | | | | | | | | | |--- weights: [1.00, 0.00] class: 0 | | | | | | | | | |--- ExerciseAngina_Y > 0.50 | | | | | | | | | | |--- weights: [0.00, 2.00] class: 1 | | | | | | | | |--- ST_Slope_Flat > 0.50 | | | | | | | | | |--- weights: [0.00, 9.00] class: 1 | | | | | | | |--- Cholesterol > 196.50 | | | | | | | | |--- Age <= 64.50 | | | | | | | | | |--- weights: [3.00, 0.00] class: 0 | | | | | | | | |--- Age > 64.50 | | | | | | | | | |--- weights: [0.00, 1.00] class: 1 | | | | | |--- RestingBP > 130.50 | | | | | | |--- Age <= 41.50 | | | | | | | |--- FastingBS_1 <= 0.50 | | | | | | | | |--- weights: [1.00, 0.00] class: 0 | | | | | | | |--- FastingBS_1 > 0.50 | | | | | | | | |--- weights: [0.00, 1.00] class: 1 | | | | | | |--- Age > 41.50 | | | | | | | |--- Oldpeak <= 0.60 | | | | | | | | |--- Oldpeak <= 0.25 | | | | | | | | | |--- weights: [0.00, 18.00] class: 1 | | | | | | | | |--- Oldpeak > 0.25 | | | | | | | | | |--- Cholesterol <= 130.81 | | | | | | | | | | |--- weights: [0.00, 1.00] class: 1 | | | | | | | | | |--- Cholesterol > 130.81 | | | | | | | | | | |--- weights: [1.00, 0.00] class: 0 | | | | | | | |--- Oldpeak > 0.60 | | | | | | | | |--- weights: [0.00, 74.00] class: 1 | | | | |--- Age > 71.50 | | | | | |--- Cholesterol <= 72.81 | | | | | | |--- weights: [1.00, 0.00] class: 0 | | | | | |--- Cholesterol > 72.81 | | | | | | |--- Cholesterol <= 267.50 | | | | | | | |--- weights: [0.00, 6.00] class: 1 | | | | | | |--- Cholesterol > 267.50 | | | | | | | |--- weights: [1.00, 0.00] class: 0 | | | |--- MaxHR > 130.50 | | | | |--- RestingBP <= 121.00 | | | | | |--- Age <= 55.50 | | | | | | |--- Oldpeak <= 0.20 | | | | | | | |--- weights: [0.00, 3.00] class: 1 | | | | | | |--- Oldpeak > 0.20 | | | | | | | |--- Age <= 41.50 | | | | | | | | |--- weights: [0.00, 2.00] class: 1 | | | | | | | |--- Age > 41.50 | | | | | | | | |--- MaxHR <= 149.00 | | | | | | | | | |--- weights: [7.00, 0.00] class: 0 | | | | | | | | |--- MaxHR > 149.00 | | | | | | | | | |--- weights: [0.00, 1.00] class: 1 | | | | | |--- Age > 55.50 | | | | | | |--- Age <= 63.50 | | | | | | | |--- weights: [0.00, 7.00] class: 1 | | | | | | |--- Age > 63.50 | | | | | | | |--- ST_Slope_Flat <= 0.50 | | | | | | | | |--- weights: [0.00, 1.00] class: 1 | | | | | | | |--- ST_Slope_Flat > 0.50 | | | | | | | | |--- weights: [1.00, 0.00] class: 0 | | | | |--- RestingBP > 121.00 | | | | | |--- Age <= 42.00 | | | | | | |--- weights: [1.00, 0.00] class: 0 | | | | | |--- Age > 42.00 | | | | | | |--- Age <= 59.50 | | | | | | | |--- Cholesterol <= 192.50 | | | | | | | | |--- Age <= 55.50 | | | | | | | | | |--- weights: [0.00, 6.00] class: 1 | | | | | | | | |--- Age > 55.50 | | | | | | | | | |--- weights: [1.00, 0.00] class: 0 | | | | | | | |--- Cholesterol > 192.50 | | | | | | | | |--- weights: [0.00, 26.00] class: 1 | | | | | | |--- Age > 59.50 | | | | | | | |--- ChestPainType_NAP <= 0.50 | | | | | | | | |--- MaxHR <= 133.00 | | | | | | | | | |--- weights: [1.00, 0.00] class: 0 | | | | | | | | |--- MaxHR > 133.00 | | | | | | | | | |--- weights: [0.00, 13.00] class: 1 | | | | | | | |--- ChestPainType_NAP > 0.50 | | | | | | | | |--- Age <= 64.50 | | | | | | | | | |--- weights: [3.00, 0.00] class: 0 | | | | | | | | |--- Age > 64.50 | | | | | | | | | |--- weights: [0.00, 2.00] class: 1 | |--- MaxHR > 150.50 | | |--- ChestPainType_NAP <= 0.50 | | | |--- MaxHR <= 153.50 | | | | |--- Age <= 42.50 | | | | | |--- weights: [0.00, 1.00] class: 1 | | | | |--- Age > 42.50 | | | | | |--- weights: [4.00, 0.00] class: 0 | | | |--- MaxHR > 153.50 | | | | |--- Cholesterol <= 294.50 | | | | | |--- Cholesterol <= 242.00 | | | | | | |--- RestingBP <= 126.00 | | | | | | | |--- ChestPainType_ATA <= 0.50 | | | | | | | | |--- RestingECG_Normal <= 0.50 | | | | | | | | | |--- weights: [1.00, 0.00] class: 0 | | | | | | | | |--- RestingECG_Normal > 0.50 | | | | | | | | | |--- weights: [0.00, 2.00] class: 1 | | | | | | | |--- ChestPainType_ATA > 0.50 | | | | | | | | |--- weights: [3.00, 0.00] class: 0 | | | | | | |--- RestingBP > 126.00 | | | | | | | |--- Cholesterol <= 231.00 | | | | | | | | |--- weights: [0.00, 5.00] class: 1 | | | | | | | |--- Cholesterol > 231.00 | | | | | | | | |--- Sex_M <= 0.50 | | | | | | | | | |--- weights: [0.00, 1.00] class: 1 | | | | | | | | |--- Sex_M > 0.50 | | | | | | | | | |--- weights: [1.00, 0.00] class: 0 | | | | | |--- Cholesterol > 242.00 | | | | | | |--- MaxHR <= 154.50 | | | | | | | |--- Sex_M <= 0.50 | | | | | | | | |--- weights: [0.00, 1.00] class: 1 | | | | | | | |--- Sex_M > 0.50 | | | | | | | | |--- weights: [1.00, 0.00] class: 0 | | | | | | |--- MaxHR > 154.50 | | | | | | | |--- weights: [0.00, 13.00] class: 1 | | | | |--- Cholesterol > 294.50 | | | | | |--- MaxHR <= 155.50 | | | | | | |--- weights: [0.00, 2.00] class: 1 | | | | | |--- MaxHR > 155.50 | | | | | | |--- weights: [3.00, 0.00] class: 0 | | |--- ChestPainType_NAP > 0.50 | | | |--- RestingBP <= 139.00 | | | | |--- RestingECG_ST <= 0.50 | | | | | |--- weights: [13.00, 0.00] class: 0 | | | | |--- RestingECG_ST > 0.50 | | | | | |--- MaxHR <= 161.00 | | | | | | |--- weights: [1.00, 0.00] class: 0 | | | | | |--- MaxHR > 161.00 | | | | | | |--- weights: [0.00, 1.00] class: 1 | | | |--- RestingBP > 139.00 | | | | |--- Cholesterol <= 178.50 | | | | | |--- weights: [1.00, 0.00] class: 0 | | | | |--- Cholesterol > 178.50 | | | | | |--- weights: [0.00, 3.00] class: 1 |--- ST_Slope_Up > 0.50 | |--- Cholesterol <= 58.81 | | |--- Oldpeak <= 0.10 | | | |--- FastingBS_1 <= 0.50 | | | | |--- ExerciseAngina_Y <= 0.50 | | | | | |--- weights: [5.00, 0.00] class: 0 | | | | |--- ExerciseAngina_Y > 0.50 | | | | | |--- RestingBP <= 143.50 | | | | | | |--- weights: [0.00, 2.00] class: 1 | | | | | |--- RestingBP > 143.50 | | | | | | |--- weights: [1.00, 0.00] class: 0 | | | |--- FastingBS_1 > 0.50 | | | | |--- weights: [0.00, 2.00] class: 1 | | |--- Oldpeak > 0.10 | | | |--- RestingBP <= 152.50 | | | | |--- weights: [0.00, 23.00] class: 1 | | | |--- RestingBP > 152.50 | | | | |--- MaxHR <= 138.00 | | | | | |--- weights: [0.00, 1.00] class: 1 | | | | |--- MaxHR > 138.00 | | | | | |--- weights: [1.00, 0.00] class: 0 | |--- Cholesterol > 58.81 | | |--- ExerciseAngina_Y <= 0.50 | | | |--- Oldpeak <= 2.40 | | | | |--- Age <= 56.50 | | | | | |--- Oldpeak <= 0.75 | | | | | | |--- RestingBP <= 112.50 | | | | | | | |--- RestingBP <= 111.00 | | | | | | | | |--- Cholesterol <= 197.50 | | | | | | | | | |--- RestingECG_Normal <= 0.50 | | | | | | | | | | |--- weights: [0.00, 2.00] class: 1 | | | | | | | | | |--- RestingECG_Normal > 0.50 | | | | | | | | | | |--- weights: [3.00, 0.00] class: 0 | | | | | | | | |--- Cholesterol > 197.50 | | | | | | | | | |--- weights: [20.00, 0.00] class: 0 | | | | | | | |--- RestingBP > 111.00 | | | | | | | | |--- weights: [0.00, 1.00] class: 1 | | | | | | |--- RestingBP > 112.50 | | | | | | | |--- RestingBP <= 151.00 | | | | | | | | |--- weights: [122.00, 0.00] class: 0 | | | | | | | |--- RestingBP > 151.00 | | | | | | | | |--- RestingBP <= 153.50 | | | | | | | | | |--- weights: [0.00, 1.00] class: 1 | | | | | | | | |--- RestingBP > 153.50 | | | | | | | | | |--- weights: [9.00, 0.00] class: 0 | | | | | |--- Oldpeak > 0.75 | | | | | | |--- Cholesterol <= 156.00 | | | | | | | |--- weights: [0.00, 1.00] class: 1 | | | | | | |--- Cholesterol > 156.00 | | | | | | | |--- Oldpeak <= 1.05 | | | | | | | | |--- RestingBP <= 129.50 | | | | | | | | | |--- ChestPainType_ATA <= 0.50 | | | | | | | | | | |--- FastingBS_1 <= 0.50 | | | | | | | | | | | |--- weights: [0.00, 2.00] class: 1 | | | | | | | | | | |--- FastingBS_1 > 0.50 | | | | | | | | | | | |--- weights: [1.00, 0.00] class: 0 | | | | | | | | | |--- ChestPainType_ATA > 0.50 | | | | | | | | | | |--- weights: [2.00, 0.00] class: 0 | | | | | | | | |--- RestingBP > 129.50 | | | | | | | | | |--- weights: [3.00, 0.00] class: 0 | | | | | | | |--- Oldpeak > 1.05 | | | | | | | | |--- weights: [11.00, 0.00] class: 0 | | | | |--- Age > 56.50 | | | | | |--- RestingBP <= 115.00 | | | | | | |--- Cholesterol <= 228.50 | | | | | | | |--- weights: [3.00, 0.00] class: 0 | | | | | | |--- Cholesterol > 228.50 | | | | | | | |--- weights: [0.00, 3.00] class: 1 | | | | | |--- RestingBP > 115.00 | | | | | | |--- Cholesterol <= 291.50 | | | | | | | |--- Age <= 57.50 | | | | | | | | |--- Oldpeak <= 0.25 | | | | | | | | | |--- weights: [2.00, 0.00] class: 0 | | | | | | | | |--- Oldpeak > 0.25 | | | | | | | | | |--- weights: [0.00, 1.00] class: 1 | | | | | | | |--- Age > 57.50 | | | | | | | | |--- MaxHR <= 160.50 | | | | | | | | | |--- weights: [23.00, 0.00] class: 0 | | | | | | | | |--- MaxHR > 160.50 | | | | | | | | | |--- MaxHR <= 162.00 | | | | | | | | | | |--- weights: [0.00, 1.00] class: 1 | | | | | | | | | |--- MaxHR > 162.00 | | | | | | | | | | |--- weights: [6.00, 0.00] class: 0 | | | | | | |--- Cholesterol > 291.50 | | | | | | | |--- Sex_M <= 0.50 | | | | | | | | |--- MaxHR <= 165.50 | | | | | | | | | |--- weights: [4.00, 0.00] class: 0 | | | | | | | | |--- MaxHR > 165.50 | | | | | | | | | |--- ChestPainType_NAP <= 0.50 | | | | | | | | | | |--- weights: [0.00, 1.00] class: 1 | | | | | | | | | |--- ChestPainType_NAP > 0.50 | | | | | | | | | | |--- weights: [1.00, 0.00] class: 0 | | | | | | | |--- Sex_M > 0.50 | | | | | | | | |--- weights: [0.00, 2.00] class: 1 | | | |--- Oldpeak > 2.40 | | | | |--- weights: [0.00, 2.00] class: 1 | | |--- ExerciseAngina_Y > 0.50 | | | |--- Cholesterol <= 212.50 | | | | |--- ChestPainType_ATA <= 0.50 | | | | | |--- weights: [0.00, 5.00] class: 1 | | | | |--- ChestPainType_ATA > 0.50 | | | | | |--- weights: [1.00, 0.00] class: 0 | | | |--- Cholesterol > 212.50 | | | | |--- Sex_M <= 0.50 | | | | | |--- weights: [6.00, 0.00] class: 0 | | | | |--- Sex_M > 0.50 | | | | | |--- MaxHR <= 136.00 | | | | | | |--- RestingBP <= 146.00 | | | | | | | |--- weights: [6.00, 0.00] class: 0 | | | | | | |--- RestingBP > 146.00 | | | | | | | |--- weights: [0.00, 1.00] class: 1 | | | | | |--- MaxHR > 136.00 | | | | | | |--- RestingBP <= 135.00 | | | | | | | |--- RestingBP <= 105.50 | | | | | | | | |--- weights: [1.00, 0.00] class: 0 | | | | | | | |--- RestingBP > 105.50 | | | | | | | | |--- weights: [0.00, 5.00] class: 1 | | | | | | |--- RestingBP > 135.00 | | | | | | | |--- weights: [2.00, 0.00] class: 0
# importance of features in the tree building ( The importance of a feature is computed as the
# (normalized) total reduction of the criterion brought by that feature. It is also known as the Gini importance )
print(
pd.DataFrame(
model.feature_importances_, columns=["Imp"], index=X_train.columns
).sort_values(by="Imp", ascending=False)
)
Imp ST_Slope_Up 0.374 Cholesterol 0.175 MaxHR 0.102 Age 0.078 RestingBP 0.076 Oldpeak 0.043 ExerciseAngina_Y 0.037 Sex_M 0.036 ChestPainType_NAP 0.032 FastingBS_1 0.013 ChestPainType_ATA 0.013 RestingECG_Normal 0.012 RestingECG_ST 0.006 ST_Slope_Flat 0.005 ChestPainType_TA 0.000
importances = model.feature_importances_
indices = np.argsort(importances)
plt.figure(figsize=(8, 8))
plt.title("Feature Importances")
plt.barh(range(len(indices)), importances[indices], color="violet", align="center")
plt.yticks(range(len(indices)), [feature_names[i] for i in indices])
plt.xlabel("Relative Importance")
plt.show()
# Choose the type of classifier.
estimator = DecisionTreeClassifier(random_state=1)
# Grid of parameters to choose from
parameters = {
"max_depth": np.arange(6, 15),
"min_samples_leaf": [1, 2, 5, 7, 10],
"max_leaf_nodes": [2, 3, 5, 10],
}
# Type of scoring used to compare parameter combinations
acc_scorer = make_scorer(recall_score)
# Run the grid search
grid_obj = GridSearchCV(estimator, parameters, scoring=acc_scorer, cv=5)
grid_obj = grid_obj.fit(X_train, y_train)
# Set the clf to the best combination of parameters
estimator = grid_obj.best_estimator_
# Fit the best algorithm to the data.
estimator.fit(X_train, y_train)
DecisionTreeClassifier(max_depth=6, max_leaf_nodes=3, random_state=1)
confusion_matrix_sklearn(estimator, X_train, y_train)
decision_tree_tune_perf_train = model_performance_classification_sklearn(
estimator, X_train, y_train
)
decision_tree_tune_perf_train
Accuracy | Recall | Precision | F1 | |
---|---|---|---|---|
0 | 0.840 | 0.918 | 0.807 | 0.859 |
plt.figure(figsize=(10, 10))
out = tree.plot_tree(
estimator,
feature_names=feature_names,
filled=True,
fontsize=9,
node_ids=False,
class_names=None,
)
# below code will add arrows to the decision tree split if they are missing
for o in out:
arrow = o.arrow_patch
if arrow is not None:
arrow.set_edgecolor("black")
arrow.set_linewidth(1)
plt.show()
# Text report showing the rules of a decision tree -
print(tree.export_text(estimator, feature_names=feature_names, show_weights=True))
|--- ST_Slope_Up <= 0.50 | |--- weights: [68.00, 285.00] class: 1 |--- ST_Slope_Up > 0.50 | |--- Cholesterol <= 58.81 | | |--- weights: [7.00, 28.00] class: 1 | |--- Cholesterol > 58.81 | | |--- weights: [226.00, 28.00] class: 0
Observations
# importance of features in the tree building ( The importance of a feature is computed as the
# (normalized) total reduction of the criterion brought by that feature. It is also known as the Gini importance )
print(
pd.DataFrame(
model.feature_importances_, columns=["Imp"], index=X_train.columns
).sort_values(by="Imp", ascending=False)
)
Imp ST_Slope_Up 0.374 Cholesterol 0.175 MaxHR 0.102 Age 0.078 RestingBP 0.076 Oldpeak 0.043 ExerciseAngina_Y 0.037 Sex_M 0.036 ChestPainType_NAP 0.032 FastingBS_1 0.013 ChestPainType_ATA 0.013 RestingECG_Normal 0.012 RestingECG_ST 0.006 ST_Slope_Flat 0.005 ChestPainType_TA 0.000
importances = estimator.feature_importances_
indices = np.argsort(importances)
plt.figure(figsize=(8, 8))
plt.title("Feature Importances")
plt.barh(range(len(indices)), importances[indices], color="violet", align="center")
plt.yticks(range(len(indices)), [feature_names[i] for i in indices])
plt.xlabel("Relative Importance")
plt.show()
# Cost Complexity Pruning
clf = DecisionTreeClassifier(random_state=1)
path = clf.cost_complexity_pruning_path(X_train, y_train)
ccp_alphas, impurities = path.ccp_alphas, path.impurities
pd.DataFrame(path)
ccp_alphas | impurities | |
---|---|---|
0 | 0.000 | 0.000 |
1 | 0.001 | 0.003 |
2 | 0.001 | 0.006 |
3 | 0.001 | 0.008 |
4 | 0.001 | 0.011 |
5 | 0.001 | 0.014 |
6 | 0.001 | 0.016 |
7 | 0.001 | 0.022 |
8 | 0.001 | 0.025 |
9 | 0.001 | 0.028 |
10 | 0.001 | 0.031 |
11 | 0.001 | 0.034 |
12 | 0.001 | 0.039 |
13 | 0.001 | 0.042 |
14 | 0.002 | 0.045 |
15 | 0.002 | 0.049 |
16 | 0.002 | 0.052 |
17 | 0.002 | 0.064 |
18 | 0.002 | 0.070 |
19 | 0.002 | 0.074 |
20 | 0.002 | 0.076 |
21 | 0.002 | 0.097 |
22 | 0.002 | 0.099 |
23 | 0.002 | 0.104 |
24 | 0.002 | 0.107 |
25 | 0.003 | 0.109 |
26 | 0.003 | 0.112 |
27 | 0.003 | 0.115 |
28 | 0.003 | 0.120 |
29 | 0.003 | 0.122 |
30 | 0.003 | 0.131 |
31 | 0.003 | 0.144 |
32 | 0.003 | 0.147 |
33 | 0.003 | 0.150 |
34 | 0.003 | 0.155 |
35 | 0.003 | 0.158 |
36 | 0.003 | 0.167 |
37 | 0.003 | 0.183 |
38 | 0.004 | 0.187 |
39 | 0.004 | 0.195 |
40 | 0.004 | 0.199 |
41 | 0.005 | 0.204 |
42 | 0.005 | 0.209 |
43 | 0.005 | 0.214 |
44 | 0.007 | 0.221 |
45 | 0.008 | 0.229 |
46 | 0.008 | 0.238 |
47 | 0.010 | 0.247 |
48 | 0.019 | 0.266 |
49 | 0.046 | 0.312 |
50 | 0.186 | 0.498 |
fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(ccp_alphas[:-1], impurities[:-1], marker="o", drawstyle="steps-post")
ax.set_xlabel("effective alpha")
ax.set_ylabel("total impurity of leaves")
ax.set_title("Total Impurity vs effective alpha for training set")
plt.show()
Next, we train a decision tree using effective alphas. The last value in ccp_alphas is the alpha value that prunes the whole tree, leaving the tree, clfs[-1], with one node.
clfs = []
for ccp_alpha in ccp_alphas:
clf = DecisionTreeClassifier(random_state=1, ccp_alpha=ccp_alpha)
clf.fit(X_train, y_train)
clfs.append(clf)
print(
"Number of nodes in the last tree is: {} with ccp_alpha: {}".format(
clfs[-1].tree_.node_count, ccp_alphas[-1]
)
)
Number of nodes in the last tree is: 1 with ccp_alpha: 0.18637790733584375
clfs = clfs[:-1]
ccp_alphas = ccp_alphas[:-1]
node_counts = [clf.tree_.node_count for clf in clfs]
depth = [clf.tree_.max_depth for clf in clfs]
fig, ax = plt.subplots(2, 1, figsize=(10, 7))
ax[0].plot(ccp_alphas, node_counts, marker="o", drawstyle="steps-post")
ax[0].set_xlabel("alpha")
ax[0].set_ylabel("number of nodes")
ax[0].set_title("Number of nodes vs alpha")
ax[1].plot(ccp_alphas, depth, marker="o", drawstyle="steps-post")
ax[1].set_xlabel("alpha")
ax[1].set_ylabel("depth of tree")
ax[1].set_title("Depth vs alpha")
fig.tight_layout()
recall_train = []
for clf in clfs:
pred_train = clf.predict(X_train)
values_train = recall_score(y_train, pred_train)
recall_train.append(values_train)
recall_test = []
for clf in clfs:
pred_test = clf.predict(X_test)
values_test = recall_score(y_test, pred_test)
recall_test.append(values_test)
fig, ax = plt.subplots(figsize=(15, 5))
ax.set_xlabel("alpha")
ax.set_ylabel("Recall")
ax.set_title("Recall vs alpha for training and testing sets")
ax.plot(ccp_alphas, recall_train, marker="o", label="train", drawstyle="steps-post")
ax.plot(ccp_alphas, recall_test, marker="o", label="test", drawstyle="steps-post")
ax.legend()
plt.show()
index_best_model = np.argmax(recall_test)
best_model = clfs[index_best_model]
print(best_model)
DecisionTreeClassifier(ccp_alpha=0.007897469530524122, random_state=1)
confusion_matrix_sklearn(model, X_test, y_test)
decision_tree_perf_test = model_performance_classification_sklearn(
model, X_test, y_test
)
decision_tree_perf_test
Accuracy | Recall | Precision | F1 | |
---|---|---|---|---|
0 | 0.739 | 0.749 | 0.806 | 0.776 |
confusion_matrix_sklearn(estimator, X_test, y_test)
decision_tree_tune_perf_test = model_performance_classification_sklearn(
estimator, X_test, y_test
)
decision_tree_tune_perf_test
Accuracy | Recall | Precision | F1 | |
---|---|---|---|---|
0 | 0.826 | 0.898 | 0.829 | 0.862 |
# training performance comparison
models_train_comp_df = pd.concat(
[decision_tree_perf_train.T, decision_tree_tune_perf_train.T], axis=1,
)
models_train_comp_df.columns = ["Decision Tree sklearn", "Decision Tree (Pre-Pruning)"]
print("Training performance comparison:")
models_train_comp_df
Training performance comparison:
Decision Tree sklearn | Decision Tree (Pre-Pruning) | |
---|---|---|
Accuracy | 1.000 | 0.840 |
Recall | 1.000 | 0.918 |
Precision | 1.000 | 0.807 |
F1 | 1.000 | 0.859 |
# testing performance comparison
models_test_comp_df = pd.concat(
[decision_tree_perf_test.T, decision_tree_tune_perf_test.T], axis=1,
)
models_test_comp_df.columns = ["Decision Tree sklearn", "Decision Tree (Pre-Pruning)"]
print("Test set performance comparison:")
models_test_comp_df
Test set performance comparison:
Decision Tree sklearn | Decision Tree (Pre-Pruning) | |
---|---|---|
Accuracy | 0.739 | 0.826 |
Recall | 0.749 | 0.898 |
Precision | 0.806 | 0.829 |
F1 | 0.776 | 0.862 |