🔥 PyTorch 到底是什么?为什么这么火?
PyTorch 是一个 开源的深度学习框架,它让神经网络的构建和训练变得 像搭积木一样简单。如果你想用 AI 训练模型,比如识别图片、生成文本、甚至做聊天机器人,PyTorch 都能帮你快速实现!💡
💡 为什么 PyTorch 这么受欢迎?
🔹 代码像 Python,一学就会!
- PyTorch 语法简单,跟 NumPy 很像,适合 Python 程序员上手。
- 代码可读性高,调试方便,适合新手和研究人员。
🔹 动态图(Dynamic Computation Graph)超级灵活!
- 你可以像写普通 Python 代码一样写神经网络,不需要提前定义整个计算过程。
- 适合 复杂网络,比如 RNN、Transformer 等。
🔹 GPU 加速,训练飞快!
- PyTorch 可以自动调用 GPU,让训练速度提升 几十倍!💨
🔹 强大的自动微分(Autograd)!
- PyTorch 能自动求导数,帮你计算梯度,省去手动推导的麻烦。
- 反向传播(Backpropagation)只要 一行代码 就能实现!
🤖 PyTorch 的核心概念
1️⃣ Tensor(张量):
- PyTorch 里的数据结构,类似于 NumPy 数组,但可以跑在 GPU 上。
- 举个例子,下面创建了一个 2x3 的 Tensor:
import torch
x = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
print(x)
2️⃣ Autograd(自动求导):
- PyTorch 可以 自动计算梯度,只要
requires_grad=True
,就能追踪计算。
pythonx = torch.tensor(2.0, requires_grad=True) y = x ** 2 # y = x^2 y.backward() # 计算 dy/dx print(x.grad) # 输出 4.0
3️⃣ 神经网络(torch.nn):
- PyTorch 提供了 现成的神经网络模块,比如全连接层、卷积层、RNN 等。
pythonimport torch.nn as nn model = nn.Linear(10, 1) # 10 个输入 -> 1 个输出
4️⃣ 优化器(torch.optim):
- 负责更新参数,比如常见的 SGD、Adam 等。
pythonoptimizer = torch.optim.SGD(model.parameters(), lr=0.01)
5️⃣ 数据加载(torch.utils.data):
- 轻松加载大规模数据,比如 图片、文本 等。
pythonfrom torch.utils.data import DataLoader, TensorDataset dataset = TensorDataset(torch.randn(100, 10), torch.randn(100, 1)) dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
🎯 PyTorch 训练神经网络的基本流程
- 准备数据(加载数据集、转换成 Tensor)
- 定义模型(搭建神经网络)
- 选择损失函数 & 优化器
- 训练模型(循环更新参数)
- 测试 & 预测
完整代码示例(训练一个简单的线性回归模型):
pythonimport torch import torch.nn as nn import torch.optim as optim # 1. 生成数据 x = torch.randn(100, 1) y = 3 * x + 2 + 0.1 * torch.randn(100, 1) # 目标函数:y = 3x + 2 # 2. 定义模型 model = nn.Linear(1, 1) # 线性回归 # 3. 选择损失函数 & 优化器 loss_fn = nn.MSELoss() # 均方误差 optimizer = optim.SGD(model.parameters(), lr=0.1) # 4. 训练模型 for epoch in range(100): y_pred = model(x) # 前向传播 loss = loss_fn(y_pred, y) # 计算损失 optimizer.zero_grad() # 梯度清零 loss.backward() # 反向传播 optimizer.step() # 更新参数 if epoch % 10 == 0: print(f'Epoch {epoch}, Loss: {loss.item()}') # 5. 预测 test_x = torch.tensor([[5.0]]) predicted_y = model(test_x) print(f'预测值: {predicted_y.item()}')
全部 0条评论