import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data.dataset import TensorDataset
from .models import RegDiffusion
from tqdm import tqdm
from .logger import LightLogger
from datetime import datetime
from .grn import GRN
from .evaluator import GRNEvaluator
from .logger import LightLogger
import matplotlib.pyplot as plt
import warnings
def linear_beta_schedule(timesteps, start_noise, end_noise):
scale = 1000 / timesteps
beta_start = scale * start_noise
beta_end = scale * end_noise
return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float)
def power_beta_schedule(timesteps, start_noise, end_noise, power=2):
linspace = torch.linspace(0, 1, timesteps, dtype = torch.float)
poweredspace = linspace ** power
scale = 1000 / timesteps
beta_start = scale * start_noise
beta_end = scale * end_noise
return beta_start + (beta_end - beta_start) * poweredspace
[docs]
class RegDiffusionTrainer:
"""
Initialize and Train a RegDiffusion model.
For architecture and training details, please refer to our paper.
> From noise to knowledge: probabilistic diffusion-based neural inference
You can access the model through `RegDiffusionTrainer.model`.
Args:
exp_array (np.ndarray): 2D numpy array. If used on single-cell RNAseq,
the rows are cells and the columns are genes. Data should be log
transformed. You may also want to remove all non expressed genes.
cell_types (np.ndarray): (Optional) 1D integer array for cell type. If
you have labels in your cell type, you need to convert them to
interge. Default is None.
T (int): Total number of diffusion steps. Default: 5,000
start_noise (float): Minimal noise level (beta) to be added. Default:
0.0001
end_noise (float): Maximal noise level (beta) to be added. Default:
0.02
time_dim (int): Dimension size for the time embedding. Default: 64.
celltype_dim (int): Dimension size for the cell type embedding.
Default: 4.
hidden_dim (list): Dimension sizes for the feature learning layers. We
use the size of the first layer as the dimension for gene embeddings
as well. Default: [16, 16, 16].
init_coef (int): A coefficent to control the value to initialize the
adjacency matrix. Here we define regulatory norm as 1 over (number
of genes - 1). The value which we use to initialize the model is
`init_coef` times of the regulatory norm. Default: 5.
lr_nn (float): Learning rate for the rest of the neural networks except
the adjacency matrix. Default: 0.001
lr_adj (float): Learning rate for the adjacency matrix. By default, it
equals to 0.02 * gene regulatory norm, which equals 1/(n_gene-1).
weight_decay_nn (float): L2 regularization coef on the rest of the
neural networks. Default: 0.1.
weight_decay_adj (float): L2 regularization coef on the adj matrix.
Default: 0.01.
sparse_loss_coef (float): L1 regularization coef on the adj matrix.
Default: 0.25.
adj_dropout (float): Probability of an edge to be zeroed. Default: 0.3.
batch_size (int): Batch size for training. Default: 128.
n_steps (int): Total number of training iterations. Default: 1000.
train_split (float): Train partition. Default: 1.0.
train_split_seed (int): Random seed for train/val partition.
Default: 123
device (str or torch.device): Device where the model is running. For
example, "cpu", "cuda", "cuda:1", and etc. You are not recommended
to run this model on Apple's MPS chips. Default is "cuda" but if
you only has CPU, it will switch back to CPU.
compile (boolean): Whether to compile the model before training.
Compile the model is a good idea on large dataset and ofter improves
inference speed when it works. For smaller dataset, eager execution
if often good enough.
evaluator (GRNEvaluator): (Optional) A defined GRNEvaluator if ground
truth data is available. Evaluation will be done every 100 steps by
default but you can change this setting through the eval_on_n_steps
option. Default is None
eval_on_n_steps (int): If an evaluator is provided, the trainer will
run evaluation every `eval_on_n_steps` steps. Default: 100.
logger (LightLogger): (Optional) A LightLogger to log training process.
The only situation when you need to provide this is when you want
to save logs from different trainers into the same logger. Default
is None.
"""
def __init__(
self, exp_array, cell_types=None,
T=5000, start_noise=0.0001, end_noise=0.02,
time_dim=64, celltype_dim=4, hidden_dims=[16, 16, 16],
init_coef = 5,
lr_nn=1e-3, lr_adj=None,
weight_decay_nn=0.1, weight_decay_adj = 0.01,
sparse_loss_coef=0.25, adj_dropout=0.30,
batch_size=128, n_steps=1000,
train_split=1.0, train_split_seed=123,
device='cuda', compile=False,
evaluator=None, eval_on_n_steps=100, logger=None
):
hp = locals()
del hp['exp_array']
del hp['cell_types']
del hp['logger']
self.hp = hp
if device == 'mps':
raise Exception("We noticed unreliable training behavior on",
"Apple's silicon. Consider using other devices.")
elif device.startswith('cuda'):
if not torch.cuda.is_available():
print(
"You specified cuda as your computing device but apprently",
"it's not available. Setting device to cpu for now. ")
device = 'cpu'
self.device = device
self.hp['device'] = device
# Logger ---------------------------------------------------------------
if logger is None:
self.logger = LightLogger()
self.note_id = self.logger.start()
# Define diffusion schedule
self.betas = linear_beta_schedule(T, start_noise, end_noise)
self.alphas = 1. - self.betas
alpha_bars = torch.cumprod(self.alphas, axis=0)
self.mean_schedule = torch.sqrt(alpha_bars).to(device)
self.std_schedule = torch.sqrt(1. - alpha_bars).to(device)
# Prepare Data ---------------------------------------------------------
if (exp_array.sum(0) == 0).sum() > 0:
warnings.warn(
"Some columns in the exp_array contains all zero values, "
"which often causes trouble in inference. Please consider "
"removing these columns before continuing. "
)
if cell_types is None:
cell_types = np.zeros(exp_array.shape[0], dtype=int)
self.n_celltype = len(np.unique(cell_types))
n_cell, n_gene = exp_array.shape
self.n_cell = n_cell
self.n_gene = n_gene
self.evaluator = evaluator
## Normalize data
cell_min = exp_array.min(axis=1, keepdims=True)
cell_max = exp_array.max(axis=1, keepdims=True)
normalized_X = (exp_array - cell_min) / (cell_max - cell_min)
normalized_X = (normalized_X - normalized_X.mean(0))/normalized_X.std(0)
## Train/validation split
random_state = np.random.RandomState(train_split_seed)
train_val_split = random_state.rand(normalized_X.shape[0])
train_index = train_val_split <= train_split
val_index = train_val_split > train_split
x_tensor_train = torch.tensor(
normalized_X[train_index, ], dtype=torch.float32)
celltype_tensor_train = torch.tensor(
cell_types[train_index], dtype=int)
x_tensor_val = torch.tensor(
normalized_X[val_index, ], dtype=torch.float32)
celltype_tensor_val = torch.tensor(cell_types[val_index],dtype=int)
## Setup dataset and dataloader
self.train_dataset = torch.utils.data.TensorDataset(
x_tensor_train, celltype_tensor_train
)
# Implement bootstrap for train sampler
train_sampler = torch.utils.data.RandomSampler(
self.train_dataset, replacement=True, num_samples=batch_size)
self.train_dataloader = torch.utils.data.DataLoader(
self.train_dataset,
sampler=train_sampler,
batch_size = batch_size,
drop_last=True)
self.val_dataset = torch.utils.data.TensorDataset(
x_tensor_val, celltype_tensor_val
)
self.val_dataloader = torch.utils.data.DataLoader(
self.val_dataset,
shuffle=False,
batch_size = batch_size,
drop_last=False)
# Setup Model ----------------------------------------------------------
gene_reg_norm = 1/(n_gene-1)
self.model = RegDiffusion(
n_gene=n_gene,
time_dim=time_dim,
n_celltype=self.n_celltype,
celltype_dim = celltype_dim,
hidden_dims=hidden_dims,
adj_dropout=adj_dropout,
init_coef=init_coef
)
# Setup optimizer ------------------------------------------------------
if lr_adj is None:
lr_adj = gene_reg_norm/50
self.hp['lr_adj'] = lr_adj
adj_params = []
non_adj_params = []
for name, param in self.model.named_parameters():
if name.endswith('adj_A'):
adj_params.append(param)
else:
if not name.endswith('_nonparam'):
non_adj_params.append(param)
self.opt = torch.optim.Adam(
[{'params': non_adj_params}, {'params': adj_params}],
lr=lr_nn,
weight_decay=weight_decay_nn, betas=[0.9, 0.99]
)
self.opt.param_groups[0]['lr'] = lr_nn
self.opt.param_groups[1]['lr'] = lr_adj
self.opt.param_groups[1]['weight_decay'] = weight_decay_adj
self.model.to(device)
if self.device.startswith('cuda') and compile:
self.original_model = self.model
self.model = torch.compile(self.model)
self.total_time_cost=0
self.losses_on_gene=None
self.model_name='RegDiffusion'
[docs]
@torch.no_grad()
def forward_pass(self, x_0, t):
"""
Forward diffusion process
Args:
x_0 (torch.FloatTensor): Torch tensor for expression data. Rows are
cells and columns are genes
t (torch.LongTensor): Torch tensor for diffusion time steps.
"""
noise = torch.randn_like(x_0)
mean_coef = self.mean_schedule.gather(dim=-1, index=t)
std_coef = self.std_schedule.gather(dim=-1, index=t)
x_t = mean_coef.unsqueeze(-1) * x_0 + std_coef.unsqueeze(-1) * noise
return x_t, noise
[docs]
def train(self, n_steps=None):
"""
Train the initialized model for a number of steps.
Args:
n_steps (int): Number of steps to train. If not provided, it will
train the model by the n_steps sepcified in class
initialization. Please read our paper to see how to identify
the converge point.
"""
start_time = datetime.now()
eval_steps = self.hp['eval_on_n_steps']
if n_steps is None:
n_steps = self.hp['n_steps']
sampled_adj = self.model.get_sampled_adj_()
with tqdm(range(n_steps)) as pbar:
for epoch in pbar:
epoch_loss = []
for step, batch in enumerate(self.train_dataloader):
x_0, ct = batch
x_0 = x_0.to(self.device)
ct = ct.to(self.device)
self.opt.zero_grad()
t = torch.randint(
0, self.hp['T'], (x_0.shape[0],),
device=self.device
).long()
x_noisy, noise = self.forward_pass(x_0, t)
z = self.model(x_noisy, t, ct)
loss_ = F.mse_loss(noise, z, reduction='none')
loss = loss_.mean()
adj_m = self.model.get_adj_()
loss_sparse = adj_m.mean() * self.hp['sparse_loss_coef']
if epoch > 10:
loss = loss + loss_sparse
loss.backward()
self.opt.step()
epoch_loss.append(loss.item())
train_loss = np.mean(epoch_loss)
sampled_adj_new = self.model.get_sampled_adj_()
adj_diff = (
sampled_adj_new - sampled_adj
).mean().item()*(self.n_gene-1)
sampled_adj = sampled_adj_new
pbar.set_description(
f'Training loss: {train_loss:.3f}, Change on Adj: {adj_diff:.3f}')
epoch_log = {'train_loss': train_loss, 'adj_change': adj_diff}
if epoch % eval_steps == eval_steps - 1:
if self.evaluator is not None:
eval_result = self.evaluator.evaluate(
self.model.get_adj()
)
for k in eval_result.keys():
epoch_log[k] = eval_result[k]
if self.hp['train_split'] < 1:
with torch.no_grad():
val_epoch_loss = []
for step, batch in enumerate(self.val_dataloader):
x_0, ct = batch
x_0 = x_0.to(self.device)
ct = ct.to(self.device)
t = torch.randint(
0, self.hp['T'], (x_0.shape[0],),
device=self.device).long()
x_noisy, noise = self.forward_pass(x_0, t)
z = self.model(x_noisy, t, ct)
step_val_loss = F.mse_loss(
noise, z, reduction='mean').item()
val_epoch_loss.append(step_val_loss)
epoch_log['val_loss'] = np.mean(val_epoch_loss)
self.logger.log(epoch_log)
self.losses_on_gene = loss_.detach().mean(0).cpu().numpy()
self.total_time_cost += int(
(datetime.now() - start_time).total_seconds())
return None
[docs]
def training_curves(self):
"""
Plot out the training curves on `train_loss` and `adj_change`. Check
out our paper for how to use `adj_change` to identify the convergence
point.
"""
log_df = self.logger.to_df()
if 'train_loss' in log_df:
figure, axes = plt.subplots(1, 2, figsize=(8, 3))
axes[0].plot(log_df['train_loss'])
axes[0].set_xlabel('Steps')
axes[0].set_ylabel('Training Loss')
axes[1].plot(log_df['adj_change'][1:])
axes[1].set_xlabel('Steps')
axes[1].set_ylabel('Amount of Change in Adj. Matrix')
plt.show()
else:
print(
'Training log and Adj Change are not available. Train your ',
'model using the .train() method.')
[docs]
def get_grn(self, gene_names, tf_names=None, top_gene_percentile=None):
"""
Obtain a GRN object. You need to provide the genes names.
Args:
gene_names (np.ndarray): An array of names of all genes. The order
of genes should be the same as the order used in your expression
table.
tf_names (np.ndarray):An array of names of all transcriptional
factors. The order of genes should be the same as the order
used in your expression table.
top_gene_percentile (int): If provided, we will set the value on
weak links to be zero. It is useful if you want to save the
regulatory relationship in a GRN object as a sparse matrix.
"""
adj = self.model.get_adj()
return GRN(adj, gene_names, tf_names, top_gene_percentile)
[docs]
def get_adj(self):
"""
Obtain the adjacency matrix. The values in this adjacency matix has
been scaled using regulatory norm. You may expect strong links to go
beyond 5 or 10 in most cases.
"""
return self.model.get_adj()