Source code for exploralytics.visualize.visualizer

import plotly.graph_objects as go
import pandas as pd
import numpy as np
import plotly.figure_factory as ff
import plotly.express as px
from plotly.subplots import make_subplots

from .utils import (
    identify_num_rows,
    highlight_bars_colors,
    add_footer
)

[docs] class Visualizer: """ A class to create and customize data visualizations using Plotly. This class provides methods to generate various plots with consistent styling and formatting. It allows customization of plot colors, dimensions, templates, and other visual elements. Attributes ---------- color : str Color code for plot elements (default: "#94C973") height : int Height of the plot in pixels (default: 768) width : int Width of the plot in pixels (default: 1366) template : str Plotly template name (default: "simple_white") colorscale : list Colorscale for the plot (default: px.colors.diverging.Earth) texts_font_style : str, optional Font style for text elements in the plot (default: None) title_bold : bool Flag to set the plot title in bold (default: False) Methods ------- __init__(color="#94C973", height=768, width=1366, template="simple_white", colorscale=px.colors.diverging.Earth, texts_font_style=None, title_bold=False) Initializes the Visualizer with the specified parameters. """
[docs] def __init__( self, color = "#94C973", height = 768, width = 1366, template = "simple_white", colorscale=px.colors.diverging.Earth, texts_font_style: str = None, title_bold: bool = False, ): # Initialize visualization parameters self.color = color # Set default color for plot elements self.height = height # Set default plot height self.width = width # Set default plot width self.template = template # Set default plotly template self.colorscale = colorscale self.font_family = texts_font_style self.title_bold = title_bold
def _get_base_layout_config(self, title: str, subtitle: str) -> dict: """ Create a base layout configuration for Plotly figures. Args: title (str): title of the plot subtitle (str): subtitle of the plot Returns: dict: dictionary containing layout configuration """ return dict( # Title configuration title_text=f"<b>{title}</b><br><sup>{subtitle}</sup>" if self.title_bold else f"{title}<br><sup>{subtitle}</sup>", # Set title and subtitle text with optional bold title_x=0.5, # Center title horizontally # General figure settings showlegend=False, # Hide legend by default height=self.height, # Set figure height from class attribute width=self.width, # Set figure width from class attribute template=self.template, # Use template defined in class # Hover label settings hoverlabel_font_family=self.font_family, # Set font family for hover labels # X-axis settings xaxis_title_font_family=self.font_family, # Set font family for x-axis title xaxis_tickfont_family=self.font_family, # Set font family for x-axis tick labels xaxis_ticks="", # Remove x-axis tick marks xaxis_showline=False, # Remove x-axis line xaxis_ticklabelstandoff=10, # Increase distance between tick labels and axis line xaxis_zeroline = False, # Hide x-axis zero line # Y-axis settings yaxis_title_font_family=self.font_family, # Set font family for y-axis title yaxis_tickfont_family=self.font_family, # Set font family for y-axis tick labels yaxis_ticks="", # Remove y-axis tick marks yaxis_showline=False, # Remove x-axis line yaxis_ticklabelstandoff=10, # Increase distance between tick labels and axis line yaxis_zeroline = False, # Hide y-axis zero line )
[docs] def plot_histograms( self, df: pd.DataFrame, specific_cols: list[str] = [], num_cols: int = 1, title: str = 'How distributed the numerical values are?', subtitle: str = 'Histogram of each column with numerical data type', footer: str = None, show_mean: bool = False, show_median: bool = False, ) -> go.Figure: """ Create multiple histogram subplots for numerical columns in a dataframe. Creates a grid of histograms showing the distribution of numerical data. Can optionally show mean and median lines on each histogram. Parameters ---------- df : pandas.DataFrame Input dataframe containing the data to plot specific_cols : list, optional List of specific columns to plot (default: empty list, plots all numerical columns) num_cols : int, optional Number of columns in the subplot grid (default: 1) title : str, optional Main title of the plot subtitle : str, optional Subtitle of the plot footer : str, optional Text to show at bottom of plot show_mean : bool, optional Whether to show mean line on histograms (default: False) show_median : bool, optional Whether to show median line on histograms (default: False) font_family : str, optional Font family for text elements (default: None) title_bold : bool, optional Whether to make title text bold (default: False) Returns ------- plotly.graph_objects.Figure Figure containing the histogram subplots """ # Select columns based on input: either use specific columns or all numerical columns if len(specific_cols) >= 1: numerical_columns = [column for column in specific_cols if df[column].dtypes in ['int64', 'float64']] else: numerical_columns = df.select_dtypes(include=np.number).columns.tolist() # Calculate number of rows needed based on number of columns and subplots num_rows = identify_num_rows(numerical_columns, desired_num_col=num_cols) # Create subplot grid fig = make_subplots(rows=num_rows, cols=num_cols, subplot_titles=numerical_columns, ) # Create histogram for each numerical column for index, col_name in enumerate(numerical_columns): # Calculate subplot position row = index // num_cols + 1 # Determine row number for current subplot col = index % num_cols + 1 # Determine column number for current subplot # Add histogram trace fig.add_trace(go.Histogram(x=df[col_name], autobinx=True, marker_color=self.color, name=col_name, hovertemplate="Interval: %{x}<br>Count: %{y}<extra></extra>" ), row=row, col=col) # Add mean line if requested if show_mean: fig.add_vline(x=df[col_name].mean(), line_color="grey", # Changed color for visibility line_dash="solid", line_width=2, # Made line thicker row=row, col=col, ) # Add median line if requested if show_median: fig.add_vline(x=df[col_name].median(), line_color="grey", # Changed color for visibility line_dash="dash", line_width=2, # Made line thicker row=row, col=col, ) # Apply the layout configuration fig.update_layout( # Title configuration title_text=f"<b>{title}</b><br><sup>{subtitle}</sup>" if self.title_bold else f"{title}<br><sup>{subtitle}</sup>", # Set title and subtitle text with optional bold title_x=0.5, # Center title horizontally # General figure settings showlegend=False, # Hide legend by default height=self.height, # Set figure height from class attribute width=self.width, # Set figure width from class attribute template=self.template, # Use template defined in class # Hover label settings hoverlabel_font_family=self.font_family, # Set font family for hover labels # X-axis settings xaxis_title_font_family=self.font_family, # Set font family for x-axis title xaxis_tickfont_family=self.font_family, # Set font family for x-axis tick labels ) # Add optional footer if footer is not None: add_footer(fig, footer, font_family=self.font_family) return fig
[docs] def plot_correlation_map( self, df: pd.DataFrame, title: str = 'How correlated the numerical values are?', subtitle: str = 'Correlation matrix of columns with numerical data type', footer: str = None ) -> go.Figure: """ Create a correlation matrix heatmap showing relationships between numerical columns. Generates a triangular heatmap visualization where each cell shows the correlation between two variables. Only shows the lower triangle to avoid redundancy. Includes hover text and annotations for correlation values. Parameters ---------- df : pandas.DataFrame Input dataframe containing the numerical columns to correlate title : str, optional Main title for the plot (default: 'How correlated the numerical values are?') subtitle : str, optional Subtitle to display below main title (default: 'Correlation matrix of columns with numerical data type') footer : str, optional Text to show at bottom of plot Returns ------- plotly.graph_objects.Figure Figure containing the correlation heatmap """ # Calculate correlation matrix and round values to 2 decimal places corr = df.select_dtypes(['int', 'float']).corr().round(2) # Create mask for upper triangle to avoid redundant information mask = np.triu(np.ones_like(corr, dtype=bool)) # Apply mask to correlation matrix - sets upper triangle to NaN df_mask = corr.mask(mask) # Create custom hover text showing correlation between each pair of variables hovertext = [[f"{col1} vs {col2}<br>Correlation: {val:.2f}" if not pd.isna(val) else "" for col2, val in zip(df_mask.columns, row)] for col1, row in zip(df_mask.index, df_mask.values)] # Create annotated heatmap with custom settings fig = ff.create_annotated_heatmap( z=df_mask.to_numpy(), # Correlation values x=df_mask.columns.tolist(), # Column names for x-axis y=df_mask.columns.tolist(), # Column names for y-axis colorscale=self.colorscale, # Color scheme for heatmap showscale=True, # Show color scale ygap=1, # Gap between y-axis cells xgap=1, # Gap between x-axis cells hoverongaps=False, # Disable hover on empty cells hoverinfo='text', # Use custom hover text text=hovertext # Custom hover text ) # Add optional footer if footer is not None: add_footer(fig, footer, font_family=self.font_family) # Move x-axis labels to bottom of plot fig.update_xaxes(side="bottom") # Get the base configuration layout_config = self._get_base_layout_config(title, subtitle) # Add/update specific settings for this plot layout_config.update({ 'xaxis_showgrid': False, # Hide x-axis gridlines 'yaxis_showgrid': False, # Hide y-axis gridlines 'yaxis_autorange': 'reversed', # Reverse y-axis order 'xaxis_tickangle': 90, # Rotate x-axis labels }) # Apply the combined layout configuration fig.update_layout(**layout_config) # Format cell annotations - remove NaN values and format numbers for i in range(len(fig.layout.annotations)): if fig.layout.annotations[i].text == 'nan': fig.layout.annotations[i].text = "" # Replace NaN with empty string else: try: value = float(fig.layout.annotations[i].text) fig.layout.annotations[i].text = f"{value:.2f}" # Format to 2 decimal places except ValueError: pass return fig
[docs] def plot_correlation_with_target( self, df: pd.DataFrame, target_column: str, title: str = 'How correlated the features are with the target?', subtitle: str = 'Correlation coefficient of each feature', footer: str = None ) -> go.Figure: """ Create a horizontal bar chart showing how each feature correlates with a target variable. This visualization helps identify which features have strong positive or negative relationships with the target variable. Positive correlations are shown in blue, negative in light red. Values range from -1 (perfect negative correlation) to 1 (perfect positive correlation). Parameters ---------- df : pandas.DataFrame Input dataframe containing numerical columns to analyze target_column : str Name of the column to compare other features against title : str, optional Main title for the plot (default: 'How correlated the features are with the target?') subtitle : str, optional Subtitle to display below main title (default: 'Correlation coefficient of each feature') footer : str, optional Text to show at bottom of plot Returns ------- plotly.graph_objects.Figure Interactive bar chart showing correlation coefficients """ # Select numerical columns only df = df.select_dtypes(['int', 'float']) # Calculate correlations with target and sort them correlations = df.corr()[target_column].sort_values() # Remove target column's correlation with itself (always 1) correlations = correlations.drop(target_column) # Assign colors based on correlation direction: # Light red for negative correlations, blue for positive colors = ['#FF9999' if c < 0 else '#2E75B6' for c in correlations] # Initialize empty figure fig = go.Figure() # Add horizontal bar trace fig.add_trace( go.Bar( y=correlations.index, # Feature names on y-axis x=correlations.values, # Correlation values on x-axis orientation='h', # Make bars horizontal marker_color=colors, # Color bars based on correlation text=[f'{x:.2f}' for x in correlations.values], # Show correlation values textposition='outside', # Place text outside of bars ) ) # Add optional footer if footer is not None: add_footer(fig, footer, font_family=self.font_family) # Get the base configuration layout_config = self._get_base_layout_config(title, subtitle) # Update base configuration with specific settings layout_config.update({ 'height': max(400, len(correlations) * 30), # Dynamic height 'xaxis_title': 'Correlation Coefficient', 'xaxis_gridwidth': 1, 'xaxis_tickformat': '.2f', 'yaxis_title': 'Features', 'yaxis_autorange': "reversed" }) # Apply the combined layout configuration fig.update_layout(**layout_config) return fig
[docs] def plot_hbar( self, df: pd.DataFrame, x_col: str, y_col: str = None, title: str = "How distributed are the categories?", subtitle: str = "Horizontal bar plot of categories", footer: str = None, add_hline: bool = False, top_n: int = None, highlight_top_n: tuple[int, str] = None, # (n, hex_color) highlight_low_n: tuple[int, str] = None # (n, hex_color) ) -> go.Figure: """ Create a horizontal bar plot with optional highlighting and statistics. Shows either value counts of a single column or relationship between two columns. Can highlight top/bottom values and show mean line. Bars can be limited to show only top N values. Parameters ---------- df : pandas.DataFrame The data to plot x_col : str Name of column to show on y-axis of plot y_col : str, optional Name of column for bar lengths. If None, shows value counts of x_col title : str, optional Main title of plot (default: "How distributed are the categories?") subtitle : str, optional Subtitle shown below main title (default: "Horizontal bar plot of categories") footer : str, optional Text to show at bottom of plot top_n : int, optional Number of bars to show. If None, shows all add_hline : bool, optional Whether to add mean line (default: False) highlight_top_n : tuple, optional Tuple of (n, hex color code) to highlight top n bars highlight_low_n : tuple, optional Tuple of (n, hex color code) to highlight bottom n bars Returns ------- plotly.graph_objects.Figure Interactive horizontal bar plot """ # Create empty figure canvas fig = go.Figure() if y_col is None: # Calculate frequency of each category in x_col value_counts = df[x_col].value_counts().rename_axis(x_col).reset_index(name='Count') # Limit to top N categories if specified plot_data = value_counts.head(top_n) if top_n else value_counts # Determine bar colors based on highlighting options colors = (highlight_bars_colors(highlight_top_n, highlight_low_n, len(plot_data)) if highlight_top_n or highlight_low_n else self.color) # Create bars with custom hover text fig.add_trace( go.Bar( y=plot_data[x_col], x=plot_data['Count'], orientation='h', marker_color=colors, hovertemplate=( f"<b>{x_col}</b>: " + "%{y}<br>" + f"Count: " + "%{x:,.0f}<br>" + "<extra></extra>" ) ) ) else: # Sort data by y_col values plot_data = df.sort_values(y_col, ascending=False) if top_n: plot_data = plot_data.head(top_n) # Determine bar colors based on highlighting options colors = (highlight_bars_colors(highlight_top_n, highlight_low_n, len(plot_data)) if highlight_top_n or highlight_low_n else self.color) # Create bars with custom hover text fig.add_trace( go.Bar( y=plot_data[x_col], x=plot_data[y_col], orientation='h', marker_color=colors, hovertemplate=( f"<b>{x_col}</b>: " + "%{y}<br>" + f"{y_col}: " + "%{x:,.2f}<br>" + "<extra></extra>" ) ) ) # Add mean reference line if requested if add_hline: mean_value = plot_data[y_col].mean() if y_col else plot_data['Count'].mean() fig.add_vline( x=mean_value, line_color="grey", line_dash="dash", line_width=2, ) # Add optional footer if footer is not None: add_footer(fig, footer, font_family=self.font_family) # Get the base configuration layout_config = self._get_base_layout_config(title, subtitle) # Update base configuration with specific settings layout_config.update({ 'xaxis_title': y_col if y_col else "Count", # Set x-axis title to column name or default 'yaxis_title': x_col, # Set y-axis title to column name 'yaxis_autorange': "reversed" # Show categories from top to bottom }) # Apply the combined layout configuration fig.update_layout(**layout_config) return fig
[docs] def plot_dot( self, df: pd.DataFrame, x_col: str, y_col: str, title: str = "How are the categories distributed?", subtitle: str = "Dot plot of categories", footer: str = None, add_hline_at: tuple[str, float] = None, # (label, value) top_n: int = None, highlight_top_n: tuple[int, str] = None, # (n, hex_color) highlight_low_n: tuple[int, str] = None # (n, hex_color) ) -> go.Figure: """ Create a dot plot showing values as dots with connecting lines to x-axis. This plot shows values as dots above their categories, with vertical lines connecting each dot to its category on the x-axis. Can highlight highest/lowest values with different colors and show a reference line at a specific value. Parameters ---------- df : pandas.DataFrame Data to plot x_col : str Column name for categories on x-axis y_col : str Column name for values shown as dot heights title : str, optional Main title of plot subtitle : str, optional Subtitle shown below main title footer : str, optional Text to show at bottom of plot add_hline_at : tuple[str, float], optional Reference line (label, value) to add top_n : int, optional Number of dots to show, ordered by value highlight_top_n : tuple[int, str], optional (number, color) for highlighting highest values highlight_low_n : tuple[int, str], optional (number, color) for highlighting lowest values Returns ------- plotly.graph_objects.Figure Interactive dot plot figure """ # Sort values in descending order and limit to top_n if specified plot_data = df.sort_values(y_col, ascending=False) if top_n: plot_data = plot_data.head(top_n) # Get colors for dots based on highlighting options colors = (highlight_bars_colors(highlight_top_n, highlight_low_n, len(plot_data)) if highlight_top_n or highlight_low_n else self.color) # Create figure and add dots with hover information fig = go.Figure() fig.add_trace( go.Scatter( x=plot_data[x_col], y=plot_data[y_col], mode='markers+text', marker=dict( color=colors, size=18 ), text=plot_data[y_col].round(1), textposition='top center', hovertemplate=( f"<b>{x_col}</b>: " + "%{x}<br>" + f"{y_col}: " + "%{y:,.2f}<br>" + "<extra></extra>" ) ) ) # Add vertical connecting lines between dots and x-axis for i, row in plot_data.iterrows(): fig.add_shape( type='line', x0=row[x_col], x1=row[x_col], y0=0, y1=row[y_col], line=dict(color='lightgrey', width=2), layer="below", ) # Add optional reference line with label if add_hline_at is not None: fig.add_hline( y=add_hline_at[-1], line=dict(color='grey', dash='dot', width=0.5), annotation_text=f'{add_hline_at[0]}: {add_hline_at[-1]:.1f}', annotation_position='top right', layer="below", ) # Add optional footer if footer is not None: add_footer(fig, footer, font_family=self.font_family) # Get the base configuration layout_config = self._get_base_layout_config(title, subtitle) # Update base configuration with specific settings layout_config.update({ 'yaxis_title': "", 'yaxis_visible': False, # Hide y-axis 'xaxis_title': "", 'xaxis_showline': True, 'xaxis_linecolor': 'lightgrey', 'xaxis_linewidth': 2, 'xaxis_type': 'category', 'margin': dict(t=100, pad=0) # Override base margin settings }) # Apply the combined layout configuration fig.update_layout(**layout_config) return fig