注意
点击 这里 下载完整的示例代码
PyTorch: 张量¶
创建于:2020年12月03日 | 最后更新:2020年12月03日 | 最后验证:2024年11月05日
一个三阶多项式,通过最小化欧几里得距离的平方来训练,以预测从\(-\pi\)到\(pi\)的\(y=\sin(x)\)。
此实现使用PyTorch张量手动计算前向传播、损失和后向传播。
PyTorch Tensor 基本上与 numpy 数组相同:它不了解深度学习、计算图或梯度,只是一个用于任意数值计算的通用 n 维数组。
numpy数组和PyTorch张量之间最大的区别在于,PyTorch张量可以在CPU或GPU上运行。要在GPU上运行操作,只需将张量转换为cuda数据类型。
99 463.81201171875
199 312.108154296875
299 211.08497619628906
399 143.78140258789062
499 98.92195129394531
599 69.00719451904297
699 49.04833984375
799 35.724693298339844
899 26.82526397705078
999 20.877582550048828
1099 16.900102615356445
1199 14.238385200500488
1299 12.455989837646484
1399 11.26158618927002
1499 10.46059799194336
1599 9.923055648803711
1699 9.561992645263672
1799 9.319270133972168
1899 9.15597152709961
1999 9.045997619628906
Result: y = 0.009059899486601353 + 0.8446162343025208 x + -0.0015629848930984735 x^2 + -0.0916057676076889 x^3
import torch
import math
dtype = torch.float
device = torch.device("cpu")
# device = torch.device("cuda:0") # Uncomment this to run on GPU
# Create random input and output data
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
y = torch.sin(x)
# Randomly initialize weights
a = torch.randn((), device=device, dtype=dtype)
b = torch.randn((), device=device, dtype=dtype)
c = torch.randn((), device=device, dtype=dtype)
d = torch.randn((), device=device, dtype=dtype)
learning_rate = 1e-6
for t in range(2000):
# Forward pass: compute predicted y
y_pred = a + b * x + c * x ** 2 + d * x ** 3
# Compute and print loss
loss = (y_pred - y).pow(2).sum().item()
if t % 100 == 99:
print(t, loss)
# Backprop to compute gradients of a, b, c, d with respect to loss
grad_y_pred = 2.0 * (y_pred - y)
grad_a = grad_y_pred.sum()
grad_b = (grad_y_pred * x).sum()
grad_c = (grad_y_pred * x ** 2).sum()
grad_d = (grad_y_pred * x ** 3).sum()
# Update weights using gradient descent
a -= learning_rate * grad_a
b -= learning_rate * grad_b
c -= learning_rate * grad_c
d -= learning_rate * grad_d
print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3')
脚本总运行时间: ( 0 分钟 0.317 秒)