PyTorch设置可复现/重复实验
admin
2024-01-26 00:41:00

搬来了定型设置的方法,深度学习在训练过程中,由于随机初始化,样本读取的随机性,导致重复的实验结果会有差别,个别情况甚至波动较大。一般论文为了严谨,实验结论能够复现/可重复,通常采取固定随机种子使得结果确定

确定性设置

1 随机种子设置

随机函数是最大的不确定性来源,包含了模型参数的随机初始化,样本的shuffle。

  • PyTorch 随机种子

  • python 随机种子

  • numpy 随机种子

# PyTorch
import torch
torch.manual_seed(0)# python
import random
random.seed(0)# Third part libraries
import numpy as np
np.random.seed(0)

CPU版本下,上述随机种子设置完成之后,基本就可实现实验的可复现了。

对于GPU版本,存在大量算法实现为不确定结果的算法,这种算法实现效率很高,但是每次返回的值会不完全一样。主要是由于浮点精度舍弃,不同浮点数以不同顺序相加,值可能会有很小的差异(小数点最末位)。

2 GPU算法确定性实现

GPU算法的不确定来源有两个

  • CUDA convolution benchmarking

  • nondeterministic algorithms

CUDA convolution benchmarking 是为了提升运行效率,对模型参数试运行后,选取最优实现。不同硬件以及benchmarking本身存在噪音,导致不确定性

nondeterministic algorithms:GPU最大优势就是并行计算,如果能够忽略顺序,就避免了同步要求,能够大大提升运行效率,所以很多算法都有非确定性结果的算法实现。通过设置use_deterministic_algorithms,就可以使得pytorch选择确定性算法。

# 不需要benchmarking
torch.backends.cudnn.benchmark=False# 选择确定性算法
torch.use_deterministic_algorithms()

RUNTIME ERROR

对于一个PyTorch 的函数接口,没有确定性算法实现,只有非确定性算法实现,同时设置了use_deterministic_algorithms(),那么会导致运行时错误。比如:

>>> import torch
>>> torch.use_deterministic_algorithms(True)
>>> torch.randn(2, 2).cuda().index_add_(0, torch.tensor([0, 1]), torch.randn(2, 2))
Traceback (most recent call last):
File "", line 1, in 
RuntimeError: index_add_cuda_ does not have a deterministic implementation, but you set
'torch.use_deterministic_algorithms(True)'. ...

错误原因:

index_add没有确定性的实现,出现这种错误,一般都是因为调用了torch.index_select 这个api接口,或者直接调用tensor.index_add_。

解决方案:

自己定义一个确定性的实现,替换调用的接口。对于torch.index_select 这个接口,可以有如下的实现。

def deterministic_index_select(input_tensor, dim, indices):"""input_tensor: Tensordim: dim indices: 1D tensor"""tensor_transpose = torch.transpose(x, 0, dim)return tensor_transpose[indices].transpose(dim, 0)

样本读取随机

  1. 多线程情况下,设置每个线程读取的随机种子

  2. 设置样本generator

# 设置每个读取线程的随机种子
def seed_worker(worker_id):worker_seed = torch.initial_seed() % 2**32numpy.random.seed(worker_seed)random.seed(worker_seed)g = torch.Generator()
# 设置样本shuffle随机种子,作为DataLoader的参数
g.manual_seed(0)DataLoader(train_dataset,batch_size=batch_size,num_workers=num_workers,worker_init_fn=seed_worker,generator=g,
)

有点短哦~~   whaosoft aiot http://143ai.com 

相关内容

热门资讯

杭州灵隐飞来峰景区12月1日起...   新华社杭州11月19日电(记者段菁菁)为更好地满足市民游客的旅游需求、提升游览品质,杭州西湖风景...
杭州官宣取消灵隐寺门票! 每经编辑|程鹏 11月19日,每经小编从杭州西湖景区官方账号了解到,自2025年12月1日起,灵隐...
装上伊贝莎泳池,民宿营业额反超... “之前旺季靠降价抢单,现在客人主动加价订周末房!” 经营民宿三年的李姐,至今对引入伊贝莎泳池后的变化...
霉腌醉酱的美食哲学,看中国烹饪... 主讲人、图片提供 / 茅天尧 文字整理 / 孙阳 绍兴菜历史悠久,技艺源远流长,以“霉”“腌”“醉...