PyTorch Implementation of SAUCIE algorithm

In [1]:
%matplotlib inline
import os
import glob
import numpy as np
import torch
import pandas as pd
import loompy
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.colors
import seaborn as sns
sns.set(style='whitegrid')
import warnings
warnings.filterwarnings('ignore')
from sklearn.datasets import make_moons, make_s_curve
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
import umap

print('numpy:', np.__version__)
print('torch:', torch.__version__)
print('pandas:', pd.__version__)
print('loompy:', loompy.__version__)
print('matplotlib:', matplotlib.__version__)
print('seaborn:', sns.__version__)

import saucie_pytorch
numpy: 1.16.2
torch: 1.3.0
pandas: 0.24.1
loompy: 3.0.6
matplotlib: 3.0.3
seaborn: 0.9.0

トイデータで実験

トイデータを適当に作って実験する。
3つのバッチ的なデータセットを用意。それぞれのバッチは、丸型、線状、S字型、三日月型の4つのクラスターから構成される。
バッチ効果的な現象を模して、平行移動させたり縮尺を変えたりちょっとだけ回転させたりする。

In [2]:
#test data
# Generate random sample, two components
n_samples = 200
np.random.seed(0)

# Batch X1
X1_1 = 0.5 * np.random.randn(n_samples, 2) + np.array([-1, 1])
X1_2 = 0.4 * (np.dot(make_s_curve(n_samples, noise=0.1)[0][:, [0,2]], np.array([[1, -.5], [1, 1.5]])) + np.array([7, 0]))
x, t = make_moons(n_samples*2, noise=0.2)
X1_3 = 1. * (x[t == 0] + np.array([-4, -1]))
X1_4 = np.dot(np.random.randn(n_samples, 2), np.array([[.5, -.8], [.2, .5]])) + np.array([0, -2])
X1 = np.vstack((X1_1, X1_2, X1_3, X1_4))

# Batch X2
angle = -2.*np.pi/30.
C = np.array([[np.cos(angle), np.sin(angle)],
                     [-np.sin(angle), np.cos(angle)]])
X2 = X1 * np.array([0.5, 1.5])
X2 = np.dot(X2, C) + np.array([8, 4])

# Batch X3
angle = -2.*np.pi * 1. / 18.
C = np.array([[np.cos(angle), np.sin(angle)],
                     [-np.sin(angle), np.cos(angle)]])
X3 = X1 * np.array([1.5, 0.5])
X3 = np.dot(X3, C) + np.array([-8, 5])

mpl.rcParams['figure.dpi'] = 200
fig = plt.figure(figsize=(4,3))
ax = fig.add_subplot(111)
for i, c in enumerate(['r', 'g', 'b', 'c']):
    start = i * n_samples
    end = (i+1) * n_samples
    ax.scatter(X1[start:end, 0], X1[start:end, 1], c=c, s=6, alpha=.5, label=f'cluster-{i+1}')
    ax.scatter(X2[start:end, 0], X2[start:end, 1], c=c, s=6, alpha=.5)
    ax.scatter(X3[start:end, 0], X3[start:end, 1], c=c, s=6, alpha=.5)

coord_m = np.median(X1, axis=0)
ax.annotate(str('X1'), xy=(coord_m), fontsize=13, bbox={"facecolor":'white', "edgecolor":"k", "alpha":0.3})
coord_m = np.median(X2, axis=0)
ax.annotate(str('X2'), xy=(coord_m), fontsize=13, bbox={"facecolor":'white', "edgecolor":"k", "alpha":0.3})
coord_m = np.median(X3, axis=0)
ax.annotate(str('X3'), xy=(coord_m), fontsize=13, bbox={"facecolor":'white', "edgecolor":"k", "alpha":0.3})

ax.set_title('Toy dataset. 3 batches')
ax.legend(fontsize=6)
plt.show()

それぞれPandasのデータフレームを作って保存。

In [3]:
df_X1 = pd.DataFrame(X1, index=[f'Cell{i}' for i in range(4*n_samples)], columns=['Gene1', 'Gene2'])
df_X1.to_csv('./data/X1.csv')
df_X2 = pd.DataFrame(X2, index=[f'Cell{i}' for i in range(4*n_samples)], columns=['Gene1', 'Gene2'])
df_X2.to_csv('./data/X2.csv')
df_X3 = pd.DataFrame(X3, index=[f'Cell{i}' for i in range(4*n_samples)], columns=['Gene1', 'Gene2'])
df_X3.to_csv('./data/X3.csv')

display(df_X1.head())
Gene1 Gene2
Cell0 -0.117974 1.200079
Cell1 -0.510631 2.120447
Cell2 -0.066221 0.511361
Cell3 -0.524956 0.924321
Cell4 -1.051609 1.205299

SAUCIE実行。
学習させたい複数のバッチをそれぞれpandas.DataFrameとしてロードして、それらのリストをsaucie_pytorch.trainに渡す。
まずはバッチ補正モード。

In [4]:
CONFIG = { 
    'latent_size': 2, # embedding dimension.
    'lambda_b': 0.5, # regularization rate of MMD loss.
    'lambda_c': 0.1, # reg. rate of Entropy loss.
    'lambda_d': 0.2, # reg. rate of Intra-cluster distance loss.
    'binmin': 10, # minimum number of cells to be clustered.
    'max_clusters': 100, # max number of clusters.
    'merge_k_nearest': 3, # number of nearest clusteres to search in merging process.
    'layers': [512, 256, 128], # number of nodes in each layer of encoder and decoder.
    'learning_rate': 1e-3, # learning rate of optimizer.
    'use_batchnorm': False, # use batch normalization layer.
    'minibatch_size': 256,
    'max_iterations': 1000, # max iteration steps
    'log_interval': 100, # interval of steps to display loss information.
    'use_gpu': False, 
    'train_dir': './tmp', # dir to save the state of the model
    'data_dir': './data', # input files dir
    'out_dir': './results', # dir to output result files
    'seed':13 # seed of random number generators
}
cfg = CONFIG

device = torch.device("cpu")
torch.manual_seed(cfg['seed'])
np.random.seed(cfg['seed'])

DFs = []
names = []
for csv_file in glob.glob(os.path.join(cfg['data_dir'], 'X*.csv')):
    names.append(os.path.basename(csv_file).split('.')[0])
    DFs.append(pd.read_csv(csv_file, index_col=0))

# train batch correction model
saucie_pytorch.train(DFs, mode='BatchCorrection', cfg=cfg, device=device)

# Batch correction by using trained model
results_embedded, results_reconstructed = saucie_pytorch.output_activations(DFs, mode='BatchCorrection', cfg=cfg)

# write files
for (df, name) in zip(results_embedded, names):
    df.to_csv(os.path.join(cfg['out_dir'], f'{name}_embedding.csv'))
for (df, name) in zip(results_reconstructed, names):
    df.to_csv(os.path.join(cfg['out_dir'], f'{name}_reconstructed.csv'))
Training the BatchCorrection model for 3 datasets
	step:	0	train loss: 1.813693
	step:	100	train loss: 0.342179
	step:	200	train loss: 0.147149
	step:	300	train loss: 0.096193
	step:	400	train loss: 0.070833
	step:	500	train loss: 0.063686
	step:	600	train loss: 0.062334
	step:	700	train loss: 0.061811
	step:	800	train loss: 0.057311
	step:	900	train loss: 0.056425
	step:	1000	train loss: 0.054029
Reconstructing data by using trained model...

次にクラスタリングモードを実行。
パラメータはデータセットに合わせて適宜。

In [5]:
cfg['lambda_c'] = 0.7
cfg['lambda_d'] = 0.5
cfg['merge_k_nearest'] = 5

# train clustering model
saucie_pytorch.train(results_reconstructed, mode='Clustering', cfg=cfg, device=device)

# get clusters and output
results_clusters = saucie_pytorch.output_activations(results_reconstructed, mode='Clustering', cfg=cfg)
for (clusters, name) in zip(results_clusters, names):
    np.savetxt(os.path.join(cfg['out_dir'], f'{name}_clusters.txt'), clusters.astype(int), fmt='%d')
Training the Clustering model for 3 datasets
	step:	0	train loss: 13.944221
	step:	100	train loss: 3.806284
	step:	200	train loss: 3.002613
	step:	300	train loss: 2.600712
	step:	400	train loss: 2.392790
	step:	500	train loss: 2.315872
	step:	600	train loss: 2.198824
	step:	700	train loss: 2.077101
	step:	800	train loss: 1.971114
	step:	900	train loss: 1.866989
	step:	1000	train loss: 1.818996
Reconstructing data by using trained model...
9 clusters detected. Merging clusters...
	9 cells (0.38 % of total) are not clustered.
Merging done. Total 3 clusters.

バッチ補正の結果を見てみる。
バッチ補正モードで実行すると、*_embedding.csv というファイルが出力されている。
autoencoderの真ん中、embedding layerのactivationの値がそのまま記載されていて、これを埋め込みとして扱う。

In [6]:
PL_Colors = [matplotlib.colors.rgb2hex(x) for x in cm.tab10.colors]

# embedding results
cluster_colors = ['r']*n_samples +\
                         ['g']*n_samples +\
                         ['b']*n_samples +\
                         ['c']*n_samples

emb1 = pd.read_csv('./results/X1_embedding.csv', index_col=0)
emb2 = pd.read_csv('./results/X2_embedding.csv', index_col=0)
emb3 = pd.read_csv('./results/X3_embedding.csv', index_col=0)

fig = plt.figure(figsize=(4,3))
ax1 = fig.add_subplot(121)
ax1.scatter(emb1['Dim1'], emb1['Dim2'], c=PL_Colors[0], s=6, alpha=.3, label='X1')
ax1.scatter(emb2['Dim1'], emb2['Dim2'], c=PL_Colors[1], s=6, alpha=.3, label='X2')
ax1.scatter(emb3['Dim1'], emb3['Dim2'], c=PL_Colors[2], s=6, alpha=.3, label='X3')
ax1.legend(fontsize=6)
ax1.set_title('Embedding layer\n(batch effect removed?)', fontsize=6)
ax1.tick_params(labelsize=6)
ax2 = fig.add_subplot(122)
ax2.scatter(emb1['Dim1'], emb1['Dim2'], c=cluster_colors, s=6, alpha=.3)
ax2.scatter(emb2['Dim1'], emb2['Dim2'], c=cluster_colors, s=6, alpha=.3)
ax2.scatter(emb3['Dim1'], emb3['Dim2'], c=cluster_colors, s=6, alpha=.3)
ax2.set_title('Clusters in batches\nare aligned?', fontsize=6)
ax2.tick_params(labelsize=6)
plt.figure()
Out[6]:
<Figure size 1200x800 with 0 Axes>
<Figure size 1200x800 with 0 Axes>

低次元空間では、それぞれのバッチのクラスタがちゃんとアラインされている。

In [7]:
# reconstructing results
recon1 = pd.read_csv('./results/X1_reconstructed.csv', index_col=0)
recon2 = pd.read_csv('./results/X2_reconstructed.csv', index_col=0)
recon3 = pd.read_csv('./results/X3_reconstructed.csv', index_col=0)

fig = plt.figure(figsize=(12,4))
ax1 = fig.add_subplot(141)
ax1.scatter(recon1['Gene1'], recon1['Gene2'], c=PL_Colors[0], s=6, alpha=.3, label='X1')
ax1.scatter(recon2['Gene1'], recon2['Gene2'], c=PL_Colors[1], s=6, alpha=.3, label='X2')
ax1.scatter(recon3['Gene1'], recon3['Gene2'], c=PL_Colors[2], s=6, alpha=.3, label='X3')
ax1.legend(fontsize=6)
ax1.set_title('Reconstructing layer output', fontsize=8)
yb, yu = ax1.get_ylim()
xb, xu = ax1.get_xlim()
ax2 = fig.add_subplot(142)
ax2.scatter(recon1['Gene1'], recon1['Gene2'], c=cluster_colors, s=6, alpha=.3)
ax2.set_title('X1 (with true cluster labels)', fontsize=10)
ax2.set_xlim(xb, xu)
ax2.set_ylim(yb, yu)
ax3 = fig.add_subplot(143)
ax3.scatter(recon2['Gene1'], recon2['Gene2'], c=cluster_colors, s=6, alpha=.3)
ax3.set_title('X2 (with true cluster labels)', fontsize=10)
ax3.set_xlim(xb, xu)
ax3.set_ylim(yb, yu)
ax4 = fig.add_subplot(144)
ax4.scatter(recon3['Gene1'], recon3['Gene2'], c=cluster_colors, s=6, alpha=.3)
ax4.set_title('X3 (with true cluster labels)', fontsize=10)
ax4.set_xlim(xb, xu)
ax4.set_ylim(yb, yu)
plt.show()

また、近い座標にアラインされてはいるものの、それぞれのバッチ固有の構造(細長い・太いとか傾きとか)は維持されている。

In [8]:
CT_Colors = [matplotlib.colors.rgb2hex(x) for x in cm.tab20.colors]

# clusters
cluster1 = np.loadtxt('./results/X1_clusters.txt').astype(int)
cluster2 = np.loadtxt('./results/X2_clusters.txt').astype(int)
cluster3 = np.loadtxt('./results/X3_clusters.txt').astype(int)
print(np.unique(np.hstack((cluster1, cluster2, cluster3)), return_counts=True))
cluster_IDs = np.unique(np.hstack((cluster1, cluster2, cluster3)))
clID2color = {}
for i, cl_id in enumerate(cluster_IDs):
    clID2color[cl_id] = CT_Colors[i]
clID2color[-1] = '#0f0f0f'
x1_colors = [clID2color[x] for x in cluster1]
x2_colors = [clID2color[x] for x in cluster2]
x3_colors = [clID2color[x] for x in cluster3]

fig = plt.figure(figsize=(12,4))
ax1 = fig.add_subplot(141)
ax1.scatter(emb1['Dim1'], emb1['Dim2'], c=x1_colors, s=6, alpha=.3, label='X1')
ax1.scatter(emb2['Dim1'], emb2['Dim2'], c=x2_colors, s=6, alpha=.3, label='X2')
ax1.scatter(emb3['Dim1'], emb3['Dim2'], c=x3_colors, s=6, alpha=.3, label='X3')
ax1.set_title('Reconstructing layer output', fontsize=8)
ax2 = fig.add_subplot(142)
ax2.scatter(emb1['Dim1'], emb1['Dim2'], c=x1_colors, s=6, alpha=.3)
ax2.set_title('X1 (with predicted cluster labels)', fontsize=10)
ax2.set_xlim(ax1.get_xlim())
ax2.set_ylim(ax1.get_ylim())
ax3 = fig.add_subplot(143)
ax3.scatter(emb2['Dim1'], emb2['Dim2'], c=x2_colors, s=6, alpha=.3)
ax3.set_title('X2 (with predicted cluster labels)', fontsize=10)
ax3.set_xlim(ax1.get_xlim())
ax3.set_ylim(ax1.get_ylim())
ax4 = fig.add_subplot(144)
ax4.scatter(emb3['Dim1'], emb3['Dim2'], c=x3_colors, s=6, alpha=.3)
ax4.set_title('X3 (with predicted cluster labels)', fontsize=10)
ax4.set_xlim(ax1.get_xlim())
ax4.set_ylim(ax1.get_ylim())
plt.show()
(array([-1,  0,  1]), array([   9, 2261,  130]))

クラスタリングの結果は微妙。
あえてこのニューラルネットワークで同時にクラスタリングまで学習する必要はない気がする。
良い低次元表現が得られたらそこで別手法でクラスタリングしたほうがいいかも。

膵臓データで実験

Seurat panc8 dataset

データは SeuratData からとってきた。
"panc8" データセットをダウンロードして、"loom" file format に変換、保存した。

In [9]:
ds = loompy.connect('./data/panc8.loom')

以下のような感じで、smartseq, indropなど、様々なプラットフォームで解析された膵臓関連細胞の遺伝子発現テーブルとなっている。

In [10]:
ds
Out[10]:

34363 rows, 14890 columns, 2 layers
(showing up to 10x10)
./data/panc8.loom
name: ['2.0.1']
name: ['RNA']
name: ['(7072, 7072)']
name: ['20191015T084700.981081Z']
name: ['0.2.1.9000']

       CellIDD101_5D101_7D101_10D101_13D101_14D101_17D101_21D101_22D101_25D101_27...
       ClusterID1111111111...
       ClusterNamecelseqcelseqcelseqcelseqcelseqcelseqcelseqcelseqcelseqcelseq...
       assigned_clusterFALSEFALSEFALSEFALSEFALSEFALSEFALSEFALSEFALSEFALSE...
       celltypegammaacinaralphadeltabetaductalendothelialdeltaalphaductal...
       datasetcelseqcelseqcelseqcelseqcelseqcelseqcelseqcelseqcelseqcelseq...
       nCount_RNA4615.809629001.5626707.85648797.2245032.557613474.8669285.7518442.52110089.52615758.838...
       nFeature_RNA1986420924082964226439823575298531974678...
       orig_identD101D101D101D101D101D101D101D101D101D101...
       replicatecelseqcelseqcelseqcelseqcelseqcelseqcelseqcelseqcelseqcelseq...
       techcelseqcelseqcelseqcelseqcelseqcelseqcelseqcelseqcelseqcelseq...
GeneSelectedvst_meanvst_variablevst_variancevst_variance_expectedvst_variance_standardized           ...
A1BG-AS10.00.0791145501.44555373.3155250.43599543 0.00.00.00.00.00.00.00.00.00.0...
A1BG0.03.99280520370.688631000.52780.37049308 0.00.00.00.00.00.00.00.00.00.0...
A1CF0.09.49513503097.82623287.4090.94233066 0.03.0177166730.03.0177166730.00.00.00.00.00.0...
A2M-AS10.00.1400387107.48253257.51730870.9953739 0.00.00.00.00.00.00.00.00.00.0...
A2ML10.00.1146395601.40642365.6201580.25024626 0.00.00.00.00.00.00.00.00.00.0...
A2MP10.00.004097235400.0513644780.042176911.2178341 0.00.00.00.00.00.00.00.00.00.0...
A2M1.02.785155814073.886609.98424.258952 0.00.00.00.00.00.01.0019582260.00.00.0...
A4GALT0.00.3829184022.00370234.7942540.6323947 0.00.00.00.00.00.01.0019582260.00.00.0...
A4GNT0.00.002888634600.018329980.0238426460.76878965 0.00.00.00.00.00.00.00.00.00.0...
AAAS0.06.949800502327.6092121.66041.0970695 0.00.00.00.00.00.00.00.00.00.0...
.........................................................
In [11]:
np.unique(ds.ca['dataset'], return_counts=True)
Out[11]:
(array(['celseq', 'celseq2', 'fluidigmc1', 'indrop1', 'indrop2', 'indrop3',
        'indrop4', 'smartseq2'], dtype=object),
 array([1004, 2285,  638, 1937, 1724, 3605, 1303, 2394]))

データセットごとに分割、それぞれでノーマライズ。

In [12]:
Dataset = np.unique(ds.ca['dataset'])

DFs = {}
for d in Dataset:
    print('Dataset:', d)
    cell_ind = np.where(ds.ca['dataset'] == d)[0]
    df = pd.DataFrame(ds[:, cell_ind], index=ds.ra['Gene'], columns=ds.ca['CellID'][cell_ind])
    normalized = 10000 * df.values / df.values.sum(axis=0)
    lognormalized = np.log1p(normalized)
    df = pd.DataFrame(lognormalized, index=df.index, columns=df.columns)
    print('\t', df.shape)
    DFs[d] = df
Dataset: celseq
	 (34363, 1004)
Dataset: celseq2
	 (34363, 2285)
Dataset: fluidigmc1
	 (34363, 638)
Dataset: indrop1
	 (34363, 1937)
Dataset: indrop2
	 (34363, 1724)
Dataset: indrop3
	 (34363, 3605)
Dataset: indrop4
	 (34363, 1303)
Dataset: smartseq2
	 (34363, 2394)

外れ値的な細胞を雑にフィルタリングする。

In [13]:
fig = plt.figure(figsize=(12,12))
valid_cell_IDs = {}
filtered_DFs = {}
for i, d in enumerate(Dataset):
    df = DFs[d]
    tot_exprs = df.values.sum(axis=0)
    n_genes = (df.values > 0.0).astype(int).sum(axis=0)
    z = np.polyfit(np.log(n_genes), tot_exprs, 1)
    use_cells = np.logical_and((tot_exprs < (np.log(n_genes) * z[0] + z[1]) + 500), (tot_exprs > (np.log(n_genes) * z[0] + z[1]) - 500))
    ax = fig.add_subplot(3, 3, i+1)
    ax.scatter(np.log(n_genes)[~use_cells], tot_exprs[~use_cells], s=3, alpha=.3, c='gray')
    ax.scatter(np.log(n_genes)[use_cells], tot_exprs[use_cells], s=3, alpha=.3, c='blue')
    x_left, x_right = ax.get_xlim()
    x = np.linspace(x_left, x_right, 100)
    y = z[0] * x + z[1]
    ax.plot(x, y, c='red')    
    ax.set_title(d)
    print(f'Dataset {d}: total={len(df.columns)}, selected={use_cells.astype(int).sum()}')
    valid_cell_IDs[d] = list(df.columns[use_cells])
    filtered_DFs[d] = df.loc[:, df.columns[use_cells]]
plt.show()
Dataset celseq: total=1004, selected=938
Dataset celseq2: total=2285, selected=2174
Dataset fluidigmc1: total=638, selected=520
Dataset indrop1: total=1937, selected=1865
Dataset indrop2: total=1724, selected=1708
Dataset indrop3: total=3605, selected=3207
Dataset indrop4: total=1303, selected=1278
Dataset smartseq2: total=2394, selected=1730

細胞間で変動が激しい遺伝子をデータセットごとそれぞれで選んで統合。

In [14]:
n_highly_variable = 2000

HVGs = []
for d in Dataset:
    print('Dataset:', d)
    df = filtered_DFs[d]
    genes = df.index
    mean = np.expm1(df.values).mean(axis=1)
    mean[mean == 0] = 1e-23
    dispersion = np.expm1(df.values).var(axis=1, ddof=1) / mean
    dispersion[dispersion == 0] = np.nan
    HVGs += list(genes[np.argsort(dispersion)[::-1][:n_highly_variable]])
HVGs = list(set(HVGs))
print(len(HVGs))
Dataset: celseq
Dataset: celseq2
Dataset: fluidigmc1
Dataset: indrop1
Dataset: indrop2
Dataset: indrop3
Dataset: indrop4
Dataset: smartseq2
7622

可視化のためにいったんデータセットを統合する。

In [15]:
df = filtered_DFs[Dataset[0]].loc[HVGs, :]
for d in Dataset[1:]:
    df = df.join(filtered_DFs[d].loc[HVGs, :], how='outer')
df.shape
Out[15]:
(7622, 13420)

標準化、PCA、UMAP。

In [16]:
scaler = StandardScaler()
scaled_values = scaler.fit_transform(df.values.T)
scaled_values = np.clip(scaled_values, None, 10)
df_scaled = pd.DataFrame(scaled_values.T, index=df.index, columns=df.columns)

pca = PCA(n_components=200)
pca_coords = pca.fit_transform(df_scaled.values.T)
df_pc = pd.DataFrame(pca_coords, index=df.columns, columns=[f'PC{x+1}' for x in range(pca_coords.shape[1])])

umap_model = umap.UMAP(n_components=2, \
                       n_neighbors=15, min_dist=0.5, metric='cosine',  \
                       random_state=42, verbose=False)
umap_coords = umap_model.fit_transform(df_pc.values)
df_umap = pd.DataFrame(umap_coords, index=df.columns, columns=['x1', 'x2'])
In [17]:
mpl.rcParams['figure.dpi']= 200

fig = plt.figure(figsize=(12, 6))
ax1 = fig.add_subplot(1, 2, 1)
for i, d in enumerate(Dataset):
    cell_IDs = ds.ca['CellID'][ds.ca['dataset'] == d]
    ax1.scatter(df_umap.loc[cell_IDs, 'x1'], df_umap.loc[cell_IDs, 'x2'], \
            s=6, alpha=.5, c=PL_Colors[i], label=d)
ax1.set_title('Differences in datasets and replicates')
ax1.set_xlabel('umap-1'); ax1.set_ylabel('umap-2')
ax1.legend()
ax2 = fig.add_subplot(1, 2, 2)
for i, ct in enumerate(np.unique(ds.ca['celltype'])):
    cell_IDs = ds.ca['CellID'][ds.ca['celltype'] == ct]
    ax2.scatter(df_umap.loc[cell_IDs, 'x1'], df_umap.loc[cell_IDs, 'x2'], \
            s=6, alpha=.5, c=CT_Colors[i], label=ct)
ax2.set_title('Differences in cell types')
ax2.set_xlabel('umap-1'); ax2.set_ylabel('umap-2')
ax2.legend(fontsize=6)
plt.show()

実験プラットフォームによってだいぶ違う。

In [18]:
for dataset in np.unique(ds.ca['dataset']):
    print(dataset)
    cell_IDs = set(df.columns).intersection(set(ds.ca['CellID'][ds.ca['dataset'] == dataset]))
    
    df_dataset = df.transpose().loc[cell_IDs, :]
    print('\t', df_dataset.shape)
    df_dataset.to_csv(f'./data/panc8_{dataset}.csv')
celseq
	 (938, 7622)
celseq2
	 (2174, 7622)
fluidigmc1
	 (520, 7622)
indrop1
	 (1865, 7622)
indrop2
	 (1708, 7622)
indrop3
	 (3207, 7622)
indrop4
	 (1278, 7622)
smartseq2
	 (1730, 7622)

全部同時に補正するのは(可能だが)大変なので、とりあえず "smartseq2", "celseq2", "indrop4" の3バッチで補正してみる。

In [19]:
CONFIG = { 
    'latent_size': 2, # embedding dimension.
    'lambda_b': 0.1, # regularization rate of MMD loss.
    'lambda_c': 0.1, # reg. rate of Entropy loss.
    'lambda_d': 0.1, # reg. rate of Intra-cluster distance loss.
    'binmin': 10, # minimum number of cells to be clustered.
    'max_clusters': 100, # max number of clusters.
    'merge_k_nearest': 3, # number of nearest clusteres to search in merging process.
    'layers': [512, 256, 128], # number of nodes in each layer of encoder and decoder.
    'learning_rate': 1e-4, # learning rate of optimizer.
    'use_batchnorm': True, # use batch normalization layer.
    'minibatch_size': 256,
    'max_iterations': 3000, # max iteration steps
    'log_interval': 100, # interval of steps to display loss information.
    'use_gpu': False, 
    'train_dir': './tmp', # dir to save the state of the model
    'data_dir': './data', # input files dir
    'out_dir': './results', # dir to output result files
    'seed':13 # seed of random number generators
}

cfg = CONFIG

device = torch.device("cpu")
torch.manual_seed(cfg['seed'])
np.random.seed(cfg['seed'])

DFs = []
names = ['panc8_smartseq2', 'panc8_celseq2', 'panc8_indrop4']
DFs.append(pd.read_csv(f'./data/{names[0]}.csv', index_col=0))
DFs.append(pd.read_csv(f'./data/{names[1]}.csv', index_col=0))
DFs.append(pd.read_csv(f'./data/{names[2]}.csv', index_col=0))

# train batch correction model
saucie_pytorch.train(DFs, mode='BatchCorrection', cfg=cfg, device=device)

# Batch correction by using trained model
results_embedded, results_reconstructed = saucie_pytorch.output_activations(DFs, mode='BatchCorrection', cfg=cfg)

# write files
for (df, name) in zip(results_embedded, names):
    df.to_csv(os.path.join(cfg['out_dir'], f'{name}_embedding.csv'))
for (df, name) in zip(results_reconstructed, names):
    df.to_csv(os.path.join(cfg['out_dir'], f'{name}_reconstructed.csv'))
Training the BatchCorrection model for 3 datasets
	step:	0	train loss: 0.884945
	step:	100	train loss: 0.713481
	step:	200	train loss: 0.695819
	step:	300	train loss: 0.693376
	step:	400	train loss: 0.692774
	step:	500	train loss: 0.690956
	step:	600	train loss: 0.691674
	step:	700	train loss: 0.689970
	step:	800	train loss: 0.690027
	step:	900	train loss: 0.689275
	step:	1000	train loss: 0.688493
	step:	1100	train loss: 0.688804
	step:	1200	train loss: 0.687296
	step:	1300	train loss: 0.686949
	step:	1400	train loss: 0.687065
	step:	1500	train loss: 0.687520
	step:	1600	train loss: 0.686558
	step:	1700	train loss: 0.686354
	step:	1800	train loss: 0.685838
	step:	1900	train loss: 0.686098
	step:	2000	train loss: 0.686053
	step:	2100	train loss: 0.686449
	step:	2200	train loss: 0.686578
	step:	2300	train loss: 0.685014
	step:	2400	train loss: 0.685322
	step:	2500	train loss: 0.685045
	step:	2600	train loss: 0.684929
	step:	2700	train loss: 0.684702
	step:	2800	train loss: 0.685270
	step:	2900	train loss: 0.685913
	step:	3000	train loss: 0.684444
Reconstructing data by using trained model...
In [20]:
names = ['panc8_smartseq2', 'panc8_celseq2', 'panc8_indrop4']
emb1 = pd.read_csv(f'./results/{names[0]}_embedding.csv', index_col=0)
emb2 = pd.read_csv(f'./results/{names[1]}_embedding.csv', index_col=0)
emb3 = pd.read_csv(f'./results/{names[2]}_embedding.csv', index_col=0)
In [21]:
PL_Colors = [matplotlib.colors.rgb2hex(x) for x in cm.tab10.colors]
CT_Colors = [matplotlib.colors.rgb2hex(x) for x in cm.tab20.colors]
all_emb = pd.concat([emb1, emb2, emb3])

fig = plt.figure(figsize=(8,4))
ax1 = fig.add_subplot(1, 2, 1)
ax1.scatter(emb1['Dim1'], emb1['Dim2'], s=6, alpha=.3, c=PL_Colors[0], label=names[0])
ax1.scatter(emb2['Dim1'], emb2['Dim2'], s=6, alpha=.3, c=PL_Colors[1], label=names[1])
ax1.scatter(emb3['Dim1'], emb3['Dim2'], s=6, alpha=.3, c=PL_Colors[2], label=names[2])
ax1.legend()
ax1.tick_params(labelsize=6)

ax2 = fig.add_subplot(1, 2, 2)
for i, ct in enumerate(np.unique(ds.ca['celltype'])):
    cell_IDs = ds.ca['CellID'][ds.ca['celltype'] == ct]
    ax2.scatter(all_emb.loc[cell_IDs, 'Dim1'], all_emb.loc[cell_IDs, 'Dim2'], \
            s=6, alpha=.5, c=CT_Colors[i], label=ct)
ax2.legend(fontsize=6)
ax2.set_xlim(ax1.get_xlim())
ax2.set_ylim(ax1.get_ylim())
ax2.tick_params(labelsize=6)

plt.show()

fig = plt.figure(figsize=(8,4))
axb1 = fig.add_subplot(1, 3, 1)
for i, ct in enumerate(np.unique(ds.ca['celltype'])):
    cell_IDs = ds.ca['CellID'][ds.ca['celltype'] == ct]
    axb1.scatter(emb1.loc[cell_IDs, 'Dim1'], emb1.loc[cell_IDs, 'Dim2'], \
            s=6, alpha=.5, c=CT_Colors[i], label=ct)
axb1.tick_params(labelsize=6)
axb1.set_xlim(ax1.get_xlim())
axb1.set_ylim(ax1.get_ylim())
axb2 = fig.add_subplot(1, 3, 2)
for i, ct in enumerate(np.unique(ds.ca['celltype'])):
    cell_IDs = ds.ca['CellID'][ds.ca['celltype'] == ct]
    axb2.scatter(emb2.loc[cell_IDs, 'Dim1'], emb2.loc[cell_IDs, 'Dim2'], \
            s=6, alpha=.5, c=CT_Colors[i], label=ct)
axb2.tick_params(labelsize=6)
axb2.set_xlim(ax1.get_xlim())
axb2.set_ylim(ax1.get_ylim())
axb3 = fig.add_subplot(1, 3, 3)
for i, ct in enumerate(np.unique(ds.ca['celltype'])):
    cell_IDs = ds.ca['CellID'][ds.ca['celltype'] == ct]
    axb3.scatter(emb3.loc[cell_IDs, 'Dim1'], emb3.loc[cell_IDs, 'Dim2'], \
            s=6, alpha=.5, c=CT_Colors[i], label=ct)
axb3.tick_params(labelsize=6)
axb3.set_xlim(ax1.get_xlim())
axb3.set_ylim(ax1.get_ylim())
plt.show()

なんかあまりうまくいっているように見えない...
細胞型の配向がある程度揃えられては、いる。が、それ以上に全体がつぶれてしまっている。
いずれにしても、実データで使うときはパラメータの調整がめちゃめちゃ難しい。とくにMaximum Mean Discrepancyが関わる部分。
ちょっとでもパラメータ変わると、低次元表現がcollapseする(y=x的なラインになってしまったり、ひどい時は一点に収束してしまったり)。結果もあまり安定しない。
データセットに合わせて繊細な微調整が必要になるっぽい。実用としてはちょっと...

Imputation

dropoutの補完。オートエンコーダで再構成したデータを読み込むだけでオーケー。

In [22]:
original = pd.read_csv(f'./data/{names[0]}.csv', index_col=0)
reconstructed = pd.read_csv(f'./results/{names[0]}_reconstructed.csv', index_col=0)
In [23]:
original.columns[ np.argsort(original.values.mean(axis=0))[::-1][:50] ]
Out[23]:
Index(['TMEM66', 'PIGY+PYURF', 'CTSL', 'LINC01420', 'PDIA3P1', 'MRPL57',
       'NEURL1', 'KMT2A', 'TMEM263', 'SC5D', 'TAAR5', 'TMEM261', 'CNIH1',
       'OTUD6B-AS1', 'CCDC186', 'KLHL42', 'EBLN3', 'CCDC104', 'SMCO4',
       'FXYD2.1', 'MICU2', 'DMTN', 'LUZP6+MTPN', 'UNC13A', 'NCBP2-AS2',
       'LINC00998', 'ATG101', 'ZNF271', 'ZSCAN16-AS1', 'SMIM19', 'NSMF',
       'KMT2D', 'SPIDR', 'MS4A8', 'XIST', 'UQCC1', 'NOL4L', 'RPARP-AS1',
       'HEIH', 'CCAR2', 'RPS14P3', 'LOC102724316', 'LINC00948', 'CYB561A3',
       'TUNAR', 'AREL1', 'ZBED5-AS1', 'ZPR1', 'ZNF204P', 'USP24'],
      dtype='object')
In [30]:
gene1 = 'TMEM66'
gene2 = 'PIGY+PYURF'

fig = plt.figure(figsize=(12, 6))
ax1 = fig.add_subplot(121)
ax1.scatter(original[gene1], original[gene2], s=6, alpha=.3)
ax1.set_xlabel(gene1 + ' expressions')
ax1.set_ylabel(gene2 + ' expressions')
ax1.set_title('Original')
ax2 = fig.add_subplot(122)
ax2.scatter(reconstructed[gene1], reconstructed[gene2], s=6, alpha=.3)
ax2.set_title('Reconstructed')
ax2.set_xlabel(gene1 + ' expressions')
ax2.set_ylabel(gene2 + ' expressions')
plt.show()

できてるっぽいけど、妥当な結果なのかは知識がなくて判断できない。

In [ ]: