3D teeth instance segmentation. В темноте, но не один

от автора

3D сегментация зубов от поиска данных до конечного результата. Почти.

Дисклеймер

Данная статья не является обучающей в любом понимании этого термина и носит сугубо информативный характер. Автор статьи не несет ответственности за время, потраченное на ее чтение.

Об авторе

Добрый — всем, зовут Андрей(27). Постараюсь коротко. Почему программирование? По образованию — бакалавр электромеханик, профессию знаю. Отработал 2 года на должности инженера-энергетика в буровой компании вполне успешно, вместо повышения написал заявление — сгорел, да не по мне оказалось это всё. Нравится создавать, находить решения сложных задач, с ПК в обнимку с сознательных лет. Выбор очевиден. Вначале (полгода назад), всерьёз думал записаться на курсы от Я или подобные. Начитался отзывов, поговорил с участниками и понял что с получением информацией проблем нет. Так нашел сайт, там получил базу по Python и с ним уже начал свой путь (сейчас там постепенно изучаю всё, что связано с ML). Сразу заинтересовало машинное обучение, CV в частности. Придумал себе задачу и вот здесь (по мне, так отличный способ учиться).

1. Введение

В результате нескольких неудачных попыток, пришел к решению использовать 2 легковесные модели для получения желаемого результата. 1-ая сегментирует все зубы как [1, 0] категорию, а вторая делит их на категории[0, 8]. Но начнем по порядку.

2. Поиск и подготовка данных

Потратив не один вечер на поиск данных для работы, пришел в выводу что в свободном доступе челюсть в хорошем качестве и формате (*.stl, *.nrrd и т.д.) не получится. Лучшее, что мне попалось — это тестовый образец головы пациента после хирургической операции на челюсти в программе 3D Slicer.

Очевидно, мне не нужна голова целиком, поэтому обрезал исходник в той же программе до размера 163*112*120рх (в данном посте {x*y*z = ш-г-в} и 1рх — 0,5мм), оставив только зубы и сопутствующие челюстно-лицевые части.

Уже больше похоже на то что нужно, дальше — интереснее. Теперь нужно создать маски всех необходимых нам объектов. Для тех, кто уже работал с этим — «autothreshold» не то чтобы совсем не работает, просто лишнего много, думаю, исправление заняло бы столько же времени, сколько и разметка вручную(через маски).

- Пиксели(срезы слева)? - Вспоминаем размер изображения
— Пиксели(срезы слева)? — Вспоминаем размер изображения

Размечал часов 12~14. И да, тот факт что я не сразу разметил каждый зуб как категорию стоил мне еще порядка 4 часов. В итоге у нас есть данные, с которыми у же можно работать.

Конечный вариант маски. Smooth 0.5. (сглаживание в обучении не использовалось)
Конечный вариант маски. Smooth 0.5. (сглаживание в обучении не использовалось)

Должен добавить, даже на мой (без опыта) взгляд, этих данных очень мало для обучения и последующей полноценной работы нейронной сети. На данном этапе, единственное что пришло в голову, повернуть имеющиеся данные N-раз и соединить, random-crop использовать не стал.

Код подготовки данных
import nrrd import torch import torchvision.transforms as tf   class DataBuilder:     def __init__(self,                  data_path,                  list_of_categories,                  num_of_chunks: int = 0,                  augmentation_coeff: int = 0,                  num_of_classes: int = 0,                  normalise: bool = False,                  fit: bool = True,                  data_format: int = 0,                  save_data: bool = False                  ):         self.data_path = data_path         self.number_of_chunks = num_of_chunks         self.augmentation_coeff = augmentation_coeff         self.list_of_cats = list_of_categories         self.num_of_cls = num_of_classes         self.normalise = normalise         self.fit = fit         self.data_format = data_format         self.save_data = save_data      def forward(self):         data = self.get_data()         data = self.fit_data(data) if self.fit else data         data = self.pre_normalize(data) if self.normalise else data         data = self.data_augmentation(data, self.augmentation_coeff) if self.augmentation_coeff != 0 else data         data = self.new_chunks(data, self.number_of_chunks) if self.number_of_chunks != 0 else data         data = self.category_splitter(data, self.num_of_cls, self.list_of_cats) if self.num_of_cls != 0 else data         torch.save(data, self.data_path[-14:]+'.pt') if self.save_data else None          return torch.unsqueeze(data, 1)      def get_data(self):         if self.data_format == 0:             return torch.from_numpy(nrrd.read(self.data_path)[0])         elif self.data_format == 1:             return torch.load(self.data_path).cpu()         elif self.data_format == 2:             return torch.unsqueeze(self.data_path, 0).cpu()         else:             print('Available types are: "nrrd", "tensor" or "self.tensor(w/o load)"')      @staticmethod     def fit_data(some_data):         data = torch.movedim(some_data, (1, 0), (0, -1))         data_add_x = torch.nn.ZeroPad2d((5, 0, 0, 0))         data = data_add_x(data)         data = torch.movedim(data, -1, 0)         data_add_z = torch.nn.ZeroPad2d((0, 0, 8, 0))          return data_add_z(data)      @staticmethod     def pre_normalize(some_data):         min_d, max_d = torch.min(some_data), torch.max(some_data)          return (some_data - min_d) / (max_d - min_d)      @staticmethod     def data_augmentation(some_data, aug_n):         torch.manual_seed(17)         tr_data = []         for e in range(aug_n):             transform = tf.RandomRotation(degrees=(20*e, 20*e))             for image in some_data:                 image = torch.unsqueeze(image, 0)                 image = transform(image)                 tr_data.append(image)          return tr_data      def new_chunks(self, some_data, n_ch):         data = torch.stack(some_data, 0) if self.augmentation_coeff != 0 else some_data         data = torch.squeeze(data, 1)         chunks = torch.chunk(data, n_ch, 0)          return torch.stack(chunks)      @staticmethod     def category_splitter(some_data, alpha, list_of_categories):         data, _ = torch.squeeze(some_data, 1).to(torch.int64), alpha         for i in list_of_categories:             data = torch.where(data < i, _, data)             _ += 1          return data - alpha 

Имейте ввиду что это финальная версия кода подготовки данных для 3D U-net. Форвард:

  • Загружаем дату (в зависимости от типа).

  • Добавляем 0 по краям чтобы подогнать размер до 168*120*120 (вместо исходных 163*112*120). *пригодится дальше.

  • Нормализуем входящие данные в 0…1 (исходные ~-2000…16000).

  • Поворачиваем N-раз и соединяем.

  • Полученные данные режем на равные части чтобы забить память видеокарты по максимуму (в моем случае это 1, 1, 72, 120, 120).

  • Эта часть распределяет по категориям 28 имеющихся зубов и фон для облегчения обучения моделей (см. Введение):

    • одну категорию для 1-ой;

    • на 9 категорий (8+фон) для 2-ой.

Dataloader стандартный
import torch.utils.data as tud   class ToothDataset(tud.Dataset):     def __init__(self, images, masks):         self.images = images         self.masks = masks      def __len__(self): return len(self.images)      def __getitem__(self, index):         if self.masks is not None:             return self.images[index, :, :, :, :],\                     self.masks[index, :, :, :, :]         else:             return self.images[index, :, :, :, :]   def get_loaders(images, masks,                 batch_size: int = 1,                 num_workers: int = 1,                 pin_memory: bool = True):      train_ds = ToothDataset(images=images,                             masks=masks)      data_loader = tud.DataLoader(train_ds,                                  batch_size=batch_size,                                  shuffle=False,                                  num_workers=num_workers,                                  pin_memory=pin_memory)      return data_loader 

На выходе имеем следующее:

Semantic

Instance

Predictions

Data

(27*, 1, 56*, 120,120)[0…1]

(27*, 1, 56*, 120,120) [0, 1]

(1, 1, 168, 120, 120)[0…1]

Masks

(27*, 1, 56*, 120,120)[0, 1]

(27*, 1, 56*, 120,120)[0, 8]

*эти размеры менялись, в зависимости от эксперимента, подробности — дальше.

3. Выбор и настройка моделей обучения

Цель работы — обучение. Поэтому взял наиболее простую и понятную для себя модель нейросети архитектуры U-Net. Код не выкладываю, можно посмотреть тут.

2D U-Net
2D U-Net

Подробно рассказывать не буду, информации в достатке в сети. Метод оптимизации — Adam, функция расчета потерь Dice-loss(implement), спусков/подъемов 4, фильтры [64, 128, 256, 512] (знаю, много, об этом — позже). Обучал в среднем 60-80 epochs на эксперимент. Transfer learning не использовал.

model.summary()
model = UNet(dim=2, in_channels=1, out_channels=1, n_blocks=4, start_filters=64).to(device) print(summary(model, (1, 168, 120)))  """ ----------------------------------------------------------------         Layer (type)               Output Shape         Param # ================================================================             Conv2d-1         [-1, 64, 168, 120]             640               ReLU-2         [-1, 64, 168, 120]               0        BatchNorm2d-3         [-1, 64, 168, 120]             128             Conv2d-4         [-1, 64, 168, 120]          36,928               ReLU-5         [-1, 64, 168, 120]               0        BatchNorm2d-6         [-1, 64, 168, 120]             128          MaxPool2d-7           [-1, 64, 84, 60]               0          DownBlock-8  [[-1, 64, 84, 60], [-1, 64, 168, 120]]  0             Conv2d-9          [-1, 128, 84, 60]          73,856              ReLU-10          [-1, 128, 84, 60]               0       BatchNorm2d-11          [-1, 128, 84, 60]             256            Conv2d-12          [-1, 128, 84, 60]         147,584              ReLU-13          [-1, 128, 84, 60]               0       BatchNorm2d-14          [-1, 128, 84, 60]             256         MaxPool2d-15          [-1, 128, 42, 30]               0         DownBlock-16  [[-1, 128, 42, 30], [-1, 128, 84, 60]]  0            Conv2d-17          [-1, 256, 42, 30]         295,168              ReLU-18          [-1, 256, 42, 30]               0       BatchNorm2d-19          [-1, 256, 42, 30]             512            Conv2d-20          [-1, 256, 42, 30]         590,080              ReLU-21          [-1, 256, 42, 30]               0       BatchNorm2d-22          [-1, 256, 42, 30]             512         MaxPool2d-23          [-1, 256, 21, 15]               0         DownBlock-24  [[-1, 256, 21, 15], [-1, 256, 42, 30]]  0            Conv2d-25          [-1, 512, 21, 15]       1,180,160              ReLU-26          [-1, 512, 21, 15]               0       BatchNorm2d-27          [-1, 512, 21, 15]           1,024            Conv2d-28          [-1, 512, 21, 15]       2,359,808              ReLU-29          [-1, 512, 21, 15]               0       BatchNorm2d-30          [-1, 512, 21, 15]           1,024         DownBlock-31  [[-1, 512, 21, 15], [-1, 512, 21, 15]]  0   ConvTranspose2d-32          [-1, 256, 42, 30]         524,544              ReLU-33          [-1, 256, 42, 30]               0       BatchNorm2d-34          [-1, 256, 42, 30]             512       Concatenate-35          [-1, 512, 42, 30]               0            Conv2d-36          [-1, 256, 42, 30]       1,179,904              ReLU-37          [-1, 256, 42, 30]               0       BatchNorm2d-38          [-1, 256, 42, 30]             512            Conv2d-39          [-1, 256, 42, 30]         590,080              ReLU-40          [-1, 256, 42, 30]               0       BatchNorm2d-41          [-1, 256, 42, 30]             512           UpBlock-42          [-1, 256, 42, 30]               0   ConvTranspose2d-43          [-1, 128, 84, 60]         131,200              ReLU-44          [-1, 128, 84, 60]               0       BatchNorm2d-45          [-1, 128, 84, 60]             256       Concatenate-46          [-1, 256, 84, 60]               0            Conv2d-47          [-1, 128, 84, 60]         295,040              ReLU-48          [-1, 128, 84, 60]               0       BatchNorm2d-49          [-1, 128, 84, 60]             256            Conv2d-50          [-1, 128, 84, 60]         147,584              ReLU-51          [-1, 128, 84, 60]               0       BatchNorm2d-52          [-1, 128, 84, 60]             256           UpBlock-53          [-1, 128, 84, 60]               0   ConvTranspose2d-54         [-1, 64, 168, 120]          32,832              ReLU-55         [-1, 64, 168, 120]               0       BatchNorm2d-56         [-1, 64, 168, 120]             128       Concatenate-57        [-1, 128, 168, 120]               0            Conv2d-58         [-1, 64, 168, 120]          73,792              ReLU-59         [-1, 64, 168, 120]               0       BatchNorm2d-60         [-1, 64, 168, 120]             128            Conv2d-61         [-1, 64, 168, 120]          36,928              ReLU-62         [-1, 64, 168, 120]               0       BatchNorm2d-63         [-1, 64, 168, 120]             128           UpBlock-64         [-1, 64, 168, 120]               0            Conv2d-65          [-1, 1, 168, 120]              65 ================================================================ Total params: 7,702,721 Trainable params: 7,702,721 Non-trainable params: 0 ---------------------------------------------------------------- Input size (MB): 0.08 Forward/backward pass size (MB): 7434.08 Params size (MB): 29.38 Estimated Total Size (MB): 7463.54 """
Эксп.№1 2D U-Net, подача изображений покадрово, плоскость [x, z]
Эксп.№1 2D U-Net, подача изображений покадрово, плоскость [x, z]

Определенно, это — зубы. Только кроме зубов есть много всего, нам ненужного. Подробнее о трансформации numpy — *.stl в Главе 6. Посмотрим ещё раз на фактический размер и качество изображений, которые попадают на вход нейросети:

Слева на право: 1. Не видно[x, y]. 2. Немного лучше[x, z]. 3.Ещё лучше[y, z]
Слева на право: 1. Не видно[x, y]. 2. Немного лучше[x, z]. 3.Ещё лучше[y, z]

Если сам не видишь на 100% где там начался зуб а где нет, то как тогда эту работу выполнит нейросеть? Как минимум, необходимо изменить плоскость подачи изображения.

Проведя не один день разбираясь в том, как можно улучшить сложившуюся ситуацию, пришел к тому, что можно составлять каскад и сетей, поочередно обрабатывающих изображение, аналогично работе фильтров грубой и тонкой очистки.

Эксп.№2 Каскад 2-ух 2D U-Net, подача изображений покадрово, плоскость [y, z]
Эксп.№2 Каскад 2-ух 2D U-Net, подача изображений покадрово, плоскость [y, z]

Прогресс виден, однако вместе с помехами пропадают и части зубов, дальнейшее обучение тому подтверждение:

Эксп.№3 Каскад 2-ух 2D U-Net, подача изображений покадрово плоскость [y, z] с увеличением времени обучения на 50%
Эксп.№3 Каскад 2-ух 2D U-Net, подача изображений покадрово плоскость [y, z] с увеличением времени обучения на 50%

Ввиду последних событий было принято решение о переходе на 3D архитектуру нейронной сети. Переподготовил входные данные, а именно разделил на части размером (24*, 120, 120). Почему так? — изначально большая модель обучения (~22млн. параметров). Моя видеокарта(1063gtx) не могла физически вместить больше.

24*

Это размер глубины. Был подобран так чтобы:

  • количество данных(1512, 120, 120) делится нацело на это число — получается 63;

  • в свою очередь получившийся batch size (24, 120, 120) — максимум, вмещающийся в память видеокарты с текущими параметрами сети;

  • само это число (24) делилось на количество спусков/подъемов так же нацело (имеется в виду соответствие выражению 24/2/2/2=3 и 3*2*2*2=24, где количество делений/умножений на 2 соответствует количеству спусков/подъемов минус 1);

  • то же самое не только для глубины данных, но и длинны и ширины. Подробнее в .summary()

model.summary()
model = UNet(dim=3, in_channels=1, out_channels=1, n_blocks=4, start_filters=64).to(device) print(summary(model, (1, 24, 120, 120)))  """   ----------------------------------------------------------------         Layer (type)               Output Shape         Param # ================================================================             Conv3d-1     [-1, 64, 24, 120, 120]             1,792               ReLU-2     [-1, 64, 24, 120, 120]                 0        BatchNorm3d-3     [-1, 64, 24, 120, 120]               128             Conv3d-4     [-1, 64, 24, 120, 120]           110,656               ReLU-5     [-1, 64, 24, 120, 120]                 0        BatchNorm3d-6     [-1, 64, 24, 120, 120]               128          MaxPool3d-7        [-1, 64, 12, 60, 60]                0          DownBlock-8  [[-1, 64, 12, 60, 60], [-1, 64, 24, 120, 120]]               0             Conv3d-9       [-1, 128, 12, 60, 60]          221,312              ReLU-10       [-1, 128, 12, 60, 60]                0       BatchNorm3d-11       [-1, 128, 12, 60, 60]              256            Conv3d-12       [-1, 128, 12, 60, 60]          442,496              ReLU-13       [-1, 128, 12, 60, 60]                0       BatchNorm3d-14       [-1, 128, 12, 60, 60]              256         MaxPool3d-15       [-1, 128, 6, 30, 30]                 0         DownBlock-16  [[-1, 128, 6, 30, 30], [-1, 128, 12, 60, 60]]               0            Conv3d-17       [-1, 256, 6, 30, 30]           884,992              ReLU-18       [-1, 256, 6, 30, 30]                 0       BatchNorm3d-19       [-1, 256, 6, 30, 30]               512            Conv3d-20       [-1, 256, 6, 30, 30]         1,769,728              ReLU-21       [-1, 256, 6, 30, 30]                 0       BatchNorm3d-22       [-1, 256, 6, 30, 30]               512         MaxPool3d-23       [-1, 256, 3, 15, 15]                 0         DownBlock-24  [[-1, 256, 3, 15, 15], [-1, 256, 6, 30, 30]]               0            Conv3d-25       [-1, 512, 3, 15, 15]         3,539,456              ReLU-26       [-1, 512, 3, 15, 15]                 0       BatchNorm3d-27       [-1, 512, 3, 15, 15]             1,024            Conv3d-28       [-1, 512, 3, 15, 15]         7,078,400              ReLU-29       [-1, 512, 3, 15, 15]                 0       BatchNorm3d-30       [-1, 512, 3, 15, 15]             1,024         DownBlock-31  [[-1, 512, 3, 15, 15], [-1, 512, 3, 15, 15]]               0   ConvTranspose3d-32       [-1, 256, 6, 30, 30]         1,048,832              ReLU-33       [-1, 256, 6, 30, 30]                 0       BatchNorm3d-34       [-1, 256, 6, 30, 30]               512       Concatenate-35       [-1, 512, 6, 30, 30]                 0            Conv3d-36       [-1, 256, 6, 30, 30]         3,539,200              ReLU-37       [-1, 256, 6, 30, 30]                 0       BatchNorm3d-38       [-1, 256, 6, 30, 30]               512            Conv3d-39       [-1, 256, 6, 30, 30]         1,769,728              ReLU-40       [-1, 256, 6, 30, 30]                 0       BatchNorm3d-41       [-1, 256, 6, 30, 30]               512           UpBlock-42       [-1, 256, 6, 30, 30]                 0   ConvTranspose3d-43       [-1, 128, 12, 60, 60]          262,272              ReLU-44       [-1, 128, 12, 60, 60]                0       BatchNorm3d-45       [-1, 128, 12, 60, 60]              256       Concatenate-46       [-1, 256, 12, 60, 60]                0            Conv3d-47       [-1, 128, 12, 60, 60]          884,864              ReLU-48       [-1, 128, 12, 60, 60]                0       BatchNorm3d-49       [-1, 128, 12, 60, 60]              256            Conv3d-50       [-1, 128, 12, 60, 60]          442,496              ReLU-51       [-1, 128, 12, 60, 60]                0       BatchNorm3d-52       [-1, 128, 12, 60, 60]              256           UpBlock-53       [-1, 128, 12, 60, 60]                0   ConvTranspose3d-54       [-1, 64, 24, 120, 120]          65,600              ReLU-55       [-1, 64, 24, 120, 120]               0       BatchNorm3d-56       [-1, 64, 24, 120, 120]             128       Concatenate-57      [-1, 128, 24, 120, 120]               0            Conv3d-58       [-1, 64, 24, 120, 120]         221,248              ReLU-59       [-1, 64, 24, 120, 120]               0       BatchNorm3d-60       [-1, 64, 24, 120, 120]             128            Conv3d-61       [-1, 64, 24, 120, 120]         110,656              ReLU-62       [-1, 64, 24, 120, 120]               0       BatchNorm3d-63       [-1, 64, 24, 120, 120]             128           UpBlock-64       [-1, 64, 24, 120, 120]               0            Conv3d-65        [-1, 1, 24, 120, 120]              65 ================================================================ Total params: 22,400,321 Trainable params: 22,400,321 Non-trainable params: 0 ---------------------------------------------------------------- Input size (MB): 0.61 Forward/backward pass size (MB): 15974.12 Params size (MB): 85.45 Estimated Total Size (MB): 16060.18 ---------------------------------------------------------------- """
Эксп.№4 3D U-Net, подача объемом, плоскость [y, z], время*0,38
Эксп.№4 3D U-Net, подача объемом, плоскость [y, z], время*0,38

С учетом сокращенного на ~60% времени обучения(25 epochs) результат меня устроил, продолжаем.

Эксп.№5 3D U-Net, подача объемом, плоскость [y, z],  65 epochs ~ 1,5 часа
Эксп.№5 3D U-Net, подача объемом, плоскость [y, z], 65 epochs ~ 1,5 часа

Особых потерь в искомых зонах не заметил. Решил продолжать, однако результат дальнейшего обучения мы уже где то видели(эксп.№3) — значительное уменьшение искомых зон и появление артефактов:

Эксп.№6 3D U-Net, подача объемом, плоскость [x, z],  105 epochs ~ 2,1 часа
Эксп.№6 3D U-Net, подача объемом, плоскость [x, z], 105 epochs ~ 2,1 часа

«Научный» перебор параметров в течении недели принес результат. Уменьшил количество параметров сети до ~400к (от первоначальных ~22м) путем уменьшения фильтра [18, 32, 64, 128] и спуска/подъема до 3. Изменил метод оптимизации на RSMProp. Уменьшение количества параметров нейросети позволило увеличить объем входных данных в три раза (1, 1, 72*, 120, 120). Посмотрим результат?

model.summary()
model = UNet(dim=3, in_channels=1, out_channels=1, n_blocks=3, start_filters=18).to(device) print(summary(model, (1, 1, 72, 120, 120)))  """ ----------------------------------------------------------------         Layer (type)               Output Shape         Param # ================================================================             Conv3d-1     [-1, 18, 72, 120, 120]             504               ReLU-2     [-1, 18, 72, 120, 120]               0        BatchNorm3d-3     [-1, 18, 72, 120, 120]              36             Conv3d-4     [-1, 18, 72, 120, 120]           8,766               ReLU-5     [-1, 18, 72, 120, 120]               0        BatchNorm3d-6     [-1, 18, 72, 120, 120]              36          MaxPool3d-7       [-1, 18, 36, 60, 60]               0          DownBlock-8  [[-1, 18, 36, 60, 60], [-1, 18, 24, 120, 120]]               0             Conv3d-9       [-1, 36, 36, 60, 60]          17,532              ReLU-10       [-1, 36, 36, 60, 60]               0       BatchNorm3d-11       [-1, 36, 36, 60, 60]              72            Conv3d-12       [-1, 36, 36, 60, 60]          35,028              ReLU-13       [-1, 36, 36, 60, 60]               0       BatchNorm3d-14       [-1, 36, 36, 60, 60]              72         MaxPool3d-15        [-1, 36, 18, 30, 30]              0         DownBlock-16  [[-1, 36, 18, 30, 30], [-1, 36, 36, 60, 60]]               0            Conv3d-17        [-1, 72, 18, 30, 30]         70,056              ReLU-18        [-1, 72, 18, 30, 30]              0       BatchNorm3d-19        [-1, 72, 18, 30, 30]            144            Conv3d-20        [-1, 72, 18, 30, 30]        140,040              ReLU-21        [-1, 72, 18, 30, 30]              0       BatchNorm3d-22        [-1, 72, 18, 30, 30]            144         DownBlock-23  [[-1, 72, 18, 30, 30], [-1, 72, 18, 30, 30]]               0   ConvTranspose3d-24       [-1, 36, 36, 60, 60]          20,772              ReLU-25       [-1, 36, 36, 60, 60]               0       BatchNorm3d-26       [-1, 36, 36, 60, 60]              72       Concatenate-27       [-1, 72, 36, 60, 60]               0            Conv3d-28       [-1, 36, 36, 60, 60]          70,020              ReLU-29       [-1, 36, 36, 60, 60]               0       BatchNorm3d-30       [-1, 36, 36, 60, 60]              72            Conv3d-31       [-1, 36, 36, 60, 60]          35,028              ReLU-32       [-1, 36, 36, 60, 60]               0       BatchNorm3d-33       [-1, 36, 36, 60, 60]              72           UpBlock-34       [-1, 36, 36, 60, 60]               0   ConvTranspose3d-35     [-1, 18, 72, 120, 120]           5,202              ReLU-36     [-1, 18, 72, 120, 120]               0       BatchNorm3d-37     [-1, 18, 72, 120, 120]              36       Concatenate-38     [-1, 36, 72, 120, 120]               0            Conv3d-39     [-1, 18, 72, 120, 120]          17,514              ReLU-40     [-1, 18, 72, 120, 120]               0       BatchNorm3d-41     [-1, 18, 72, 120, 120]              36            Conv3d-42     [-1, 18, 72, 120, 120]           8,766              ReLU-43     [-1, 18, 72, 120, 120]               0       BatchNorm3d-44     [-1, 18, 72, 120, 120]              36           UpBlock-45     [-1, 18, 72, 120, 120]               0            Conv3d-46      [-1, 1, 72, 120, 120]              19 ================================================================ Total params: 430,075 Trainable params: 430,075 Non-trainable params: 0 ---------------------------------------------------------------- Input size (MB): 1.32 Forward/backward pass size (MB): 5744.38 Params size (MB): 1.64 Estimated Total Size (MB): 5747.34 ---------------------------------------------------------------- """
72*

Некоторые из вас подумают, исходные данные (168, 120, 120), а часть (72, 120, 120). Назревает вопрос, как делить. Всё просто, во 2 главе мы увеличивали размер наших данных и затем делили их на части, соответствующие объему памяти видеокарты. Я увеличил данные в 9 раз (1512, 120, 120) т.е. повернул на 9 различных углов относительно одной оси, а затем разделил на 21(batch size) часть по (72, 120, 120). Так же 72 соответствует всем условиям, описанным в 24*(выше).

Эксп.№7 3D U-Net, подача объемом, плоскость [x, z], Маска (слева) и готовая сегментация (справа), оптимизированные параметры сети, время обучения(65 epochs) ~ 14мин.
Эксп.№7 3D U-Net, подача объемом, плоскость [x, z], Маска (слева) и готовая сегментация (справа), оптимизированные параметры сети, время обучения(65 epochs) ~ 14мин.

Результат вполне удовлетворительный, есть недочеты (вроде «похудевших» зубов). Возможно, исправим их в другом посте. Для этапа semantic segmentation я думаю мы сделали достаточно, теперь необходимо задать категории.

О размере подаваемых данных

Первоначальная идея при переходе на 3D архитектуру была в том чтобы делить данные не слайсами (как в данном посте) (1512, 120, 120) —> 21*(1, 72, 120, 120), а кубиками ~х*(30, 30, 30) или около того (результат этой попытки не был сохранен оп понятным причинам). Опытным путем понял 2 вещи: чем большими порциями ты подаешь 3-х мерные объекты, тем лучше результат(для моего конкретного случая); и нужно больше изучать теорию того, с чем работаешь.

О времени обучения и размере модели

Параметры сети подобраны так, что обучение 1 epochs на моей «старушке» занимает ~13сек, а размер конечной модели не превышает 2мб (прошлая>80мб). Время рабочего цикла примерно равно 1 epochs. Однако стоит понимать, это обучение и работа на данных достаточно маленького размера.

Для разделения на категории пришлось немного повозиться с функцией расчета ошибки и визуализацией данных. Первоначально поставил себе задачу разделить на 8 категорий + фон. О loss function и визуализации поговорим подробнее.

Код training loop
import torch from tqdm import tqdm from _loss_f import LossFunction   class TrainFunction:     def __init__(self,                  data_loader,                  device_for_training,                  model_name,                  model_name_pretrained,                  model,                  optimizer,                  scale,                  learning_rate: int = 1e-2,                  num_epochs: int = 1,                  transfer_learning: bool = False,                  binary_loss_f: bool = True                  ):         self.data_loader = data_loader         self.device = device_for_training         self.model_name_pretrained = model_name_pretrained         self.semantic_binary = binary_loss_f         self.num_epochs = num_epochs         self.model_name = model_name         self.transfer = transfer_learning         self.optimizer = optimizer         self.learning_rate = learning_rate         self.model = model         self.scale = scale      def forward(self):         print('Running on the:', torch.cuda.get_device_name(self.device))         self.model.load_state_dict(torch.load(self.model_name_pretrained)) if self.transfer else None         optimizer = self.optimizer(self.model.parameters(), lr=self.learning_rate)         for epoch in range(self.num_epochs):             self.train_loop(self.data_loader, self.model, optimizer, self.scale, epoch)             torch.save(self.model.state_dict(), 'models/' + self.model_name+str(epoch+1)                        + '_epoch.pth') if (epoch + 1) % 10 == 0 else None      def train_loop(self, loader, model, optimizer, scales, i):         loop, epoch_loss = tqdm(loader), 0         loop.set_description('Epoch %i' % (self.num_epochs - i))         for batch_idx, (data, targets) in enumerate(loop):             data, targets = data.to(device=self.device, dtype=torch.float), \                             targets.to(device=self.device, dtype=torch.long)             optimizer.zero_grad()             *тут секрет*             with torch.cuda.amp.autocast():                 predictions = model(data)                 loss = LossFunction(predictions, targets,                                     device_for_training=self.device,                                     semantic_binary=self.semantic_binary                                     ).forward()             scales.scale(loss).backward()             scales.step(optimizer)             scales.update()             epoch_loss += (1 - loss.item())*100             loop.set_postfix(loss=loss.item())         print('Epoch-acc', round(epoch_loss / (batch_idx+1), 2)) 

4. Функция расчета ошибки

Мне в целом понравилось как проявляет себя Dice-loss в сегментации, только ‘проблема’ в том что он работает с форматом данных [0, 1]. Однако, если предварительно разделить данные на категории (а так же привести к формату [0, 1]), и пропускать пары (имеется ввиду «предсказание» и «маска» только одной категории) в стандартную Dice-loss функцию, то это может сработать.

Код categorical_dice_loss
import torch   class LossFunction:     def __init__(self,                  prediction,                  target,                  device_for_training,                  semantic_binary: bool = True,                  ):         self.prediction = prediction         self.device = device_for_training         self.target = target         self.semantic_binary = semantic_binary      def forward(self):         if self.semantic_binary:             return self.dice_loss(self.prediction, self.target)         return self.categorical_dice_loss(self.prediction, self.target)      @staticmethod     def dice_loss(predictions, targets, alpha=1e-5):         intersection = 2. * (predictions * targets).sum()         denomination = (torch.square(predictions) + torch.square(targets)).sum()         dice_loss = 1 - torch.mean((intersection + alpha) / (denomination + alpha))          return dice_loss      def categorical_dice_loss(self, prediction, target):         pr, tr = self.prepare_for_multiclass_loss_f(prediction, target)         target_categories, losses = torch.unique(tr).tolist(), 0         for num_category in target_categories:             categorical_target = torch.where(tr == num_category, 1, 0)             categorical_prediction = pr[num_category][:][:][:]             losses += self.dice_loss(categorical_prediction, categorical_target).to(self.device)          return losses / len(target_categories)      @staticmethod     def prepare_for_multiclass_loss_f(prediction, target):         prediction_prepared = torch.squeeze(prediction, 0)         target_prepared = torch.squeeze(target, 0)         target_prepared = torch.squeeze(target_prepared, 0)          return prediction_prepared, target_prepared 

Тут просто, но всё равно объясню «categorical_dice_loss»:

  • подготовка данных (убираем ненужные в данном расчете измерения);

  • получения списка категорий, которые содержит каждый batch масок;

  • для каждой категории берем «прогноз» и «маску» соответствующих категорий, приводим значения к формату [0, 1] и пропускаем через стандартную Dice-loss;

  • складывая результаты и деля на количество категорий, получаем усредненное значение для каждого batct. Ну а дальше всё без изменений.

Так же, думаю, помог бы перевод данных к one-hot формату, но только не в момент формирования основного дата сета (раздует в размере), а непосредственно перед расчетом ошибки, но я не проверял. Кто в курсе, напишите, пожалуйста, буду рад. Результат работы данной функции будет в Главе(5).

5. Визуализация данных

Так и хочется добавить «..как отдельный вид искусства». Начну с того что прочитать *.nrrd оказалось самым простым.

Код
import nrrd # читает в numpy read = nrrd.read(data_path)  data, meta_data = read[0], read[1]  print(data.shape, np.max(data), np.min(data), meta_data, sep="\n")  (163, 112, 120) 14982 -2254   OrderedDict([('type', 'short'), ('dimension', 3), ('space', 'left-posterior-superior'), ('sizes', array([163, 112, 120])), ('space directions', array([[-0.5,  0. ,  0. ],        [ 0. , -0.5,  0. ],        [ 0. ,  0. ,  0.5]])), ('kinds', ['domain', 'domain', 'domain']), ('endian', 'little'), ('encoding', 'gzip'), ('space origin', array([131.57200623,  80.7661972 ,  32.29940033]))])

Дальше — сложнее, как обратно перевести? Получается, если я ничего не путаю, необходимо из числа получить вершины и грани, между которыми образуется поверхность.

Неправильный путь

Иными словами, чтобы сделать куб нам необходимо 8 вершин и 12 треугольных поверхностей. В этом и состояла первая идея (до применения специальных библиотек) — заменить все пиксели (числа в 3-х мерной матрице) на такие кубики. Код я не сохранил, но смысл прост, рисуем куб на месте «пикселя» со сдвигом -1 по трем направлениям, потом следующий и т.д.

Выглядит это так же бредово, как и звучит
Выглядит это так же бредово, как и звучит

Отрицательный результат — тоже результат, продолжаем. На этом этапе я уже понял, что без сторонних библиотек мне не обойтись. Первой попыткой в была пара Skimage и Stl.

from skimage.measure import marching_cubes import nrrd import numpy as np from stl import mesh  path = 'some_path.nrrd' data = nrrd.read(path)[0]   def three_d_creator(some_data):     vertices, faces, volume, _ = marching_cubes(some_data)     cube = mesh.Mesh(np.full(faces.shape[0], volume.shape[0], dtype=mesh.Mesh.dtype))     for i, f in enumerate(faces):         for j in range(3):             cube.vectors[i][j] = vertices[f[j]]     cube.save('name.stl')      return cube   stl = three_d_creator(datas)

Пользовался этим способом, но иногда файлы «ломались» в процессе сохранения и не открывались. А на те, которые открывались, ругался встроенный в Win 10 3D Builder и постоянно пытался там что-то исправить. Так же еще придется «прикрутить» к коду модуль для просмотра 3D объектов без их сохранения. Решение «из коробки» дальше.

На момент написания статью пользуюсь v3do. Коротко, быстро, удобно и можно сразу осмотреть модель.

Код перевода npy в stl и вывода объекта на дисплей
from vedo import Volume, show, write  prediction = 'some_data_path.npy'  def show_save(data, save=False):     data_multiclass = Volume(data, c='Set2', alpha=(0.1, 1), alphaUnit=0.87, mode=1)     data_multiclass.addScalarBar3D(nlabels=9)     show([(data_multiclass, "Multiclass teeth segmentation prediction")], bg='black', N=1, axes=1).close()     write(data_multiclass.isosurface(), 'some_name_.stl') if save else None      show_save(prediction, save=True)

Названия функций говорят сами за себя.

Пришло время увидеть конечный результат всего вышесказанного. Томить не буду:

model.summary()
model = UNet(dim=3, in_channels=1, out_channels=9, n_blocks=3, start_filters=9).to(device) print(summary(model, (1, 168*, 120, 120)))      """ ----------------------------------------------------------------         Layer (type)               Output Shape         Param # ================================================================             Conv3d-1      [-1, 9, 168, 120, 120]            252               ReLU-2      [-1, 9, 168, 120, 120]              0        BatchNorm3d-3      [-1, 9, 168, 120, 120]             18             Conv3d-4      [-1, 9, 168, 120, 120]          2,196               ReLU-5      [-1, 9, 168, 120, 120]              0        BatchNorm3d-6      [-1, 9, 168, 120, 120]             18          MaxPool3d-7        [-1, 9, 84, 60, 60]               0          DownBlock-8  [[-1, 9, 84, 60, 60], [-1, 9, 168, 120, 120]]               0             Conv3d-9       [-1, 18, 84, 60, 60]           4,392              ReLU-10       [-1, 18, 84, 60, 60]               0       BatchNorm3d-11       [-1, 18, 84, 60, 60]              36            Conv3d-12       [-1, 18, 84, 60, 60]           8,766              ReLU-13       [-1, 18, 84, 60, 60]               0       BatchNorm3d-14       [-1, 18, 84, 60, 60]              36         MaxPool3d-15       [-1, 18, 42, 30, 30]               0         DownBlock-16  [[-1, 18, 18, 42, 30], [-1, 18, 84, 60, 60]]               0            Conv3d-17       [-1, 36, 42, 30, 30]          17,532              ReLU-18       [-1, 36, 42, 30, 30]               0       BatchNorm3d-19       [-1, 36, 42, 30, 30]              72            Conv3d-20       [-1, 36, 42, 30, 30]          35,028              ReLU-21       [-1, 36, 42, 30, 30]               0       BatchNorm3d-22       [-1, 36, 42, 30, 30]              72         DownBlock-23  [[-1, 36, 42, 30, 30], [-1, 36, 42, 30, 30]]               0   ConvTranspose3d-24       [-1, 18, 84, 60, 60]           5,202              ReLU-25       [-1, 18, 84, 60, 60]               0       BatchNorm3d-26       [-1, 18, 84, 60, 60]              36       Concatenate-27       [-1, 36, 84, 60, 60]               0            Conv3d-28       [-1, 18, 84, 60, 60]          17,514              ReLU-29       [-1, 18, 84, 60, 60]               0       BatchNorm3d-30       [-1, 18, 84, 60, 60]              36            Conv3d-31       [-1, 18, 84, 60, 60]           8,766              ReLU-32       [-1, 18, 84, 60, 60]               0       BatchNorm3d-33       [-1, 18, 84, 60, 60]              36           UpBlock-34       [-1, 18, 84, 60, 60]               0   ConvTranspose3d-35      [-1, 9, 168, 120, 120]          1,305              ReLU-36      [-1, 9, 168, 120, 120]              0       BatchNorm3d-37      [-1, 9, 168, 120, 120]             18       Concatenate-38     [-1, 18, 168, 120, 120]              0            Conv3d-39      [-1, 9, 168, 120, 120]          4,383              ReLU-40      [-1, 9, 168, 120, 120]              0       BatchNorm3d-41      [-1, 9, 168, 120, 120]             18            Conv3d-42      [-1, 9, 168, 120, 120]          2,196              ReLU-43      [-1, 9, 168, 120, 120]              0       BatchNorm3d-44      [-1, 9, 168, 120, 120]             18           UpBlock-45      [-1, 9, 168, 120, 120]              0            Conv3d-46      [-1, 9, 168, 120, 120]             90 ================================================================ Total params: 108,036 Trainable params: 108,036 Non-trainable params: 0 ---------------------------------------------------------------- Input size (MB): 3.96 Forward/backward pass size (MB): 12170.30 Params size (MB): 0.41 Estimated Total Size (MB): 12174.66 ----------------------------------------------------------------     """

*Ввиду ещё большего уменьшения параметров сети(фильтр[9, 18, 36, 72]), удалось уместить объект в память видеокарты целиком — 9*(168, 120, 120)

6. After words

Думал, что закончил, а оказалось — только начал. Тут еще есть над чем поработать. Мне, в целом, 2 этап не нравится, хоть он и работает. Зачем заново переопределять каждый пиксель, когда мне нужен целый регион? А если, образно, есть 28 разделенных регионов, зачем мне пытаться определить их все, не проще ли определить один зуб и завязать это всё на «условный» ориентированный/неориентированный граф? Или вместо U-net использовать GCNN и вместо Pytorch — Pytorch3D? Пятна, думаю, можно убрать с помощью выравнивания данных внутри bounding box(ведь один зуб может принадлежать только 1 категории). Но, возможно, это вопросы для следующей публикации.

Прототип (набросок)
Тот самый «условный граф»
Пример неориентированного графа на 28 категорий с "разделителями"
Пример неориентированного графа на 28 категорий с «разделителями»

Отдельное спасибо моей жене — Алёне, за особую поддержку во время этого «погружения в темноту».

Благодарю всех за внимание. Конструктивная критика и предложения, как исправлений, так и новых проектов — приветствуются.

ссылка на оригинал статьи https://habr.com/ru/post/558826/


Комментарии

Добавить комментарий

Ваш адрес email не будет опубликован. Обязательные поля помечены *