热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

【visiontransformer】LETR论文解读及代码实战(二)

LETR:LineSegmentDetectionUsingTransformerswithoutEdges基于vision-transformerDETR提取wireframe

LETR: Line Segment Detection Using Transformers without Edges

基于vision-transformer/DETR 提取wireframe的网络框架,截止日前实现了sota性能。

论文:https://arxiv.org/abs/2101.01909

代码:https://github.com/mlpc-ucsd/LETR


项目实战:

(1)构建环境:

git clone https://github.com/mlpc-ucsd/LETR.git

mkdir -p data
mkdir -p evaluation/data
mkdir -p expconda create -n letr
conda activate letr
conda install -c pytorch pytorch torchvision
conda install cython scipy
pip install -U 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'
pip install docopt

(2)准备数据,wireframe raw 数据集和YorkUrban 数据集。

(3)模型训练

# Usage: bash script/*/*.sh [exp name]
bash script/train/a0_train_stage1_res50.sh res50_stage1 # LETR-R50
bash script/train/a1_train_stage1_res101.sh res101_stage1 # LETR-R101

调试过程及问题解决

 (1)ImportError: cannot import name '_new_empty_tensor'

Traceback (most recent call last):File "src/main.py", line 13, in import datasetsFile "/home/wsx/0A_DATA/LETR/src/datasets/__init__.py", line 5, in from .coco import build as build_cocoFile "/home/wsx/0A_DATA/LETR/src/datasets/coco.py", line 10, in import datasets.transforms as TFile "/home/wsx/0A_DATA/LETR/src/datasets/transforms.py", line 18, in from util.misc import interpolateFile "/home/wsx/0A_DATA/LETR/src/util/misc.py", line 22, in from torchvision.ops import _new_empty_tensor
ImportError: cannot import name '_new_empty_tensor'

solution:

定位报错文件中File "/home/wsx/0A_DATA/LETR/src/util/misc.py", line 22, in ,将报错文件misc.py中的关于“_new_empty_tensor”的内容注释掉,如下:

# if float(torchvision.__version__[:3]) <0.7:# from torchvision.ops import _new_empty_tensor# from torchvision.ops.misc import _output_size

&#xff08;2&#xff09;再运行&#xff1a;ImportError: cannot import name &#39;_LinearWithBias&#39;

Traceback (most recent call last):File "src/main.py", line 17, in from models import build_modelFile "/home/wsx/0A_DATA/LETR/src/models/__init__.py", line 2, in from .letr import buildFile "/home/wsx/0A_DATA/LETR/src/models/letr.py", line 15, in from .transformer import build_transformerFile "/home/wsx/0A_DATA/LETR/src/models/transformer.py", line 16, in from .multi_head_attention import MultiheadAttentionFile "/home/wsx/0A_DATA/LETR/src/models/multi_head_attention.py", line 11, in from torch.nn.modules.linear import _LinearWithBias
ImportError: cannot import name &#39;_LinearWithBias&#39;

solution&#xff1a;

定位出错的文件&#xff0c;  File "/home/wsx/0A_DATA/LETR/src/models/multi_head_attention.py", line 11, in &#xff0c;将 line 11 的 # from torch.nn.modules.linear import _LinearWithBias 注释掉&#xff0c;
改为 from torch.nn.modules.linear import Linear&#xff0c;如下

# from torch.nn.modules.linear import _LinearWithBias
from torch.nn.modules.linear import Linear

同时将引用部分 multi_head_attention.py 的440行 &#xff0c;

self.out_proj &#61; _LinearWithBias(embed_dim, embed_dim)改为&#xff1a;self.out_proj &#61; Linear(embed_dim, embed_dim)

&#xff08;3&#xff09;再运行&#xff0c;RuntimeError: The NVIDIA driver on your system is too old (found version 10010).

Traceback (most recent call last):File "src/main.py", line 214, in main(args)File "src/main.py", line 21, in mainutils.init_distributed_mode(args)File "/home/wsx/0A_DATA/LETR/src/util/misc.py", line 421, in init_distributed_modetorch.cuda.set_device(args.gpu)File "/home/wsx/anaconda3/envs/letr/lib/python3.6/site-packages/torch/cuda/__init__.py", line 264, in set_devicetorch._C._cuda_setDevice(device)File "/home/wsx/anaconda3/envs/letr/lib/python3.6/site-packages/torch/cuda/__init__.py", line 172, in _lazy_inittorch._C._cuda_init()
RuntimeError: The NVIDIA driver on your system is too old (found version 10010). Please update your GPU driver by downloading and installing a new version from the URL: http://www.nvidia.com/Download/index.aspx Alternatively, go to: https://pytorch.org to install a PyTorch version that has been compiled with your version of the CUDA driver.

solution&#xff1a;

查看pytorch 版本

(letr) wsx&#64;hello:~/0A_DATA/LETR$ python
Python 3.6.4 |Anaconda, Inc.| (default, Mar 13 2018, 01:15:57)
[GCC 7.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> print(torch.__version__)
1.9.0&#43;cu102

查看CUDA版本

nvidia-smi

 pytorch 是cuda10.2的&#xff0c;机子的cuda是10.1的&#xff0c;故卸载pytorch &#xff0c;重装pytorch&#43;cuda,


下面pip安装可省略&#xff0c;直接用conda  

pip3 install torch&#61;&#61;1.8.1&#43;cu101 torchvision&#61;&#61;0.9.1&#43;cu101 -f https://download.py torch.org/whl/cu101/torch_stable.html

(letr) wsx&#64;hello:~/0A_DATA/LETR$ python -m pip install torch&#61;&#61;1.8.1&#43;cu101 torchvision&#61;&#61;0.9.1&#43;cu101
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
ERROR: Could not find a version that satisfies the requirement torch&#61;&#61;1.8.1&#43;cu101 (from versions: 1.0.0, 1.0.1, 1.0.1.post2, 1.1.0, 1.2.0, 1.3.0, 1.3.1, 1.4.0, 1.5.0, 1.5.1, 1.6.0, 1.7.0, 1.7.1, 1.8.0, 1.8.1, 1.9.0, 1.9.1, 1.10.0, 1.10.1, 1.10.2)
ERROR: No matching distribution found for torch&#61;&#61;1.8.1&#43;cu101


 conda 安装

conda install pytorch&#61;&#61;1.8.0 torchvision&#61;&#61;0.9.0 torchaudio&#61;&#61;0.8.0 cudatoolkit&#61;10.1 -c pytorch

大功告成&#xff1a;

 


推荐阅读
author-avatar
剧情归一_905
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有