NMTF module¶
- class NMTF.NMTF(verbose=True, max_iter=100, seed=1001, term_tol=1e-05, max_l_u=0, max_l_v=0, max_a_u=0, max_a_v=0, k1=2, k2=2, var_lambda=False, var_alpha=False, shape_param=10, mid_epoch_param=5, init_style='random', save_clust=False, draw_intermediate_graph=False, save_intermediate=False, track_objective=False, kill_factors=False, device='cpu', out_path=None, legacy=False)[source]¶
Bases:
object
Base class for NMTF model. Provides minimal support functionality and returns factorized matrices.
- Parameters:
k1 (int) -- Number of components for U matrix. (Default: 2)
k2 (int) -- Number of components for V matrix. (Default: 2)
verbose (bool, optional) -- If True, displays progress messages. (Default: True)
max_iter (int, optional) -- Maximum number of iterations for optimization. (Default: 100)
seed (int, optional) -- Seed for random number generation. (Default: 1001)
term_tol (float, optional) -- Tolerance level for convergence, defined by relative change of error. (Default: 1e-5)
max_l_u (float, optional) -- Maximum orthogonal regularization term for U matrix. (Default: 0)
max_l_v (float, optional) -- Maximum orthogonal regularization term for V matrix. (Default: 0)
max_a_u (float, optional) -- Maximum sparsity constraint for U matrix. (Default: 0)
max_a_v (float, optional) -- Maximum sparsity constraint for V matrix. (Default: 0)
var_lambda (bool, optional) -- If True, lambda increases based on a sigmoid schedule. (Default: False)
var_alpha (bool, optional) -- If True, alpha increases based on a sigmoid schedule. (Default: False)
shape_param (float, optional) -- Controls the steepness of the sigmoid schedule for both alpha and lambda. (Default: 10)
mid_epoch_param (int, optional) -- Epoch at which the sigmoid scheduling function achieves a mean value. (Default: 5)
init_style (str, optional) -- Initialization method for factors; either "nnsvd" (default) or "random".
save_clust (bool, optional) -- If True, saves cluster assignments after every iteration. (Default: False)
track_objective (bool, optional) -- If True, tracks the objective function. (Default: False)
kill_factors (bool, optional) -- If True, halts updates if factor values go to zero. (Default: False)
device (str, optional) -- Device for computation, either "cpu" or "cuda". (Default: "cpu")
out_path (str, optional) -- Path to save output files. (Default: '.')
- Return type:
NMTF object
- assign_X_data(X)[source]¶
Adds a Torch data object to SCOTCH. The input X must be a two-dimensional, non-negative Torch tensor.
- Parameters:
X (torch.Tensor) -- Torch data object to add to SCOTCH. Must be a two-dimensional, non-negative Torch tensor.
- assign_cluster()[source]¶
Assign clusters based on the lower-dimensional embedding matrices U and V.
This method assigns clusters by taking the argmax along the appropriate dimensions of the lower-dimensional embedding matrices U and V. Specifically, it assigns clusters to each data point based on the maximum value in the corresponding row of U (for the U assignments) and the maximum value in the corresponding column of V (for the V assignments).
The cluster assignments are stored in U_assign and V_assign.
- Returns:
None
- fit()[source]¶
Fits the data using the optimization algorithm.
This method executes the necessary steps to fit the model to the data using an optimization algorithm. It begins by initializing factors, normalizing, and scaling them, and then updates the S matrix. The NMTF algorithm is then started and iterated upon. It tracks the objective function setup and updates the model's factors at each iteration.
Steps:
Initializes the factors (U, V, and S).
Normalizes and scales the U and V factors.
Updates the S matrix.
Tracks the objective function setup.
Begins the NMTF optimization algorithm.
- During each iteration:
Updates U, V, and S using the specified update method (legacy or unit-based).
Calculates the objective value.
Optionally prints detailed information about the iteration, including time, objective value, and reconstruction error.
Optionally saves intermediate values of U, S, and V.
Optionally tracks cluster convergence using the Jaccard Index for both U and V assignments.
Optionally visualizes and saves intermediate graphical representations of the factors.
Stops when the relative error falls below a specified tolerance (termTol).
- Returns:
None
- print_USV(file_pre)[source]¶
Write the lower-dimensional matrices (U, V, and S) to tab-delimited text files.
This method saves the U, V, and S matrices to text files with names based on the provided prefix. The matrices are saved in tab-delimited format and will be named file_pre_U.txt, file_pre_V.txt, and file_pre_S.txt.
- Parameters:
file_pre (str) -- Prefix to append to the file names.
- Returns:
None
- print_output(out_path)[source]¶
Write output files related to the factorization and clustering results.
This method writes multiple output files, including the lower-dimensional matrices (U, S, V), terms associated with the objective function (e.g., reconstruction error, lambda regularization terms), and the assignment of U and V at every iteration. It also tracks the stepwise convergence of cluster assignments.
- The output files include:
reconstruction_error.txt: The reconstruction error over iterations.
lU_error.txt: The lambda regularization error for U.
lV_error.txt: The lambda regularization error for V.
relative_error.txt: The relative error over iterations.
U_assign.txt: The U assignments at each iteration (if save_clust is enabled).
V_assign.txt: The V assignments at each iteration (if save_clust is enabled).
V_JI.txt: The Jaccard Index for V assignments (if save_clust is enabled).
U_JI.txt: The Jaccard Index for U assignments (if save_clust is enabled).
- Parameters:
out_path (str) -- The path where the output files will be saved.
- Returns:
None
- recluster_V(linkage_type='average', dist_metric='euclidean')[source]¶
Clusters the V matrix using hierarchical clustering, with the specified linkage type and distance metric. Afterward, it reapplies SCOTCH based on the cluster representations to remove overly redundant factors from S.
This process involves performing hierarchical clustering on the V matrix to group similar factors and reduce redundancy. SCOTCH is then reapplied to the clustered data to improve the factorization.
- Parameters:
linkage_type (str) -- The type of linkage method to use for hierarchical clustering. Must be one of the following: 'single', 'complete', 'average', or 'ward'. Default is 'average'.
dist_metric (str or int) -- The distance metric used for calculating pairwise distances in clustering. It can be one of the following: 'cosine', 'euclidean', 'city_block', 'chebyshev', or an integer for a p-metric. Default is 'euclidean'.
- Returns:
None
- save_cluster()[source]¶
Save cluster assignments and errors for each iteration of the algorithm.
This method initializes tensors to store the cluster assignments for both U and V matrices at each iteration of the algorithm. It also initializes tensors for the Jaccard Index (JI) for both U and V and tracks the relative error over iterations.
Steps: 1. Initializes tensors for storing U cluster assignments (U_assign) and Jaccard Index (U_JI). 2. Initializes tensors for storing V cluster assignments (V_assign) and Jaccard Index (V_JI). 3. Initializes tensor to store the relative error over iterations (relative_error).
- Returns:
None
- update()[source]¶
Defines one update step for the U, V, and S factors.
This method updates the U, V, and S matrices in one iteration by performing the necessary operations for each matrix, including applying regularization, sparsity constraints, and other updates to ensure the factors are optimized. It also updates the residual matrix (R) as part of the optimization process.
- Steps:
Updates the U matrix using the '_update_U' method.
Updates the P matrix.
If lU or aU is greater than 0, recalculates the residual matrix R.
Updates the V matrix using the '_update_V' method.
Updates the Q matrix.
Recalculates the residual matrix R if necessary.
Updates the S matrix.
Normalizes and scales U and V matrices.
Re-updates the P and Q matrices.
- Returns:
None
- update_unit()[source]¶
Defines one update step for U, V, and S, using the unit rules.
This method updates the U, V, and S matrices in one iteration using the unit-based update rules. The update steps ensure that regularization, sparsity constraints, and other necessary updates are applied in the manner that follows the unit rule approach.
- Steps:
Updates the U matrix using the '_update_U_unit' method.
Updates the P matrix.
Updates the V matrix using the '_update_V_unit' method.
Updates the Q matrix.
Updates the S matrix.
Re-updates the P and Q matrices.
- Returns:
None
- visualize_clusters(cmap='viridis', interp='nearest', max_x=1)[source]¶
Visualizes the factors from the NMTF model.
This function generates a visualization of the factors resulting from the NMTF model. It supports customizing the color scheme, interpolation method, and the scaling of the visualization.
- Parameters:
factor_name (str) -- The name of the factor to visualize (e.g., 'U', 'V').
cmap (str, optional) -- The colormap to use for the visualization. Default is 'viridis'.
interp (str, optional) -- The interpolation method for rendering. Default is 'nearest'.
max_val (float, optional) -- The maximum value for scaling the color map. Default is 1.
- Returns:
The matplotlib figure object representing the factor visualization.
- Return type:
matplotlib.figure.Figure
- visualize_clusters_sorted(cmap='viridis', interp='nearest', max_x=1)[source]¶
Visualizes the clusters by ordering elements of the matrix based on their cluster assignments.
- The function sorts the elements of the matrix by their cluster order and alternates the color of each
cluster between grey and black. This approach avoids potential issues with limited color palettes, ensuring better visual distinction between clusters.
- Parameters:
cmap (str, optional) -- The colormap to be used for visualization. Defaults to 'viridis'.
interp (str, optional) -- The interpolation method for rendering the image. Defaults to 'nearest'.
max_x (int, optional) -- The maximum for color scale. Value between [0, 1] where 1 represents the max value in X. Default is 1.
- Returns:
Sorted clusters heatmap representation.
- Return type:
matplotlib.figure.Figure
- visualize_factors(cmap='viridis', interp='nearest', max_u=1, max_v=1, max_x=1)[source]¶
This function generates a visual representation of the NMTF factors, allowing users to specify the colormap and interpolation method used for image display.
- Parameters:
cmap (str, optional) -- The colormap to be used for visualization. Default is 'viridis'.
interp (str, optional) -- The interpolation method to be used for image display. Default is 'nearest'.
max_u (float, optional) -- The maximum for color scale. Value between [0, 1] where 1 represents the max value in U. Default is 1.
max_v (float, optional) -- The maximum for color scale. Value between [0, 1] where 1 represents the max value in V. Default is 1.
max_x (float, optional) -- The maximum for color scale. Value between [0, 1] where 1 represents the max value in X. Default is 1.
- Returns:
U, S, V matrix heatmaps with X and product.
- Return type:
matplotlib.figure.Figure
- visualize_factors_sorted(cmap='viridis', interp='nearest', max_u=1, max_v=1, max_x=1)[source]¶
This function generates a visual representation of the NMTF factors, allowing users to specify the colormap and interpolation method used for image display.
- Parameters:
cmap (str, optional) -- Colormap for the visualization. Default is 'viridis'.
interp (str, optional) -- Interpolation method for image display. Default is 'nearest'.
max_u (float, optional) -- The maximum for color scale. Value between [0, 1] where 1 represents the max value in U. Default is 1.
max_v (float, optional) -- The maximum for color scale. Value between [0, 1] where 1 represents the max value in V. Default is 1.
max_x (float, optional) -- The maximum for color scale. Value between [0, 1] where 1 represents the max value in X. Default is 1.
- Returns:
U, S, V matrix heatmaps with X and product.
- Return type:
matplotlib.figure.Figure
- write_gif(filename='NMTF_fit.gif', fps=5)[source]¶
Save frames of NMTF fit to a GIF figure.
This method generates and saves a GIF showing the intermediate steps of the NMTF fitting process. It is important that the draw_interm ediate_graph parameter is set to True during the fit to capture these frames.
- Parameters:
filename (str, optional) -- The file name to save the GIF. Default is "NMTF_fit.gif".
fps (int, optional) -- The desired frames per second for the GIF. Default is 5.
- Returns:
None