Skip to content

LLM-PyTorch

LLM 大语言模型

  • 大语言模型 (人工智能模型)
  • 人工智能(人类智能水平计算机系统) -> 机器学习 (专注学习算法的开发和改进)-> 深度学习(多层神经网络的机器学习)
  • 机器学习:使计算机能够从数据中学习,并在没有被明确编程的情况下进行预测或决策
  • 深度:指的是人工神经元或节点的多个隐藏层,这些层使它们能够对数据中的复杂非线性关系进行建模。
  • 典型机器学习:预测建模工作流(监督学习)
    • 对训练数据集进行训练,训练好的模型对新观测数据进行预测
    • 例如:对垃圾邮箱和非垃圾邮箱进行标签甄别

pytorch

三大核心组件

  • 张量库:扩展了NumPy基于数组的编程功能,增加了GPU加速特性,从而实现了CPU和GPU之间的无缝计算切换;
  • 自动微分引擎autograd:它能够自动计算张量操作的梯度,从而简化反向传播和模型优化;
  • 深度学习库:它提供了模块化、灵活且高效的构建块(包括预训练模型、损失函数和优化器);

准备工作

  • 安装pytorch

    bash
    pip install torch==2.4.0
  • 简单测试

    py
    # torch 版本
    print(torch.__version__)
    # 检查安装是否识别了内置的NVIDIA GPU
    print(torch.cuda.is_available())
    # 以检查你的Mac是否支持使用Apple Silicon芯片加速PyTorch
    print(torch.backends.mps.is_available())

理解张量

张量是一种数据容器:存储多维数据,其中每个维度表示一个不同的特征

标量、向量、矩阵和张量

  • 标量是零维张量(例如,仅一个数值)​2
  • 向量是一维张量 [1 2 3]
  • 矩阵是二维张量 [ 1 4 ] [ 2 5 ] [ 3 6 ]
  • 对于更高维的张量没有特定的术语,因此通常将三维张量称为“3D张量”​
py
tensor0d = torch.tensor(1)
tensor1d = torch.tensor([1,2,3])
tensor2d = torch.tensor([[1,2],[3,4]])
tensor3d = torch.tensor([[[1,2],[3,4]],[[5,6],[7,8]]])

张量形状

py
tensor2d = torch.tensor([[1, 2, 3], [4, 5, 6]])
print(tensor2d.shape)
# torch.Size([2, 3])

改变形状

py
# 改变的是行与列的数量
print(tensor2d.reshape(3, 2))
# tensor([[1, 2],
#         [3, 4],
#         [5, 6]])
print(tensor2d.view(3, 2))
# tensor([[1, 2],
#         [3, 4],
#         [5, 6]])

转置张量

py
# 是每一个值的想 x,y坐标进行调换
tensor2d = torch.tensor([[1, 2, 3], [4, 5, 6]])
# 1(0,0) 2 (1,0) 3(2,0)
# 4(0,1) 5(1,1) 6(2,1)
print(tensor2d.T)
# tensor([[1, 4],
#         [2, 5],
#         [3, 6]])
# 1(0,0) 2 (0,1) 3(0,2)
# 4(1,0) 5(1,1) 6(1,2)

矩阵相乘

py
# 矩阵相乘
print(tensor2d.matmul(tensor2d.T))
# tensor([[14, 32],
#         [32, 77]])
print(tensor2d @ tensor2d.T)
# tensor([[14, 32],
#         [32, 77]])
  • 如何理解

    A.matmul(B):用 A 的每一行,去和 B 的每一列做点积

    Plain
    textA = tensor2d (2×3)
    [ 1  2  3 ]
    [ 4  5  6 ]
    Plain
    textA.T (3×2)
    [ 1  4 ]
    [ 2  5 ]
    [ 3  6 ]
    Plain
    text(m × n) @ (n × k) → (m × k)
    Plain
    text(2 × 3) @ (3 × 2) → (2 × 2)
    ✔ 中间维度 3 对齐
    ✔ 结果是 2×2
    • 🔹 第 (0,0) 个元素(A 的第 0 行 · A.T 的第 0 列)
    Plain
    text[1, 2, 3] · [1, 2, 3]= 1*1 + 2*2 + 3*3= 1 + 4 + 9= 14
    • 🔹 第 (0,1) 个元素
    Plain
    text[1, 2, 3] · [4, 5, 6]= 1*4 + 2*5 + 3*6= 4 + 10 + 18= 32
    • 🔹 第 (1,0) 个元素
    Plain
    text[4, 5, 6] · [1, 2, 3]= 4*1 + 5*2 + 6*3= 4 + 10 + 18= 32
    • 🔹 第 (1,1) 个元素
    Plain
    text[4, 5, 6] · [4, 5, 6]= 4*4 + 5*5 + 6*6= 16 + 25 + 36= 77
    • ✅ 拼成最终矩阵
    Plain
    text[ 14  32 ][ 32  77 ]

自动微分引擎

计算图

  • 计算图列出了计算神经网络输出所需的计算顺序
  • 我们需要用它来计算反向传播所需的梯度,这是神经网络的主要训练算法

简单逻辑回归分类器的前向传播(预测步骤)​

  • 逻辑回归的前向传播作为一个计算图;
  • PyTorch在后台构建了这样一个计算图,我们可以利用它来计算损失函数相对于模型参数(这里是w1和b)的梯度,在计算图中计算损失梯度的最常见方法是从右向左应用链式法则,这也称为“反向模型自动求导”或“反向传播”;
  • 偏导数,它测量的是一个函数相对于其中一个变量变化的速率。
  • 梯度是一个向量,包含了一个多变量函数(输入变量超过一个的函数)的所有偏导数。
  • PyTorch的autograd引擎在后台通过跟踪在张量上执行的每个操作来构建计算图。然后,通过调用grad函数,可以计算损失相对于模型参数w1的梯度

实现多层网络

神经网络

py
# 多层感知机(multilayer perceptron),即全连接神经网络
class NeuralNetwork(torch.nn.Module):
    def __init__(self, num_inputs, num_outputs):
        super().__init__()

        self.layers = torch.nn.Sequential(
            # 1st hidden layer
            torch.nn.Linear(num_inputs, 30),
            torch.nn.ReLU(),

            # 2nd hidden layer
            torch.nn.Linear(30, 20),
            torch.nn.ReLU(),

            # output layer
            torch.nn.Linear(20, num_outputs),
        )

    def forward(self, x):
        # print("x",x)
        # 与 print(X) 一致
        # tensor([[0.2391, 0.3194, 0.8111, 0.7507, 0.3306, 0.5374, 0.2845, 0.8459, 0.2232,
        #          0.2083, 0.8169, 0.1084, 0.3285, 0.7185, 0.3624, 0.3084, 0.8893, 0.4179,
        #          0.9741, 0.3697, 0.2397, 0.8936, 0.1443, 0.1365, 0.7625, 0.1632, 0.6641,
        #          0.1525, 0.9830, 0.5936, 0.9120, 0.0146, 0.6323, 0.4743, 0.7467, 0.3545,
        #          0.9994, 0.9815, 0.7399, 0.2057, 0.8742, 0.0138, 0.7676, 0.7481, 0.7570,
        #          0.6432, 0.9111, 0.2246, 0.8668, 0.6961]])
        logits = self.layers(x)
        return logits

torch.manual_seed(123)

神经网络测试

py
# NeuralNetwork 测试
def NeuralNetworkTest():
  model = NeuralNetwork(50, 3)
  X = torch.rand((1, 50))
  print(X)
  # tensor([[0.2391, 0.3194, 0.8111, 0.7507, 0.3306, 0.5374, 0.2845, 0.8459, 0.2232,
  #          0.2083, 0.8169, 0.1084, 0.3285, 0.7185, 0.3624, 0.3084, 0.8893, 0.4179,
  #          0.9741, 0.3697, 0.2397, 0.8936, 0.1443, 0.1365, 0.7625, 0.1632, 0.6641,
  #          0.1525, 0.9830, 0.5936, 0.9120, 0.0146, 0.6323, 0.4743, 0.7467, 0.3545,
  #          0.9994, 0.9815, 0.7399, 0.2057, 0.8742, 0.0138, 0.7676, 0.7481, 0.7570,
  #          0.6432, 0.9111, 0.2246, 0.8668, 0.6961]])
  out = model(X)
  print(out)
  # tensor([[-0.1670,  0.1001, -0.1219]], grad_fn=<AddmmBackward0>)

  with torch.no_grad():
      # model(X)返回的是logits
      # 不会返回 经过 激活函数softmax 或 sigmoid 处理的值
      # 因为损失函数会将 激活 与 负对数似然损失 结合在一起
      # 手动调用,获得 预测结果计算类别成员概率
      out = torch.softmax(model(X), dim=1)
  print(out)
  # tensor([[0.2983, 0.3896, 0.3121]])

数据集和数据加载器

py
# 创建训练数据集和测试数据集
class ToyDataset(Dataset):
    def __init__(self, X, y):
        self.features = X
        self.labels = y

    def __getitem__(self, index):
        one_x = self.features[index]
        one_y = self.labels[index]
        return one_x, one_y

    def __len__(self):
        return self.labels.shape[0]


X_train = torch.tensor([
        [-1.2, 3.1],
        [-0.9, 2.9],
        [-0.5, 2.6],
        [2.3, -1.1],
        [2.7, -1.5]
    ])
y_train = torch.tensor([0, 0, 0, 1, 1])

# 准备数据集
def prepare_dataset():
    X_test = torch.tensor([
        [-0.8, 2.8],
        [2.6, -1.6],
    ])
    y_test = torch.tensor([0, 1])

    train_ds = ToyDataset(X_train, y_train)
    test_ds = ToyDataset(X_test, y_test)

    # 创建数据加载器
    train_loader = DataLoader(
        dataset=train_ds,
        batch_size=2, # 每批次两条数据
        shuffle=True,  # 打乱顺序
        num_workers=0, # 当num_workers设置为大于0的数值时,会启动多个工作进程并行加载数据
        drop_last=True
    )
    test_loader = DataLoader(
        dataset=test_ds,
        batch_size=2,
        shuffle=True,
        num_workers=0
    )

    # for idx, (x, y) in enumerate(train_loader):
    #   print(f"Batch {idx+1}:", x, y)
    # Batch 1: tensor([[ 2.7000, -1.5000],
    #         [ 2.3000, -1.1000]]) tensor([1, 1])
    # Batch 2: tensor([[-0.9000,  2.9000],
    #         [-1.2000,  3.1000]]) tensor([0, 0])
    # Batch 3: tensor([[-0.5000,  2.6000]]) tensor([0])

    return train_loader, test_loader

训练

py
num_epochs = 3
train_loader,test_loader = prepare_dataset()
model = NeuralNetwork(2, 2) # 该数据集有两个特征,两个标签
optimizer = torch.optim.SGD(model.parameters(), lr=0.5)

# 进行三轮训练
for epoch in range(num_epochs):
    # 训练模式
    model.train()
    for batch_idx, (features, labels) in enumerate(train_loader):
        print("features",features)
        logits = model(features)
        # cross_entropy损失函数,后者会在内部应用softmax函数,以提高效率并增强数值稳定性
        loss = F.cross_entropy(logits, labels)  # Loss function

        optimizer.zero_grad() # 将上一轮的梯度 置为0
        loss.backward() # 计算梯度
        optimizer.step() # 优化器使用梯度更新模型参数

        # LOGGING
        print(f"Epoch: {epoch+1:03d}/{num_epochs:03d}"
              f" | Batchsize {labels.shape[0]:03d}"
              f" | Train/Val Loss: {loss:.2f}")

保存和加载模型

py
# 保存
# state_dict是一个Python字典对象,
# 它可以将模型中的每一层映射到其可训练参数(权重和偏置)
torch.save(model.state_dict(), "./torch/model.pth")
# 加载
# model = NeuralNetwork(2, 2)这一行并不是严格必需的。
# 然而,这里包含它是为了说明我们需要在内存中拥有一个模型的实例,这样才能应用保存的参数。
# 此外,NeuralNetwork(2, 2)的架构必须与最初保存的模型完全匹配。
model = NeuralNetwork(2, 2)
# model.load_state_dict()则将这些参数应用到模型中,有效地恢复了我们保存模型时模型的学习状态
model.load_state_dict(torch.load("./torch/model.pth",weights_only=True))

GPU

多GPU

  • DDP:PyTorch的分布式数据并行(DistributedDataParallel, DDP)策略
  • 模型副本独立保存,训练时同步

启动模式

  • 自己用 mp.spawn 启动多进程;
    • 单机多卡最常见;
    • 运行方式通常是:python DDP-script.py(或你手动 CUDA_VISIBLE_DEVICES=0,1 python ... 控制可见 GPU)
查看代码 DDP-script.py
py
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
# Source for "Build a Large Language Model From Scratch"
#   - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch

# Appendix A: Introduction to PyTorch (Part 3)

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# NEW imports:
import os
import platform
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group


# NEW: function to initialize a distributed process group (1 process / GPU)
# this allows communication among processes
def ddp_setup(rank, world_size):
    """
    Arguments:
        rank: a unique process ID
        world_size: total number of processes in the group
    """
    # rank of machine running rank:0 process
    # here, we assume all GPUs are on the same machine
    os.environ["MASTER_ADDR"] = "localhost"
    # any free port on the machine
    os.environ["MASTER_PORT"] = "12345"

    # initialize process group
    if platform.system() == "Windows":
        # Disable libuv because PyTorch for Windows isn't built with support
        os.environ["USE_LIBUV"] = "0"
        # Windows users may have to use "gloo" instead of "nccl" as backend
        # gloo: Facebook Collective Communication Library
        init_process_group(backend="gloo", rank=rank, world_size=world_size)
    else:
        # nccl: NVIDIA 集体通信库
        # rank: GPU 索引
        # world_size GPU数量
        # nccl: NVIDIA Collective Communication Library
        init_process_group(backend="nccl", rank=rank, world_size=world_size)
    # 设置当前的GPU设备
    torch.cuda.set_device(rank)


class ToyDataset(Dataset):
    def __init__(self, X, y):
        self.features = X
        self.labels = y

    def __getitem__(self, index):
        one_x = self.features[index]
        one_y = self.labels[index]
        return one_x, one_y

    def __len__(self):
        return self.labels.shape[0]


class NeuralNetwork(torch.nn.Module):
    def __init__(self, num_inputs, num_outputs):
        super().__init__()

        self.layers = torch.nn.Sequential(
            # 1st hidden layer
            torch.nn.Linear(num_inputs, 30),
            torch.nn.ReLU(),

            # 2nd hidden layer
            torch.nn.Linear(30, 20),
            torch.nn.ReLU(),

            # output layer
            torch.nn.Linear(20, num_outputs),
        )

    def forward(self, x):
        logits = self.layers(x)
        return logits


def prepare_dataset():
    X_train = torch.tensor([
        [-1.2, 3.1],
        [-0.9, 2.9],
        [-0.5, 2.6],
        [2.3, -1.1],
        [2.7, -1.5]
    ])
    y_train = torch.tensor([0, 0, 0, 1, 1])

    X_test = torch.tensor([
        [-0.8, 2.8],
        [2.6, -1.6],
    ])
    y_test = torch.tensor([0, 1])

    # Uncomment these lines to increase the dataset size to run this script on up to 8 GPUs:
    # factor = 4
    # X_train = torch.cat([X_train + torch.randn_like(X_train) * 0.1 for _ in range(factor)])
    # y_train = y_train.repeat(factor)
    # X_test = torch.cat([X_test + torch.randn_like(X_test) * 0.1 for _ in range(factor)])
    # y_test = y_test.repeat(factor)

    train_ds = ToyDataset(X_train, y_train)
    test_ds = ToyDataset(X_test, y_test)

    train_loader = DataLoader(
        dataset=train_ds,
        batch_size=2,
        shuffle=False,  # NEW: False because of DistributedSampler below
        pin_memory=True, # 在GPU启用更快内存传输
        drop_last=True,
        # NEW: chunk batches across GPUs without overlapping samples:
        # DistributedSampler 负责打乱数据,分割成不同且不重叠的子集
        sampler=DistributedSampler(train_ds)  # NEW
    )
    test_loader = DataLoader(
        dataset=test_ds,
        batch_size=2,
        shuffle=False,
    )
    return train_loader, test_loader


# NEW: wrapper
def main(rank, world_size, num_epochs):
    # rank(我们用作GPU ID的进程ID)已经自动传递了
    ddp_setup(rank, world_size)  # NEW: initialize process groups

    train_loader, test_loader = prepare_dataset()
    model = NeuralNetwork(num_inputs=2, num_outputs=2)
    model.to(rank) # gpu设备
    optimizer = torch.optim.SGD(model.parameters(), lr=0.5)

    model = DDP(model, device_ids=[rank])  # NEW: wrap model with DDP
    # the core model is now accessible as model.module

    for epoch in range(num_epochs):
        # NEW: Set sampler to ensure each epoch has a different shuffle order
        train_loader.sampler.set_epoch(epoch)

        model.train()
        for features, labels in train_loader:
            # 将数据转移到 gpu ID
            features, labels = features.to(rank), labels.to(rank)  # New: use rank
            logits = model(features)
            loss = F.cross_entropy(logits, labels)  # Loss function

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # LOGGING
            print(f"[GPU{rank}] Epoch: {epoch+1:03d}/{num_epochs:03d}"
                  f" | Batchsize {labels.shape[0]:03d}"
                  f" | Train/Val Loss: {loss:.2f}")

    model.eval()

    try:
        train_acc = compute_accuracy(model, train_loader, device=rank)
        print(f"[GPU{rank}] Training accuracy", train_acc)
        test_acc = compute_accuracy(model, test_loader, device=rank)
        print(f"[GPU{rank}] Test accuracy", test_acc)

    ####################################################
    # NEW (not in the book):
    except ZeroDivisionError as e:
        raise ZeroDivisionError(
            f"{e}\n\nThis script is designed for 2 GPUs. You can run it as:\n"
            "CUDA_VISIBLE_DEVICES=0,1 python DDP-script.py\n"
            f"Or, to run it on {torch.cuda.device_count()} GPUs, uncomment the code on lines 103 to 107."
        )
    ####################################################
    # 清理资源
    destroy_process_group()  # NEW: cleanly exit distributed mode


def compute_accuracy(model, dataloader, device):
    model = model.eval()
    correct = 0.0
    total_examples = 0

    for idx, (features, labels) in enumerate(dataloader):
        features, labels = features.to(device), labels.to(device)

        with torch.no_grad():
            logits = model(features)
        predictions = torch.argmax(logits, dim=1)
        compare = labels == predictions
        correct += torch.sum(compare)
        total_examples += len(compare)
    return (correct / total_examples).item()


if __name__ == "__main__":
    # This script may not work for GPUs > 2 due to the small dataset
    # Run `CUDA_VISIBLE_DEVICES=0,1 python DDP-script.py` if you have GPUs > 2
    print("PyTorch version:", torch.__version__)
    print("CUDA available:", torch.cuda.is_available())
    print("Number of GPUs available:", torch.cuda.device_count())
    torch.manual_seed(123)

    # NEW: spawn new processes
    # note that spawn will automatically pass the rank
    num_epochs = 3
    world_size = torch.cuda.device_count()
    # 多进程启动主函数,nprocs=world_size每个GPU一个进程
    mp.spawn(main, args=(world_size, num_epochs), nprocs=world_size)
    # nprocs=world_size spawns one process per GPU
  • torchrun
    • “由外部启动器(torchrun / deepspeed / SLURM 等)先把多进程启动好,
    • 你的脚本只负责读取环境变量并执行对应 rank 的那一份工作”
    • 运行方式:torchrun --nproc_per_node=2 DDP-script.py 或多机:torchrun --nnodes=2 --node_rank=0/1 ...
查看代码 DDP-script-torchrun.py
py
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
# Source for "Build a Large Language Model From Scratch"
#   - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch

# Appendix A: Introduction to PyTorch (Part 3)

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# NEW imports:
import os
import platform
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group


# NEW: function to initialize a distributed process group (1 process / GPU)
# this allows communication among processes
def ddp_setup(rank, world_size):
    """
    Arguments:
        rank: a unique process ID
        world_size: total number of processes in the group
    """
    # Only set MASTER_ADDR and MASTER_PORT if not already defined by torchrun
    if "MASTER_ADDR" not in os.environ:
        os.environ["MASTER_ADDR"] = "localhost"
    if "MASTER_PORT" not in os.environ:
        os.environ["MASTER_PORT"] = "12345"

    # initialize process group
    if platform.system() == "Windows":
        # Disable libuv because PyTorch for Windows isn't built with support
        os.environ["USE_LIBUV"] = "0"
        # Windows users may have to use "gloo" instead of "nccl" as backend
        # gloo: Facebook Collective Communication Library
        init_process_group(backend="gloo", rank=rank, world_size=world_size)
    else:
        # nccl: NVIDIA Collective Communication Library
        init_process_group(backend="nccl", rank=rank, world_size=world_size)

    torch.cuda.set_device(rank)


class ToyDataset(Dataset):
    def __init__(self, X, y):
        self.features = X
        self.labels = y

    def __getitem__(self, index):
        one_x = self.features[index]
        one_y = self.labels[index]
        return one_x, one_y

    def __len__(self):
        return self.labels.shape[0]


class NeuralNetwork(torch.nn.Module):
    def __init__(self, num_inputs, num_outputs):
        super().__init__()

        self.layers = torch.nn.Sequential(
            # 1st hidden layer
            torch.nn.Linear(num_inputs, 30),
            torch.nn.ReLU(),

            # 2nd hidden layer
            torch.nn.Linear(30, 20),
            torch.nn.ReLU(),

            # output layer
            torch.nn.Linear(20, num_outputs),
        )

    def forward(self, x):
        logits = self.layers(x)
        return logits


def prepare_dataset():
    X_train = torch.tensor([
        [-1.2, 3.1],
        [-0.9, 2.9],
        [-0.5, 2.6],
        [2.3, -1.1],
        [2.7, -1.5]
    ])
    y_train = torch.tensor([0, 0, 0, 1, 1])

    X_test = torch.tensor([
        [-0.8, 2.8],
        [2.6, -1.6],
    ])
    y_test = torch.tensor([0, 1])

    # Uncomment these lines to increase the dataset size to run this script on up to 8 GPUs:
    # factor = 4
    # X_train = torch.cat([X_train + torch.randn_like(X_train) * 0.1 for _ in range(factor)])
    # y_train = y_train.repeat(factor)
    # X_test = torch.cat([X_test + torch.randn_like(X_test) * 0.1 for _ in range(factor)])
    # y_test = y_test.repeat(factor)

    train_ds = ToyDataset(X_train, y_train)
    test_ds = ToyDataset(X_test, y_test)

    train_loader = DataLoader(
        dataset=train_ds,
        batch_size=2,
        shuffle=False,  # NEW: False because of DistributedSampler below
        pin_memory=True,
        drop_last=True,
        # NEW: chunk batches across GPUs without overlapping samples:
        sampler=DistributedSampler(train_ds)  # NEW
    )
    test_loader = DataLoader(
        dataset=test_ds,
        batch_size=2,
        shuffle=False,
    )
    return train_loader, test_loader


# NEW: wrapper
def main(rank, world_size, num_epochs):

    ddp_setup(rank, world_size)  # NEW: initialize process groups

    train_loader, test_loader = prepare_dataset()
    model = NeuralNetwork(num_inputs=2, num_outputs=2)
    model.to(rank)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.5)

    model = DDP(model, device_ids=[rank])  # NEW: wrap model with DDP
    # the core model is now accessible as model.module

    for epoch in range(num_epochs):
        # NEW: Set sampler to ensure each epoch has a different shuffle order
        train_loader.sampler.set_epoch(epoch)

        model.train()
        for features, labels in train_loader:

            features, labels = features.to(rank), labels.to(rank)  # New: use rank
            logits = model(features)
            loss = F.cross_entropy(logits, labels)  # Loss function

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # LOGGING
            print(f"[GPU{rank}] Epoch: {epoch+1:03d}/{num_epochs:03d}"
                  f" | Batchsize {labels.shape[0]:03d}"
                  f" | Train/Val Loss: {loss:.2f}")

    model.eval()

    try:
        train_acc = compute_accuracy(model, train_loader, device=rank)
        print(f"[GPU{rank}] Training accuracy", train_acc)
        test_acc = compute_accuracy(model, test_loader, device=rank)
        print(f"[GPU{rank}] Test accuracy", test_acc)

    ####################################################
    # NEW (not in the book):
    except ZeroDivisionError as e:
        raise ZeroDivisionError(
            f"{e}\n\nThis script is designed for 2 GPUs. You can run it as:\n"
            "torchrun --nproc_per_node=2 DDP-script-torchrun.py\n"
            f"Or, to run it on {torch.cuda.device_count()} GPUs, uncomment the code on lines 103 to 107."
        )
    ####################################################

    destroy_process_group()  # NEW: cleanly exit distributed mode


def compute_accuracy(model, dataloader, device):
    model = model.eval()
    correct = 0.0
    total_examples = 0

    for idx, (features, labels) in enumerate(dataloader):
        features, labels = features.to(device), labels.to(device)

        with torch.no_grad():
            logits = model(features)
        predictions = torch.argmax(logits, dim=1)
        compare = labels == predictions
        correct += torch.sum(compare)
        total_examples += len(compare)
    return (correct / total_examples).item()


if __name__ == "__main__":
    # NEW: Use environment variables set by torchrun if available, otherwise default to single-process.
    if "WORLD_SIZE" in os.environ:
        world_size = int(os.environ["WORLD_SIZE"])
    else:
        world_size = 1

    if "LOCAL_RANK" in os.environ:
        rank = int(os.environ["LOCAL_RANK"])
    elif "RANK" in os.environ:
        rank = int(os.environ["RANK"])
    else:
        rank = 0

    # Only print on rank 0 to avoid duplicate prints from each GPU process
    if rank == 0:
        print("PyTorch version:", torch.__version__)
        print("CUDA available:", torch.cuda.is_available())
        print("Number of GPUs available:", torch.cuda.device_count())

    torch.manual_seed(123)
    num_epochs = 3
    main(rank, world_size, num_epochs)