首页 > 技术知识 > 正文

【深度学习】U型的Transfomer网络(Swin-Unet)和Swin-Transformer分类

文章目录 1 概述 2 Swin-Transformer分类源码 3 训练 4 关于复杂度降低问题 5 关于SW-MSA的操作问题 6 总结 1 概述

代码暂时还未开源。 【深度学习】U型的Transfomer网络和Swin-Transformer分类 在过去的几年中,卷积神经网络(CNN)在医学图像分析中取得了里程碑式的进展。尤其是,基于U形架构和跳跃连接的深度神经网络已广泛应用于各种医学图像任务中。但是,尽管CNN取得了出色的性能,但是由于卷积操作的局限性,它无法很好地学习全局和远程语义信息交互。

在本文中,我们提出了Swin-Unet,它是用于医学图像分割的类似Unet的纯Transformer。标记化的图像块通过跳跃连接被馈送到基于Transformer的U形En-Decoder架构中,以进行局部全局语义特征学习。

【深度学习】U型的Transfomer网络和Swin-Transformer分类1

具体来说,我们使用带有偏移窗口的分层Swin Transformer作为编码器来提取上下文特征。 【深度学习】U型的Transfomer网络和Swin-Transformer分类2

并设计了具有补丁扩展层的基于对称Swin Transformer的解码器来执行上采样操作,以恢复特征图的空间分辨率。

2 Swin-Transformer分类源码

最近swin-transformer大火,代码开源两天,girhub直接飙到1.9k。估计接下来关于和swin-transformer相结合的各种网络结构paper就要出来了,哈哈,我也是其中的一员,拼手速吧各位。它的原理网上的博客已经讲的非常的细致了,甚至还有带着读源代码的。这些大佬真的很强,下面会放一些本人读过的非常有助于理解的博客。我在这里主要分享的是官方源码如何跑通,跑通它的代码还是非常不容易的,有很多的小坑。对于我们小白而言,跑通代码才能给我们继续了解原理的信心,然后也可以大胆的debug,去验证里面的代码是否与论文所述的一致。 配置环境

把代码clone到你的服务器上,或者本地 git clone https://github.com/microsoft/Swin-Transformer.git cd Swin-Transformer

创建运行环境,并进入环境 conda create -n swin python=3.7 -y conda activate swin

安装需要的环境

conda install pytorch==1.7.1 torchvision==0.8.2 cudatoolkit=10.1 -c pytorch

这里注意一下自己的环境,我这边的cuda是10.1的,所以可以直接按着官方给的这个来。怎么看自己的cuda环境呢,有很多种方法,最靠谱的是这个:

cat /usr/local/cuda/version.txt

别看nvidia-smi的那个,那个不准。

安装 timm==0.3.2: pip install timm==0.3.2 cd apex pip install -v –disable-pip-version-check –no-cache-dir –global-option=”–cpp_ext” –global-option=”–cuda_ext” ./ 3 训练

首先是训练的运行方式:

python -m torch.distributed.launch –nproc_per_node 4 –master_port 12345 main.py –cfg configs/swin_tiny_patch4_window7_224.yaml –data-path imagenet –batch-size 64

–data-path对应的就是数据文件所在的位置

然后是测试的运行方式:

python -m torch.distributed.launch –nproc_per_node 1 –master_port 12345 main.py –eval –cfg configs/swin_tiny_patch4_window7_224.yaml –resume /pth/swin_tiny_patch4_window7_224.pth –data-path imagenet

补充一个你可能遇到的bug,CalledProcessError & RuntimeError。遇到这个bug的原因就是上面的命令里面写的有问题,比如路径写错了或者压根儿路径对应的文件就不存在,那么就会报这样的错误,大家注意!!!

swin-transformer文件目录 【深度学习】U型的Transfomer网络和Swin-Transformer分类3 这就是从官方源码那边clone下来的,区别在于我这里加了个pth文件,就是模型文件,我提供的百度网盘文件里面有,还有imagenet文件夹,这里面放着数据。

4 关于复杂度降低问题

【深度学习】U型的Transfomer网络和Swin-Transformer分类4 Swin-transformer是怎么把复杂度降低的呢? Swin Transformer Block这个模块和普通的transformer的区别就在于W-MSA,而它就是降低复杂度计算的大功臣。 关于复杂度的计算,我简单的给大家介绍一下,首先是transformer本身基于全局的复杂度计算,这一块儿讲起来有点复杂,感兴趣的同学我们可以会后一起探讨推导过程。在这里,我们假设已知MSA的复杂度是图像大小的平方,根据MSA的复杂度,我们可以得出A的复杂度是(3×3)²,最后复杂度是81。Swin transformer是在每个local windows(红色部分)计算self-attention,根据MSA的复杂度我们可以得出每个红色窗口的复杂度是1×1的平方,也就是1的四次方。然后9个窗口,这些窗口的复杂度加和,最后B的复杂度为9。

5 关于SW-MSA的操作问题

W-MSA虽然降低了计算复杂度,但是不重合的window之间缺乏信息交流,所以想要窗口之间的信息有所交流,那么就可以把左图演化成右图这样,但是这就产生了一个问题,如此操作,会产生更多的windows,并且其中一部分window小于普通的window,比如4个window -> 9个window,windows数量增加了一倍多。这计算量又上来了。因此我们有两个目的,Windows数量不能多,window之间信息得有交流。

【深度学习】U型的Transfomer网络和Swin-Transformer分类5

6 总结

transformer的出现并不是为了替代CNN。因为transformer有着CNN没有的功能性,它不仅可以提取特征,还可以做很多CNN做不到的事情,比如多模态融合。而swin transformer就是一个趋势,将CNN与transformer各自的优势有效的结合了起来。 skip connections数量的影响? Swin-UNet在1/4,1/8和1/16的降采样尺度上添加了skip connections。通过将skip connections数分别更改为0、1、2和3,实验了不同skip connections数量对模型分割性能的影响。从下表中可以看出,模型的性能随着skip connections数的增加而提高。因此,为了使模型更加鲁棒,本工作中设置skip connections数为3。

猜你喜欢