Project Overview
Purpose: To build an interactive web application that predicts heart disease risk using machine learning models trained on the UCI Heart Disease dataset, with a user-friendly interface for inputting patient data and viewing results.
Dataset: The UCI Heart Disease dataset (heart.csv) contains 303 records with 13 features (e.g., age, sex, chol, thalach) and a binary target (target: 1 = heart disease, 0 = no heart disease).
Models: The app supports three machine learning models:
- Random Forest (F1 score: ~0.885, per the notebook).
- Logistic Regression (F1 score: ~0.869).
- Naive Bayes (F1 score: ~0.873).
Features:
- Model selection and performance metrics (F1 score, accuracy).
- Input form for patient data with validation.
- Visualizations (correlation heatmap, feature importance for Random Forest).
- Prediction with confidence score and input summary.
- Medical disclaimer to clarify the app is not a diagnostic tool.
Code Explanation
The code is organized into several key sections, each handling a specific part of the app’s functionality. I’ll explain each section in detail, followed by the overall workflow.
1. Imports and Setup
pythonCopy
import streamlit as st import pandas as pd import pickle import numpy as np import matplotlib.pyplot as plt import seaborn as sns from sklearn.ensemble import RandomForestClassifier from sklearn.linear_model import LogisticRegression from sklearn.naive_bayes import GaussianNB from sklearn.model_selection import train_test_split from sklearn.preprocessing import StandardScaler from sklearn.metrics import f1_score, accuracy_score, confusion_matrix import os # Set page configuration st.set_page_config(page_title="Heart Disease Prediction System", layout="wide")
- Libraries:
- streamlit: For building the web app interface.
- pandas, numpy: For data manipulation.
- pickle: For saving/loading models and scalers.
- matplotlib, seaborn: For visualizations (heatmap, feature importance).
- sklearn: For machine learning models, preprocessing, and evaluation metrics.
- os: For file handling.
- Page Configuration:
- Sets the app’s title to “Heart Disease Prediction System” and uses a wide layout for better UI.
2. Title and Description
pythonCopy
st.title("Heart Disease Prediction System") st.markdown(""" This application predicts the likelihood of heart disease based on patient data using machine learning models. The models were trained on the UCI Heart Disease dataset. Select a model and enter patient details below. """)
- Displays the app’s title and a brief description, informing users about the app’s purpose and the dataset used.
3. Model and Data Loading
pythonCopy
@st.cache_resource def load_model_and_data(): try: data = pd.read_csv("heart.csv") except FileNotFoundError: st.error("heart.csv file not found. Please ensure the dataset is in the same directory as this app.") return None, None, None, None, None if data.isnull().sum().any(): data = data.dropna() X = data.drop('target', axis=1) y = data['target'] X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) scaler = StandardScaler() X_train_scaled = scaler.fit_transform(X_train) X_test_scaled = scaler.transform(X_test) models = { 'Random Forest': RandomForestClassifier(n_estimators=100, random_state=42), 'Logistic Regression': LogisticRegression(random_state=42), 'Naive Bayes': GaussianNB() } model_performance = {} feature_importances = None for name, model in models.items(): model.fit(X_train_scaled, y_train) y_pred = model.predict(X_test_scaled) model_performance[name] = { 'F1 Score': f1_score(y_test, y_pred), 'Accuracy': accuracy_score(y_test, y_pred) } if name == 'Random Forest': feature_importances = pd.Series(model.feature_importances_, index=X.columns) with open(f'{name.lower().replace(" ", "_")}_model.pkl', 'wb') as f: pickle.dump(model, f) with open('scaler.pkl', 'wb') as f: pickle.dump(scaler, f) return models, scaler, X.columns, data, model_performance, feature_importances
- Function: load_model_and_data loads the dataset, preprocesses it, trains models, and saves them for reuse.
- Steps:
- Load Dataset: Attempts to read heart.csv. If missing, displays an error and returns None for all outputs.
- Handle Missing Values: Drops rows with missing values (simple handling; the UCI dataset typically has no missing values).
- Split Data: Separates features (X) and target (y), then splits into training (80%) and test (20%) sets with a fixed random seed for reproducibility.
- Scale Features: Uses StandardScaler to standardize features (mean=0, std=1), which is crucial for models like Logistic Regression and Naive Bayes.
- Train Models: Trains three models (Random Forest, Logistic Regression, Naive Bayes) on scaled training data.
- Evaluate Models: Computes F1 score and accuracy on the test set for each model, storing results in model_performance.
- Feature Importance: For Random Forest, extracts feature importances to show which features (e.g., cp, thalach) are most influential.
- Save Models and Scaler: Saves each model and the scaler as .pkl files for caching.
- Caching: @st.cache_resource ensures this function runs only once unless the dataset changes, improving performance.
- Returns: Trained models, scaler, feature columns, dataset, performance metrics, and Random Forest feature importances.
4. Load Models and Data
pythonCopy
models, scaler, feature_columns, data, model_performance, feature_importances = load_model_and_data() if models is None: st.stop()
- Calls load_model_and_data to get the necessary objects.
- Stops the app if models is None (e.g., if heart.csv is missing).
5. Feature Descriptions
pythonCopy
feature_descriptions = { 'age': 'Age in years', 'sex': 'Sex (1 = Male, 0 = Female)', 'cp': 'Chest pain type (0 = Typical angina, 1 = Atypical angina, 2 = Non-anginal pain, 3 = Asymptomatic)', 'trestbps': 'Resting blood pressure (mm Hg)', 'chol': 'Serum cholesterol (mg/dl)', 'fbs': 'Fasting blood sugar > 120 mg/dl (1 = True, 0 = False)', 'restecg': 'Resting ECG results (0 = Normal, 1 = ST-T wave abnormality, 2 = Left ventricular hypertrophy)', 'thalach': 'Maximum heart rate achieved', 'exang': 'Exercise-induced angina (1 = Yes, 0 = No)', 'oldpeak': 'ST depression induced by exercise relative to rest', 'slope': 'Slope of the peak exercise ST segment (0 = Upsloping, 1 = Flat, 2 = Downsloping)', 'ca': 'Number of major vessels (0-3) colored by fluoroscopy', 'thal': 'Thalassemia (0 = Unknown, 1 = Fixed defect, 2 = Normal, 3 = Reversible defect)' }
- Defines a dictionary mapping each feature to a human-readable description, used in the input form to guide users.
6. Sidebar for Model Selection and Performance
pythonCopy
st.sidebar.header("Model Selection & Performance") model_choice = st.sidebar.selectbox("Choose Model", list(models.keys()), index=0) st.sidebar.subheader("Model Performance") for name, metrics in model_performance.items(): st.sidebar.write(f"**{name}**:") st.sidebar.write(f"- F1 Score: {metrics['F1 Score']:.3f}") st.sidebar.write(f"- Accuracy: {metrics['Accuracy']:.3f}")
- Sidebar: Creates a sidebar for model selection and performance display.
- Model Selection: A dropdown lets users choose between Random Forest, Logistic Regression, or Naive Bayes (defaults to Random Forest).
- Performance Metrics: Displays F1 score and accuracy for each model, helping users understand model quality (e.g., Random Forest F1 ~0.885).
7. Correlation Heatmap
pythonCopy
st.subheader("Feature Correlation Heatmap") fig, ax = plt.subplots(figsize=(10, 8)) sns.heatmap(data.corr(), annot=True, cmap='coolwarm', ax=ax) st.pyplot(fig)
- Displays a heatmap of feature correlations using seaborn.
- Helps users see relationships between features (e.g., cp and target may have a strong correlation).
- Uses matplotlib for plotting and st.pyplot to render in Streamlit.
8. Feature Importance (Random Forest)
pythonCopy
if model_choice == 'Random Forest': st.subheader("Feature Importance (Random Forest)") fig, ax = plt.subplots(figsize=(10, 6)) feature_importances.sort_values().plot(kind='barh', ax=ax) ax.set_title("Feature Importance") st.pyplot(fig)
- If the selected model is Random Forest, shows a horizontal bar plot of feature importances.
- Indicates which features (e.g., cp, thalach, ca) most influence predictions, enhancing interpretability.
9. Input Form
pythonCopy
st.subheader("Enter Patient Details") with st.form(key='patient_form'): inputs = {} cols = st.columns(3) for i, feature in enumerate(feature_columns): with cols[i % 3]: if feature in ['sex', 'fbs', 'exang']: inputs[feature] = st.selectbox(f"{feature} ({feature_descriptions[feature]})", [0, 1], key=feature) elif feature in ['cp', 'restecg', 'slope', 'thal']: options = sorted(data[feature].unique()) inputs[feature] = st.selectbox(f"{feature} ({feature_descriptions[feature]})", options, key=feature) elif feature == 'ca': inputs[feature] = st.selectbox(f"{feature} ({feature_descriptions[feature]})", [0, 1, 2, 3, 4], key=feature) else: min_val, max_val = data[feature].min(), data[feature].max() inputs[feature] = st.number_input( f"{feature} ({feature_descriptions[feature]})", min_value=float(min_val), max_value=float(max_val), value=float(data[feature].mean()), step=1.0 if feature == 'age' else 0.1, key=feature ) submit_button = st.form_submit_button(label="Predict Heart Disease")
- Form: Creates a form for users to input patient data, organized in three columns for a clean layout.
- Input Types:
- Binary Features (sex, fbs, exang): Dropdowns with options [0, 1].
- Categorical Features (cp, restecg, slope, thal): Dropdowns with unique values from the dataset.
- Special Case (ca): Dropdown with options [0, 1, 2, 3, 4].
- Numerical Features (age, trestbps, chol, thalach, oldpeak): Number inputs with min/max constraints from the dataset and default values set to the feature mean.
- Submit Button: Triggers prediction when clicked.
10. Prediction Processing
pythonCopy
if submit_button: input_data = np.array([inputs[feature] for feature in feature_columns]).reshape(1, -1) if np.any(np.isnan(input_data)): st.error("Please ensure all inputs are valid numbers.") else: input_data_scaled = scaler.transform(input_data) selected_model = models[model_choice] prediction = selected_model.predict(input_data_scaled)[0] probability = selected_model.predict_proba(input_data_scaled)[0][1] * 100 st.subheader("Prediction Result") if prediction == 1: st.error(f"The model predicts a **high likelihood of heart disease** (Confidence: {probability:.2f}%).") else: st.success(f"The model predicts a **low likelihood of heart disease** (Confidence: {100 - probability:.2f}%).") st.subheader("Input Summary") input_df = pd.DataFrame([inputs], columns=feature_columns) st.table(input_df) st.markdown(""" **Disclaimer**: This prediction is based on a machine learning model and is not a medical diagnosis. Please consult a healthcare professional for accurate diagnosis and treatment. """)
- Input Validation: Checks for invalid (NaN) inputs and displays an error if found.
- Scaling: Transforms user inputs using the trained StandardScaler.
- Prediction: Uses the selected model to predict (0 or 1) and compute the probability of heart disease (class 1).
- Result Display:
- If prediction == 1: Shows a red error message indicating high likelihood with confidence (e.g., 75%).
- If prediction == 0: Shows a green success message indicating low likelihood with confidence (e.g., 25%).
- Input Summary: Displays a table of user inputs for reference.
- Disclaimer: Clarifies that the prediction is not a medical diagnosis.
11. Footer
pythonCopy
st.markdown("---") st.write("Built with Streamlit | Heart Disease Prediction System © 2025")
- Adds a footer with a separator and app credits.
Workflow of the Project
The workflow describes how the app processes data, trains models, and delivers predictions. Here’s the step-by-step flow:
- Initialization:
- The app starts, loading required libraries and setting up the Streamlit interface (title, description, wide layout).
- Data and Model Loading:
- The load_model_and_data function runs (cached to avoid repetition):
- Loads heart.csv or shows an error if missing.
- Drops any missing values (if present).
- Splits data into features (X) and target (y), then into train (80%) and test (20%) sets.
- Scales features using StandardScaler.
- Trains Random Forest, Logistic Regression, and Naive Bayes models.
- Computes F1 score and accuracy for each model on the test set.
- Saves feature importances for Random Forest.
- Saves models and scaler as .pkl files.
- Returns models, scaler, feature columns, dataset, performance metrics, and feature importances.
- The load_model_and_data function runs (cached to avoid repetition):
- User Interface Setup:
- Sidebar: Displays a model selection dropdown and performance metrics (F1 score, accuracy) for each model.
- Main Page:
- Shows a correlation heatmap of the dataset.
- If Random Forest is selected, shows a feature importance plot.
- Presents a form for entering patient details (13 features) with dropdowns and number inputs.
- User Interaction:
- User selects a model from the sidebar (e.g., Random Forest).
- User fills out the form with patient data (e.g., age=55, sex=1, cp=2).
- User clicks “Predict Heart Disease” to submit.
- Prediction Processing:
- The app validates inputs to ensure no invalid values (NaN).
- Scales the input data using the saved StandardScaler.
- Uses the selected model to predict (0 = no heart disease, 1 = heart disease) and compute the probability of heart disease.
- Displays the result (high/low likelihood with confidence), input summary, and a medical disclaimer.
- Visualization and Interpretability:
- The correlation heatmap shows feature relationships (e.g., thalach vs. target).
- The feature importance plot (for Random Forest) highlights key predictors (e.g., cp, ca).
- Performance metrics in the sidebar provide transparency about model quality.
- Output:
- User sees a clear prediction (e.g., “High likelihood of heart disease, Confidence: 75%”).
- Input summary table confirms entered values.
- Disclaimer emphasizes consulting a doctor.
Key Features and Their Purpose
- Model Selection: Allows users to choose between Random Forest (best F1 score), Logistic Regression, or Naive Bayes, reflecting the notebook’s analysis.
- Performance Metrics: Shows F1 score and accuracy to build trust in model quality (e.g., Random Forest F1 ~0.885).
- Input Form: Simplifies data entry with constrained inputs (e.g., ca limited to 0-4) and descriptive labels.
- Visualizations:
- Correlation Heatmap: Helps users understand feature relationships.
- Feature Importance: Shows which features drive Random Forest predictions.
- Disclaimer: Ensures users don’t treat predictions as medical diagnoses.
- Caching: Improves performance by avoiding repeated model training.
How to Run the App
- Prepare the Dataset:
- Download heart.csv from the UCI Machine Learning Repository or Kaggle.
- Place it in the same directory as app.py.
- Install Dependencies: bashCopy
pip install streamlit pandas scikit-learn numpy matplotlib seaborn
- Run the App: bashCopy
streamlit run app.py
- Open the provided URL (e.g., http://localhost:8501) in a browser.
- Interact:
- Select a model (e.g., Random Forest).
- View performance metrics in the sidebar.
- Enter patient data in the form.
- Submit to see the prediction, input summary, and visualizations.
Potential Improvements
- Preprocessing: If the notebook includes specific preprocessing (e.g., outlier removal, feature engineering), incorporate those steps.
- Additional Models: Add KNN and Decision Tree (excluded due to lower F1 scores) if desired.
- Visualizations: Include a confusion matrix or ROC curve for deeper model insights.
- Input Validation: Add stricter checks (e.g., realistic ranges for thalach based on age).
- Pretrained Models: Load pretrained models instead of training on each run for faster startup in production.
- Data Exploration: Add a section to view dataset statistics or sample records.
Alignment with heart_Disease.ipynb
The app aligns with the notebook’s findings:
- Random Forest as Best Model: Used as the default model (F1 = 0.885).
- Multiple Models: Supports Logistic Regression and Naive Bayes, reflecting the notebook’s comparisons.
- Dataset: Uses heart.csv with the same features and target.
- Evaluation: Displays F1 scores, matching the notebook’s focus on F1 as the key metric.
If you have the full notebook or specific preprocessing/training details, sharing them would allow further refinement (e.g., exact train-test split, hyperparameter tuning).
Summary
The app is a robust, user-friendly tool for predicting heart disease using machine learning. It loads and preprocesses the UCI Heart Disease dataset, trains multiple models, and provides an interactive interface for predictions. The workflow involves data loading, model training, user input, prediction, and visualization, with a focus on interpretability and transparency. The code is modular, efficient (with caching), and aligned with the notebook’s findings, making it a practical implementation of a machine learning-based health prediction system.