SCOTCH module

class SCOTCH.SCOTCH(k1, k2, verbose=True, max_iter=100, seed=1001, term_tol=1e-05, max_l_u=0, max_l_v=0, max_a_u=0, max_a_v=0, var_lambda=False, var_alpha=False, shape_param=10, mid_epoch_param=5, init_style='random', save_clust=False, draw_intermediate_graph=False, save_intermediate=False, track_objective=False, kill_factors=False, device='cpu', out_path='.')[source]

Bases: NMTF

SCOTCH Class

The SCOTCH class extends from the NMTF class. It has a specific __init__ method with several input parameters. The only required inputs are k1 and k2.

__init__ Input Parameters:

  • k1, k2 (int): Lower dimension size of U and V. (required)

  • verbose (bool, optional): If True, prints messages. (default: True)

  • max_iter (int, optional): Maximum number of iterations. (default: 100)

  • seed (int, optional): Random seed for initialization. (default: 1001)

  • term_tol (float, optional): Relative error threshold for convergence. (default: 1e-5)

  • max_l_u (float, optional): Maximum regularization on U. (default: 0)

  • max_l_v (float, optional): Maximum regularization on V. (default: 0)

  • max_a_u (float, optional): Maximum sparse regularization on U. (default: 0, change at own risk)

  • max_a_v (float, optional): Maximum sparse regularization on V. (default: 0, change at own risk)

  • var_lambda (bool, optional): If True, the regularization parameters l_U and l_V increase to max value using a sigmoid scheduler. Generally set to False. (default: False)

  • var_alpha (bool, optional): If True, the regularization parameters a_U and a_V increase to max value using a sigmoid scheduler. Generally set to False. (default: False)

  • shape_param (float, optional): Controls the rate of increase for l_U, l_V, a_U, and a_V when var_lambda=True. (default: 10)

  • mid_epoch_param (int, optional): Sets the epoch where l_U, l_V, a_U, and a_V reach half of their max values if var_lambda=True. (default: 5)

  • init_style (str, optional): Initialization method for SCOTCH. Should be either "random" or "nnsvd". (default: "random")

  • save_clust (bool, optional): Whether to save cluster assignments after each epoch. (default: False)

  • draw_intermediate_graph (bool, optional): If True, draws and saves the matrix representation after each epoch. These can be saved as a GIF. (default: False)

  • track_objective (bool, deprecated): (default: False)

  • kill_factors (bool, optional): If True, SCOTCH will halt updates if any factors in U and V reach zero. (default: False)

  • device (str, optional): Specifies the device to run SCOTCH on: "cpu" or "cuda:". (default: "cpu")

  • out_path (str, optional): Directory to save SCOTCH output files. (default: '.')

add_data_from_adata(adata)[source]

Loads data from AnnData object into SCOTCH framework.

Parameters:

adata (anndata.AnnData) -- anndata.AnnData object to extract data from. Transforms adata.X to PyTorch object.

add_data_from_file(file)[source]

Loads matrix representation into PyTorch tensor object to run with SCOTCH.

Parameters:

file (str) -- The file path to load data from and should have the valid extensions like '.pt', '.txt', or '.h5ad'.

add_scotch_embeddings_to_adata(adata, prefix='')[source]

Adds SCOTCH objects to an AnnData object.

Parameters:
  • prefix (str) -- Prefix to add to AnnData objects created by SCOTCH.

  • adata (anndata.AnnData) -- The AnnData object to which SCOTCH embeddings will be added.

combined_embedding_visualization(adata, gene_cluster_id='gene_clusters', gene_embedding_id='gene_embedding', top_k=5, max_point_size=100, palette='viridis', var1='cell_clusters', var2='sample', S_matrix_id='S_matrix', prefix=None)[source]

Generate a combined visualization of embeddings from the data, with options to color by metadata.

This function is designed to visualize embeddings (e.g., UMAP, PCA, t-SNE) stored in adata, optionally allowing users to color the points by specific metadata columns (like cell type or condition). It arranges one or more plots in a grid layout.

Parameters:
  • adata (anndata.AnnData) -- An AnnData object containing single-cell data. Must contain embeddings in adata.obsm.

  • gene_embedding_id (str, optional) -- The key in adata.obsm where the embedding is stored (e.g., 'X_umap' for UMAP). Default is 'gene_embedding'.

  • top_k (int, optional (default=5)) -- the number of top features to display per gene cluster. Must be a positive integer.

  • max_point_size (int, optional (default=2)) -- Max point size for V bubble plot.

  • palette (str, optional (default='viridis')) -- The color palette to use for the scatter plot when coloring points.

  • var1 (str, optional) -- The primary variable (from adata.obs) to use for heatmap plotting (e.g., 'cell_clusters').

  • var2 (str, optional) -- The secondary variable (from adata.obs) for co-occurrence and proportions (e.g., 'sample').

  • S_matrix_id (str, optional) --

    Key in adata.uns for an externally referenced matrix (e.g., factor matrix or count matrix).

    Returns

  • -------

  • matplotlib.figure.Figure -- A matplotlib figure containing the generated visualizations.

combined_enrichment_visualization(adata, enrich_object_id, top_k=5, max_point_size=100, palette='viridis', var1='cell_clusters', var2='sample', S_matrix_id=None)[source]

Create a 2x3 subplot visualization combining enrichment bubble plots, element count heatmaps, co-occurrence proportions, and S matrix heatmaps.

This visualization provides insights into cellular data enrichment and relationships between variables.

Parameters:
  • adata (anndata.AnnData) -- An AnnData object containing single-cell data. Must contain obs, uns, and required data matrices.

  • enrich_object_id (str) -- The identifier for enrichment data in adata.uns.

  • top_k (int, optional (default=5)) -- The number of top enrichment terms to display in visualizations. Must be a positive integer.

  • max_point_size (int, optional (default=100)) -- The maximum size for the points in the bubble plot. Determines the largest bubble size.

  • palette (str, optional (default='viridis')) -- The color palette used for the bubble plot.

  • var1 (str, optional) -- The first variable for co-occurrence proportions. Must exist in adata.obs.columns.

  • var2 (str, optional) -- The second variable for co-occurrence proportions. Must exist in adata.obs.columns.

  • S_matrix_id (str, optional) -- The key in adata.uns for the additional heatmap matrix. The associated value must be a 2D matrix.

Returns:

A matplotlib figure object containing the generated subplots.

Return type:

matplotlib.figure.Figure

make_adata_from_scotch(prefix='')[source]

Create an AnnData object from the given data.

Parameters:
  • self (object) -- The instance of the class containing the data.

  • prefix (str) -- A string appended to the generated AnnData objects.

Returns:

An AnnData object containing the processed data.

Return type:

anndata.AnnData

make_top_regulators_list(adata, gene_cluster_id='gene_clusters', gene_embedding_id='gene_embedding', prefix=None, top_k=5)[source]

Create a list of top regulators for each gene cluster based on gene embeddings.

Parameters:
  • self (object) -- The instance of the class containing the data.

  • adata (anndata.AnnData) -- An AnnData object containing single-cell gene expression data.

  • gene_cluster_id (str) -- The key for the gene clusters stored in adata.var. (default is "gene_clusters")

  • gene_embedding_id (str) -- The key for the gene embedding matrix stored in adata.varm. (default is "gene_embedding")

  • prefix (str) -- The string utilized when adding SCOTCH data to anndata. Use instead of gene_cluster_id and gene_embedding_id. (default is None)

  • top_k (int, optional) -- The number of top genes to select per cluster (default is 5).

Returns:

A list of tuples, each containing the cluster index and the top top_k genes for that cluster.

Return type:

list of tuples

plot_S_matrix(adata, S_matrix_id='S_matrix', palette='viridis', ax=None)[source]

Plot a heatmap of the S matrix stored in adata.uns under the given key (S_matrix_id).

Parameters:
  • adata (anndata.AnnData) -- An AnnData object containing the single-cell data.

  • S_matrix_id (str) -- The key in adata.uns where the S matrix is stored. Default is 'S_matrix'.

  • palette (str) -- The color palette to use for the heatmap. Default is 'viridis'.

  • ax (matplotlib.axes.Axes, optional) -- A matplotlib Axes object to plot the heatmap on. If not provided, a new figure and Axes will be created.

Returns:

A matplotlib Figure object containing the heatmap, or None if ax is provided.

Return type:

matplotlib.figure.Figure or None

plot_cooccurrence_proportions(adata, field_1='cell_clusters', field_2='sample', cmap='Reds', ax=None)[source]

Generate a heatmap of co-occurrence proportions between two categorical variables.

This function creates a heatmap to visualize the co-occurrence proportions of two categorical variables stored in the adata.obs dataframe. The rows correspond to values from field_1 and the columns correspond to values from field_2. The heatmap displays normalized proportions per row.

Parameters:
  • adata (anndata.AnnData) -- An AnnData object containing the single-cell data. It must include adata.obs with field_1 and field_2 as categorical variables.

  • field_1 (str, optional) -- The name of the first categorical variable in adata.obs (used for heatmap rows). Its values will define the rows in the heatmap. (default is 'cell_clusters')

  • field_2 (str, optional) -- The name of the second categorical variable in adata.obs (used for heatmap columns). Its values will define the columns in the heatmap. (Default is 'sample'.)

  • cmap (str, optional) -- The color map used for the heatmap. It must be a valid Matplotlib colormap, with "Reds" as the default. This determines the gradient colors representing value intensity in the heatmap.

  • ax (matplotlib.axes.Axes or None, optional) -- A Matplotlib Axes object on which the heatmap will be plotted. If None, a new figure and axis are created, and the function returns the generated figure. If provided, the heatmap is plotted on the existing axis, and no figure is returned.

Returns:

A Matplotlib Figure object containing the heatmap of co-occurrence proportions, if a new figure is generated. If ax is provided, the function returns None.

Return type:

matplotlib.figure.Figure or None

Raises:
  • TypeError -- If adata is not an AnnData object, or field_1 and field_2 are not strings.

  • ValueError -- If field_1 or field_2 are not found in adata.obs.columns.

plot_element_count_heatmap(adata, field='cell_clusters', orientation='vertical', cmap='Blues', v_min=0, ax=None)[source]

This function produces a heatmap displaying the count of unique elements in a specific column of an AnnData object. The orientation of the heatmap can be controlled, along with customization options like the color map and axis.

Parameters:
  • adata (anndata.AnnData) -- An AnnData object containing the single-cell data. This parameter stores observations and variables, including metadata used for the analysis.

  • field (str optional.) -- The column in adata.obs for which the counts should be calculated. The unique values in this column are counted and visualized in the heatmap. Default is cell_clusters

  • orientation (str, optional) -- The orientation of the heatmap (either rows or columns represent the elements being counted). Acceptable values are 'vertical' (default) or 'horizontal'.

  • cmap (str, optional) -- The color map used to style the heatmap. For example, use 'Blues' for a blue shade gradient. Defaults to 'Blues'.

  • v_min (int, optional) -- The minimum value for the heatmap color scale. This is useful to set a threshold for visualization. Default is 0.

  • ax (matplotlib.axes.Axes, optional) -- A matplotlib Axes object onto which the heatmap will be drawn. If not provided, a new figure and axes will be created.

Returns:

A matplotlib Figure object containing the heatmap. If ax is provided, then the returned Figure will be None, as the plot will be drawn on the given Axes.

Return type:

matplotlib.figure.Figure or None

plot_reconstruction_error(adata)[source]
run_enrich_analyzer(adata, gene_cluster_id, go_regnet_file, fdr=0.05, test_type='persg', prefix='GO')[source]

Perform gene ontology (GO) enrichment analysis and store results in the AnnData object.

Parameters:
  • self (object) -- The instance of the class containing the data.

  • adata (anndata.AnnData) -- An AnnData object containing single-cell gene expression data.

  • gene_cluster_id (str) -- Identifier for the gene cluster to be analyzed.

  • go_regnet_file (str) -- File path to the gene ontology (GO) enrichment file.

  • fdr (float, optional) -- The false discovery rate threshold (default is 0.05).

  • test_type (str, optional) -- Type of statistical test to be performed (default is 'persg'). Valid options are 'persg' and 'fullgraph'.

  • prefix (str, optional) -- Prefix for storing enrichment results in the AnnData object (default is 'GO').

Returns:

None. Enrichment results are stored in adata.uns[prefix + "enrichment"].

Return type:

None

visualize_enrichment_bubbleplots(adata, enrich_object_id, gene_cluster_id='gene cluster', term_id='TermName', FC_id='Foldenr', q_val_id='CorrPval', top_k=5, max_point_size=100, palette='viridis', gene_cluster_set=None, ax=None)[source]

Visualize enrichment results as a bubble plot, where the size of the bubbles represents log2 fold change and the color represents -log10 p-values.

Parameters:
  • self (object) -- The instance of the class containing the data.

  • adata (anndata.AnnData) -- An AnnData object containing single-cell gene expression data.

  • enrich_object_id (str) -- The key in adata.uns containing the enrichment data.

  • gene_cluster_id (str, optional) -- The column name representing gene clusters in the enrichment data (default is 'gene cluster').

  • term_id (str, optional) -- The column name representing terms in the enrichment data (default is 'TermName').

  • FC_id (str, optional) -- The column name representing fold change values in the enrichment data (default is 'Foldenr').

  • q_val_id (str, optional) -- The column name representing the corrected p-values in the enrichment data (default is 'CorrPval').

  • top_k (int, optional) -- The number of top terms to select per gene cluster (default is 5).

  • max_point_size (int, optional) -- The maximum size of the bubbles in the plot (default is 100).

  • palette (str, optional) -- The color palette used for the plot (default is 'viridis').

Returns:

A matplotlib.figure.Figure object containing the bubble plot.

Return type:

matplotlib.figure.Figure

visualize_marker_gene_bubbleplot_per_cell_cluster(adata, cell_cluster_id='cell_clusters', gene_cluster_id='gene_clusters', gene_embedding_id='gene_embedding', prefix=None, top_k=5, max_point_size=300, palette='viridis', ax=None)[source]

Visualize marker gene expression as a bubble plot, where the size of the bubbles represents the percent of non-zero counts and the color represents the mean marker expression. Top_k are selected for each gene cluster. Genes names are follows by the gene cluster that each gene corresponds to.

Parameters:
  • self (object) -- The instance of the class containing the data.

  • adata (anndata.AnnData) -- An AnnData object containing single-cell gene expression data.

  • cell_cluster_id (str) -- The column name representing cell clusters in adata.obs. (default is 'cell_clusters')

  • gene_cluster_id (str) -- The column name representing gene clusters in adata.var. (default is 'gene_clusters')

  • gene_embedding_id (str) -- The identifier for the gene embedding used for selecting top markers. (default is 'gene_embedding')

  • prefix (str) -- The string utilized when adding SCOTCH data to anndata. Use instead of gene_cluster_id and gene_embedding_id. (default is None)

  • top_k (int, optional) -- The number of top markers to consider per gene cluster (default is 5).

  • max_point_size (int, optional) -- The maximum size of the bubbles in the plot (default is 300).

  • palette (str, optional) -- The color palette used for the plot (default is 'viridis').

  • ax (matplotlib.axes.Axes, optional) -- The matplotlib.axes.axes object to plot in. If none, new figure is generated and returned (default is None)

Returns:

A matplotlib.figure.Figure object containing the bubble plot if ax not passed. Else returns none.

Return type:

matplotlib.figure.Figure or None

visualize_marker_gene_bubbleplot_per_gene_cluster(adata, gene_cluster_id='gene_clusters', gene_embedding_id='gene_embedding', prefix=None, top_k=5, max_point_size=300, palette='viridis', ax=None)[source]

Visualize marker gene expression as a bubble plot, where the size of the bubbles represents the percent of non-zero counts and the color represents the mean marker expression. Top_k are selected for each gene cluster. Genes names are follows by the gene cluster that each gene corresponds to.

Parameters:
  • self (object) -- The instance of the class containing the data.

  • adata (anndata.AnnData) -- An AnnData object containing single-cell gene expression data.

  • gene_cluster_id (str, optional) -- The identifier for the gene cluster used for selecting top markers.

  • gene_embedding_id (str, optional) -- The identifier for the gene embedding used for selecting top markers.

  • prefix (str, optional) -- The prefix string used when adding SCOTCH data to anndata object.

  • top_k (int, optional) -- The number of top markers to consider per gene cluster (default is 5).

  • max_point_size (int, optional) -- The maximum size of the bubbles in the plot (default is 300).

  • palette (str, optional) -- The color palette used for the plot (default is 'viridis').

  • ax (matplotlib.axes.Axes, optional) -- A matplotlib.axes.Axes object to plot the bubble plot. If not provided a new figures is generated

Returns:

A matplotlib.figure.Figure object containing the bubble plot.

Return type:

matplotlib.figure.Figure