作者:bb | 来源:互联网 | 2023-08-30 20:16
1.编写先初始化fromtorch.utils.tensorboardimportSummaryWriterwriterSummaryWriter(commentfLR_{lr}
1.编写
先初始化
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}')
散点图
writer.add_scalar('Loss/train', loss.item(), global_step) #loss
直方图
writer.add_histogram('weights/' + tag, value.data.cpu().numpy(), global_step) #weight
绘图
writer.add_images('masks/true', true_masks, global_step)
writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.5, global_step)
2.查看
tensorboard --logdir=runs
这里的runs为tensorboard保存log的目录,默认为当前文件夹下的runs文件夹
然后进入http://localhost:6006/ 即可
![](https://img2.php1.cn/3cdc5/3984/8fd/6a8bd1ff48e83471.png)
![](https://img2.php1.cn/3cdc5/3984/8fd/31cb0871c0c97596.png)
3.删除log
ctrl c 关闭命令行的tensorboard,然后删除log文件再重新启动tensorboard。
4.详细样例如下,见pytorch-unet:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}')for epoch in range(epochs):net.train()epoch_loss = 0with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:for batch in train_loader:imgs = batch['image']true_masks = batch['mask']assert imgs.shape[1] == net.n_channels, \f'Network has been defined with {net.n_channels} input channels, ' \f'but loaded images have {imgs.shape[1]} channels. Please check that ' \'the images are loaded correctly.'imgs = imgs.to(device=device, dtype=torch.float32)mask_type = torch.float32 if net.n_classes == 1 else torch.longtrue_masks = true_masks.to(device=device, dtype=mask_type)masks_pred = net(imgs)loss = criterion(masks_pred, true_masks[:, 0, :, :].unsqueeze(1))epoch_loss += loss.item()writer.add_scalar('Loss/train', loss.item(), global_step)pbar.set_postfix(**{'loss (batch)': loss.item()})optimizer.zero_grad()loss.backward()nn.utils.clip_grad_value_(net.parameters(), 0.1)optimizer.step()pbar.update(imgs.shape[0])global_step += 1if global_step % (n_train // (3 * batch_size)) == 0:for tag, value in net.named_parameters():tag = tag.replace('.', '/')writer.add_histogram('weights/' + tag, value.data.cpu().numpy(), global_step)writer.add_histogram('grads/' + tag, value.grad.data.cpu().numpy(), global_step)val_score = eval_net(net, val_loader, device)if global_step > 500:scheduler.step(val_score)writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], global_step)if net.n_classes > 1:logging.info('Validation cross entropy: {}'.format(val_score))writer.add_scalar('Loss/test', val_score, global_step)else:logging.info('Validation Dice Coeff: {}'.format(val_score))writer.add_scalar('Dice/test', val_score, global_step)writer.add_images('images', imgs, global_step)if net.n_classes == 1:writer.add_images('masks/true', true_masks, global_step)writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.5, global_step)writer.close()