Прогнозы по ДНК для новичков на PyTorch

от автора


Видимо, астрологи Хабра объявили день Python в биологии. Сегодня мы приготовили для вас материал, в котором аспирант Эрин Уилсон делится кодом умозрительного примера прогнозирования количественной оценки последовательностей ДНК. За подробностями приглашаем под кат — к старту нашего флагманского курса по Data Science.

ДНК — поток комплексных данных, который можно представить в виде нуклеотидной последовательности ACTG. Тем не менее многие сложные паттерны и структурные нюансы трудно понять человеку, который смотрит на необработанную последовательность нуклеотидов.

В последние годы в моделировании ДНК достигли заметного прогресса. Этот прогресс стал возможным благодаря технологиям глубокого обучения.

Для моделирования ДНК исследователи применяли различные методы. В их числе свёрточные нейросети (CNN), сети с долговременной и кратковременной памятью (LSTM) и даже трансформеры для прогноза различных геномных измерений непосредственно по последовательностям ДНК. Самое полезное в этих моделях — автоматический поиск паттернов и мотивов последовательностей ДНК при достаточном количестве качественных обучающих данных. При этом исследователь не должен заранее указывать искомые паттерны. Нейросеть может их прогнозировать. Полезность глубокого обучения в геномике прибавляет энтузиазма всем, кому нужно сопоставлять последовательности ДНК с биологическими функциями!

Как аспирант, увлечённая математическими методами решения задач рационального природопользования и синтетической биологии, я с интересом изучила возможности PyTorch для исследования паттернов последовательностей ДНК.

Обучающих материалов по PyTorch немало, но многие из них посвящены работе с графикой и языковыми данными. ДНК использовалась в качестве входных данных в целом ряде интересных проектов, которые разработали в PyTorch фреймворки для моделирования различных биологических явлений 1, 2, 3, но они могут быть довольно трудными для новичка.

Мне было нелегко найти примеры в PyTorch для новичков, которые также были бы связаны с ДНК, поэтому я составила это краткое руководство на случай, если будущие архитекторы моделей ДНК найдут его полезным для начала работы.

Само руководство можно запустить в интерактивном режиме блокнота Jupyter. Проследить его план можно по краткому описанию ключевых понятий и Github gists в остальной части этой статьи.

Баннер
Научим решать интересные задачи, чтобы вы прокачали карьеру или стали востребованным IT-специалистом:

Построим модель для прогноза количественных показателей по последовательности ДНК в PyTorch

В этом руководстве рассказывается, как PyTorch использует необработанные последовательности ДНК в качестве входных данных и вводит их в модель нейросети, чтобы прогнозировать количественный балл по самой последовательности.

План руководства

  1. Синтез данных ДНК.
  2. Подготовка данных для обучения в PyTorch.
  3. Определение модели в PyTorch.
  4. Определение функций обучающего цикла.
  5. Запуск моделей.
  6. Проверка прогнозов модели на тестовом наборе данных.
  7. Визуализация свёрточных фильтров.
  8. Итоги.

Полагаю, что читатель знаком со следующими аспектами машинного обучения:

  • Нейросеть и свёрточная нейросеть (CNN).
  • Эпохи обучения моделей.
  • Разделение данных на наборы train/val/test.
  • Функции потерь и сравнение кривых потерь train и val.

Также полагаю, что читатель знаком с некоторыми понятиями биологии. Это:

  • Нуклеотиды в ДНК.
  • Регуляторный мотив.
  • Визуализация мотивов ДНК.

Я не утверждаю, что эти методы оптимальны. Я всё ещё учусь, и это лишь моё представление о решении задачи. Уверена, что есть решения элегантнее.

1. Синтез данных ДНК

Обычно учёные хотят прогнозировать балл связывания, силу экспрессии генов, классифицировать события связывания фактора транскрипции и тому подобное. У нас всё проще. В этом руководстве мы пытаемся научить модель обнаруживать в последовательности ДНК очень маленький и простой паттерн и присваивать ему соответствующий балл. Для этого мы используем глубокое обучение. Но практическая задача даже не в этом. Хочется просто убедиться, что правильно настроили элементы PyTorch и модель может учиться на входных данных, похожих на последовательность ДНК).

Допустим, у нас есть октамерная последовательность ДНК. Начислим ей определённое число баллов за каждый нуклеотид:

  • A (аденин) = +20 баллов.
  • C (цитозин) = +17 баллов.
  • G (гуанин) = +14 баллов.
  • T (тимин) = +11 баллов.

Для каждого октамера суммируем балл и усредняем его. Например:

AAAAAAAA даёт нам 20,0

mean(20 + 20 + 20 + 20 + 20 + 20 + 20 + 20) = 20,0

ACAAAAAA даёт 19,625

mean(20 + 17 + 20 + 20 + 20 + 20 + 20 + 20) = 19,625

Баллы выбраны для нуклеотидов произвольно — никакой настоящей биологии в этом нет! Мы просто назначаем последовательностям баллы, чтобы попрактиковаться в работе с PyTorch.

Во многих работах свёрточные нейросети (CNN) определяют «мотивы» или короткие паттерны ДНК. Они могут активировать или подавлять биологические реакции. Поэтому добавим в систему начисления баллов ещё один элемент: иногда нам нужно моделировать мотивы, влияющие на экспрессию генов. Допустим, последовательность получает +10 баллов при наличии мотива TAT в любой части октамера или -10 баллов при наличии мотива GCG. Повторюсь, к реальности это никакого отношения не имеет. Я просто показала вам на примере, как работает простой механизм моделирования активации и подавления.


Простая система подсчёта баллов для октамерных последовательностей ДНК

Вот реализация этой простой системы:

def kmers(k):     '''Generate a list of all k-mers for a given k'''          return [''.join(x) for x in product(['A','C','G','T'], repeat=k)]      seqs8 = kmers(8) print('Total 8mers:',len(seqs8)) # prints: Total 8mers: 65536  score_dict = {     'A':20,     'C':17,     'G':14,     'T':11 }                    def score_seqs_motif(seqs):     '''     Calculate the scores for a list of sequences based on      the above score_dict     '''     data = []     for seq in seqs:         # get the average score by nucleotide         score = np.mean([score_dict[base] for base in seq])                  # give a + or - bump if this k-mer has a specific motif         if 'TAT' in seq:             score += 10         if 'GCG' in seq:             score -= 10         data.append([seq,score])              df = pd.DataFrame(data, columns=['seq','score'])     return df      mer8 = score_seqs_motif(seqs8) 

Строим распределение баллов для октамерных последовательностей и видим, что они разделились на 3 группы:

  • последовательности с мотивом GCG (~5 баллов);
  • последовательности без мотивов (~15 баллов);
  • последовательности с мотивом TAT (~25 баллов).


Распределение баллов для октамерных последовательностей

Теперь нам предстоит научить модель прогнозировать балл по последовательности ДНК.

2. Подготовка данных для обучения в PyTorch

Чтобы научить нейросеть прогнозировать балл, нужно дать ей входные данные в виде матрицы чисел. Поясню на примере. Пусть мы хотим разделить фотографии с кошками и без кошек на два набора. Сеть распознаёт изображение как матрицу значений пикселей и изучает паттерны, связанные с относительным расположением пикселей (например, паттерны, соответствующие кошачьим ушам или носу с усами).

Аналогичным образом в матрицу чисел нужно превратить последовательности ДНК (строки ACGT). Но как же нам сделать из нашей ДНК «кошку»?

Одна из распространённых стратегий — прямое кодирование [one-hot encoding] ДНК с одним активным состоянием. Для такого кодирования рассмотрим каждый нуклеотид как вектор длиной 4, где в трёх позициях стоят нули, а в одной — единица. Место единицы зависит от типа нуклеотида.


Благодаря прямому кодированию компьютер определит ДНК так же, как кошку на картинке!

def one_hot_encode(seq):     """     Given a DNA sequence, return its one-hot encoding     """     # Make sure seq has only allowed bases     allowed = set("ACTGN")     if not set(seq).issubset(allowed):         invalid = set(seq) - allowed         raise ValueError(f"Sequence contains chars not in allowed DNA alphabet (ACGTN): {invalid}")              # Dictionary returning one-hot encoding for each nucleotide      nuc_d = {'A':[1.0,0.0,0.0,0.0],              'C':[0.0,1.0,0.0,0.0],              'G':[0.0,0.0,1.0,0.0],              'T':[0.0,0.0,0.0,1.0],              'N':[0.0,0.0,0.0,0.0]}          # Create array from nucleotide sequence     vec=np.array([nuc_d[x] for x in seq])              return vec    # look at DNA seq of 8 As a8 = one_hot_encode("AAAAAAAA") print("AAAAAAAA:\n",a8)  # prints: # AAAAAAAA: # [[1. 0. 0. 0.] # [1. 0. 0. 0.] # [1. 0. 0. 0.] # [1. 0. 0. 0.] # [1. 0. 0. 0.] # [1. 0. 0. 0.] # [1. 0. 0. 0.] # [1. 0. 0. 0.]]  # look at DNA seq of random nucleotides s = one_hot_encode("AGGTACCT") print("AGGTACCT:\n",s) print("shape:",s.shape)  # prints: # AGGTACCT: # [[1. 0. 0. 0.] # [0. 0. 1. 0.] # [0. 0. 1. 0.] # [0. 0. 0. 1.] # [1. 0. 0. 0.] # [0. 1. 0. 0.] # [0. 1. 0. 0.] # [0. 0. 0. 1.]] # shape: (8, 4) 

При помощи прямого кодирования можно подготовить наборы train, val и test. Соответствующая функция quick_split случайным образом выбирает индексы в кадре данных pandas для разделения (функция для этого есть и в sklearn).

В реальных задачах, не связанных с синтезом, может потребоваться более интеллектуальная стратегия разделения. Всё зависит от вашей прогностической задачи. В научных статьях разделение train/test часто производится по хромосоме или по другим особенностям расположения генома.

def quick_split(df, split_frac=0.8):     '''     Given a df of samples, randomly split indices between     train and test at the desired fraction     '''     cols = df.columns # original columns, use to clean up reindexed cols     df = df.reset_index()      # shuffle indices     idxs = list(range(df.shape[0]))     random.shuffle(idxs)      # split shuffled index list by split_frac     split = int(len(idxs)*split_frac)     train_idxs = idxs[:split]     test_idxs = idxs[split:]          # split dfs and return     train_df = df[df.index.isin(train_idxs)]     test_df = df[df.index.isin(test_idxs)]              return train_df[cols], test_df[cols]           full_train_df, test_df = quick_split(mer8) train_df, val_df = quick_split(full_train_df)  print("Train:", train_df.shape) print("Val:", val_df.shape) print("Test:", test_df.shape)  # prints:  # Train: (41942, 2) # Val: (10486, 2) # Test: (13108, 2) 

При подготовке данных в PyTorch важным шагом является применение объектов DataLoader и Dataset. Чтобы разобраться в этом, мне пришлось изрядно погуглить, а чтобы найти решение, пришлось прочесать кипу документов и постов со StackOverflow!

Dataset заключает данные в объект, которому можно легко задать отформатированные примеры X и маркировку Y для обучаемой модели. DataLoader получает Dataset и ряд других деталей о том, как объединять данные в серии и упрощает прохождение итераций обучения.

from torch.utils.data import Dataset, DataLoader  ## Here is a custom defined Dataset object specialized for one-hot encoded DNA:  class SeqDatasetOHE(Dataset):     '''     Dataset for one-hot-encoded sequences     '''     def __init__(self,                  df,                  seq_col='seq',                  target_col='score'                 ):         # +--------------------+         # | Get the X examples |         # +--------------------+         # extract the DNA from the appropriate column in the df         self.seqs = list(df[seq_col].values)         self.seq_len = len(self.seqs[0])                  # one-hot encode sequences, then stack in a torch tensor         self.ohe_seqs = torch.stack([torch.tensor(one_hot_encode(x)) for x in self.seqs])              # +------------------+         # | Get the Y labels |         # +------------------+         self.labels = torch.tensor(list(df[target_col].values)).unsqueeze(1)              def __len__(self): return len(self.seqs)          def __getitem__(self,idx):         # Given an index, return a tuple of an X with it's associated Y         # This is called inside DataLoader         seq = self.ohe_seqs[idx]         label = self.labels[idx]                  return seq, label   ## Here is how I constructed DataLoaders from Datasets.  def build_dataloaders(train_df,                       test_df,                       seq_col='seq',                       target_col='score',                       batch_size=128,                       shuffle=True                      ):     '''     Given a train and test df with some batch construction     details, put them into custom SeqDatasetOHE() objects.      Give the Datasets to the DataLoaders and return.     '''          # create Datasets         train_ds = SeqDatasetOHE(train_df,seq_col=seq_col,target_col=target_col)     test_ds = SeqDatasetOHE(test_df,seq_col=seq_col,target_col=target_col)      # Put DataSets into DataLoaders     train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=shuffle)     test_dl = DataLoader(test_ds, batch_size=batch_size)          return train_dl,test_dl           train_dl, val_dl = build_dataloaders(train_df, val_df) 

Теперь наши объекты DataLoader можно использовать в обучающем цикле!

3. Определение модели PyTorch

Сначала мне хотелось попробовать метод свёрточных нейросетей, потому что в изучении данных генома он показал себя отлично. Для сравнения я взяла простую линейную модель. Вот несколько определений модели:

# very simple linear model class DNA_Linear(nn.Module):     def __init__(self, seq_len):         super().__init__()         self.seq_len = seq_len         # the 4 is for our one-hot encoded vector length 4!         self.lin = nn.Linear(4*seq_len, 1)      def forward(self, xb):         # reshape to flatten sequence dimension         xb = xb.view(xb.shape[0],self.seq_len*4)         # Linear wraps up the weights/bias dot product operations         out = self.lin(xb)         return out  # basic CNN model class DNA_CNN(nn.Module):     def __init__(self,                  seq_len,                  num_filters=32,                  kernel_size=3):         super().__init__()         self.seq_len = seq_len                  self.conv_net = nn.Sequential(             # 4 is for the 4 nucleotides             nn.Conv1d(4, num_filters, kernel_size=kernel_size),             nn.ReLU(inplace=True),             nn.Flatten(),             nn.Linear(num_filters*(seq_len-kernel_size+1), 1)         )       def forward(self, xb):         # permute to put channel in correct order         # (batch_size x 4channel x seq_len)         xb = xb.permute(0,2,1)                   #print(xb.shape)         out = self.conv_net(xb)         return out 

Это не оптимальные модели, но надо же с чего-то начинать (ещё раз, мы просто учимся применять PyTorch в контексте ДНК).

  • Линейная модель пытается прогнозировать балл по весу нуклеотидов на каждой позиции.
  • Модель свёрточных нейросетей использует 32 фильтра с длиной (kernel_size) 3 для сканирования всей октамерной последовательности и поиска в ней паттернов-тримеров.

4. Определение функций обучающего цикла

Теперь нужно определить цикл обучения. Скажу честно, я не слишком-то уверена в этом решении. Я долго продиралась через ошибки несоответствия размерности матрицы. Наверное, есть красивые подходы! Но, может быть, и так сойдёт? Напишите мне пару слов, если у вас есть свои мысли на этот счёт.

Так или иначе, я пришла к такому определению стека функций:

# adds default optimizer and loss function run_model()     # loops through epochs     fit()         # loop through batches         train_step()             # calc train loss for batch             loss_batch()         val_step()             # calc val loss for batch             loss_batch() 

# +--------------------------------+ # | Training and fitting functions | # +--------------------------------+  def loss_batch(model, loss_func, xb, yb, opt=None,verbose=False):     '''     Apply loss function to a batch of inputs. If no optimizer     is provided, skip the back prop step.     '''     if verbose:         print('loss batch ****')         print("xb shape:",xb.shape)         print("yb shape:",yb.shape)         print("yb shape:",yb.squeeze(1).shape)         #print("yb",yb)      # get the batch output from the model given your input batch      # ** This is the model's prediction for the y labels! **     xb_out = model(xb.float())          if verbose:         print("model out pre loss", xb_out.shape)         #print('xb_out', xb_out)         print("xb_out:",xb_out.shape)         print("yb:",yb.shape)         print("yb.long:",yb.long().shape)          loss = loss_func(xb_out, yb.float()) # for MSE/regression     # __FOOTNOTE 2__          if opt is not None: # if opt         loss.backward()         opt.step()         opt.zero_grad()      return loss.item(), len(xb)      def train_step(model, train_dl, loss_func, device, opt):     '''     Execute 1 set of batched training within an epoch     '''     # Set model to Training mode     model.train()     tl = [] # train losses     ns = [] # batch sizes, n          # loop through train DataLoader     for xb, yb in train_dl:         # put on GPU         xb, yb = xb.to(device),yb.to(device)                  # provide opt so backprop happens         t, n = loss_batch(model, loss_func, xb, yb, opt=opt)                  # collect train loss and batch sizes         tl.append(t)         ns.append(n)          # average the losses over all batches         train_loss = np.sum(np.multiply(tl, ns)) / np.sum(ns)          return train_loss  def val_step(model, val_dl, loss_func, device):     '''     Execute 1 set of batched validation within an epoch     '''     # Set model to Evaluation mode     model.eval()     with torch.no_grad():         vl = [] # val losses         ns = [] # batch sizes, n                  # loop through validation DataLoader         for xb, yb in val_dl:             # put on GPU             xb, yb = xb.to(device),yb.to(device)              # Do NOT provide opt here, so backprop does not happen             v, n = loss_batch(model, loss_func, xb, yb)              # collect val loss and batch sizes             vl.append(v)             ns.append(n)      # average the losses over all batches     val_loss = np.sum(np.multiply(vl, ns)) / np.sum(ns)          return val_loss      def fit(epochs, model, loss_func, opt, train_dl, val_dl,device,patience=1000):     '''     Fit the model params to the training data, eval on unseen data.     Loop for a number of epochs and keep train of train and val losses      along the way     '''     # keep track of losses     train_losses = []         val_losses = []          # loop through epochs     for epoch in range(epochs):         # take a training step         train_loss = train_step(model,train_dl,loss_func,device,opt)         train_losses.append(train_loss)          # take a validation step         val_loss = val_step(model,val_dl,loss_func,device)         val_losses.append(val_loss)                  print(f"E{epoch} | train loss: {train_loss:.3f} | val loss: {val_loss:.3f}")      return train_losses, val_losses   def run_model(train_dl,val_dl,model,device,               lr=0.01, epochs=50,                lossf=None,opt=None              ):     '''     Given train and val DataLoaders and a NN model, fit the mode to the training     data. By default, use MSE loss and an SGD optimizer     '''     # define optimizer     if opt:         optimizer = opt     else: # if no opt provided, just use SGD         optimizer = torch.optim.SGD(model.parameters(), lr=lr)          # define loss function     if lossf:         loss_func = lossf     else: # if no loss function provided, just use MSE         loss_func = torch.nn.MSELoss()          # run the training loop     train_losses, val_losses = fit(                                 epochs,                                  model,                                  loss_func,                                  optimizer,                                  train_dl,                                  val_dl,                                  device)      return train_losses, val_losses 

5. Запуск моделей

Для начала испытаем на наших октамерных последовательностях линейный подход.

# use GPU if available DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # get the sequence length from the first seq in the df seq_len = len(train_df['seq'].values[0])  # create Linear model object model_lin = DNA_Linear(seq_len) model_lin.to(DEVICE) # put on GPU  # run the model with default settings! lin_train_losses, lin_val_losses = run_model(     train_dl,      val_dl,      model_lin,     DEVICE ) 

Соберём данные о потерях train и val и набросаем график:

def quick_loss_plot(data_label_list,loss_type="MSE Loss",sparse_n=0):     '''     For each train/test loss trajectory, plot loss by epoch     '''     for i,(train_data,test_data,label) in enumerate(data_label_list):             plt.plot(train_data,linestyle='--',color=f"C{i}", label=f"{label} Train")         plt.plot(test_data,color=f"C{i}", label=f"{label} Val",linewidth=3.0)      plt.legend()     plt.ylabel(loss_type)     plt.xlabel("Epoch")     plt.legend(bbox_to_anchor=(1,1),loc='upper left')     plt.show()       lin_data_label = (lin_train_losses,lin_val_losses,"Lin") quick_loss_plot([lin_data_label]) 


Обучение линейной модели и оценка кривых потерь

На первый взгляд, такой подход привнёс в обучение немного.

Теперь попробуем метод свёрточных нейросетей и построим график для него.

seq_len = len(train_df['seq'].values[0])  # create Linear model object model_cnn = DNA_CNN(seq_len) model_cnn.to(DEVICE) # put on GPU  # run the model with default settings! cnn_train_losses, cnn_val_losses = run_model(     train_dl,      val_dl,      model_cnn,     DEVICE )  cnn_data_label = (cnn_train_losses,cnn_val_losses,"CNN") quick_loss_plot([lin_data_label,cnn_data_label]) 


Кривые потерь для линейного метода и метода свёрточных нейросетей (CNN)

Очевидно, что CNN может уловить паттерн, а линейный метод не может. Прогоним несколько последовательностей и посмотрим, что получится.

# oracle dict of true score for each seq oracle = dict(mer8[['seq','score']].values)  def quick_seq_pred(model, desc, seqs, oracle):     '''     Given a model and some sequences, get the model's predictions     for those sequences and compare to the oracle (true) output     '''     print(f"__{desc}__")     for dna in seqs:         s = torch.tensor(one_hot_encode(dna)).unsqueeze(0).to(DEVICE)         pred = model(s.float())         actual = oracle[dna]         diff = pred.item() - actual         print(f"{dna}: pred:{pred.item():.3f} actual:{actual:.3f} ({diff:.3f})")  def quick_8mer_pred(model, oracle):     seqs1 = ("poly-X seqs",['AAAAAAAA', 'CCCCCCCC','GGGGGGGG','TTTTTTTT'])     seqs2 = ("other seqs", ['AACCAACA','CCGGTGAG','GGGTAAGG', 'TTTCGTTT'])     seqsTAT = ("with TAT motif", ['TATAAAAA','CCTATCCC','GTATGGGG','TTTATTTT'])     seqsGCG = ("with GCG motif", ['AAGCGAAA','CGCGCCCC','GGGCGGGG','TTGCGTTT'])     TATGCG =  ("both TAT and GCG",['ATATGCGA','TGCGTATT'])      for desc,seqs in [seqs1, seqs2, seqsTAT, seqsGCG, TATGCG]:         quick_seq_pred(model, desc, seqs, oracle)         print()          # Ask the trained Linear model to make  # predictions for some 8-mers quick_8mer_pred(model_lin, oracle)  # prints: # __poly-X seqs__ # AAAAAAAA: pred:23.230 actual:20.000 (3.230) # CCCCCCCC: pred:13.582 actual:17.000 (-3.418) # GGGGGGGG: pred:7.006 actual:14.000 (-6.994) # TTTTTTTT: pred:17.767 actual:11.000 (6.767)  # __other seqs__ # AACCAACA: pred:18.818 actual:18.875 (-0.057) # CCGGTGAG: pred:12.205 actual:15.125 (-2.920) # GGGTAAGG: pred:13.826 actual:15.125 (-1.299) # TTTCGTTT: pred:14.815 actual:12.125 (2.690)  # __with TAT motif__ # TATAAAAA: pred:22.146 actual:27.750 (-5.604) # CCTATCCC: pred:16.931 actual:25.875 (-8.944) # GTATGGGG: pred:12.141 actual:24.000 (-11.859) # TTTATTTT: pred:18.266 actual:22.125 (-3.859)  # __with GCG motif__ # AAGCGAAA: pred:16.736 actual:8.125 (8.611) # CGCGCCCC: pred:12.346 actual:6.250 (6.096) # GGGCGGGG: pred:7.907 actual:4.375 (3.532) # TTGCGTTT: pred:12.839 actual:2.500 (10.339)  # __both TAT and GCG__ # ATATGCGA: pred:15.664 actual:15.875 (-0.211) # TGCGTATT: pred:14.771 actual:13.625 (1.146) 

Из приведённых выше примеров видно, что линейная модель занижает прогноз для последовательности с большим количеством G и завышает его для последовательности с большим количеством T. По-видимому, она просто не замечает, что последовательности с GCG имеют необычно низкие баллы, а последовательности с TAT — необычно высокие. Но, поскольку в линейной модели невозможно учесть различия между GCG и GAG, она просто понижает балл для последовательности с G. Из нашей схемы подсчёта баллов известно, что это не так. Дело тут не в том, что G вредны или, если хотите, субвитальны, а в том, что вредны именно мотивы GCG.

# Ask the trained CNN model to make  # predictions for some 8-mers quick_8mer_pred(model_cnn, oracle)  # prints:  # __poly-X seqs__ # AAAAAAAA: pred:19.722 actual:20.000 (-0.278) # CCCCCCCC: pred:16.620 actual:17.000 (-0.380) # GGGGGGGG: pred:13.771 actual:14.000 (-0.229) # TTTTTTTT: pred:10.767 actual:11.000 (-0.233)  # __other seqs__ # AACCAACA: pred:18.530 actual:18.875 (-0.345) # CCGGTGAG: pred:14.925 actual:15.125 (-0.200) # GGGTAAGG: pred:14.900 actual:15.125 (-0.225) # TTTCGTTT: pred:11.789 actual:12.125 (-0.336)  # __with TAT motif__ # TATAAAAA: pred:26.154 actual:27.750 (-1.596) # CCTATCCC: pred:24.321 actual:25.875 (-1.554) # GTATGGGG: pred:22.870 actual:24.000 (-1.130) # TTTATTTT: pred:20.581 actual:22.125 (-1.544)  # __with GCG motif__ # AAGCGAAA: pred:8.689 actual:8.125 (0.564) # CGCGCCCC: pred:6.531 actual:6.250 (0.281) # GGGCGGGG: pred:5.029 actual:4.375 (0.654) # TTGCGTTT: pred:3.055 actual:2.500 (0.555)  # __both TAT and GCG__ # ATATGCGA: pred:15.178 actual:15.875 (-0.697) # TGCGTATT: pred:12.853 actual:13.625 (-0.772) 

CNN лучше улавливает разницу между тримерными мотивами! При этом он хорошо переваривает последовательности и с мотивами и без них.

6. Проверка прогнозов модели на тестовом наборе данных

В любой задаче машинного обучения при оценке результатов важно убедиться, что модель может делать хорошие прогнозы по тестовым наборам, с которыми она никогда не сталкивалась в ходе обучения. Разница между реальными тестовыми последовательностями и прогнозами модели лучше всего видна на сравнительных графиках.

import altair as alt from sklearn.metrics import r2_score  def parity_plot(model_name,df,r2):     '''     Given a dataframe of samples with their true and predicted values,     make a scatterplot.     '''     plt.scatter(df['truth'].values, df['pred'].values, alpha=0.2)          # y=x line     xpoints = ypoints = plt.xlim()     plt.plot(xpoints, ypoints, linestyle='--', color='k', lw=2, scalex=False, scaley=False)      plt.ylim(xpoints)     plt.ylabel("Predicted Score",fontsize=14)     plt.xlabel("Actual Score",fontsize=14)     plt.title(f"{model_name} (r2:{r2:.3f})",fontsize=20)     plt.show()      def alt_parity_plot(model,df, r2):     '''     Make an interactive parity plot with altair     '''     chart = alt.Chart(df).mark_circle(size=100,opacity=0.4).encode(         alt.X('truth:Q'),         alt.Y('pred:Q'),         tooltip=['seq:N']     ).properties(         title=f'{model} (r2:{r2:.3f})'     ).interactive()          chart.save(f'alt_out/parity_plot_{model}.html')     display(chart)  def parity_pred(models, seqs, oracle,alt=False):     '''Given some sequences, get the model's predictions '''     dfs = {} # key: model name, value: parity_df          for model_name,model in models:         print(f"Running {model_name}")         data = []         for dna in seqs:             s = torch.tensor(one_hot_encode(dna)).unsqueeze(0).to(DEVICE)             actual = oracle[dna]             pred = model(s.float())             data.append([dna,actual,pred.item()])         df = pd.DataFrame(data, columns=['seq','truth','pred'])         r2 = r2_score(df['truth'],df['pred'])         dfs[model_name] = (r2,df)                  #plot parity plot         if alt: # make an altair plot             alt_parity_plot(model_name, df, r2)         else:             parity_plot(model_name, df, r2)              # generate plots seqs = test_df['seq'].values models = [     ("Linear", model_lin),     ("CNN", model_cnn) ] parity_pred(models, seqs, oracle) 

Разница между реальными тестовыми последовательностями и прогнозами модели.

Сравнение графиков показывает, насколько точно модель прогнозирует отдельные последовательности. У идеальной модели все графики точно лягут на диагональ y=x. Это покажет, что модель точно прогнозирует последовательности. Отклонения от диагонали y=x означают, что прогнозы модели завышены или занижены.

Как мы видим, линейная модель может в какой-то мере прогнозировать линию тренда тестовых последовательностей, но обилие последовательностей в нижней и верхней областях распределения (при наличии мотивов) легко сбивают эту модель с толку.

CNN работает гораздо лучше и ожидаемо предсказывает значения, близкие к реальным, ведь архитектура нашего CNN использует ядра тримеров для поиска ключевых мотивов вдоль последовательности.

Впрочем, CNN — тоже не идеал. Можно было бы обучать его дольше или корректировать гиперпараметры, но цель здесь не в совершенстве. Мы сильно упростили задачу по сравнению с реальными регулятивными грамматиками. И я подумала, что для интерактивного выявления последовательностей, с которыми модели не справляются, было бы интересно использовать библиотеку визуализации Altair:

Заметьте, что вне диагонали наши последовательности, как правило, имеют множественные экземпляры мотивов! При подсчёте баллов мы просто давали последовательности ± поправки за один мотив. В действительности же разумно было бы несколько таких «бонусов» за несколько мотивов. В своём примере я просто даю «бонус» за 1 экземпляр мотива, но функцию расчёта баллов можно было бы составить иначе.

В любом случае здорово, когда модель замечает множественные экземпляры и прогнозирует их важность. Кажется, здесь я немного обманула её… Впрочем, R2 0,95 — это уже не так плохо.

7. Визуализация свёрточных фильтров

При обучении моделей CNN полезно визуализировать свёрточные фильтры первого слоя, чтобы лучше понять процесс обучения. При работе с изображениями свёрточные фильтры первого слоя часто изучают паттерны границ, цвета или текстуры — основные элементы изображения, при объединении которых мы получаем более сложные признаки.

В случае ДНК мы исходим из того, что свёрточные фильтры можно рассматривать как сканеры мотивов. По аналогии с матрицей весовых коэффициентов позиций для визуализации логотипов последовательностей, свёрточный фильтр похож на матрицу, показывающую определённый паттерн ДНК. Однако вместо того, чтобы быть точной последовательностью, он может содержать некоторую неопределённость относительно того, какие нуклеотиды показываются в той или иной части паттерна. Некоторые позиции могут быть строго определены (например, в позиции 2 всегда есть A; информативность высокая), в других же с почти равной вероятностью может быть множество нуклеотидов (энтропия высокая, информативность низкая).

Вычисления на скрытых слоях нейросетей могут быть очень сложными, и не каждый свёрточный фильтр будет представлять собой очевидный паттерн. Однако фильтры всё же выявляют информативные для объяснения прогноза паттерны.

Ниже приведены некоторые функции для визуализации свёрточных фильтров первого слоя как в виде необработанной тепловой карты, так и в виде логотипа.

import logomaker  def get_conv_layers_from_model(model):     '''     Given a trained model, extract its convolutional layers     '''     model_children = list(model.children())          # counter to keep count of the conv layers     model_weights = [] # we will save the conv layer weights in this list     conv_layers = [] # we will save the actual conv layers in this list     bias_weights = []     counter = 0       # append all the conv layers and their respective weights to the list     for i in range(len(model_children)):         # get model type of Conv1d         if type(model_children[i]) == nn.Conv1d:             counter += 1             model_weights.append(model_children[i].weight)             conv_layers.append(model_children[i])             bias_weights.append(model_children[i].bias)          # also check sequential objects' children for conv1d         elif type(model_children[i]) == nn.Sequential:             for child in model_children[i]:                 if type(child) == nn.Conv1d:                     counter += 1                     model_weights.append(child.weight)                     conv_layers.append(child)                     bias_weights.append(child.bias)      print(f"Total convolutional layers: {counter}")     return conv_layers, model_weights, bias_weights  def view_filters(model_weights, num_cols=8):     model_weights = model_weights[0]     num_filt = model_weights.shape[0]     filt_width = model_weights[0].shape[1]     num_rows = int(np.ceil(num_filt/num_cols))          # visualize the first conv layer filters     plt.figure(figsize=(20, 17))      for i, filter in enumerate(model_weights):         ax = plt.subplot(num_rows, num_cols, i+1)         ax.imshow(filter.cpu().detach(), cmap='gray')         ax.set_yticks(np.arange(4))         ax.set_yticklabels(['A', 'C', 'G','T'])         ax.set_xticks(np.arange(filt_width))         ax.set_title(f"Filter {i}")      plt.tight_layout()     plt.show()      conv_layers, model_weights, bias_weights = get_conv_layers_from_model(model_cnn) view_filters(model_weights) 

Возможно, это полезная информация, но обычно люди любят визуализировать последовательности с некоторой неопределённостью в виде логотипов мотивов, откладывая позиции мотива по X, а вероятности появления нуклеотидов в каждой из них — по Y. Часто для удобства визуализации эти вероятности преобразуются в биты (информацию).

Для преобразования необработанных свёрточных фильтров в визуальные матрицы весовых коэффициентов позиций обычно собирают активации фильтров: применяют веса фильтра вдоль кодируемой прямо последовательности и измеряют активацию фильтра (а также то, насколько хорошо веса соответствуют последовательности).

Матрицы весовых коэффициентов фильтра, соответствующие близкому совпадению с последовательностью, будут высокоактивированными (то есть дадут более высокие оценки совпадения). Собирая подпоследовательности ДНК, которые дают самые высокие оценки активации, мы можем создать весовую матрицу «высокоактивированных последовательностей» для каждого фильтра и, таким образом, визуализировать свёрточный фильтр в виде логотипа мотива.


Схема, показывающая степень сбора активированных участков последовательности и их преобразования в логотипы мотивов для фильтра свёртки

def get_conv_output_for_seq(seq, conv_layer):     '''     Given an input sequeunce and a convolutional layer,      get the output tensor containing the conv filter      activations along each position in the sequence     '''     # format seq for input to conv layer (OHE, reshape)     seq = torch.tensor(one_hot_encode(seq)).unsqueeze(0).permute(0,2,1).to(DEVICE)      # run seq through conv layer     with torch.no_grad(): # don't want as part of gradient graph         # apply learned filters to input seq         res = conv_layer(seq.float())         return res[0]       def get_filter_activations(seqs, conv_layer,act_thresh=0):     '''     Given a set of input sequences and a trained convolutional layer,      determine the subsequences for which each filter in the conv layer      activate most strongly.           1.) Run seq inputs through conv layer.      2.) Loop through filter activations of the resulting tensor, saving the             position where filter activations were > act_thresh.      3.) Compile a count matrix for each filter by accumulating subsequences which             activate the filter above the threshold act_thresh     '''     # initialize dict of pwms for each filter in the conv layer     # pwm shape: 4 nucleotides X filter width, initialize to 0.0s     num_filters = conv_layer.out_channels     filt_width = conv_layer.kernel_size[0]     filter_pwms = dict((i,torch.zeros(4,filt_width)) for i in range(num_filters))          print("Num filters", num_filters)     print("filt_width", filt_width)          # loop through a set of sequences and collect subseqs where each filter activated     for seq in seqs:         # get a tensor of each conv filter activation along the input seq         res = get_conv_output_for_seq(seq, conv_layer)          # for each filter and it's activation vector         for filt_id,act_vec in enumerate(res):             # collect the indices where the activation level              # was above the threshold             act_idxs = torch.where(act_vec>act_thresh)[0]             activated_positions = [x.item() for x in act_idxs]              # use activated indicies to extract the actual DNA             # subsequences that caused filter to activate             for pos in activated_positions:                 subseq = seq[pos:pos+filt_width]                 #print("subseq",pos, subseq)                 # transpose OHE to match PWM orientation                 subseq_tensor = torch.tensor(one_hot_encode(subseq)).T                  # add this subseq to the pwm count for this filter                 filter_pwms[filt_id] += subseq_tensor                              return filter_pwms  def view_filters_and_logos(model_weights,filter_activations, num_cols=8):     '''     Given some convolutional model weights and filter activation PWMs,      visualize the heatmap and motif logo pairs in a simple grid     '''     model_weights = model_weights[0].squeeze(1)     print(model_weights.shape)      # make sure the model weights agree with the number of filters     assert(model_weights.shape[0] == len(filter_activations))          num_filts = len(filter_activations)     num_rows = int(np.ceil(num_filts/num_cols))*2+1      # ^ not sure why +1 is needed... complained otherwise          plt.figure(figsize=(20, 17))      j=0 # use to make sure a filter and it's logo end up vertically paired     for i, filter in enumerate(model_weights):         if (i)%num_cols == 0:             j += num_cols          # display raw filter         ax1 = plt.subplot(num_rows, num_cols, i+j+1)         ax1.imshow(filter.cpu().detach(), cmap='gray')         ax1.set_yticks(np.arange(4))         ax1.set_yticklabels(['A', 'C', 'G','T'])         ax1.set_xticks(np.arange(model_weights.shape[2]))         ax1.set_title(f"Filter {i}")          # display sequence logo         ax2 = plt.subplot(num_rows, num_cols, i+j+1+num_cols)         filt_df = pd.DataFrame(filter_activations[i].T.numpy(),columns=['A','C','G','T'])         filt_df_info = logomaker.transform_matrix(filt_df,from_type='counts',to_type='information')         logo = logomaker.Logo(filt_df_info,ax=ax2)         ax2.set_ylim(0,2)         ax2.set_title(f"Filter {i}")      plt.tight_layout()       # just use some seqs from test_df to activate filters some_seqs = random.choices(seqs, k=3000)  filter_activations = get_filter_activations(some_seqs, conv_layers[0],act_thresh=1) view_filters_and_logos(model_weights,filter_activations)  # prints: # Num filters 32 # filt_width 3 # torch.Size([32, 4, 3]) 

Из конкретно этого обучения CNN видно, что несколько фильтров уловили сильные мотивы TAT и GCG, в то время как другие фильтры сфокусировались и на других паттернах.

О важности визуализации свёрточных фильтров для интерпретации модели нет единого мнения. В глубоких моделях с несколькими свёрточными слоями свёрточные фильтры могут комбинироваться в скрытых слоях более сложными способами, поэтому фильтры первого слоя сами по себе могут быть не такими информативными (Ку и Эдди, 2019]). Большая часть этой области с тех пор перешла к механизмам внимания и другим методам объяснения результатов. Однако, если вам интересно визуализировать фильтры как потенциальные мотивы, эти функции могут быть полезны!

8. Итоги

В руководстве даются основы построения моделей CNN в PyTorch для обработки последовательностей ДНК. Практическая часть работы не отражает реалии биологических сигналов. Наш метод подсчёта баллов, скорее, моделирует наличие регуляторных мотивов в очень коротких последовательностях, которые легко понять человеку. Поведение PyTorch соответствовало ожиданиям. В этом маленьком примере мы проверили способность базовой CNN со сдвиговыми фильтрами к прогнозу расчёта баллов. CNN показала себя лучше, чем линейный метод, выявивший только абсолютные положения нуклеотидов (без учёта местных особенностей).

Подробнее о применении CNN к ДНК в естественных условиях можно узнать из следующих работ:

Надеюсь, эта статья поможет с PyTorch всем, кто интересуется биологией и ДНК, но мало знаком с машинным обучением.

9. Примечания

Первое

В нашем примере свёрточной нейросети используется одномерный свёрточный слой. Поскольку ДНК — не двухмерная картинка, достаточно прогнать Conv1D по всей длине последовательности, не сканируя «вниз» и «вверх». (Сдвиг фильтра «вверх» и «вниз» неприменим при кодировании матриц ДНК с одним активным состоянием. Нет смысла отделять ряды A и C от G и T — все четыре ряда должны быть строго представлены в последовательности).

Однако мне однажды понадобился инструмент анализа от keras, и тогда я нашла скрипт преобразования pytorch2keras. Скрипт преобразования умеет работать только со слоями Conv2d и выдаёт ошибки по слоям Conv1d в модели.

Если он понадобится и вам, вот пример переформатирования CNN в Conv2D, который обеспечивает беспроблемное сканирование, как для Conv1D:

class DNA_CNN_2D(nn.Module):     def __init__(self,                  seq_len,                  num_filters=31,                  kernel_size=3,                 ):         super().__init__()         self.seq_len = seq_len                  self.conv_net = nn.Sequential(             nn.Conv2d(1, num_filters, kernel_size=(4,kernel_size)),             # ^^ changed from 4 to 1 channel, moved 4 to kernel_size             nn.ReLU(),             nn.Flatten(),             nn.Linear(num_filters*(seq_len-kernel_size+1), 1)         )       def forward(self, xb):         # reshape view to batch_ssize x 4channel x seq_len         # permute to put channel in correct order         xb = xb.permute(0,2,1).unsqueeze(1)         # ^^ Conv2D input fix                  #print(xb.shape)         out = self.conv_net(xb)         return out 

Второе

Если ваша задача связана с классификацией, а не с регрессией, возможно, вы предпочтёте использовать CrossEntropyLoss. Учтите, что CrossEntropyLoss требуется формат, отличный от MSELoss. Попробуйте такой вариант:

loss = loss_func(xb_out, yb.long().squeeze (1)) 

А пока учёные моделируют ДНК, мы научим вас разрабатывать системы ИИ, которые решают широкий круг проблем:

Чтобы увидеть все курсы, кликните по баннеру:


ссылка на оригинал статьи https://habr.com/ru/company/skillfactory/blog/690562/


Комментарии

Добавить комментарий

Ваш адрес email не будет опубликован. Обязательные поля помечены *