LETR:LineSegmentDetectionUsingTransformerswithoutEdges基于vision-transformerDETR提取wireframe
LETR: Line Segment Detection Using Transformers without Edges
基于vision-transformer/DETR 提取wireframe的网络框架,截止日前实现了sota性能。
论文:https://arxiv.org/abs/2101.01909
代码:https://github.com/mlpc-ucsd/LETR
项目实战:
(1)构建环境:
git clone https://github.com/mlpc-ucsd/LETR.git
mkdir -p data
mkdir -p evaluation/data
mkdir -p expconda create -n letr
conda activate letr
conda install -c pytorch pytorch torchvision
conda install cython scipy
pip install -U 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'
pip install docopt
(2)准备数据,wireframe raw 数据集和YorkUrban 数据集。
(3)模型训练
# Usage: bash script/*/*.sh [exp name]
bash script/train/a0_train_stage1_res50.sh res50_stage1 # LETR-R50
bash script/train/a1_train_stage1_res101.sh res101_stage1 # LETR-R101
调试过程及问题解决
(1)ImportError: cannot import name '_new_empty_tensor'
Traceback (most recent call last):File "src/main.py", line 13, in import datasetsFile "/home/wsx/0A_DATA/LETR/src/datasets/__init__.py", line 5, in from .coco import build as build_cocoFile "/home/wsx/0A_DATA/LETR/src/datasets/coco.py", line 10, in import datasets.transforms as TFile "/home/wsx/0A_DATA/LETR/src/datasets/transforms.py", line 18, in from util.misc import interpolateFile "/home/wsx/0A_DATA/LETR/src/util/misc.py", line 22, in from torchvision.ops import _new_empty_tensor
ImportError: cannot import name '_new_empty_tensor'
solution:
定位报错文件中File "/home/wsx/0A_DATA/LETR/src/util/misc.py", line 22, in ,将报错文件misc.py中的关于“_new_empty_tensor”的内容注释掉,如下:
# if float(torchvision.__version__[:3]) <0.7:# from torchvision.ops import _new_empty_tensor# from torchvision.ops.misc import _output_size
&#xff08;2&#xff09;再运行&#xff1a;ImportError: cannot import name &#39;_LinearWithBias&#39;
Traceback (most recent call last):File "src/main.py", line 17, in from models import build_modelFile "/home/wsx/0A_DATA/LETR/src/models/__init__.py", line 2, in from .letr import buildFile "/home/wsx/0A_DATA/LETR/src/models/letr.py", line 15, in from .transformer import build_transformerFile "/home/wsx/0A_DATA/LETR/src/models/transformer.py", line 16, in from .multi_head_attention import MultiheadAttentionFile "/home/wsx/0A_DATA/LETR/src/models/multi_head_attention.py", line 11, in from torch.nn.modules.linear import _LinearWithBias
ImportError: cannot import name &#39;_LinearWithBias&#39;
solution&#xff1a;
定位出错的文件&#xff0c; File "/home/wsx/0A_DATA/LETR/src/models/multi_head_attention.py", line 11, in &#xff0c;将 line 11 的 # from torch.nn.modules.linear import _LinearWithBias 注释掉&#xff0c;
改为 from torch.nn.modules.linear import Linear&#xff0c;如下
# from torch.nn.modules.linear import _LinearWithBias
from torch.nn.modules.linear import Linear
同时将引用部分 multi_head_attention.py 的440行 &#xff0c;
self.out_proj &#61; _LinearWithBias(embed_dim, embed_dim)改为&#xff1a;self.out_proj &#61; Linear(embed_dim, embed_dim)
&#xff08;3&#xff09;再运行&#xff0c;RuntimeError: The NVIDIA driver on your system is too old (found version 10010).
Traceback (most recent call last):File "src/main.py", line 214, in main(args)File "src/main.py", line 21, in mainutils.init_distributed_mode(args)File "/home/wsx/0A_DATA/LETR/src/util/misc.py", line 421, in init_distributed_modetorch.cuda.set_device(args.gpu)File "/home/wsx/anaconda3/envs/letr/lib/python3.6/site-packages/torch/cuda/__init__.py", line 264, in set_devicetorch._C._cuda_setDevice(device)File "/home/wsx/anaconda3/envs/letr/lib/python3.6/site-packages/torch/cuda/__init__.py", line 172, in _lazy_inittorch._C._cuda_init()
RuntimeError: The NVIDIA driver on your system is too old (found version 10010). Please update your GPU driver by downloading and installing a new version from the URL: http://www.nvidia.com/Download/index.aspx Alternatively, go to: https://pytorch.org to install a PyTorch version that has been compiled with your version of the CUDA driver.
solution&#xff1a;
查看pytorch 版本
(letr) wsx&#64;hello:~/0A_DATA/LETR$ python
Python 3.6.4 |Anaconda, Inc.| (default, Mar 13 2018, 01:15:57)
[GCC 7.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> print(torch.__version__)
1.9.0&#43;cu102
查看CUDA版本
nvidia-smi
pytorch 是cuda10.2的&#xff0c;机子的cuda是10.1的&#xff0c;故卸载pytorch &#xff0c;重装pytorch&#43;cuda,
下面pip安装可省略&#xff0c;直接用conda
pip3 install torch&#61;&#61;1.8.1&#43;cu101 torchvision&#61;&#61;0.9.1&#43;cu101 -f https://download.py torch.org/whl/cu101/torch_stable.html
(letr) wsx&#64;hello:~/0A_DATA/LETR$ python -m pip install torch&#61;&#61;1.8.1&#43;cu101 torchvision&#61;&#61;0.9.1&#43;cu101
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
ERROR: Could not find a version that satisfies the requirement torch&#61;&#61;1.8.1&#43;cu101 (from versions: 1.0.0, 1.0.1, 1.0.1.post2, 1.1.0, 1.2.0, 1.3.0, 1.3.1, 1.4.0, 1.5.0, 1.5.1, 1.6.0, 1.7.0, 1.7.1, 1.8.0, 1.8.1, 1.9.0, 1.9.1, 1.10.0, 1.10.1, 1.10.2)
ERROR: No matching distribution found for torch&#61;&#61;1.8.1&#43;cu101
conda 安装
conda install pytorch&#61;&#61;1.8.0 torchvision&#61;&#61;0.9.0 torchaudio&#61;&#61;0.8.0 cudatoolkit&#61;10.1 -c pytorch
大功告成&#xff1a;