tqdm makes your loops show a smart progress meter.

tqdmpython中一个快速可扩展的进度条,可以在长循环中添加进度提示信息。

tqdm在阿拉伯语中含义是“进步”(taqadum, تقدّم),在西班牙语中是“我很爱你(I love you so much)”的缩写(te quiero demasiado)。

tqdm作用于迭代器(如列表),可以打印进度条:

from tqdm import tqdm

for i in tqdm(range(1000)):
    pass

for i in tqdm([1,2,3,4]):
    pass

trangetqdm(range)的简单写法:

from tqdm import trange

for i in trange(1000):
    pass

可以分配给tqdm一个变量手动控制更新,此时需要在循环结束后关闭该变量。下述代码表示总进度为$1000$,循环$100$次,则每次更新$10$。

pbar = tqdm(total=1000)
for i in range(100):
    pbar.update(10)
pbar.close()

也可以使用with语句手动控制更新:

with tqdm(total=1000) as pbar:
    for i in range(100):
        pbar.update(10)

例1:打印模型训练过程

若训练集大小为n_train,训练总轮数为epochs。每一轮训练对应一个进度条,每一个batch更新一次进度条。进度条后显示当前batch的训练损失。

使用total参数指定每一个进度条的总长度(对应n_train),使用desc参数描述进度条(传入epochepochs参数表示训练轮数),使用unit参数指定进度更新单位(对应图像img)。

使用set_postfix方法增加显示信息,使用update方法更新进度条。

for epoch in range(epochs):
    with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:
        for batch in train_loader:
            pbar.set_postfix({'loss (batch)': loss.item()})
            pbar.update(imgs.shape[0])