Subpackages
Module contents
- inspectus.attention(attn: np.ndarray | torch.Tensor | List[torch.Tensor] | Tuple[torch.Tensor, ...], query_tokens: List[str] | None, key_tokens: List[str] | None = None, *, chart_types: List[str] | None = None, color: str | Dict[str, str] = None, theme: str = 'auto')[source]
Use this to visualize attention maps.
- Parameters:
attn (array-like) – attn should be a numpy array, a PyTorch tensor or an attention output from Huggingface transformers. For numpy arrays or PyTorch tensors, it should have the shape [q_len, k_len] or [heads, q_len, k_len] or [layers, heads, q_len, k_len].
query_tokens ((List[str], optional)) – This is the list of query tokens.
key_tokens ((List[str], optional)) – This is the list of key tokens. If not provided it defaults to query tokens.
chart_types ((List[str], optional)) – A list of chart types to render. If not provided, it defaults to [‘attention_matrix’, ‘query_token_heatmap’, ‘key_token_heatmap’, ‘dimension_heatmap’]. Possible values are ‘attention_matrix’, ‘query_token_heatmap’, ‘key_token_heatmap’, ‘dimension_heatmap’, ‘token_dim_heatmap’, ‘line_grid’.
color ((str, dict)) – A color to use for rendering components. Single color for all components or a dictionary of colors with (key: chart_type, value: color). If not provided, it defaults to ‘blue’. refer https://observablehq.com/@d3/color-schemes for color options
theme (str) – The theme to use for the visualization. Possible values are ‘auto’, ‘light’, and ‘dark’. Default is ‘auto’.
- Raises:
ValueError – If src_tokens is None or if attn is empty or if attn contains an unknown type or shape.
- inspectus.distribution(data: Dict[str, List[Dict] | List[torch.Tensor] | List[np.ndarray]], *, include_mean: bool = True, include_borders: bool = False, levels=5, alpha=0.5, color_scheme='tableau10', height: int = 500, width: int = 700, height_minimap: int = 100)[source]
Generates a distribution visualization from the given data.
- Parameters:
data (dict) – A dictionary where keys are series names and values are lists of data points. Data points can be dictionaries(output from the inspectus.data_logger), numpy arrays, or PyTorch tensors.
include_mean (bool, optional) – If True, includes the mean of the data in the visualization. Default is True.
include_borders (bool, optional) – If True, includes borders at the highest and lowest levels in the visualization. Default is False.
levels (int, optional) – An Integer between 1 and 5, the number of levels in the visualization. Default is 5.
alpha (float, optional) – Opacity of the first band. Reduces by powers for each level. Default is 0.6.
color_scheme (str, optional) – The color scheme to use for the visualization. Default is ‘tableau10’.
height (int, optional) – The height of the visualization. Default is 500.
width (int, optional) – The width of the visualization. Default is 700.
height_minimap (int, optional) – The height of the minimap in the visualization. Default is 100.
- Returns:
An Altair Chart object representing the distribution visualization.
- Return type:
alt.Chart
- inspectus.series_to_distribution(series: List[Dict] | List[torch.Tensor] | List[np.ndarray], steps: np.ndarray = None)[source]
Converts a series of data points into a distribution table.
- Parameters:
series (Union[List[Dict], List['torch.Tensor'], List['np.ndarray']]) – A list of data points. Data points can be dictionaries, numpy arrays, or PyTorch tensors. Dictionary struture should be {‘values’: [data_points], ‘step’: step_value}.
steps (np.ndarray, optional) – An array of step values. If not provided, step values are inferred from the data.
- Returns:
A list of dictionaries representing the distribution table. Each dictionary contains the step, histogram, and mean of the data at that step.
- Return type:
list
- inspectus.tokens(tokens: List[str], values: torch.Tensor | np.ndarray | List[float] | Dict[str, torch.Tensor | np.ndarray | List[float]], *, token_info: list[str] | None = None, remove_padding: bool = True, colors: Dict[str, str] | None = None, theme: str = 'auto')[source]
Visualize metrics related to tokens
- Parameters:
tokens (list[str]) – List of tokens
values ((ArrayLike or dict[str, ArrayLike])) – Values to visualize. (key: name, value: list of values with shape [num_tokens])
token_info (Optional[list[str]]) – Aditional info about the tokens. Shape [num_tokens]
remove_padding (bool) – Whether to remove padding in the visualization
colors (Optional[Dict[str, str]]) – Colors to use for each metric in the visualization. If not provided, it defaults to the default color scheme.
theme (str) – The theme to use for the visualization. Possible values are ‘auto’, ‘light’, and ‘dark’. Default is ‘auto’.