This notebook demonstrates Generalized Linear Models with Poisson and Logistic regression examples.
Notebook Contents
This notebook covers Generalized Linear Models with practical examples:
Poisson Regression: For modeling count data
Logistic Regression: For binary classification
Link Functions: Understanding different GLM link functions
Model Evaluation: Comparing model performance
Use the buttons above to download the notebook or open it in your preferred environment.
📓 Notebook Preview
In [20]:
importnumpyasnpimportpandasaspdimportmatplotlib.pyplotaspltimportstatsmodels.apiassmfromstatsmodels.genmod.familiesimportPoisson,Binomialfromsklearn.model_selectionimporttrain_test_splitfromsklearn.metricsimportconfusion_matrix,classification_reportimportseabornassnsfromsklearn.metricsimportmean_absolute_percentage_error# Set random seed for reproducibilitynp.random.seed(42)# ============================================================================# EXAMPLE 1: POISSON REGRESSION# ============================================================================# Use case: Predicting count data (e.g., number of customer visits per day)print("="*70)print("POISSON REGRESSION EXAMPLE")print("="*70)# Generate synthetic datan_samples=500X_pois=np.random.randn(n_samples,2)# True coefficientsbeta_pois=np.array([0.5,-0.3])# Linear predictoreta=2+X_pois@beta_pois# Expected count (lambda parameter)lambda_true=np.exp(eta)# Generate Poisson-distributed countsy_pois=np.random.poisson(lambda_true)# Create DataFramedf_pois=pd.DataFrame({'advertising_spend':X_pois[:,0],'competitor_activity':X_pois[:,1],'customer_visits':y_pois})print("\nFirst few rows of Poisson data:")print(df_pois.head())print(f"\nBasic statistics of customer visits:")print(df_pois['customer_vbisits'].describe())
======================================================================
POISSON REGRESSION EXAMPLE
======================================================================
First few rows of Poisson data:
advertising_spend competitor_activity customer_visits
0 0.496714 -0.138264 6
1 0.647689 1.523030 11
2 -0.234153 -0.234137 5
3 1.579213 0.767435 9
4 -0.469474 0.542560 5
Basic statistics of customer visits:
count 500.000000
mean 8.490000
std 5.954899
min 0.000000
25% 5.000000
50% 7.000000
75% 11.000000
max 42.000000
Name: customer_visits, dtype: float64
In [4]:
df_pois.customer_visits.hist()
Out[4]:
<Axes: >
In [19]:
# Fit Poisson GLMX_pois_model=sm.add_constant(X_pois)poisson_model=sm.GLM(y_pois,X_pois_model,family=Gaussian())poisson_results=poisson_model.fit()print("\n"+"="*70)print("POISSON MODEL SUMMARY")print("="*70)print(poisson_results.summary())# Predictionsy_pred_pois=poisson_results.predict(X_pois_model)# Visualizationfig,axes=plt.subplots(1,2,figsize=(14,5))# Plot 1: Actual vs Predictedaxes[0].scatter(y_pois,y_pred_pois,alpha=0.5)axes[0].plot([y_pois.min(),y_pois.max()],[y_pois.min(),y_pois.max()],'r--',lw=2)axes[0].set_xlabel('Actual Count')axes[0].set_ylabel('Predicted Count')axes[0].set_title('Poisson GLM: Actual vs Predicted')axes[0].grid(True,alpha=0.3)# Plot 2: Residualsresiduals=y_pois-y_pred_poisaxes[1].scatter(y_pred_pois,residuals,alpha=0.5)axes[1].axhline(y=0,color='r',linestyle='--',lw=2)axes[1].set_xlabel('Predicted Count')axes[1].set_ylabel('Residuals')axes[1].set_title('Poisson GLM: Residual Plot')axes[1].grid(True,alpha=0.3)plt.tight_layout()plt.savefig('poisson_glm.png',dpi=100,bbox_inches='tight')print("\nPoisson GLM plots saved as 'poisson_glm.png'")
# Split dataX_train,X_test,y_train,y_test=train_test_split(X_log,y_log,test_size=0.3,random_state=42)# Fit Logistic GLMX_train_const=sm.add_constant(X_train)X_test_const=sm.add_constant(X_test)logistic_model=sm.GLM(y_train,X_train_const,family=Binomial())logistic_results=logistic_model.fit()print("\n"+"="*70)print("LOGISTIC MODEL SUMMARY")print("="*70)print(logistic_results.summary())# Predictionsy_pred_prob=logistic_results.predict(X_test_const)y_pred_class=(y_pred_prob>0.5).astype(int)# Model evaluationprint("\n"+"="*70)print("LOGISTIC MODEL EVALUATION")print("="*70)print("\nConfusion Matrix:")cm=confusion_matrix(y_test,y_pred_class)print(cm)print("\nClassification Report:")print(classification_report(y_test,y_pred_class))# Calculate odds ratiosodds_ratios=np.exp(logistic_results.params)print("\nOdds Ratios:")forname,or_valinzip(['Intercept','Feature 1','Feature 2','Feature 3'],odds_ratios):print(f"{name}: {or_val:.3f}")# Visualizationfig,axes=plt.subplots(1,2,figsize=(14,5))# Plot 1: ROC-like visualizationsorted_idx=np.argsort(y_pred_prob)axes[0].plot(y_pred_prob[sorted_idx],marker='o',markersize=2,alpha=0.6)axes[0].axhline(y=0.5,color='r',linestyle='--',lw=2,label='Threshold=0.5')axes[0].set_xlabel('Sample (sorted by predicted probability)')axes[0].set_ylabel('Predicted Probability')axes[0].set_title('Logistic GLM: Predicted Probabilities')axes[0].legend()axes[0].grid(True,alpha=0.3)# Plot 2: Probability distribution by actual classaxes[1].hist(y_pred_prob[y_test==0],bins=30,alpha=0.6,label='Actual: No Churn',color='blue')axes[1].hist(y_pred_prob[y_test==1],bins=30,alpha=0.6,label='Actual: Churn',color='red')axes[1].axvline(x=0.5,color='black',linestyle='--',lw=2,label='Threshold=0.5')axes[1].set_xlabel('Predicted Probability')axes[1].set_ylabel('Frequency')axes[1].set_title('Logistic GLM: Probability Distribution by Class')axes[1].legend()axes[1].grid(True,alpha=0.3)plt.tight_layout()plt.savefig('logistic_glm.png',dpi=100,bbox_inches='tight')print("\nLogistic GLM plots saved as 'logistic_glm.png'")print("\n"+"="*70)print("ANALYSIS COMPLETE")print("="*70)