Source code for regdiffusion.models.regdiffusion

import torch
import numpy as np
from torch import nn
import torch.nn.functional as F
from torch.distributions import Normal
import math
from typing import List


torch.set_float32_matmul_precision('high')

class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim, theta=10000):
        super().__init__()
        self.dim = dim
        self.theta = theta

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(self.theta) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(
            half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

class Block(nn.Module):
    def __init__(self, in_dim, out_dim, time_dim, celltype_dim):
        super().__init__()
        self.time_mlp = nn.Linear(time_dim, out_dim)
        self.celltype_mlp = nn.Linear(celltype_dim, out_dim)
        self.l1 = nn.Linear(in_dim, out_dim)
        self.l2 = nn.Linear(out_dim, out_dim)
        self.do1 = nn.Dropout(0.1)
        self.act = nn.Tanh()
        
    def forward(self, x, t, ct):
        h = self.do1(self.act(self.l1(x)))
        time_emb = self.act(self.time_mlp(t)).unsqueeze(1)
        celltype_emb = self.act(self.celltype_mlp(ct)).unsqueeze(1)
        h = h + time_emb + celltype_emb
        h = self.act(self.l2(h))
        return h
    
class GeneEmbeddings(nn.Module):
    def __init__(self, n_gene, gene_dim):
        super().__init__()
        gene_emb = torch.randn(n_gene, gene_dim-1)
        self.gene_emb = nn.Parameter(gene_emb, requires_grad=True)

    def forward(self, x):
        n_cell = x.shape[0]
        batch_gene_emb = self.gene_emb.unsqueeze(0).repeat(n_cell, 1, 1)
        batch_gene_emb = torch.concat([x.unsqueeze(-1), batch_gene_emb], dim=-1)
        return batch_gene_emb
    
[docs] class RegDiffusion(nn.Module): """ A RegDiffusion model. For architecture details, please refer to our paper. From noise to knowledge: probabilistic diffusion-based neural inference Args: n_genes (int): Number of Genes time_dim (int): Dimension of time step embedding n_celltype (int): Number of expected cell types. If it is not provided, there would be no celltype embedding. Default is None. celltype_dim (int): Dimension of cell types hidden_dims (list[int]): List of integer for the dimensions of the hidden layers. The first hidden dimension will be used as the size for gene embedding. adj_dropout (float): A single number between 0 and 1 specifying the percentage of values in the adjacency matrix that are dropped during training. init_coef (int): Coefficient to multiply with gene regulation norm (1/(n_gene - 1)) to initialize the adjacency matrix. """ def __init__( self, n_gene, time_dim, n_celltype=None, celltype_dim=4, hidden_dims=[16, 16, 16], adj_dropout=0.3, init_coef = 5 ): super(RegDiffusion, self).__init__() self.n_gene = n_gene self.gene_dim = hidden_dims[0] self.adj_dropout=adj_dropout self.gene_reg_norm = 1/(n_gene-1) adj_A = torch.ones(n_gene, n_gene) * self.gene_reg_norm * init_coef self.adj_A = nn.Parameter(adj_A, requires_grad =True, ) self.sampled_adj_row_nonparam = nn.Parameter( torch.randint(0, n_gene, size=(10000,)), requires_grad=False) self.sampled_adj_col_nonparam = nn.Parameter( torch.randint(0, n_gene, size=(10000,)), requires_grad=False) self.time_mlp = nn.Sequential( SinusoidalPositionEmbeddings(time_dim), nn.Linear(time_dim, time_dim), nn.Tanh() ) self.gene_emb = nn.Sequential( GeneEmbeddings(n_gene, self.gene_dim), nn.Tanh() ) if n_celltype is not None: self.celltype_emb = nn.Embedding(n_celltype, celltype_dim) self.blocks = nn.ModuleList([ Block( hidden_dims[i], hidden_dims[i+1]-1, time_dim, celltype_dim ) for i in range(len(hidden_dims) - 1) ]) self.final = nn.Linear(hidden_dims[-1]-1, 1) self.zeros_nonparam = nn.Parameter( torch.zeros(n_gene, n_gene), requires_grad=False) self.eye_nonparam = nn.Parameter( torch.eye(n_gene), requires_grad=False) self.mask_nonparam = nn.Parameter( 1 - torch.eye(n_gene), requires_grad=False) def soft_thresholding(self, x, tau): return torch.sign(x) * torch.max( self.zeros_nonparam, torch.abs(x) - tau) def I_minus_A(self): mask = self.mask_nonparam if self.train: A_dropout = (torch.rand_like(self.adj_A)>self.adj_dropout).float() A_dropout /= (1-self.adj_dropout) mask = mask * A_dropout clean_A = self.soft_thresholding(self.adj_A, self.gene_reg_norm/2)*mask return self.eye_nonparam - clean_A def get_adj_(self): return self.soft_thresholding( self.adj_A, self.gene_reg_norm/2) * self.mask_nonparam def get_adj(self): adj = self.get_adj_().detach().cpu().numpy() / self.gene_reg_norm return adj.astype(np.float16) @torch.no_grad() def get_sampled_adj_(self): return self.get_adj_()[ self.sampled_adj_row_nonparam, self.sampled_adj_col_nonparam ].detach() def get_gene_emb(self): return self.gene_emb[0].gene_emb.data.cpu().detach().numpy()
[docs] def forward(self, x, t, ct): h_time = self.time_mlp(t) h_celltype = self.celltype_emb(ct) original = x.unsqueeze(-1) h_x = self.gene_emb(x) for i, block in enumerate(self.blocks): if i != 0: h_x = torch.concat([h_x, original], dim=-1) h_x = block(h_x, h_time, h_celltype) I_minus_A = self.I_minus_A() hz = torch.einsum('ogd,gh->ohd', h_x, I_minus_A) z = self.final(hz) return z.squeeze(-1)