热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

Carla使用神经网络训练自动驾驶车辆模型搭建及训练

Carla使用神经网络训练自动驾驶车辆—模型搭建及训练上一节已经搭建好了一个carla仿真环境,并且进行了数据采集现在需要使用采集到的摄像头图片和转角数据进行模型训
Carla 使用神经网络训练自动驾驶车辆—模型搭建及训练

在这里插入图片描述

上一节已经搭建好了一个carla仿真环境,并且进行了数据采集


现在需要使用采集到的摄像头图片和转角数据进行模型训练


创建dataset类

import numpy as np
import config
from torch.utils.data import Datasetclass TrainDataSet(Dataset):def __init__(self):self.x=np.load(config.DATASET_PATH+"x_train.npy")[:config.TRAIN_DATASET_SIZES]self.y=np.load(config.DATASET_PATH+"y_train.npy")[:config.TRAIN_DATASET_SIZES]def __len__(self):return len(self.x)def __getitem__(self,index):return (self.x[index],self.y[index])class TestDataSet(Dataset):def __init__(self):self.x = np.load(config.DATASET_PATH+'x_train.npy')[config.TRAIN_DATASET_SIZES:]self.y = np.load(config.DATASET_PATH+'y_train.npy')[config.TRAIN_DATASET_SIZES:]def __len__(self):return len(self.x)def __getitem__(self, index):return (self.x[index], self.y[index])

模型部分

import torch
import torch.nn as nn
from torchsummary import summaryclass Block(nn.Module):def __init__(self, in_channels, out_channels, maxpool=False):super().__init__()self.conv = nn.Conv2d(in_channels, out_channels, 3, 1, 1)self.relu = nn.ReLU()self.maxpool = maxpoolself.max_pool = nn.MaxPool2d(2, 2)def forward(self, x):x = self.conv(x)x = self.relu(x)if self.maxpool:x = self.max_pool(x)return xclass Model(nn.Module):def __init__(self, in_channels=3, out_size=1):super().__init__()self.conv_block = nn.Sequential(Block(in_channels, 16),Block(16, 16, True),Block(16, 32),Block(32, 32, True),Block(32, 64),Block(64, 64, True),Block(64, 128),Block(128, 128, True),Block(128, 128),Block(128, 128, True),)self.fc_block = nn.Sequential(nn.Linear(1536, 100),nn.ReLU(),nn.Linear(100, 20),nn.ReLU(),nn.Linear(20, out_size))def forward(self, x):x = self.conv_block(x)x = x.view(x.size(0), -1)x = self.fc_block(x)return xif __name__ == '__main__':model = Model()summary(model, input_size=(3, 66, 200), device='cpu')data = torch.ones(1, 3, 66, 200)out = model(data)print(out.shape)

训练部分:

import torch
import torch.nn as nn
import keyboard
import matplotlib.pyplot as plt
import config
from torch.utils.data import DataLoader
from dataset import TrainDataSet
from model import Modelclass Trainer:def __init__(self, model, train_dataset, model_state=None):self.model = modelself.model_state = model_stateself.train_dataset = train_datasetself.loss_list = []self.main()def main(self):torch.manual_seed(config.SEED)torch.cuda.manual_seed(config.SEED)model = self.modelloss_list = self.loss_listmodel.to(config.DEVICE)if self.model_state:state = torch.load(self.model_state)model.load_state_dict(state)model.train()dataloader = DataLoader(self.train_dataset, batch_size=config.BATCH_SIZE, shuffle=True)optimizer = torch.optim.Adam(model.parameters(), lr=config.LEARNING_RATE)criterion = nn.MSELoss()for epoch in range(config.NUM_EPOCH):for x, y in dataloader:x, y = x.to(config.DEVICE).float(), y.to(config.DEVICE).float()x = x.reshape(x.size(0), 3, 66, 200)optimizer.zero_grad()out = model(x)loss = criterion(out, y)loss.backward()optimizer.step()loss_list.append(loss.item())self._show_loss(loss_list, '3')print('epoc[%i/%i] loss=%.5f' % (epoch, config.NUM_EPOCH, loss.item()))torch.save(model.state_dict(), config.MODEL_PATH+'model_state.pth')def _show_loss(self, loss_list, key='3'):if keyboard.is_pressed(key):plt.plot(loss_list)plt.ylim(0, 0.1)plt.show()if __name__ == '__main__':# Trainer(Model(), TrainDataSet(), model_state='model_state.pth')Trainer(Model(), TrainDataSet())

这样就会得到神经网络的权重model_state.pth,下一节使用训练好的神经网络在carla环境中进行自动驾驶测试


推荐阅读
  • 如何自行分析定位SAP BSP错误
    The“BSPtag”Imentionedintheblogtitlemeansforexamplethetagchtmlb:configCelleratorbelowwhichi ... [详细]
  • YOLOv7基于自己的数据集从零构建模型完整训练、推理计算超详细教程
    本文介绍了关于人工智能、神经网络和深度学习的知识点,并提供了YOLOv7基于自己的数据集从零构建模型完整训练、推理计算的详细教程。文章还提到了郑州最低生活保障的话题。对于从事目标检测任务的人来说,YOLO是一个熟悉的模型。文章还提到了yolov4和yolov6的相关内容,以及选择模型的优化思路。 ... [详细]
  • Spring源码解密之默认标签的解析方式分析
    本文分析了Spring源码解密中默认标签的解析方式。通过对命名空间的判断,区分默认命名空间和自定义命名空间,并采用不同的解析方式。其中,bean标签的解析最为复杂和重要。 ... [详细]
  • CSS3选择器的使用方法详解,提高Web开发效率和精准度
    本文详细介绍了CSS3新增的选择器方法,包括属性选择器的使用。通过CSS3选择器,可以提高Web开发的效率和精准度,使得查找元素更加方便和快捷。同时,本文还对属性选择器的各种用法进行了详细解释,并给出了相应的代码示例。通过学习本文,读者可以更好地掌握CSS3选择器的使用方法,提升自己的Web开发能力。 ... [详细]
  • Java容器中的compareto方法排序原理解析
    本文从源码解析Java容器中的compareto方法的排序原理,讲解了在使用数组存储数据时的限制以及存储效率的问题。同时提到了Redis的五大数据结构和list、set等知识点,回忆了作者大学时代的Java学习经历。文章以作者做的思维导图作为目录,展示了整个讲解过程。 ... [详细]
  • 本文讨论了一个关于cuowu类的问题,作者在使用cuowu类时遇到了错误提示和使用AdjustmentListener的问题。文章提供了16个解决方案,并给出了两个可能导致错误的原因。 ... [详细]
  • 不同优化算法的比较分析及实验验证
    本文介绍了神经网络优化中常用的优化方法,包括学习率调整和梯度估计修正,并通过实验验证了不同优化算法的效果。实验结果表明,Adam算法在综合考虑学习率调整和梯度估计修正方面表现较好。该研究对于优化神经网络的训练过程具有指导意义。 ... [详细]
  • 本文详细介绍了Spring的JdbcTemplate的使用方法,包括执行存储过程、存储函数的call()方法,执行任何SQL语句的execute()方法,单个更新和批量更新的update()和batchUpdate()方法,以及单查和列表查询的query()和queryForXXX()方法。提供了经过测试的API供使用。 ... [详细]
  • 也就是|小窗_卷积的特征提取与参数计算
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了卷积的特征提取与参数计算相关的知识,希望对你有一定的参考价值。Dense和Conv2D根本区别在于,Den ... [详细]
  • Python瓦片图下载、合并、绘图、标记的代码示例
    本文提供了Python瓦片图下载、合并、绘图、标记的代码示例,包括下载代码、多线程下载、图像处理等功能。通过参考geoserver,使用PIL、cv2、numpy、gdal、osr等库实现了瓦片图的下载、合并、绘图和标记功能。代码示例详细介绍了各个功能的实现方法,供读者参考使用。 ... [详细]
  • 前景:当UI一个查询条件为多项选择,或录入多个条件的时候,比如查询所有名称里面包含以下动态条件,需要模糊查询里面每一项时比如是这样一个数组条件:newstring[]{兴业银行, ... [详细]
  • 本文讨论了如何使用IF函数从基于有限输入列表的有限输出列表中获取输出,并提出了是否有更快/更有效的执行代码的方法。作者希望了解是否有办法缩短代码,并从自我开发的角度来看是否有更好的方法。提供的代码可以按原样工作,但作者想知道是否有更好的方法来执行这样的任务。 ... [详细]
  • web.py开发web 第八章 Formalchemy 服务端验证方法
    本文介绍了在web.py开发中使用Formalchemy进行服务端表单数据验证的方法。以User表单为例,详细说明了对各字段的验证要求,包括必填、长度限制、唯一性等。同时介绍了如何自定义验证方法来实现验证唯一性和两个密码是否相等的功能。该文提供了相关代码示例。 ... [详细]
  • 本文介绍了在CentOS上安装Python2.7.2的详细步骤,包括下载、解压、编译和安装等操作。同时提供了一些注意事项,以及测试安装是否成功的方法。 ... [详细]
  • Spring学习(4):Spring管理对象之间的关联关系
    本文是关于Spring学习的第四篇文章,讲述了Spring框架中管理对象之间的关联关系。文章介绍了MessageService类和MessagePrinter类的实现,并解释了它们之间的关联关系。通过学习本文,读者可以了解Spring框架中对象之间的关联关系的概念和实现方式。 ... [详细]
author-avatar
ab5212502902861
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有