Skill

Implement Explainable AI Frameworks

Framework for implementing SHAP, LIME, permutation importance, partial dependence, and counterfactual explanations.

Works with githubtorch

91
Spark score
out of 100
Updated 4 months ago
Version 1.0.0
Models

Add to Favorites

Why it matters

Leverage advanced Explainable AI (XAI) techniques to build transparent and interpretable machine learning systems. Understand model behavior, debug complex decisions, and generate human-readable explanations for AI outputs.

Outcomes

What it gets done

01

Implement SHAP and LIME for local and global model explanations.

02

Utilize deep learning XAI methods like Integrated Gradients and GradientSHAP.

03

Generate counterfactual explanations to understand decision boundaries.

04

Create comprehensive XAI pipelines for various machine learning models.

Install

Add it to your toolbox

Run in your project directory:

curl -fsSL https://spark.entire.vc/get/vb-explainable-ai-framework | bash

Capabilities

What this skill does

Extract

Pulls structured data fields from unstructured text.

Summarize

Condenses long documents or threads into key takeaways.

Generate code

Writes source code or scripts from a description.

Debug

Traces errors to their root cause and suggests fixes.

Overview

Explainable AI Framework

What it does

This framework provides implementations for various Explainable AI (XAI) techniques. It includes methods for generating global explanations such as permutation importance and partial dependence plot data. It also supports local explanations through SHAP and LIME, and can generate counterfactual explanations. For deep learning models, it offers implementations for Integrated Gradients, GradientSHAP, and occlusion analysis.

How it connects

Use this framework when you need to implement specific XAI methods like SHAP, LIME, or deep learning attribution techniques.

Source README

You are an expert in Explainable AI (XAI) frameworks, specializing in designing interpretable machine learning systems, implementing post-hoc explanation methods, and creating comprehensive transparency solutions for AI models across various domains.

Core XAI Principles

Interpretability vs Explainability

  • Interpretability: Degree to which humans can understand model decisions without additional explanation
  • Explainability: Ability to provide human-understandable reasons for model outputs
  • Global explanations: Overall model behavior patterns
  • Local explanations: Individual prediction explanations
  • Counterfactual explanations: "What would need to change for a different outcome?"

XAI Taxonomy

  • Model-agnostic: Works with any ML model (LIME, SHAP, permutation importance)
  • Model-specific: Designed for particular architectures (attention maps, gradient-based)
  • Ante-hoc: Built-in interpretability (linear models, decision trees)
  • Post-hoc: Applied after training (feature importance, example-based)

Implementation Framework

SHAP (SHapley Additive exPlanations)

import shap
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split

class SHAPExplainer:
    def __init__(self, model, background_data=None):
        self.model = model
        self.background_data = background_data
        self.explainer = None
    
    def initialize_explainer(self, explainer_type='tree'):
        """Initialize appropriate SHAP explainer"""
        if explainer_type == 'tree':
            self.explainer = shap.TreeExplainer(self.model)
        elif explainer_type == 'kernel':
            self.explainer = shap.KernelExplainer(
                self.model.predict_proba, self.background_data
            )
        elif explainer_type == 'linear':
            self.explainer = shap.LinearExplainer(
                self.model, self.background_data
            )
    
    def explain_instance(self, instance):
        """Generate SHAP values for single instance"""
        shap_values = self.explainer.shap_values(instance)
        return shap_values
    
    def generate_explanation_report(self, X_test, feature_names):
        """Comprehensive explanation report"""
        shap_values = self.explainer.shap_values(X_test)
        
        # Global feature importance
        global_importance = pd.DataFrame({
            'feature': feature_names,
            'importance': abs(shap_values).mean(0)
        }).sort_values('importance', ascending=False)
        
        return {
            'shap_values': shap_values,
            'global_importance': global_importance,
            'base_value': self.explainer.expected_value
        }

LIME Implementation

import lime
import lime.lime_tabular
import numpy as np

class LIMEExplainer:
    def __init__(self, training_data, feature_names, class_names, mode='classification'):
        self.explainer = lime.lime_tabular.LimeTabularExplainer(
            training_data,
            feature_names=feature_names,
            class_names=class_names,
            mode=mode,
            discretize_continuous=True
        )
    
    def explain_instance(self, instance, predict_fn, num_features=10):
        """Generate LIME explanation for instance"""
        explanation = self.explainer.explain_instance(
            instance, 
            predict_fn,
            num_features=num_features,
            num_samples=5000
        )
        return explanation
    
    def extract_feature_importance(self, explanation):
        """Extract and rank feature importance"""
        importance_scores = explanation.as_list()
        return sorted(importance_scores, key=lambda x: abs(x[1]), reverse=True)

Comprehensive XAI Pipeline

Multi-Method Explanation Framework

class XAIFramework:
    def __init__(self, model, X_train, feature_names, class_names=None):
        self.model = model
        self.X_train = X_train
        self.feature_names = feature_names
        self.class_names = class_names
        self.explanations = {}
    
    def generate_global_explanations(self):
        """Generate model-wide explanations"""
        # Permutation importance
        from sklearn.inspection import permutation_importance
        perm_importance = permutation_importance(
            self.model, self.X_train, y_train, n_repeats=10
        )
        
        # Partial dependence plots data
        from sklearn.inspection import partial_dependence
        pdp_results = {}
        for i, feature in enumerate(self.feature_names):
            pd_result = partial_dependence(
                self.model, self.X_train, [i]
            )
            pdp_results[feature] = pd_result
        
        self.explanations['global'] = {
            'permutation_importance': perm_importance,
            'partial_dependence': pdp_results
        }
    
    def generate_local_explanations(self, instance):
        """Generate instance-specific explanations"""
        # SHAP
        shap_explainer = SHAPExplainer(self.model)
        shap_explainer.initialize_explainer()
        shap_values = shap_explainer.explain_instance(instance)
        
        # LIME
        lime_explainer = LIMEExplainer(
            self.X_train, self.feature_names, self.class_names
        )
        lime_explanation = lime_explainer.explain_instance(
            instance, self.model.predict_proba
        )
        
        # Counterfactual explanation
        counterfactual = self.generate_counterfactual(instance)
        
        return {
            'shap': shap_values,
            'lime': lime_explanation,
            'counterfactual': counterfactual
        }
    
    def generate_counterfactual(self, instance, target_class=None):
        """Generate counterfactual explanation"""
        # Simplified counterfactual generation
        current_pred = self.model.predict([instance])[0]
        if target_class is None:
            target_class = 1 - current_pred  # Binary flip
        
        # Feature perturbation approach
        best_counterfactual = None
        min_changes = float('inf')
        
        for feature_idx in range(len(instance)):
            modified_instance = instance.copy()
            # Try different perturbations
            for delta in [-0.1, -0.5, 0.1, 0.5]:
                modified_instance[feature_idx] = instance[feature_idx] + delta
                if self.model.predict([modified_instance])[0] == target_class:
                    changes = np.sum(np.abs(modified_instance - instance))
                    if changes < min_changes:
                        min_changes = changes
                        best_counterfactual = modified_instance.copy()
        
        return best_counterfactual

Model-Specific Techniques

Deep Learning Explanations

import torch
import torch.nn as nn
from captum.attr import IntegratedGradients, GradientShap, Occlusion

class DeepLearningXAI:
    def __init__(self, model):
        self.model = model
        self.model.eval()
    
    def integrated_gradients_explanation(self, input_tensor, target_class):
        """Generate Integrated Gradients attribution"""
        ig = IntegratedGradients(self.model)
        attribution = ig.attribute(input_tensor, target=target_class)
        return attribution
    
    def gradient_shap_explanation(self, input_tensor, baseline_tensor, target_class):
        """Generate GradientSHAP attribution"""
        gradient_shap = GradientShap(self.model)
        attribution = gradient_shap.attribute(
            input_tensor, baseline_tensor, target=target_class
        )
        return attribution
    
    def occlusion_analysis(self, input_tensor, target_class, window_size=(3, 3)):
        """Generate occlusion-based explanations"""
        occlusion = Occlusion(self.model)
        attribution = occlusion.attribute(
            input_tensor,
            sliding_window_shapes=window_size,
            target=target_class
        )
        return attribution

Evaluation and Validation

XAI Metrics Framework

class XAIEvaluator:
    def __init__(self, model, explainer):
        self.model = model
        self.explainer = explainer
    
    def faithfulness_score(self, X_test, explanations, k=5):
        """Evaluate explanation faithfulness"""
        faithfulness_scores = []
        
        for i, (instance, explanation) in enumerate(zip(X_test, explanations)):
            # Get top-k important features
            top_features = np.argsort(np.abs(explanation))[-k:]
            
            # Remove top features and measure prediction change
            modified_instance = instance.copy()
            modified_instance[top_features] = 0
            
            original_pred = self.model.predict_proba([instance])[0]
            modified_pred = self.model.predict_proba([modified_instance])[0]
            
            faithfulness = np.linalg.norm(original_pred - modified_pred)
            faithfulness_scores.append(faithfulness)
        
        return np.mean(faithfulness_scores)
    
    def stability_score(self, instance, num_perturbations=10, noise_level=0.1):
        """Evaluate explanation stability"""
        explanations = []
        
        for _ in range(num_perturbations):
            # Add small noise to instance
            noisy_instance = instance + np.random.normal(0, noise_level, instance.shape)
            explanation = self.explainer.explain_instance(noisy_instance)
            explanations.append(explanation)
        
        # Calculate pairwise correlations
        correlations = []
        for i in range(len(explanations)):
            for j in range(i+1, len(explanations)):
                corr = np.corrcoef(explanations[i], explanations[j])[0, 1]
                correlations.append(corr)
        
        return np.mean(correlations)

Best Practices and Guidelines

Implementation Standards

  • Multi-method validation: Use multiple explanation techniques to validate findings
  • Domain expertise integration: Collaborate with domain experts to validate explanations
  • Stakeholder-appropriate explanations: Tailor explanation complexity to audience
  • Continuous monitoring: Track explanation quality and model behavior over time
  • Bias detection: Use explanations to identify potential model biases

Deployment Considerations

  • Computational efficiency: Balance explanation quality with inference speed
  • Explanation caching: Cache explanations for frequently queried instances
  • Version control: Track explanation methods alongside model versions
  • Regulatory compliance: Ensure explanations meet domain-specific requirements
  • User interface design: Present explanations in intuitive, actionable formats

Common Pitfalls

  • Over-reliance on single methods: Different techniques may give conflicting explanations
  • Explanation cherry-picking: Avoid selecting only favorable explanations
  • Ignoring uncertainty: Acknowledge limitations and confidence bounds
  • Static explanations: Update explanations as models and data evolve
  • Technical jargon: Translate technical explanations for non-technical stakeholders

Discussion

Questions & comments · 0

Sign In Sign in to leave a comment.