Source code for regdiffusion.trainer

import numpy as np
import pandas as pd
import scipy.sparse as sp
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataset import TensorDataset
from .models import RegDiffusion, RegDiffusionME
from tqdm import tqdm
from .logger import LightLogger
from datetime import datetime
from .grn import GRN
from .evaluator import GRNEvaluator
from .logger import LightLogger
import matplotlib.pyplot as plt
import warnings

def linear_beta_schedule(timesteps, start_noise, end_noise):
    scale = 1000 / timesteps
    beta_start = scale * start_noise
    beta_end = scale * end_noise
    return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float)

def power_beta_schedule(timesteps, start_noise, end_noise, power=2):
    linspace = torch.linspace(0, 1, timesteps, dtype = torch.float)
    poweredspace = linspace ** power
    scale = 1000 / timesteps
    beta_start = scale * start_noise
    beta_end = scale * end_noise
    return beta_start + (beta_end - beta_start) * poweredspace

class _SparseExpressionDataset(Dataset):
    """Dataset that stores sparse expression data and normalizes on-the-fly.

    Instead of materializing the full dense normalized matrix (which can be
    100+ GB for 1M cells), this dataset keeps the original sparse matrix in
    memory and converts only the requested rows to dense during __getitem__.

    Normalization applied per sample:
        1. Min-max per cell:  (x - cell_min) / cell_range
        2. Z-score per gene:  (x - gene_mean) / gene_std
    """

    def __init__(self, sparse_X, cell_types, cell_min, cell_range,
                 gene_mean, gene_std, indices):
        self.X = sparse_X            # scipy CSR matrix (shared, read-only)
        self.cell_types = cell_types  # int array (full, not subset)
        self.cell_min = cell_min      # (n_cell,) float32
        self.cell_range = cell_range  # (n_cell,) float32
        self.gene_mean = gene_mean    # (n_gene,) float32
        self.gene_std = gene_std      # (n_gene,) float32
        self.indices = indices         # row indices for this split

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        real_idx = self.indices[idx]
        row = self.X[real_idx].toarray().ravel().astype(np.float32)
        row = (row - self.cell_min[real_idx]) / self.cell_range[real_idx]
        row = (row - self.gene_mean) / self.gene_std
        ct = int(self.cell_types[real_idx])
        return torch.from_numpy(row), torch.tensor(ct, dtype=torch.long)


[docs] class RegDiffusionTrainer: """ Initialize and Train a RegDiffusion model. For architecture and training details, please refer to our paper. > From noise to knowledge: probabilistic diffusion-based neural inference You can access the model through `RegDiffusionTrainer.model`. Args: exp_array (np.ndarray or scipy.sparse matrix): 2D expression matrix. If used on single-cell RNAseq, the rows are cells and the columns are genes. Data should be log transformed. You may also want to remove all non expressed genes. Sparse matrices (e.g. from ``adata.X``) are supported and handled memory-efficiently — the full dense normalized matrix is never materialized. cell_types (np.ndarray): (Optional) 1D integer array for cell type. If you have labels in your cell type, you need to convert them to interge. Default is None. T (int): Total number of diffusion steps. Default: 5,000 start_noise (float): Minimal noise level (beta) to be added. Default: 0.0001 end_noise (float): Maximal noise level (beta) to be added. Default: 0.02 time_dim (int): Dimension size for the time embedding. Default: 64. celltype_dim (int): Dimension size for the cell type embedding. Default: 4. hidden_dim (list): Dimension sizes for the feature learning layers. We use the size of the first layer as the dimension for gene embeddings as well. Default: [16, 16, 16]. init_coef (int): A coefficent to control the value to initialize the adjacency matrix. Here we define regulatory norm as 1 over (number of genes - 1). The value which we use to initialize the model is `init_coef` times of the regulatory norm. Default: 5. lr_nn (float): Learning rate for the rest of the neural networks except the adjacency matrix. Default: 0.001 lr_adj (float): Learning rate for the adjacency matrix. By default, it equals to 0.02 * gene regulatory norm, which equals 1/(n_gene-1). weight_decay_nn (float): L2 regularization coef on the rest of the neural networks. Default: 0.1. weight_decay_adj (float): L2 regularization coef on the adj matrix. Default: 0.01. sparse_loss_coef (float): L1 regularization coef on the adj matrix. Default: 0.25. adj_dropout (float): Probability of an edge to be zeroed. Default: 0.3. batch_size (int): Batch size for training. Default: 128. n_steps (int): Total number of training iterations. Default: 1000. train_split (float): Train partition. Default: 1.0. train_split_seed (int): Random seed for train/val partition. Default: 123 device (str or torch.device): Device where the model is running. For example, "cpu", "cuda", "cuda:1", and etc. You are not recommended to run this model on Apple's MPS chips. Default is "cuda" but if you only has CPU, it will switch back to CPU. compile (boolean): Whether to compile the model before training. Compile the model is a good idea on large dataset and ofter improves inference speed when it works. For smaller dataset, eager execution if often good enough. evaluator (GRNEvaluator): (Optional) A defined GRNEvaluator if ground truth data is available. Evaluation will be done every 100 steps by default but you can change this setting through the eval_on_n_steps option. Default is None eval_on_n_steps (int): If an evaluator is provided, the trainer will run evaluation every `eval_on_n_steps` steps. Default: 100. logger (LightLogger): (Optional) A LightLogger to log training process. The only situation when you need to provide this is when you want to save logs from different trainers into the same logger. Default is None. gradient_accumulation (boolean): Whether to train with gradient accumulation. This is useful when number of genes are extremely large. use_amp (bool): Whether to use automatic mixed precision (bfloat16) during training. This reduces memory usage by storing activations in half precision while keeping model parameters in float32. Requires a GPU with bfloat16 support (Ampere or newer). Default: False. memory_efficient (bool): Whether to use the memory-efficient model variant (RegDiffusionME). This uses a custom autograd function for soft thresholding (saves boolean masks instead of float32 tensors) and sampled sparse loss (avoids materializing full n_gene x n_gene matrix for L1 regularization). Default: False. """ def __init__( self, exp_array, cell_types=None, T=5000, start_noise=0.0001, end_noise=0.02, time_dim=64, celltype_dim=4, hidden_dims=[16, 16, 16], init_coef = 5, lr_nn=1e-3, lr_adj=None, weight_decay_nn=0.1, weight_decay_adj = 0.01, sparse_loss_coef=0.25, adj_dropout=0.30, batch_size=128, n_steps=1000, train_split=1.0, train_split_seed=123, device='cuda', compile=False, evaluator=None, eval_on_n_steps=100, logger=None, gradient_accumulation=False, use_amp=False, memory_efficient=False ): hp = locals() del hp['exp_array'] del hp['cell_types'] del hp['logger'] self.hp = hp if device == 'mps': raise Exception("We noticed unreliable training behavior on", "Apple's silicon. Consider using other devices.") elif device.startswith('cuda'): if not torch.cuda.is_available(): print( "You specified cuda as your computing device but apprently", "it's not available. Setting device to cpu for now. ") device = 'cpu' self.device = device self.hp['device'] = device # Logger --------------------------------------------------------------- if logger is None: self.logger = LightLogger() else: self.logger = logger self.note_id = self.logger.start() # Define diffusion schedule self.betas = linear_beta_schedule(T, start_noise, end_noise) self.alphas = 1. - self.betas alpha_bars = torch.cumprod(self.alphas, axis=0) self.mean_schedule = torch.sqrt(alpha_bars).to(device) self.std_schedule = torch.sqrt(1. - alpha_bars).to(device) # Prepare Data --------------------------------------------------------- if (exp_array.shape[0] < batch_size): warnings.warn( "Batch size needs to be smaller than the number of cells. " ) if cell_types is None: cell_types = np.zeros(exp_array.shape[0], dtype=int) self.n_celltype = len(np.unique(cell_types)) n_cell, n_gene = exp_array.shape self.n_cell = n_cell self.n_gene = n_gene self.evaluator = evaluator if sp.issparse(exp_array): self._prepare_sparse_data( exp_array, cell_types, batch_size, train_split, train_split_seed) else: self._prepare_dense_data( exp_array, cell_types, batch_size, train_split, train_split_seed) # Setup Model ---------------------------------------------------------- gene_reg_norm = 1/(n_gene-1) ModelClass = RegDiffusionME if memory_efficient else RegDiffusion self.model = ModelClass( n_gene=n_gene, time_dim=time_dim, n_celltype=self.n_celltype, celltype_dim = celltype_dim, hidden_dims=hidden_dims, adj_dropout=adj_dropout, init_coef=init_coef ) # Setup optimizer ------------------------------------------------------ if lr_adj is None: lr_adj = gene_reg_norm/50 self.hp['lr_adj'] = lr_adj adj_params = [] non_adj_params = [] for name, param in self.model.named_parameters(): if name.endswith('adj_A'): adj_params.append(param) else: if not name.endswith('_nonparam'): non_adj_params.append(param) self.opt = torch.optim.Adam( [{'params': non_adj_params}, {'params': adj_params}], lr=lr_nn, weight_decay=weight_decay_nn, betas=[0.9, 0.99] ) self.opt.param_groups[0]['lr'] = lr_nn self.opt.param_groups[1]['lr'] = lr_adj self.opt.param_groups[1]['weight_decay'] = weight_decay_adj self.model.to(device) if self.device.startswith('cuda') and compile: self.original_model = self.model self.model = torch.compile(self.model) self.total_time_cost=0 self.losses_on_gene=None self.model_name='RegDiffusionME' if memory_efficient else 'RegDiffusion' # AMP setup self.use_amp = use_amp self.amp_device_type = 'cuda' if device.startswith('cuda') else 'cpu' def _prepare_dense_data(self, exp_array, cell_types, batch_size, train_split, train_split_seed): """Normalize dense numpy array and build TensorDataset (original path).""" cell_min = exp_array.min(axis=1, keepdims=True) cell_max = exp_array.max(axis=1, keepdims=True) cell_range = cell_max - cell_min n_zero_cells = (cell_range == 0).sum() if n_zero_cells > 0: warnings.warn( f'{n_zero_cells} cells are removed from analysis where no genes are expressed.') normalized_X = (exp_array - cell_min) / cell_range normalized_X_std = normalized_X.std(0) n_zero_genes = (normalized_X_std == 0).sum() if n_zero_genes > 0: raise ValueError( f'{n_zero_genes} genes have 0 variance. Please remove these genes from your data.') normalized_X = (normalized_X - normalized_X.mean(0)) / normalized_X_std random_state = np.random.RandomState(train_split_seed) train_val_split = random_state.rand(normalized_X.shape[0]) train_index = train_val_split <= train_split val_index = train_val_split > train_split x_tensor_train = torch.tensor( normalized_X[train_index, ], dtype=torch.float32) celltype_tensor_train = torch.tensor( cell_types[train_index], dtype=int) x_tensor_val = torch.tensor( normalized_X[val_index, ], dtype=torch.float32) celltype_tensor_val = torch.tensor(cell_types[val_index], dtype=int) self.train_dataset = TensorDataset( x_tensor_train, celltype_tensor_train) train_sampler = torch.utils.data.RandomSampler( self.train_dataset, replacement=True, num_samples=batch_size) self.train_dataloader = DataLoader( self.train_dataset, sampler=train_sampler, batch_size=batch_size, drop_last=True) self.val_dataset = TensorDataset( x_tensor_val, celltype_tensor_val) self.val_dataloader = DataLoader( self.val_dataset, shuffle=False, batch_size=batch_size, drop_last=False) def _prepare_sparse_data(self, exp_array, cell_types, batch_size, train_split, train_split_seed): """Compute normalization stats from sparse matrix in chunks, then build a _SparseExpressionDataset that normalizes on-the-fly per sample. This never materializes the full dense normalized matrix, so memory usage stays proportional to (sparse matrix size + one batch of dense rows) rather than (n_cell * n_gene * 4 bytes). """ n_cell, n_gene = exp_array.shape exp_array = sp.csr_matrix(exp_array) # ensure CSR for fast row slicing # --- Per-cell min/max in chunks --- chunk_size = min(10000, n_cell) cell_min = np.empty(n_cell, dtype=np.float64) cell_max = np.empty(n_cell, dtype=np.float64) for start in range(0, n_cell, chunk_size): end = min(start + chunk_size, n_cell) chunk = exp_array[start:end] cell_min[start:end] = np.asarray( chunk.min(axis=1).todense()).ravel() cell_max[start:end] = np.asarray( chunk.max(axis=1).todense()).ravel() cell_range = cell_max - cell_min valid_mask = cell_range > 0 n_zero_cells = (~valid_mask).sum() if n_zero_cells > 0: warnings.warn( f'{n_zero_cells} cells are removed from analysis where no ' 'genes are expressed.') # --- Per-gene mean/std of min-max-normalized data (chunked) --- n_valid = valid_mask.sum() gene_sum = np.zeros(n_gene, dtype=np.float64) gene_sum_sq = np.zeros(n_gene, dtype=np.float64) for start in range(0, n_cell, chunk_size): end = min(start + chunk_size, n_cell) chunk_valid = valid_mask[start:end] if not chunk_valid.any(): continue # Only materialize valid rows in this chunk chunk_dense = exp_array[start:end][chunk_valid].toarray().astype( np.float64) c_min = cell_min[start:end][chunk_valid][:, None] c_range = cell_range[start:end][chunk_valid][:, None] chunk_norm = (chunk_dense - c_min) / c_range gene_sum += chunk_norm.sum(axis=0) gene_sum_sq += (chunk_norm ** 2).sum(axis=0) gene_mean = (gene_sum / n_valid).astype(np.float32) gene_var = (gene_sum_sq / n_valid) - gene_mean.astype(np.float64) ** 2 gene_std = np.sqrt(np.maximum(gene_var, 0)).astype(np.float32) n_zero_genes = (gene_std == 0).sum() if n_zero_genes > 0: raise ValueError( f'{n_zero_genes} genes have 0 variance. Please remove these ' 'genes from your data.') # --- Train/validation split (only among valid cells) --- valid_indices = np.where(valid_mask)[0] random_state = np.random.RandomState(train_split_seed) split_vals = random_state.rand(len(valid_indices)) train_indices = valid_indices[split_vals <= train_split] val_indices = valid_indices[split_vals > train_split] cell_min_f32 = cell_min.astype(np.float32) cell_range_f32 = cell_range.astype(np.float32) # --- Build datasets --- self.train_dataset = _SparseExpressionDataset( exp_array, cell_types, cell_min_f32, cell_range_f32, gene_mean, gene_std, train_indices) train_sampler = torch.utils.data.RandomSampler( self.train_dataset, replacement=True, num_samples=batch_size) self.train_dataloader = DataLoader( self.train_dataset, sampler=train_sampler, batch_size=batch_size, drop_last=True) self.val_dataset = _SparseExpressionDataset( exp_array, cell_types, cell_min_f32, cell_range_f32, gene_mean, gene_std, val_indices) self.val_dataloader = DataLoader( self.val_dataset, shuffle=False, batch_size=batch_size, drop_last=False)
[docs] @torch.no_grad() def forward_pass(self, x_0, t): """ Forward diffusion process Args: x_0 (torch.FloatTensor): Torch tensor for expression data. Rows are cells and columns are genes t (torch.LongTensor): Torch tensor for diffusion time steps. """ noise = torch.randn_like(x_0) mean_coef = self.mean_schedule.gather(dim=-1, index=t) std_coef = self.std_schedule.gather(dim=-1, index=t) x_t = mean_coef.unsqueeze(-1) * x_0 + std_coef.unsqueeze(-1) * noise return x_t, noise
def train(self, n_steps=None): if self.hp['gradient_accumulation']: self._train_with_gradient_accumulation(n_steps=None) else: self._train_normal(n_steps=None) def _train_normal(self, n_steps=None): """ Train the initialized model for a number of steps. Args: n_steps (int): Number of steps to train. If not provided, it will train the model by the n_steps sepcified in class initialization. Please read our paper to see how to identify the converge point. """ start_time = datetime.now() eval_steps = self.hp['eval_on_n_steps'] if n_steps is None: n_steps = self.hp['n_steps'] sampled_adj = self.model.get_sampled_adj_() with tqdm(range(n_steps)) as pbar: for epoch in pbar: epoch_loss = [] for step, batch in enumerate(self.train_dataloader): x_0, ct = batch x_0 = x_0.to(self.device) ct = ct.to(self.device) self.opt.zero_grad() t = torch.randint( 0, self.hp['T'], (x_0.shape[0],), device=self.device ).long() x_noisy, noise = self.forward_pass(x_0, t) with torch.autocast( device_type=self.amp_device_type, dtype=torch.bfloat16, enabled=self.use_amp ): z = self.model(x_noisy, t, ct) loss_ = F.mse_loss(noise, z, reduction='none') loss = loss_.mean() if hasattr(self.model, 'get_sampled_sparse_loss'): loss_sparse = self.model.get_sampled_sparse_loss() * self.hp['sparse_loss_coef'] else: adj_m = self.model.get_adj_() loss_sparse = adj_m.mean() * self.hp['sparse_loss_coef'] if epoch > 10: loss = loss + loss_sparse loss.backward() self.opt.step() epoch_loss.append(loss.item()) train_loss = np.mean(epoch_loss) sampled_adj_new = self.model.get_sampled_adj_() adj_diff = ( sampled_adj_new - sampled_adj ).mean().item()*(self.n_gene-1) sampled_adj = sampled_adj_new pbar.set_description( f'Training loss: {train_loss:.3f}, Change on Adj: {adj_diff:.3f}') epoch_log = {'train_loss': train_loss, 'adj_change': adj_diff} if epoch % eval_steps == eval_steps - 1: if self.evaluator is not None: eval_result = self.evaluator.evaluate( self.model.get_adj() ) for k in eval_result.keys(): epoch_log[k] = eval_result[k] if self.hp['train_split'] < 1: with torch.no_grad(): val_epoch_loss = [] for step, batch in enumerate(self.val_dataloader): x_0, ct = batch x_0 = x_0.to(self.device) ct = ct.to(self.device) t = torch.randint( 0, self.hp['T'], (x_0.shape[0],), device=self.device).long() x_noisy, noise = self.forward_pass(x_0, t) with torch.autocast( device_type=self.amp_device_type, dtype=torch.bfloat16, enabled=self.use_amp ): z = self.model(x_noisy, t, ct) step_val_loss = F.mse_loss( noise, z, reduction='mean').item() val_epoch_loss.append(step_val_loss) epoch_log['val_loss'] = np.mean(val_epoch_loss) self.logger.log(epoch_log) self.losses_on_gene = loss_.detach().mean(0).cpu().numpy() self.total_time_cost += int( (datetime.now() - start_time).total_seconds()) return None def _train_with_gradient_accumulation(self, n_steps=None): """ Train the initialized model using gradient accumulation. For each batch, the gradient is computed for each sample individually and accumulated before the optimizer step. This reduces memory but increases training time. Args: n_steps (int): Number of steps to train. If not provided, it will train the model by the n_steps specified in class initialization. """ start_time = datetime.now() eval_steps = self.hp['eval_on_n_steps'] if n_steps is None: n_steps = self.hp['n_steps'] batch_size = self.hp['batch_size'] sampled_adj = self.model.get_sampled_adj_() # This will hold the per-gene losses from the final step final_losses_on_gene = torch.zeros(self.n_gene, device=self.device) with tqdm(range(n_steps)) as pbar: for step_idx in pbar: # Fetch a single batch for the current step x_0_batch, ct_batch = next(iter(self.train_dataloader)) x_0_batch = x_0_batch.to(self.device) ct_batch = ct_batch.to(self.device) # Zero gradients before starting the accumulation for the batch self.opt.zero_grad() batch_mse_loss = 0.0 accumulated_loss_on_gene = torch.zeros(self.n_gene, device=self.device) # Loop over each sample in the batch to accumulate gradients for i in range(batch_size): x_0 = x_0_batch[i:i+1] ct = ct_batch[i:i+1] # Generate timestep for the single sample t = torch.randint(0, self.hp['T'], (1,), device=self.device).long() x_noisy, noise = self.forward_pass(x_0, t) with torch.autocast( device_type=self.amp_device_type, dtype=torch.bfloat16, enabled=self.use_amp ): z = self.model(x_noisy, t, ct) # Calculate MSE loss, get per-gene loss with reduction='none' loss_ = F.mse_loss(noise, z, reduction='none') # Accumulate per-gene loss values for the last step with torch.no_grad(): accumulated_loss_on_gene += loss_.float().squeeze() # Calculate mean loss for the sample loss_sample = loss_.mean() batch_mse_loss += loss_sample.item() # Scale loss for gradient accumulation scaled_loss = loss_sample / batch_size scaled_loss.backward() # Handle sparse loss once per effective batch with torch.autocast( device_type=self.amp_device_type, dtype=torch.bfloat16, enabled=self.use_amp ): if hasattr(self.model, 'get_sampled_sparse_loss'): loss_sparse = self.model.get_sampled_sparse_loss() * self.hp['sparse_loss_coef'] else: adj_m = self.model.get_adj_() loss_sparse = adj_m.mean() * self.hp['sparse_loss_coef'] if step_idx > 10: loss_sparse.backward() # Perform the optimizer step after accumulating all gradients self.opt.step() # Logging and evaluation logic train_loss = (batch_mse_loss / batch_size) if step_idx > 10: train_loss += loss_sparse.item() sampled_adj_new = self.model.get_sampled_adj_() adj_diff = (sampled_adj_new - sampled_adj).mean().item() * (self.n_gene - 1) sampled_adj = sampled_adj_new pbar.set_description( f'Training loss: {train_loss:.3f}, Change on Adj: {adj_diff:.3f}') epoch_log = {'train_loss': train_loss, 'adj_change': adj_diff} if step_idx % eval_steps == eval_steps - 1: if self.evaluator is not None: eval_result = self.evaluator.evaluate(self.model.get_adj()) for k in eval_result.keys(): epoch_log[k] = eval_result[k] if self.hp['train_split'] < 1: with torch.no_grad(): val_epoch_loss = [] for step, batch in enumerate(self.val_dataloader): x_0, ct = batch x_0 = x_0.to(self.device) ct = ct.to(self.device) t = torch.randint( 0, self.hp['T'], (x_0.shape[0],), device=self.device).long() x_noisy, noise = self.forward_pass(x_0, t) with torch.autocast( device_type=self.amp_device_type, dtype=torch.bfloat16, enabled=self.use_amp ): z = self.model(x_noisy, t, ct) step_val_loss = F.mse_loss( noise, z, reduction='mean').item() val_epoch_loss.append(step_val_loss) epoch_log['val_loss'] = np.mean(val_epoch_loss) self.logger.log(epoch_log) final_losses_on_gene = accumulated_loss_on_gene # Save from the last step # Set the final per-gene losses, averaged over the batch self.losses_on_gene = (final_losses_on_gene / batch_size).detach().cpu().numpy() self.total_time_cost += int((datetime.now() - start_time).total_seconds()) return None
[docs] def training_curves(self): """ Plot out the training curves on `train_loss` and `adj_change`. Check out our paper for how to use `adj_change` to identify the convergence point. """ log_df = self.logger.to_df() if 'train_loss' in log_df: figure, axes = plt.subplots(1, 2, figsize=(8, 3)) axes[0].plot(log_df['train_loss']) axes[0].set_xlabel('Steps') axes[0].set_ylabel('Training Loss') axes[1].plot(log_df['adj_change'][1:]) axes[1].set_xlabel('Steps') axes[1].set_ylabel('Amount of Change in Adj. Matrix') plt.show() else: print( 'Training log and Adj Change are not available. Train your ', 'model using the .train() method.')
[docs] def get_grn(self, gene_names, tf_names=None, top_gene_percentile=None): """ Obtain a GRN object. You need to provide the genes names. Args: gene_names (np.ndarray): An array of names of all genes. The order of genes should be the same as the order used in your expression table. tf_names (np.ndarray):An array of names of all transcriptional factors. The order of genes should be the same as the order used in your expression table. top_gene_percentile (int): If provided, we will set the value on weak links to be zero. It is useful if you want to save the regulatory relationship in a GRN object as a sparse matrix. """ adj = self.model.get_adj() return GRN(adj, gene_names, tf_names, top_gene_percentile)
[docs] def get_adj(self): """ Obtain the adjacency matrix. The values in this adjacency matix has been scaled using regulatory norm. You may expect strong links to go beyond 5 or 10 in most cases. """ return self.model.get_adj()