Source code for raggie.utils

from typing import List, Optional, Union

import numpy as np
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans

import matplotlib.pyplot as plt
import seaborn as sns

from .types import RaggiePlotterClass

[docs] class RaggiePlotter(RaggiePlotterClass): """ Raggie plotter for visualizing keys using t-SNE and optional k-means clustering. This class provides methods to reduce embeddings to 2D space and visualize them with clustering and annotations. """ def __init__(self, model): """ Initialize the RaggiePlotter instance. Args: model: The model used to encode keys into embeddings. """ self.model = model
[docs] def plot(self, keys: List[str], perplexity: Optional[float] = None, learning_rate: Union[float, str] = 'auto', n_iter_without_progress: int = 1000, random_state: int = 42, n_clusters: int = None, show: bool = True, save_path: Optional[str] = None ) -> None: """ Perform t-SNE dimensionality reduction and visualize keys with optional k-means clustering. Args: keys (List[str]): List of keys to visualize. perplexity (Optional[float]): Perplexity parameter for t-SNE. learning_rate (Union[float, str]): Learning rate for t-SNE. n_iter_without_progress (int): Number of iterations without progress before stopping. random_state (int): Random seed for reproducibility. n_clusters (int): Number of clusters for k-means (optional). show (bool): Whether to display the plot. save_path (Optional[str]): Path to save the plot (optional). """ sns.set_theme(style="whitegrid") embeddings = self.model.model.encode(keys, convert_to_numpy=True) perplexity = perplexity or min(30, len(keys) - 1) assert len(keys) >= 2, "At least two keys are required for t-SNE visualization." reduced_embeddings = self._compute_tsne_embeddings( embeddings, perplexity, learning_rate, n_iter_without_progress, random_state ) fig, ax = plt.subplots(figsize=(12, 8)) if n_clusters is not None: cluster_labels = self._perform_kmeans_clustering(reduced_embeddings, n_clusters) palette = sns.color_palette("tab10", n_clusters) cluster_colors = {cluster: palette[cluster] for cluster in range(n_clusters)} centroids = [] group_names = [] for cluster in range(n_clusters): cluster_points = reduced_embeddings[cluster_labels == cluster] centroids.append(cluster_points.mean(axis=0)) group_names.append(", ".join([keys[i] for i in range(len(keys)) if cluster_labels[i] == cluster])) centroids = np.array(centroids) scatter = self._plot_data_points(reduced_embeddings, cluster_labels, palette) self._plot_centroids(centroids, group_names, cluster_colors) else: scatter = self._plot_data_points(reduced_embeddings, keys, "viridis") ax.set_title("t-SNE Visualization of Keys", fontsize=16) ax.set_xlabel("Dimension 1", fontsize=12) ax.set_ylabel("Dimension 2", fontsize=12) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300) print(f"Plot saved to {save_path}") if show: plt.show() return fig
def _compute_tsne_embeddings(self, embeddings: np.ndarray, perplexity: Optional[float], learning_rate: Union[float, str], n_iter_without_progress: int, random_state: int ) -> np.ndarray: """Compute t-SNE embeddings.""" tsne = TSNE( n_components=2, perplexity=perplexity, learning_rate=learning_rate, n_iter_without_progress=n_iter_without_progress, random_state=random_state, ) return tsne.fit_transform(embeddings) def _perform_kmeans_clustering(self, embeddings: np.ndarray, n_clusters: int ) -> np.ndarray: """Perform k-means clustering on the embeddings.""" kmeans = KMeans(n_clusters=n_clusters, random_state=42) cluster_labels = kmeans.fit_predict(embeddings) return cluster_labels def _plot_data_points(self, reduced_embeddings: np.ndarray, hue: Union[List[str], np.ndarray], palette: List[str] ) -> sns.scatterplot: """Plot data points with optional clustering.""" scatter = sns.scatterplot( x=reduced_embeddings[:, 0], y=reduced_embeddings[:, 1], hue=hue, palette=palette, s=100, alpha=0.8 ) return scatter def _plot_centroids(self, centroids: np.ndarray, group_names: list, cluster_colors: dict ) -> None: """Plot centroids with matching colors and annotate them.""" for cluster, centroid in enumerate(centroids): plt.scatter( centroid[0], centroid[1], c=[cluster_colors[cluster]], s=300, edgecolors='black' ) wrapped_group_name = "\n".join(group_names[cluster].split(", ")) plt.text( centroid[0], centroid[1], wrapped_group_name, fontsize=10, ha='center', va='center', color='black', bbox=dict(facecolor='white', alpha=0.6, edgecolor='black') )