Source code for inspectus

__version__ = '0.1.5'

from typing import List, Optional, TYPE_CHECKING, Union, Tuple, Dict

from inspectus.utils import get_color_list

if TYPE_CHECKING:
    import numpy as np
    import torch

BASIS_POINTS = [
    0,
    6.68,
    15.87,
    30.85,
    50.00,
    69.15,
    84.13,
    93.32,
    100.00
]


[docs] def attention(attn: Union[ 'np.ndarray', 'torch.Tensor', List['torch.Tensor'], Tuple['torch.Tensor', ...], ], query_tokens: Optional[List['str']], key_tokens: Optional[List['str']] = None, *, chart_types: Optional[List['str']] = None, color: Union[str, Dict[str, str]] = None, theme: str = "auto"): """ Use this to visualize attention maps. Parameters ---------- attn : array-like `attn` should be a numpy array, a PyTorch tensor or an attention output from Huggingface transformers. For numpy arrays or PyTorch tensors, it should have the shape `[q_len, k_len]` or `[heads, q_len, k_len]` or `[layers, heads, q_len, k_len]`. query_tokens : (List[str], optional) This is the list of query tokens. key_tokens : (List[str], optional) This is the list of key tokens. If not provided it defaults to query tokens. chart_types : (List[str], optional) A list of chart types to render. If not provided, it defaults to `['attention_matrix', 'query_token_heatmap', 'key_token_heatmap', 'dimension_heatmap']`. Possible values are `'attention_matrix', 'query_token_heatmap', 'key_token_heatmap', 'dimension_heatmap', 'token_dim_heatmap', 'line_grid'`. color : (str, dict) A color to use for rendering components. Single color for all components or a dictionary of colors with (key: chart_type, value: color). If not provided, it defaults to 'blue'. refer https://observablehq.com/@d3/color-schemes for color options theme : str The theme to use for the visualization. Possible values are 'auto', 'light', and 'dark'. Default is 'auto'. Raises ------ ValueError If src_tokens is None or if attn is empty or if attn contains an unknown type or shape. """ if query_tokens is None: raise ValueError('Tokens should be provided') if key_tokens is None: key_tokens = query_tokens from inspectus.attention_viz import parse_attn, attention_chart, parse_colors attn, dimensions = parse_attn(attn) for a in attn: if a.matrix.shape != (len(query_tokens), len(key_tokens)): raise ValueError(f'Attention matrix size should be equal to ' f'[query_len, key_len] = [{len(query_tokens)}, {len(key_tokens)}]; got ' f'{list(a.matrix.shape)} instead.') if chart_types is None: chart_types = ['attention_matrix', 'query_token_heatmap', 'key_token_heatmap', 'dimension_heatmap'] attention_chart( attn=attn, src_tokens=[str(t) for t in query_tokens], tgt_tokens=[str(t) for t in key_tokens], chart_types=chart_types, color=parse_colors(color), dimensions=dimensions, theme=theme )
def compress_series(series, compress_steps=1): """ Compresses a series to blocks of `compress_steps` Args: series: series as a list of dictionaries compress_steps: number of steps to compress Returns: """ res = [] for d in series: values = d['values'] if not isinstance(values, list): values = [values] if not res or d['step'] - res[-1]['step'] >= compress_steps: res.append({'step': d['step'], 'values': [] + values}) else: res[-1]['values'] += values return res
[docs] def series_to_distribution(series: Union[ List[Dict], List['torch.Tensor'], List['np.ndarray'], ], steps: 'np.ndarray' = None): """ Converts a series of data points into a distribution table. Parameters ---------- series : Union[List[Dict], List['torch.Tensor'], List['np.ndarray']] A list of data points. Data points can be dictionaries, numpy arrays, or PyTorch tensors. Dictionary struture should be {'values': [data_points], 'step': step_value}. steps : np.ndarray, optional An array of step values. If not provided, step values are inferred from the data. Returns ------- list A list of dictionaries representing the distribution table. Each dictionary contains the step, histogram, and mean of the data at that step. """ import numpy as np table = [] for i in range(len(series)): data = series[i] try: import torch if isinstance(data, torch.Tensor): data = data.detach().cpu().numpy() except ImportError: pass if isinstance(data, dict): dist = np.percentile(data['values'], BASIS_POINTS) mean = np.mean(data['values']) step = data['step'] else: dist = np.percentile(data, BASIS_POINTS) mean = np.mean(data) step = steps[i] if steps is not None else i histogram = [dist[i] for i in range(0, 9)] row = { 'step': step, 'histogram': histogram, 'mean': mean } table.append(row) return table
[docs] def distribution(data: Dict[str, Union[ List[Dict], List['torch.Tensor'], List['np.ndarray'], ]], *, include_mean: bool = True, include_borders: bool = False, levels=5, alpha=0.5, color_scheme='tableau10', height: int = 500, width: int = 700, height_minimap: int = 100): """ Generates a distribution visualization from the given data. Parameters ---------- data : dict A dictionary where keys are series names and values are lists of data points. Data points can be dictionaries(output from the inspectus.data_logger), numpy arrays, or PyTorch tensors. include_mean : bool, optional If True, includes the mean of the data in the visualization. Default is True. include_borders : bool, optional If True, includes borders at the highest and lowest levels in the visualization. Default is False. levels : int, optional An Integer between 1 and 5, the number of levels in the visualization. Default is 5. alpha : float, optional Opacity of the first band. Reduces by powers for each level. Default is 0.6. color_scheme : str, optional The color scheme to use for the visualization. Default is 'tableau10'. height : int, optional The height of the visualization. Default is 500. width : int, optional The width of the visualization. Default is 700. height_minimap : int, optional The height of the minimap in the visualization. Default is 100. Returns ------- alt.Chart An Altair Chart object representing the distribution visualization. """ from inspectus.distribution_viz import render, _histogram_to_table if levels > 5: levels = 5 if levels < 1: levels = 1 table = [] i = 0 for name, series in data.items(): if len(series) == 0: continue if isinstance(series[0], dict) and 'histogram' in series[0]: table += _histogram_to_table(series, name) else: table += _histogram_to_table(series_to_distribution(series), name) i += 1 return render(table, levels=levels, alpha=alpha, include_borders=include_borders, include_mean=include_mean, color_scheme=color_scheme, height=height, width=width, height_minimap=height_minimap)
ArrayLike = Union['torch.Tensor', 'np.ndarray', List[float]]
[docs] def tokens(tokens: List[str], values: Union[ArrayLike, Dict[str, ArrayLike]], *, token_info: Optional[list[str]] = None, remove_padding: bool = True, colors: Optional[Dict[str, str]] = None, theme: str = "auto"): """ Visualize metrics related to tokens Parameters ---------- tokens : list[str] List of tokens values : (ArrayLike or dict[str, ArrayLike]) Values to visualize. (key: name, value: list of values with shape [num_tokens]) token_info : Optional[list[str]] Aditional info about the tokens. Shape [num_tokens] remove_padding : bool Whether to remove padding in the visualization colors : Optional[Dict[str, str]] Colors to use for each metric in the visualization. If not provided, it defaults to the default color scheme. theme : str The theme to use for the visualization. Possible values are 'auto', 'light', and 'dark'. Default is 'auto'. """ if not isinstance(values, dict): values = {'value': values} if colors is None: colors = {} color_index = 0 for name, _ in values.items(): if name not in colors: colors[name] = get_color_list()[color_index % len(get_color_list())] color_index += 1 for value_list in values.values(): if len(value_list) != len(tokens): raise ValueError("All value lists must have the same length as the tokens list") if token_info is not None and len(token_info) != len(tokens): raise ValueError("token_info must have the same length as the tokens list") from inspectus.token_viz import visualize_tokens visualize_tokens(tokens, values, token_info=token_info, remove_padding=remove_padding, colors=colors, theme=theme)
__all__ = ['attention', 'series_to_distribution', 'distribution', 'tokens']