{ "cells": [ { "cell_type": "code", "id": "initial_id", "metadata": { "collapsed": true, "ExecuteTime": { "end_time": "2026-04-22T07:03:02.177207285Z", "start_time": "2026-04-22T07:02:59.204677901Z" } }, "source": [ "\n", "import torch\n", "import d2l\n", "import numpy\n", "import torch.nn as nn\n", "import torch.nn.functional as F" ], "outputs": [], "execution_count": 2 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:02.454043741Z", "start_time": "2026-04-22T07:03:02.230904947Z" } }, "cell_type": "code", "source": [ "net = nn.Sequential(nn.Linear(20, 256), nn.ReLU(), nn.Linear(256, 10))\n", "X = torch.rand(2, 20)\n", "net(X)" ], "id": "dcd5590e7795eec1", "outputs": [ { "data": { "text/plain": [ "tensor([[ 0.0041, -0.3465, -0.2096, 0.2304, -0.1043, 0.0066, 0.1817, 0.0355,\n", " 0.2685, -0.0461],\n", " [-0.0932, -0.1621, -0.1244, 0.2398, -0.0759, 0.0680, 0.1511, 0.0224,\n", " 0.2522, -0.0228]], grad_fn=)" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 3 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:04.497911379Z", "start_time": "2026-04-22T07:03:03.603349572Z" } }, "cell_type": "code", "source": [ "class MLP(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.hidden=nn.Linear(20,256)\n", " self.out=nn.Linear(256,10)\n", " def forward(self,X):\n", " return self.out(F.relu(self.hidden(X)))\n" ], "id": "4ae330604b643cb4", "outputs": [], "execution_count": 4 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:05.186050578Z", "start_time": "2026-04-22T07:03:04.689781242Z" } }, "cell_type": "code", "source": [ "net=MLP()\n", "net(X)" ], "id": "cca55c6c0c7da12f", "outputs": [ { "data": { "text/plain": [ "tensor([[-0.2165, 0.1394, 0.0867, 0.0692, 0.2914, -0.1427, 0.2218, -0.0533,\n", " -0.2137, 0.0044],\n", " [-0.2020, 0.0648, 0.0514, 0.0500, 0.2555, -0.1679, 0.1621, -0.1462,\n", " -0.2527, 0.0386]], grad_fn=)" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 5 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:05.664056879Z", "start_time": "2026-04-22T07:03:05.331438994Z" } }, "cell_type": "code", "source": [ "class FixedHiddenMLP(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " # 不计算梯度的随机权重参数。因此其在训练期间保持不变\n", " self.rand_weight = torch.rand((20, 20), requires_grad=False)\n", " self.linear = nn.Linear(20, 20)\n", " def forward(self, X):\n", " X = self.linear(X)\n", " # 使用创建的常量参数以及relu和mm函数\n", " X = F.relu(torch.mm(X, self.rand_weight) + 1)\n", " # 复用全连接层。这相当于两个全连接层共享参数\n", " X = self.linear(X)\n", " # 控制流\n", " while X.abs().sum() > 1:\n", " X /= 2\n", " return X.sum()" ], "id": "4518d62611d5e749", "outputs": [], "execution_count": 6 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:06.104535508Z", "start_time": "2026-04-22T07:03:05.824555955Z" } }, "cell_type": "code", "source": [ "net = FixedHiddenMLP()\n", "net(X)" ], "id": "fae0187ece4ed5c6", "outputs": [ { "data": { "text/plain": [ "tensor(-0.0023, grad_fn=)" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 7 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:06.273938290Z", "start_time": "2026-04-22T07:03:06.117091179Z" } }, "cell_type": "code", "source": [ "class NestMLP(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.net = nn.Sequential(nn.Linear(20, 64), nn.ReLU(),\n", " nn.Linear(64, 32), nn.ReLU())\n", " self.linear = nn.Linear(32, 16)\n", " def forward(self, X):\n", " return self.linear(self.net(X))\n", " chimera = nn.Sequential(NestMLP(), nn.Linear(16, 20), FixedHiddenMLP())\n", " chimera(X)" ], "id": "407ef13a86453aae", "outputs": [], "execution_count": 8 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:06.462449517Z", "start_time": "2026-04-22T07:03:06.323939028Z" } }, "cell_type": "code", "source": [ "net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 1))\n", "X = torch.rand(size=(2, 4))\n", "net(X)" ], "id": "9f3526f263c7a249", "outputs": [ { "data": { "text/plain": [ "tensor([[-0.1265],\n", " [-0.0471]], grad_fn=)" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 9 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:06.843325610Z", "start_time": "2026-04-22T07:03:06.539581889Z" } }, "cell_type": "code", "source": "print(net[2].state_dict())", "id": "8c73f8daa02ba28b", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "OrderedDict([('weight', tensor([[ 0.0136, -0.1015, 0.1191, 0.2722, 0.3456, -0.0650, -0.0437, -0.2806]])), ('bias', tensor([-0.0945]))])\n" ] } ], "execution_count": 10 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:07.189943309Z", "start_time": "2026-04-22T07:03:06.962295444Z" } }, "cell_type": "code", "source": "net[2].state_dict()", "id": "b6fee6b64fb96e3c", "outputs": [ { "data": { "text/plain": [ "OrderedDict([('weight',\n", " tensor([[ 0.0136, -0.1015, 0.1191, 0.2722, 0.3456, -0.0650, -0.0437, -0.2806]])),\n", " ('bias', tensor([-0.0945]))])" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 11 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:07.395792068Z", "start_time": "2026-04-22T07:03:07.243437434Z" } }, "cell_type": "code", "source": "print(type(net[2].bias))", "id": "b38e8dc384e038c5", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "execution_count": 12 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:07.629769183Z", "start_time": "2026-04-22T07:03:07.457413574Z" } }, "cell_type": "code", "source": [ "print(net[2].bias)\n", "print(net[2].bias.data)\n" ], "id": "73f12ca3669d9ede", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Parameter containing:\n", "tensor([-0.0945], requires_grad=True)\n", "tensor([-0.0945])\n" ] } ], "execution_count": 13 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:07.873696310Z", "start_time": "2026-04-22T07:03:07.679040535Z" } }, "cell_type": "code", "source": "net[2].weight.grad==None", "id": "db0fe33018c16fac", "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 14 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:07.984798139Z", "start_time": "2026-04-22T07:03:07.896070141Z" } }, "cell_type": "code", "source": [ "print(*[(name, param.shape) for name, param in net[0].named_parameters()])\n", "print(*[(name, param.shape) for name, param in net.named_parameters()])" ], "id": "75847a1c608ee5c7", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "('weight', torch.Size([8, 4])) ('bias', torch.Size([8]))\n", "('0.weight', torch.Size([8, 4])) ('0.bias', torch.Size([8])) ('2.weight', torch.Size([1, 8])) ('2.bias', torch.Size([1]))\n" ] } ], "execution_count": 15 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:08.010084298Z", "start_time": "2026-04-22T07:03:07.991112964Z" } }, "cell_type": "code", "source": "net.state_dict()['2.bias'].data", "id": "cc74913e8742da7d", "outputs": [ { "data": { "text/plain": [ "tensor([-0.0945])" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 16 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:08.088926956Z", "start_time": "2026-04-22T07:03:08.042795645Z" } }, "cell_type": "code", "source": [ "def block1():\n", " return nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 4),nn.ReLU())\n", "def block2():\n", " net = nn.Sequential()\n", " for i in range(4):\n", " net.add_module(f'block{i}', block1())\n", " return net" ], "id": "53c39c5e61fa7bf5", "outputs": [], "execution_count": 17 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:08.330228511Z", "start_time": "2026-04-22T07:03:08.096053767Z" } }, "cell_type": "code", "source": [ "rgnet = nn.Sequential(block2(),nn.Linear(4,1))\n", "rgnet(X)" ], "id": "d3ac7759b619aca", "outputs": [ { "data": { "text/plain": [ "tensor([[0.0117],\n", " [0.0117]], grad_fn=)" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 18 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:08.645186191Z", "start_time": "2026-04-22T07:03:08.455908607Z" } }, "cell_type": "code", "source": "print(rgnet)", "id": "8fc60f64b07781e6", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sequential(\n", " (0): Sequential(\n", " (block0): Sequential(\n", " (0): Linear(in_features=4, out_features=8, bias=True)\n", " (1): ReLU()\n", " (2): Linear(in_features=8, out_features=4, bias=True)\n", " (3): ReLU()\n", " )\n", " (block1): Sequential(\n", " (0): Linear(in_features=4, out_features=8, bias=True)\n", " (1): ReLU()\n", " (2): Linear(in_features=8, out_features=4, bias=True)\n", " (3): ReLU()\n", " )\n", " (block2): Sequential(\n", " (0): Linear(in_features=4, out_features=8, bias=True)\n", " (1): ReLU()\n", " (2): Linear(in_features=8, out_features=4, bias=True)\n", " (3): ReLU()\n", " )\n", " (block3): Sequential(\n", " (0): Linear(in_features=4, out_features=8, bias=True)\n", " (1): ReLU()\n", " (2): Linear(in_features=8, out_features=4, bias=True)\n", " (3): ReLU()\n", " )\n", " )\n", " (1): Linear(in_features=4, out_features=1, bias=True)\n", ")\n" ] } ], "execution_count": 19 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:08.903696348Z", "start_time": "2026-04-22T07:03:08.733628048Z" } }, "cell_type": "code", "source": "rgnet[0][1][0].bias.data", "id": "e590aaafca787b50", "outputs": [ { "data": { "text/plain": [ "tensor([ 0.2396, -0.2293, -0.3365, 0.0070, -0.0166, -0.2328, -0.1627, 0.3407])" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 20 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:09.181470689Z", "start_time": "2026-04-22T07:03:08.920938456Z" } }, "cell_type": "code", "source": [ "def init_normal(m):\n", " if type(m) == nn.Linear:\n", " nn.init.normal_(m.weight, mean=0, std=0.01)\n", " nn.init.zeros_(m.bias)\n", "net.apply(init_normal)\n", "net[0].weight.data[0], net[0].bias.data[0]" ], "id": "925ca33221d0a87e", "outputs": [ { "data": { "text/plain": [ "(tensor([ 0.0166, 0.0092, 0.0013, -0.0031]), tensor(0.))" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 21 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:09.303120289Z", "start_time": "2026-04-22T07:03:09.184866866Z" } }, "cell_type": "code", "source": [ "def init_xavier(m):\n", " if type(m) == nn.Linear:\n", " nn.init.xavier_uniform_(m.weight)\n", "def init_42(m):\n", " if type(m) == nn.Linear:\n", " nn.init.constant_(m.weight, 42)\n", "\n", "net[0].apply(init_xavier)\n", "net[2].apply(init_42)\n", "print(net[0].weight.data[0])\n", "print(net[2].weight.data)" ], "id": "81e2de84a8c4ef32", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([-0.2085, 0.4344, -0.3960, 0.5868])\n", "tensor([[42., 42., 42., 42., 42., 42., 42., 42.]])\n" ] } ], "execution_count": 22 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:09.411878374Z", "start_time": "2026-04-22T07:03:09.355106030Z" } }, "cell_type": "code", "source": [ "x = torch.arange(4)\n", "torch.save(x, 'x-file')" ], "id": "f05bb378bb60ab9e", "outputs": [], "execution_count": 23 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:09.509836133Z", "start_time": "2026-04-22T07:03:09.427360581Z" } }, "cell_type": "code", "source": [ "x2 = torch.load('x-file')\n", "x2" ], "id": "a74ecaaac0d826c6", "outputs": [ { "data": { "text/plain": [ "tensor([0, 1, 2, 3])" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 24 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:09.542056671Z", "start_time": "2026-04-22T07:03:09.518568625Z" } }, "cell_type": "code", "source": [ "class MLP(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.hidden = nn.Linear(20, 256)\n", " self.output = nn.Linear(256, 10)\n", " def forward(self, x):\n", " return self.output(F.relu(self.hidden(x)))\n", "\n", "net = MLP()\n", "X = torch.randn(size=(2, 20))\n", "Y = net(X)" ], "id": "b42598f0c4a8e801", "outputs": [], "execution_count": 25 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:09.603222999Z", "start_time": "2026-04-22T07:03:09.548610614Z" } }, "cell_type": "code", "source": "torch.save(net.state_dict(), 'mlp.params')", "id": "aaa22eef549caa6f", "outputs": [], "execution_count": 26 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:09.699407155Z", "start_time": "2026-04-22T07:03:09.607082306Z" } }, "cell_type": "code", "source": [ "clone = MLP()\n", "clone.load_state_dict(torch.load('mlp.params'))\n", "clone.eval()" ], "id": "b92f920229abeeae", "outputs": [ { "data": { "text/plain": [ "MLP(\n", " (hidden): Linear(in_features=20, out_features=256, bias=True)\n", " (output): Linear(in_features=256, out_features=10, bias=True)\n", ")" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 27 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:09.854186935Z", "start_time": "2026-04-22T07:03:09.721875531Z" } }, "cell_type": "code", "source": [ "Y_clone = clone(X)\n", "Y_clone == Y" ], "id": "646c9eb6d7cc81c2", "outputs": [ { "data": { "text/plain": [ "tensor([[True, True, True, True, True, True, True, True, True, True],\n", " [True, True, True, True, True, True, True, True, True, True]])" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 28 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:10.143072800Z", "start_time": "2026-04-22T07:03:09.938713854Z" } }, "cell_type": "code", "source": [ "def corr2d(X,K):\n", " h,w=K.shape\n", " Y=torch.ones((X.shape[0]-h+1,X.shape[1]-w+1))\n", " for i in range(Y.shape[0]):\n", " for j in range(Y.shape[1]):\n", " Y[i,j]=(X[i:i+h,j:j+w]*K).sum()\n", " return Y\n" ], "id": "d45f9adfe47fce20", "outputs": [], "execution_count": 29 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:10.449191223Z", "start_time": "2026-04-22T07:03:10.209878470Z" } }, "cell_type": "code", "source": [ "X = torch.tensor([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]])\n", "K = torch.tensor([[0.0, 1.0], [2.0, 3.0]])\n", "corr2d(X,K)" ], "id": "db7279e13647c315", "outputs": [ { "data": { "text/plain": [ "tensor([[19., 25.],\n", " [37., 43.]])" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 30 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:10.723922087Z", "start_time": "2026-04-22T07:03:10.558883210Z" } }, "cell_type": "code", "source": [ "class Conv2D(nn.Module):\n", " def __init__(self, kernel_size):\n", " super().__init__()\n", " self.weight = nn.Parameter(torch.rand(kernel_size))\n", " self.bias = nn.Parameter(torch.zeros(1))\n", " def forward(self, x):\n", " return corr2d(x, self.weight) + self.bias\n" ], "id": "d60be1bd12a1f37e", "outputs": [], "execution_count": 31 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:11.208027335Z", "start_time": "2026-04-22T07:03:10.944803814Z" } }, "cell_type": "code", "source": [ "X = torch.ones((6, 8))\n", "X[:, 2:6] = 0\n", "X" ], "id": "5083789b7a728442", "outputs": [ { "data": { "text/plain": [ "tensor([[1., 1., 0., 0., 0., 0., 1., 1.],\n", " [1., 1., 0., 0., 0., 0., 1., 1.],\n", " [1., 1., 0., 0., 0., 0., 1., 1.],\n", " [1., 1., 0., 0., 0., 0., 1., 1.],\n", " [1., 1., 0., 0., 0., 0., 1., 1.],\n", " [1., 1., 0., 0., 0., 0., 1., 1.]])" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 32 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:11.547516752Z", "start_time": "2026-04-22T07:03:11.280664423Z" } }, "cell_type": "code", "source": [ "K = torch.tensor([[1.0, -1.0]])\n", "Y = corr2d(X, K)\n", "Y" ], "id": "ee8d6bedbde886ad", "outputs": [ { "data": { "text/plain": [ "tensor([[ 0., 1., 0., 0., 0., -1., 0.],\n", " [ 0., 1., 0., 0., 0., -1., 0.],\n", " [ 0., 1., 0., 0., 0., -1., 0.],\n", " [ 0., 1., 0., 0., 0., -1., 0.],\n", " [ 0., 1., 0., 0., 0., -1., 0.],\n", " [ 0., 1., 0., 0., 0., -1., 0.]])" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 33 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:11.992663567Z", "start_time": "2026-04-22T07:03:11.712431704Z" } }, "cell_type": "code", "source": "corr2d(X.t(), K)", "id": "a8278c3837fa9a1c", "outputs": [ { "data": { "text/plain": [ "tensor([[0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.]])" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 34 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:12.544730677Z", "start_time": "2026-04-22T07:03:12.187262859Z" } }, "cell_type": "code", "source": "conv2d = nn.Conv2d(1,1, kernel_size=(1, 2), bias=False)", "id": "ec61cdb61a8cabff", "outputs": [], "execution_count": 35 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:12.918278160Z", "start_time": "2026-04-22T07:03:12.663915511Z" } }, "cell_type": "code", "source": [ "X = X.reshape((1, 1, 6, 8))\n", "Y = Y.reshape((1, 1, 6, 7))\n", "lr = 3e-2" ], "id": "d2fc19d84c79a10", "outputs": [], "execution_count": 36 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:14.229826949Z", "start_time": "2026-04-22T07:03:12.942822259Z" } }, "cell_type": "code", "source": [ "for i in range(100):\n", " Y_hat = conv2d(X)\n", " l = (Y_hat - Y) ** 2\n", " conv2d.zero_grad()\n", " l.sum().backward()\n", " # 迭代卷积核\n", " conv2d.weight.data[:] -= lr * conv2d.weight.grad\n", " if (i + 1) % 20 == 0:\n", " print(f'epoch {i+1}, loss {l.sum():.3f}')" ], "id": "51fbb2e6398a9bd5", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch 20, loss 0.000\n", "epoch 40, loss 0.000\n", "epoch 60, loss 0.000\n", "epoch 80, loss 0.000\n", "epoch 100, loss 0.000\n" ] } ], "execution_count": 37 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:14.392987505Z", "start_time": "2026-04-22T07:03:14.281161755Z" } }, "cell_type": "code", "source": "conv2d.weight.data.reshape((1, 2))\n", "id": "bf53a423f429dfe4", "outputs": [ { "data": { "text/plain": [ "tensor([[ 1.0000, -1.0000]])" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 38 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:14.572225657Z", "start_time": "2026-04-22T07:03:14.447504117Z" } }, "cell_type": "code", "source": [ "\n", "# 为了方便起见,我们定义了一个计算卷积层的函数。\n", "# 此函数初始化卷积层权重,并对输入和输出提高和缩减相应的维数\n", "def comp_conv2d(conv2d, X):\n", "# 这里的(1,1)表示批量大小和通道数都是1\n", " X = X.reshape((1, 1) + X.shape)\n", " Y = conv2d(X)\n", " # 省略前两个维度:批量大小和通道\n", " return Y.reshape(Y.shape[2:])\n", "# 请注意,这里每边都填充了1行或1列,因此总共添加了2行或2列\n", "conv2d = nn.Conv2d(1, 1, kernel_size=3, padding=1)" ], "id": "77b61d8c9a2363cc", "outputs": [], "execution_count": 39 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:14.842090802Z", "start_time": "2026-04-22T07:03:14.640382418Z" } }, "cell_type": "code", "source": [ "X = torch.rand(size=(8, 8))\n", "comp_conv2d(conv2d, X).shape" ], "id": "beda6ffa67ec2677", "outputs": [ { "data": { "text/plain": [ "torch.Size([8, 8])" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 40 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:15.005425705Z", "start_time": "2026-04-22T07:03:14.845097024Z" } }, "cell_type": "code", "source": [ "conv2d = nn.Conv2d(1, 1, kernel_size=(5, 3), padding=(2, 1))\n", "comp_conv2d(conv2d, X).shape" ], "id": "8c51095daea1432d", "outputs": [ { "data": { "text/plain": [ "torch.Size([8, 8])" ] }, "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 41 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:15.154120979Z", "start_time": "2026-04-22T07:03:15.063863068Z" } }, "cell_type": "code", "source": [ "conv2d = nn.Conv2d(1, 1, kernel_size=3, padding=1, stride=2)\n", "comp_conv2d(conv2d, X).shape" ], "id": "581bf1b15162cbf6", "outputs": [ { "data": { "text/plain": [ "torch.Size([4, 4])" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 42 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:15.290065680Z", "start_time": "2026-04-22T07:03:15.156986867Z" } }, "cell_type": "code", "source": [ "conv2d = nn.Conv2d(1, 1, kernel_size=(3, 5), padding=(0, 1), stride=(3, 4))\n", "comp_conv2d(conv2d, X).shape" ], "id": "6f7a2411247baff0", "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 2])" ] }, "execution_count": 43, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 43 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:15.447782218Z", "start_time": "2026-04-22T07:03:15.341665415Z" } }, "cell_type": "code", "source": [ "def corr2d_multi_in(X,K):\n", " return sum(corr2d(x,k) for x,k in zip(X,K))\n", "X = torch.tensor([[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]],\n", "[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]])\n", "K = torch.tensor([[[0.0, 1.0], [2.0, 3.0]], [[1.0, 2.0], [3.0, 4.0]]])\n", "corr2d_multi_in(X, K)" ], "id": "7ac0f17f97b2daa8", "outputs": [ { "data": { "text/plain": [ "tensor([[ 56., 72.],\n", " [104., 120.]])" ] }, "execution_count": 44, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 44 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:15.631383700Z", "start_time": "2026-04-22T07:03:15.507675748Z" } }, "cell_type": "code", "source": [ "def corr2d_multi_in_out(X,K) ->torch.Tensor :\n", " return torch.stack([corr2d_multi_in(X,k) for k in K],0)\n" ], "id": "d409110d0d6b4b49", "outputs": [], "execution_count": 45 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:15.955093562Z", "start_time": "2026-04-22T07:03:15.716320703Z" } }, "cell_type": "code", "source": [ "K = torch.stack((K, K + 1, K + 2), 0)\n", "K.shape" ], "id": "4114cd871a627075", "outputs": [ { "data": { "text/plain": [ "torch.Size([3, 2, 2, 2])" ] }, "execution_count": 46, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 46 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:16.110185723Z", "start_time": "2026-04-22T07:03:15.964443180Z" } }, "cell_type": "code", "source": "corr2d_multi_in_out(X, K)", "id": "ce52f41dc9585f8c", "outputs": [ { "data": { "text/plain": [ "tensor([[[ 56., 72.],\n", " [104., 120.]],\n", "\n", " [[ 76., 100.],\n", " [148., 172.]],\n", "\n", " [[ 96., 128.],\n", " [192., 224.]]])" ] }, "execution_count": 47, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 47 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:16.181875274Z", "start_time": "2026-04-22T07:03:16.125158123Z" } }, "cell_type": "code", "source": [ "def corr2d_multi_in_out_1x1(X, K):\n", " h_i,h,w=X.shape\n", " h_o=K.shape[0]\n", " X=X.reshape((h_i,h*w))\n", " print(X.shape)\n", " K=K.reshape((h_o,h_i))\n", " print(K.shape)\n", " Y=torch.matmul(K,X)\n", " return Y.reshape((h_o,h,w))" ], "id": "362d8c692b3c1d75", "outputs": [], "execution_count": 48 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:16.240983284Z", "start_time": "2026-04-22T07:03:16.187079922Z" } }, "cell_type": "code", "source": [ "X = torch.normal(0, 1, (3, 3, 3))\n", "K = torch.normal(0, 1, (2, 3, 1, 1))" ], "id": "28e761f677df8b16", "outputs": [], "execution_count": 49 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:16.386442288Z", "start_time": "2026-04-22T07:03:16.245865317Z" } }, "cell_type": "code", "source": "Y1 = corr2d_multi_in_out_1x1(X, K)", "id": "8eb276fed751a6b9", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([3, 9])\n", "torch.Size([2, 3])\n" ] } ], "execution_count": 50 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:16.444217309Z", "start_time": "2026-04-22T07:03:16.422396408Z" } }, "cell_type": "code", "source": [ "Y2 = corr2d_multi_in_out(X, K)\n", "assert float(torch.abs(Y1 - Y2).sum()) < 1e-6" ], "id": "be28e27d30f36e2c", "outputs": [], "execution_count": 51 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:16.504989589Z", "start_time": "2026-04-22T07:03:16.449374426Z" } }, "cell_type": "code", "source": [ "def pool2d(X,pool_size,mode='max'):\n", " p_h,p_w =pool_size\n", " Y = torch.zeros((X.shape[0]-p_h+1,X.shape[1]-p_w+1))\n", " for i in range(Y.shape[0]):\n", " for j in range(Y.shape[1]):\n", " match mode:\n", " case 'max':\n", " Y[i,j]=X[i:i+p_h,j:j+p_w].max()\n", " case 'avg':\n", " Y[i,j]=X[i:i+p_h,j:j+p_w].mean()\n", "\n", " return Y" ], "id": "3c3f71349a2e54c0", "outputs": [], "execution_count": 52 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:16.635764547Z", "start_time": "2026-04-22T07:03:16.510663393Z" } }, "cell_type": "code", "source": [ "X = torch.tensor([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]])\n", "pool2d(X, (2, 2))" ], "id": "a67207c861cf0cfd", "outputs": [ { "data": { "text/plain": [ "tensor([[4., 5.],\n", " [7., 8.]])" ] }, "execution_count": 53, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 53 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:16.938183687Z", "start_time": "2026-04-22T07:03:16.693509080Z" } }, "cell_type": "code", "source": "pool2d(X, (2, 2), 'avg')", "id": "e387b48df3831b85", "outputs": [ { "data": { "text/plain": [ "tensor([[2., 3.],\n", " [5., 6.]])" ] }, "execution_count": 54, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 54 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:17.569350884Z", "start_time": "2026-04-22T07:03:17.198359992Z" } }, "cell_type": "code", "source": [ "X = torch.arange(16, dtype=torch.float32).reshape((1, 1, 4, 4))\n", "X" ], "id": "41b618b3a48522b4", "outputs": [ { "data": { "text/plain": [ "tensor([[[[ 0., 1., 2., 3.],\n", " [ 4., 5., 6., 7.],\n", " [ 8., 9., 10., 11.],\n", " [12., 13., 14., 15.]]]])" ] }, "execution_count": 55, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 55 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:17.840996860Z", "start_time": "2026-04-22T07:03:17.734395761Z" } }, "cell_type": "code", "source": [ "pool2d=nn.MaxPool2d(3)\n", "pool2d(X)" ], "id": "c77484a8d1267259", "outputs": [ { "data": { "text/plain": [ "tensor([[[[10.]]]])" ] }, "execution_count": 56, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 56 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:17.967526280Z", "start_time": "2026-04-22T07:03:17.871240998Z" } }, "cell_type": "code", "source": [ "pool2d = nn.MaxPool2d(3, padding=1, stride=2)\n", "pool2d(X)" ], "id": "847a2bacfb6f2bd7", "outputs": [ { "data": { "text/plain": [ "tensor([[[[ 5., 7.],\n", " [13., 15.]]]])" ] }, "execution_count": 57, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 57 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:18.057661560Z", "start_time": "2026-04-22T07:03:17.998953875Z" } }, "cell_type": "code", "source": [ "pool2d = nn.MaxPool2d((2, 3), stride=(2, 3), padding=(0, 1))\n", "pool2d(X)" ], "id": "5efad1e0b616fff7", "outputs": [ { "data": { "text/plain": [ "tensor([[[[ 5., 7.],\n", " [13., 15.]]]])" ] }, "execution_count": 58, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 58 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:18.211656993Z", "start_time": "2026-04-22T07:03:18.081025077Z" } }, "cell_type": "code", "source": [ "X = torch.cat((X, X + 1), 1)\n", "X" ], "id": "386d4b3eb8069328", "outputs": [ { "data": { "text/plain": [ "tensor([[[[ 0., 1., 2., 3.],\n", " [ 4., 5., 6., 7.],\n", " [ 8., 9., 10., 11.],\n", " [12., 13., 14., 15.]],\n", "\n", " [[ 1., 2., 3., 4.],\n", " [ 5., 6., 7., 8.],\n", " [ 9., 10., 11., 12.],\n", " [13., 14., 15., 16.]]]])" ] }, "execution_count": 59, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 59 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:18.254285311Z", "start_time": "2026-04-22T07:03:18.221776346Z" } }, "cell_type": "code", "source": [ "pool2d = nn.MaxPool2d(3, padding=1, stride=2)\n", "pool2d(X)" ], "id": "ba5f57a8ca2a3b06", "outputs": [ { "data": { "text/plain": [ "tensor([[[[ 5., 7.],\n", " [13., 15.]],\n", "\n", " [[ 6., 8.],\n", " [14., 16.]]]])" ] }, "execution_count": 60, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 60 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:18.415322634Z", "start_time": "2026-04-22T07:03:18.283827059Z" } }, "cell_type": "code", "source": [ "net = nn.Sequential(\n", " nn.Conv2d(1,6,kernel_size=5,padding=2), #1*1*28*28 -> 1*6*28*28\n", " nn.Sigmoid(),\n", " nn.AvgPool2d(kernel_size=2, stride=2), #1*6*28*28 -> 1*6*14*14\n", " nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(), #1*6*14*14 -> 1*16*10*10\n", " nn.AvgPool2d(kernel_size=2, stride=2), #1*16*10*10 -> 1*16*5*5\n", " nn.Flatten(),\n", " nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),\n", " nn.Linear(120, 84), nn.Sigmoid(),\n", " nn.Linear(84, 10)\n", ")\n", "X = torch.rand(size=(1,1,28,28),dtype=torch.float32)\n", "for layer in net:\n", " X=layer(X)\n", " print(layer.__class__.__name__,'output shape: \\t',X.shape)" ], "id": "1eabc29f9c838842", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Conv2d output shape: \t torch.Size([1, 6, 28, 28])\n", "Sigmoid output shape: \t torch.Size([1, 6, 28, 28])\n", "AvgPool2d output shape: \t torch.Size([1, 6, 14, 14])\n", "Conv2d output shape: \t torch.Size([1, 16, 10, 10])\n", "Sigmoid output shape: \t torch.Size([1, 16, 10, 10])\n", "AvgPool2d output shape: \t torch.Size([1, 16, 5, 5])\n", "Flatten output shape: \t torch.Size([1, 400])\n", "Linear output shape: \t torch.Size([1, 120])\n", "Sigmoid output shape: \t torch.Size([1, 120])\n", "Linear output shape: \t torch.Size([1, 84])\n", "Sigmoid output shape: \t torch.Size([1, 84])\n", "Linear output shape: \t torch.Size([1, 10])\n" ] } ], "execution_count": 61 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:20.954626561Z", "start_time": "2026-04-22T07:03:18.435497060Z" } }, "cell_type": "code", "source": [ "import d2l.torch as d2l\n", "batch_size = 256\n", "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)" ], "id": "e372f75817ad4a0f", "outputs": [], "execution_count": 62 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:21.060902242Z", "start_time": "2026-04-22T07:03:21.006364318Z" } }, "cell_type": "code", "source": [ "lr, num_epochs = 0.9, 10\n", "#d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())" ], "id": "9aaeb948f3353955", "outputs": [], "execution_count": 63 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:21.128721611Z", "start_time": "2026-04-22T07:03:21.065860429Z" } }, "cell_type": "code", "source": [ "class Inception(nn.Module):\n", " def __init__(self,in_channels,c1,c2,c3,c4,**kwargs):\n", " super(Inception,self).__init__(**kwargs)\n", " self.p1_1 = nn.Conv2d(in_channels,c1,kernel_size=1)\n", " self.p2_1 = nn.Conv2d(in_channels,c2[0],kernel_size=1)\n", " self.p2_2 = nn.Conv2d(c2[0],c2[1],kernel_size=3,padding=1)\n", " self.p3_1 = nn.Conv2d(in_channels,c3[0],kernel_size=1)\n", " self.p3_2 = nn.Conv2d(c3[0],c3[1],kernel_size=5,padding=2)\n", " self.p4_1 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)\n", " self.p4_2 = nn.Conv2d(in_channels, c4, kernel_size=1)\n", " def forward(self,x):\n", " p1 = F.relu(self.p1_1(x))\n", " p2 = F.relu(self.p2_2(F.relu(self.p2_1(x))))\n", " p3 = F.relu(self.p3_2(F.relu(self.p3_1(x))))\n", " p4 = F.relu(self.p4_2(self.p4_1(x)))\n", " return torch.cat((p1,p2,p3,p4),dim=1)" ], "id": "6d3bb3f70f297dba", "outputs": [], "execution_count": 64 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:21.758676619Z", "start_time": "2026-04-22T07:03:21.134168809Z" } }, "cell_type": "code", "source": [ "b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),\n", " nn.ReLU(),\n", " nn.MaxPool2d(kernel_size=3, stride=2, padding=1))\n", "b2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1),\n", " nn.ReLU(),\n", " nn.Conv2d(64, 192, kernel_size=3, padding=1),\n", " nn.ReLU(),\n", " nn.MaxPool2d(kernel_size=3, stride=2, padding=1))\n", "b3 = nn.Sequential(Inception(192, 64, (96, 128), (16, 32), 32),\n", " Inception(256, 128, (128, 192), (32, 96), 64),\n", " nn.MaxPool2d(kernel_size=3, stride=2, padding=1))\n", "b4 = nn.Sequential(Inception(480, 192, (96, 208), (16, 48), 64),\n", " Inception(512, 160, (112, 224), (24, 64), 64),\n", " Inception(512, 128, (128, 256), (24, 64), 64),\n", " Inception(512, 112, (144, 288), (32, 64), 64),\n", " Inception(528, 256, (160, 320), (32, 128), 128),\n", " nn.MaxPool2d(kernel_size=3, stride=2, padding=1))\n", "b5 = nn.Sequential(Inception(832, 256, (160, 320), (32, 128), 128),\n", " Inception(832, 384, (192, 384), (48, 128), 128),\n", " nn.AdaptiveAvgPool2d((1,1)),\n", " nn.Flatten())\n", "net = nn.Sequential(b1, b2, b3, b4, b5, nn.Linear(1024, 10))\n", "X = torch.rand(size=(1, 1, 96, 96))\n", "for layer in net:\n", " X = layer(X)\n", " print(layer.__class__.__name__,'output shape:\\t', X.shape)" ], "id": "6ef7022bcb288d65", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sequential output shape:\t torch.Size([1, 64, 24, 24])\n", "Sequential output shape:\t torch.Size([1, 192, 12, 12])\n", "Sequential output shape:\t torch.Size([1, 480, 6, 6])\n", "Sequential output shape:\t torch.Size([1, 832, 3, 3])\n", "Sequential output shape:\t torch.Size([1, 1024])\n", "Linear output shape:\t torch.Size([1, 10])\n" ] } ], "execution_count": 65 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:22.718023030Z", "start_time": "2026-04-22T07:03:21.806456021Z" } }, "cell_type": "code", "source": [ "import torchinfo\n", "torchinfo.summary(net,(1,1,96,96))" ], "id": "acc019ce7afa4470", "outputs": [ { "data": { "text/plain": [ "==========================================================================================\n", "Layer (type:depth-idx) Output Shape Param #\n", "==========================================================================================\n", "Sequential [1, 10] --\n", "├─Sequential: 1-1 [1, 64, 24, 24] --\n", "│ └─Conv2d: 2-1 [1, 64, 48, 48] 3,200\n", "│ └─ReLU: 2-2 [1, 64, 48, 48] --\n", "│ └─MaxPool2d: 2-3 [1, 64, 24, 24] --\n", "├─Sequential: 1-2 [1, 192, 12, 12] --\n", "│ └─Conv2d: 2-4 [1, 64, 24, 24] 4,160\n", "│ └─ReLU: 2-5 [1, 64, 24, 24] --\n", "│ └─Conv2d: 2-6 [1, 192, 24, 24] 110,784\n", "│ └─ReLU: 2-7 [1, 192, 24, 24] --\n", "│ └─MaxPool2d: 2-8 [1, 192, 12, 12] --\n", "├─Sequential: 1-3 [1, 480, 6, 6] --\n", "│ └─Inception: 2-9 [1, 256, 12, 12] --\n", "│ │ └─Conv2d: 3-1 [1, 64, 12, 12] 12,352\n", "│ │ └─Conv2d: 3-2 [1, 96, 12, 12] 18,528\n", "│ │ └─Conv2d: 3-3 [1, 128, 12, 12] 110,720\n", "│ │ └─Conv2d: 3-4 [1, 16, 12, 12] 3,088\n", "│ │ └─Conv2d: 3-5 [1, 32, 12, 12] 12,832\n", "│ │ └─MaxPool2d: 3-6 [1, 192, 12, 12] --\n", "│ │ └─Conv2d: 3-7 [1, 32, 12, 12] 6,176\n", "│ └─Inception: 2-10 [1, 480, 12, 12] --\n", "│ │ └─Conv2d: 3-8 [1, 128, 12, 12] 32,896\n", "│ │ └─Conv2d: 3-9 [1, 128, 12, 12] 32,896\n", "│ │ └─Conv2d: 3-10 [1, 192, 12, 12] 221,376\n", "│ │ └─Conv2d: 3-11 [1, 32, 12, 12] 8,224\n", "│ │ └─Conv2d: 3-12 [1, 96, 12, 12] 76,896\n", "│ │ └─MaxPool2d: 3-13 [1, 256, 12, 12] --\n", "│ │ └─Conv2d: 3-14 [1, 64, 12, 12] 16,448\n", "│ └─MaxPool2d: 2-11 [1, 480, 6, 6] --\n", "├─Sequential: 1-4 [1, 832, 3, 3] --\n", "│ └─Inception: 2-12 [1, 512, 6, 6] --\n", "│ │ └─Conv2d: 3-15 [1, 192, 6, 6] 92,352\n", "│ │ └─Conv2d: 3-16 [1, 96, 6, 6] 46,176\n", "│ │ └─Conv2d: 3-17 [1, 208, 6, 6] 179,920\n", "│ │ └─Conv2d: 3-18 [1, 16, 6, 6] 7,696\n", "│ │ └─Conv2d: 3-19 [1, 48, 6, 6] 19,248\n", "│ │ └─MaxPool2d: 3-20 [1, 480, 6, 6] --\n", "│ │ └─Conv2d: 3-21 [1, 64, 6, 6] 30,784\n", "│ └─Inception: 2-13 [1, 512, 6, 6] --\n", "│ │ └─Conv2d: 3-22 [1, 160, 6, 6] 82,080\n", "│ │ └─Conv2d: 3-23 [1, 112, 6, 6] 57,456\n", "│ │ └─Conv2d: 3-24 [1, 224, 6, 6] 226,016\n", "│ │ └─Conv2d: 3-25 [1, 24, 6, 6] 12,312\n", "│ │ └─Conv2d: 3-26 [1, 64, 6, 6] 38,464\n", "│ │ └─MaxPool2d: 3-27 [1, 512, 6, 6] --\n", "│ │ └─Conv2d: 3-28 [1, 64, 6, 6] 32,832\n", "│ └─Inception: 2-14 [1, 512, 6, 6] --\n", "│ │ └─Conv2d: 3-29 [1, 128, 6, 6] 65,664\n", "│ │ └─Conv2d: 3-30 [1, 128, 6, 6] 65,664\n", "│ │ └─Conv2d: 3-31 [1, 256, 6, 6] 295,168\n", "│ │ └─Conv2d: 3-32 [1, 24, 6, 6] 12,312\n", "│ │ └─Conv2d: 3-33 [1, 64, 6, 6] 38,464\n", "│ │ └─MaxPool2d: 3-34 [1, 512, 6, 6] --\n", "│ │ └─Conv2d: 3-35 [1, 64, 6, 6] 32,832\n", "│ └─Inception: 2-15 [1, 528, 6, 6] --\n", "│ │ └─Conv2d: 3-36 [1, 112, 6, 6] 57,456\n", "│ │ └─Conv2d: 3-37 [1, 144, 6, 6] 73,872\n", "│ │ └─Conv2d: 3-38 [1, 288, 6, 6] 373,536\n", "│ │ └─Conv2d: 3-39 [1, 32, 6, 6] 16,416\n", "│ │ └─Conv2d: 3-40 [1, 64, 6, 6] 51,264\n", "│ │ └─MaxPool2d: 3-41 [1, 512, 6, 6] --\n", "│ │ └─Conv2d: 3-42 [1, 64, 6, 6] 32,832\n", "│ └─Inception: 2-16 [1, 832, 6, 6] --\n", "│ │ └─Conv2d: 3-43 [1, 256, 6, 6] 135,424\n", "│ │ └─Conv2d: 3-44 [1, 160, 6, 6] 84,640\n", "│ │ └─Conv2d: 3-45 [1, 320, 6, 6] 461,120\n", "│ │ └─Conv2d: 3-46 [1, 32, 6, 6] 16,928\n", "│ │ └─Conv2d: 3-47 [1, 128, 6, 6] 102,528\n", "│ │ └─MaxPool2d: 3-48 [1, 528, 6, 6] --\n", "│ │ └─Conv2d: 3-49 [1, 128, 6, 6] 67,712\n", "│ └─MaxPool2d: 2-17 [1, 832, 3, 3] --\n", "├─Sequential: 1-5 [1, 1024] --\n", "│ └─Inception: 2-18 [1, 832, 3, 3] --\n", "│ │ └─Conv2d: 3-50 [1, 256, 3, 3] 213,248\n", "│ │ └─Conv2d: 3-51 [1, 160, 3, 3] 133,280\n", "│ │ └─Conv2d: 3-52 [1, 320, 3, 3] 461,120\n", "│ │ └─Conv2d: 3-53 [1, 32, 3, 3] 26,656\n", "│ │ └─Conv2d: 3-54 [1, 128, 3, 3] 102,528\n", "│ │ └─MaxPool2d: 3-55 [1, 832, 3, 3] --\n", "│ │ └─Conv2d: 3-56 [1, 128, 3, 3] 106,624\n", "│ └─Inception: 2-19 [1, 1024, 3, 3] --\n", "│ │ └─Conv2d: 3-57 [1, 384, 3, 3] 319,872\n", "│ │ └─Conv2d: 3-58 [1, 192, 3, 3] 159,936\n", "│ │ └─Conv2d: 3-59 [1, 384, 3, 3] 663,936\n", "│ │ └─Conv2d: 3-60 [1, 48, 3, 3] 39,984\n", "│ │ └─Conv2d: 3-61 [1, 128, 3, 3] 153,728\n", "│ │ └─MaxPool2d: 3-62 [1, 832, 3, 3] --\n", "│ │ └─Conv2d: 3-63 [1, 128, 3, 3] 106,624\n", "│ └─AdaptiveAvgPool2d: 2-20 [1, 1024, 1, 1] --\n", "│ └─Flatten: 2-21 [1, 1024] --\n", "├─Linear: 1-6 [1, 10] 10,250\n", "==========================================================================================\n", "Total params: 5,977,530\n", "Trainable params: 5,977,530\n", "Non-trainable params: 0\n", "Total mult-adds (Units.MEGABYTES): 276.66\n", "==========================================================================================\n", "Input size (MB): 0.04\n", "Forward/backward pass size (MB): 4.74\n", "Params size (MB): 23.91\n", "Estimated Total Size (MB): 28.69\n", "==========================================================================================" ] }, "execution_count": 66, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 66 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:22.823562168Z", "start_time": "2026-04-22T07:03:22.774952520Z" } }, "cell_type": "code", "source": [ "lr, num_epochs, batch_size = 0.1, 10, 128\n", "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)\n", "#d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())" ], "id": "3760a5e5813405f7", "outputs": [], "execution_count": 67 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:22.875034764Z", "start_time": "2026-04-22T07:03:22.825694238Z" } }, "cell_type": "code", "source": [ "class Residual(nn.Module):\n", " def __init__(self,input_channels,num_channels,use_1x1conv=False,strides=1):\n", " super().__init__()\n", " self.conv1 = nn.Conv2d(input_channels,num_channels,kernel_size=3,padding=1,stride=strides)\n", " self.conv2 = nn.Conv2d(num_channels,num_channels,kernel_size=3,padding=1)\n", " if use_1x1conv:\n", " self.conv3 = nn.Conv2d(input_channels,num_channels,kernel_size=1,stride=strides)\n", " else:\n", " self.conv3= None\n", " self.bn1=nn.BatchNorm2d(num_channels)\n", " self.bn2=nn.BatchNorm2d(num_channels)\n", " def forward(self,X):\n", " Y=F.relu(self.bn1(self.conv1(X)))\n", " Y=self.bn2(self.conv2(Y))\n", " if self.conv3:\n", " X = self.conv3(X)\n", " Y+=X\n", " return F.relu(Y)\n" ], "id": "9300979845ba6916", "outputs": [], "execution_count": 68 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:22.926048268Z", "start_time": "2026-04-22T07:03:22.876715194Z" } }, "cell_type": "code", "source": [ "blk = Residual(3,3)\n", "X = torch.rand(4, 3, 6, 6)" ], "id": "1248323517ff3228", "outputs": [], "execution_count": 69 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:22.992544732Z", "start_time": "2026-04-22T07:03:22.927760279Z" } }, "cell_type": "code", "source": [ "blk = Residual(3,6, use_1x1conv=True, strides=2)\n", "blk(X).shape" ], "id": "82cdbd71a157b51c", "outputs": [ { "data": { "text/plain": [ "torch.Size([4, 6, 3, 3])" ] }, "execution_count": 70, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 70 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:23.041967126Z", "start_time": "2026-04-22T07:03:22.993777596Z" } }, "cell_type": "code", "source": [ "b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),\n", "nn.BatchNorm2d(64), nn.ReLU(),\n", "nn.MaxPool2d(kernel_size=3, stride=2, padding=1))" ], "id": "727da1d2d363ac62", "outputs": [], "execution_count": 71 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:23.093341944Z", "start_time": "2026-04-22T07:03:23.044341294Z" } }, "cell_type": "code", "source": [ "def resnet_block(input_channels, num_channels, num_residuals,\n", " first_block=False):\n", " blk = []\n", " for i in range(num_residuals):\n", " if i == 0 and not first_block:\n", " blk.append(Residual(input_channels, num_channels,\n", " use_1x1conv=True, strides=2))\n", " else:\n", " blk.append(Residual(num_channels, num_channels))\n", " return blk" ], "id": "124134971f8441c0", "outputs": [], "execution_count": 72 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:23.143917227Z", "start_time": "2026-04-22T07:03:23.095134233Z" } }, "cell_type": "code", "source": [ "b2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))\n", "b3 = nn.Sequential(*resnet_block(64, 128, 2))\n", "b4 = nn.Sequential(*resnet_block(128, 256, 2))\n", "b5 = nn.Sequential(*resnet_block(256, 512, 2))" ], "id": "ca1f1c69fba3e913", "outputs": [], "execution_count": 73 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:23.192826666Z", "start_time": "2026-04-22T07:03:23.145253995Z" } }, "cell_type": "code", "source": [ "net = nn.Sequential(b1, b2, b3, b4, b5,\n", "nn.AdaptiveAvgPool2d((1,1)),\n", "nn.Flatten(), nn.Linear(512, 10))" ], "id": "f21db27de5dbdec1", "outputs": [], "execution_count": 74 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:23.412254076Z", "start_time": "2026-04-22T07:03:23.195173742Z" } }, "cell_type": "code", "source": [ "X = torch.rand(size=(1, 1, 224, 224))\n", "for layer in net:\n", " X = layer(X)\n", " print(layer.__class__.__name__,'output shape:\\t', X.shape)" ], "id": "6f8851a2bfd18c4e", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sequential output shape:\t torch.Size([1, 64, 56, 56])\n", "Sequential output shape:\t torch.Size([1, 64, 56, 56])\n", "Sequential output shape:\t torch.Size([1, 128, 28, 28])\n", "Sequential output shape:\t torch.Size([1, 256, 14, 14])\n", "Sequential output shape:\t torch.Size([1, 512, 7, 7])\n", "AdaptiveAvgPool2d output shape:\t torch.Size([1, 512, 1, 1])\n", "Flatten output shape:\t torch.Size([1, 512])\n", "Linear output shape:\t torch.Size([1, 10])\n" ] } ], "execution_count": 75 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:23.462272089Z", "start_time": "2026-04-22T07:03:23.415437344Z" } }, "cell_type": "code", "source": [ "lr, num_epochs, batch_size = 0.05, 10, 256\n", "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)\n", "#d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())" ], "id": "e095d74b29dffef6", "outputs": [], "execution_count": 76 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:23.541760019Z", "start_time": "2026-04-22T07:03:23.464100823Z" } }, "cell_type": "code", "source": [ "import torch\n", "import d2l.torch as d2l\n", "import numpy\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "print(torch.version.__version__)" ], "id": "3fd6d22221f87bea", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.10.0+cu128\n" ] } ], "execution_count": 77 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:23.592662909Z", "start_time": "2026-04-22T07:03:23.543067093Z" } }, "cell_type": "code", "source": [ "A=torch.Tensor([[1,2,0,0],[0,2,0,0],[0,0,2,1],[0,0,0,3]])\n", "C=torch.Tensor([[1,0,0,0],[0,1,0,0],[0,0,-2,3],[0,0,0,-3]])" ], "id": "254f5d3d659dbe0f", "outputs": [], "execution_count": 78 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:23.643655879Z", "start_time": "2026-04-22T07:03:23.594970016Z" } }, "cell_type": "code", "source": "B=torch.Tensor([[2,0,0,0],[-2,1,0,0],[0,0,-3,0],[0,0,0,-3]])", "id": "a13d9c27c2fdbfad", "outputs": [], "execution_count": 79 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:23.700556244Z", "start_time": "2026-04-22T07:03:23.645948395Z" } }, "cell_type": "code", "source": "torch.mm(A,C)", "id": "e513a37beaa85f8f", "outputs": [ { "data": { "text/plain": [ "tensor([[ 1., 2., 0., 0.],\n", " [ 0., 2., 0., 0.],\n", " [ 0., 0., -4., 3.],\n", " [ 0., 0., 0., -9.]])" ] }, "execution_count": 80, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 80 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:23.827304753Z", "start_time": "2026-04-22T07:03:23.751870264Z" } }, "cell_type": "code", "source": "torch.det(torch.mm(torch.mm(A,C),B))", "id": "9a85eceac652875f", "outputs": [ { "data": { "text/plain": [ "tensor(1296.)" ] }, "execution_count": 81, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 81 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:23.931211184Z", "start_time": "2026-04-22T07:03:23.879693736Z" } }, "cell_type": "code", "source": "1296**5\n", "id": "6dc27d79722da58f", "outputs": [ { "data": { "text/plain": [ "3656158440062976" ] }, "execution_count": 82, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 82 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:23.997938663Z", "start_time": "2026-04-22T07:03:23.945902637Z" } }, "cell_type": "code", "source": "torch.mm(C,B)", "id": "ec5a170d775f4705", "outputs": [ { "data": { "text/plain": [ "tensor([[ 2., 0., 0., 0.],\n", " [-2., 1., 0., 0.],\n", " [ 0., 0., 6., -9.],\n", " [ 0., 0., 0., 9.]])" ] }, "execution_count": 83, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 83 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:24.170960756Z", "start_time": "2026-04-22T07:03:24.067057874Z" } }, "cell_type": "code", "source": [ "T = 1000 # 总共产生1000个点\n", "time = torch.arange(1, T + 1, dtype=torch.float32)\n", "x = torch.sin(0.01 * time) + torch.normal(0, 0.2, (T,))\n", "d2l.plot(time, [x], 'time', 'x', xlim=[1, 1000], figsize=(6, 3))" ], "id": "c3884b10464c6baa", "outputs": [ { "data": { "text/plain": [ "
" ], "image/svg+xml": "\n\n\n \n \n \n \n 2026-04-22T15:03:24.145178\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" }, "metadata": {}, "output_type": "display_data", "jetTransient": { "display_id": null } } ], "execution_count": 84 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:24.309169705Z", "start_time": "2026-04-22T07:03:24.235365355Z" } }, "cell_type": "code", "source": [ "tau = 4\n", "features = torch.zeros((T - tau, tau))\n", "for i in range(tau):\n", " features[:, i] = x[i: T - tau + i]\n", "labels = x[tau:].reshape((-1, 1))\n", "x,features,labels\n" ], "id": "5d450c6f6b724a14", "outputs": [ { "data": { "text/plain": [ "(tensor([-0.0948, 0.2143, -0.2523, -0.1235, -0.1826, 0.1189, -0.1963, 0.2347,\n", " 0.1456, -0.1118, 0.3787, 0.3861, 0.2881, 0.1958, 0.0402, 0.0816,\n", " 0.4793, 0.0351, 0.2378, 0.1459, 0.1108, 0.2544, -0.0127, 0.0733,\n", " 0.3156, 0.0257, 0.3207, 0.3259, 0.3693, 0.0584, 0.1730, 0.3100,\n", " 0.2328, 0.0525, 0.4465, 0.1293, 0.4330, 0.3193, 0.4704, 0.5238,\n", " 0.5323, 0.4887, 0.0831, 0.5924, 0.6972, 0.3490, 0.7476, 0.6039,\n", " 0.9995, 0.1455, 0.1417, 0.5968, 0.6673, 0.3425, 0.7685, 0.4904,\n", " 0.2203, 0.2109, 0.4600, 0.5055, 0.3558, 0.7020, 0.7435, 0.4713,\n", " 0.4318, 0.5861, 0.3592, 0.7750, 0.6640, 0.7908, 0.2776, 0.5868,\n", " 0.6283, 0.3461, 0.6308, 0.7547, 0.5564, 0.7181, 0.7852, 0.7823,\n", " 0.7238, 0.9294, 0.9023, 0.8100, 0.5561, 0.7124, 1.1566, 0.7628,\n", " 0.9630, 0.4425, 1.0628, 0.7014, 0.4439, 0.7286, 0.8099, 0.5786,\n", " 1.0638, 0.9519, 0.8388, 1.2088, 0.9172, 0.7014, 0.5667, 0.6040,\n", " 0.5549, 0.7959, 0.9167, 0.9074, 0.6108, 0.8999, 0.9197, 0.8539,\n", " 0.6566, 0.9941, 0.6902, 0.8782, 1.4898, 0.9888, 1.1911, 0.5683,\n", " 0.8868, 0.7122, 0.8960, 1.1454, 1.2660, 1.0001, 0.6582, 0.9706,\n", " 1.0110, 1.0355, 0.9761, 0.9439, 1.0824, 1.4095, 1.2544, 0.6541,\n", " 1.0486, 1.0638, 0.9000, 0.9835, 1.2558, 1.1702, 0.8466, 0.7696,\n", " 1.2446, 1.1460, 0.9258, 0.8150, 1.1086, 0.9475, 0.9675, 0.7330,\n", " 1.1263, 1.1718, 0.9413, 1.0272, 0.7733, 0.9831, 0.8759, 0.7970,\n", " 0.6360, 1.1815, 0.9689, 0.6976, 0.9265, 0.8338, 0.7960, 0.7705,\n", " 1.2601, 1.2775, 0.7706, 1.0216, 1.1916, 0.8603, 0.9864, 1.0777,\n", " 0.8930, 1.0063, 0.8376, 0.9923, 0.8081, 0.8020, 1.1461, 1.1018,\n", " 0.8931, 1.0005, 0.8635, 0.7197, 1.2577, 1.0584, 1.4032, 0.8911,\n", " 1.1415, 0.8241, 0.7946, 1.0221, 0.8792, 0.7211, 1.1821, 0.8079,\n", " 0.8926, 1.0765, 0.9949, 0.9159, 0.7329, 0.9950, 0.7491, 0.8750,\n", " 1.1863, 1.0095, 0.8046, 0.6274, 0.8936, 0.7595, 0.8423, 0.8655,\n", " 0.6918, 0.7347, 1.1179, 0.5931, 0.8745, 0.4858, 0.9338, 1.1382,\n", " 0.6084, 0.9479, 0.8726, 0.7202, 0.9596, 0.4386, 1.2525, 0.5120,\n", " 0.7222, 0.6566, 0.8965, 0.7545, 1.1104, 0.6634, 0.5654, 1.0095,\n", " 0.6558, 0.7260, 0.8515, 0.3430, 0.7703, 0.3753, 0.4490, 0.4373,\n", " 0.8283, 0.5455, 0.7584, 0.8197, 0.4781, 0.3350, 0.6714, 0.3969,\n", " 0.7131, 0.5609, 0.4327, 0.4293, 0.3552, 0.5445, 0.5609, 0.4110,\n", " 0.8525, 0.3402, 0.4064, 0.5172, 0.5845, 0.6185, 0.4719, 0.9092,\n", " 0.6964, 0.7267, 0.6934, 0.4337, 0.2031, 0.0898, 0.4377, 0.4203,\n", " 0.2855, 0.4673, 0.6029, 0.4368, 0.1521, -0.0606, 0.2532, 0.4365,\n", " 0.2989, 0.0743, 0.2734, 0.1060, 0.5543, -0.1211, 0.0968, 0.4911,\n", " 0.5107, 0.4583, 0.2777, -0.0513, 0.1437, 0.0548, 0.0933, 0.1172,\n", " 0.0718, 0.4027, 0.1805, 0.0869, -0.3066, 0.5615, -0.2721, -0.2765,\n", " 0.0850, -0.1473, -0.1622, -0.1335, 0.1328, 0.0703, -0.6712, -0.1121,\n", " -0.1208, 0.0092, -0.0805, -0.2017, 0.2339, -0.3533, -0.4598, 0.0620,\n", " -0.5254, 0.0197, -0.0593, -0.1914, -0.4259, -0.0115, -0.5406, 0.0137,\n", " -0.4240, -0.2822, -0.0796, -0.3495, -0.4475, 0.2453, -0.3729, -0.4086,\n", " -0.2618, -0.4539, -0.6140, -0.2483, -0.4165, -0.3736, -0.0737, -0.0212,\n", " -0.3644, -0.0472, -0.4087, -0.6794, -0.5921, -0.5632, -0.5971, -0.1935,\n", " -0.9543, -0.7976, -0.3485, -0.9538, -0.6171, -0.7755, -0.4651, -0.8194,\n", " -0.3005, -0.5191, -0.5902, -0.2464, -0.6908, -0.5054, -0.5528, -1.1089,\n", " -0.7206, -0.8067, -0.6780, -0.2981, -0.6683, -0.4324, -0.8497, -0.5928,\n", " -0.7203, -0.3751, -0.7423, -0.4109, -0.7345, -0.6653, -0.5752, -0.5198,\n", " -0.7046, -1.1754, -0.9447, -0.7304, -0.6510, -0.5954, -0.7592, -0.5285,\n", " -0.4249, -0.7993, -1.3758, -0.6218, -1.0691, -0.5775, -0.8174, -0.7021,\n", " -0.7784, -0.7553, -1.2137, -0.7302, -0.7253, -0.6819, -1.3077, -1.3472,\n", " -0.7104, -0.8387, -0.5973, -0.8619, -1.1138, -1.1314, -0.9765, -1.2121,\n", " -0.8168, -0.7763, -1.2988, -0.9282, -1.1715, -0.7216, -0.7182, -0.2972,\n", " -0.7471, -1.0089, -1.1431, -1.0396, -1.0381, -0.5979, -0.7363, -0.7808,\n", " -0.9106, -1.1468, -1.1357, -0.6406, -0.9603, -1.2653, -1.5958, -1.0592,\n", " -0.9698, -0.8252, -1.2515, -1.0474, -1.1103, -1.0035, -0.6669, -0.9120,\n", " -0.9146, -1.1079, -0.8379, -0.9123, -0.5831, -1.6515, -0.9385, -1.0699,\n", " -1.1498, -0.7861, -0.8942, -1.0452, -1.0064, -0.9116, -1.1150, -0.7801,\n", " -1.0283, -1.0296, -1.0927, -0.7945, -1.0705, -1.3215, -1.2510, -0.9158,\n", " -0.9377, -0.7314, -0.9773, -1.1910, -1.0539, -1.1439, -1.0784, -0.8543,\n", " -1.1323, -1.3193, -0.8014, -0.7318, -0.5805, -0.8239, -1.1228, -1.0473,\n", " -0.8206, -0.6544, -1.2654, -1.0757, -0.5389, -0.9908, -0.7894, -0.7463,\n", " -1.0391, -0.8023, -0.8568, -1.2414, -0.9595, -1.1151, -0.9689, -1.1145,\n", " -0.6853, -0.7547, -1.1000, -0.9054, -1.2262, -1.1359, -1.0174, -0.3782,\n", " -0.8056, -1.1828, -0.8426, -0.9958, -0.9495, -1.2745, -0.7039, -0.5893,\n", " -0.5648, -1.0538, -0.6724, -0.6340, -0.5070, -1.0956, -1.0957, -0.6823,\n", " -0.5258, -0.5777, -0.9268, -0.5280, -0.5989, -0.8364, -0.7439, -0.7619,\n", " -1.0159, -1.0627, -0.9416, -0.6270, -0.4307, -0.8575, -1.0748, -0.5529,\n", " -0.9339, -0.7416, -0.6674, -0.3178, -0.6815, -0.7499, -0.6359, -0.8157,\n", " -0.5582, -0.5083, -0.4527, -0.8350, -0.6317, -0.4338, -0.4875, -0.4046,\n", " -0.3166, -0.3413, -0.4722, -0.7010, -1.2025, -0.2133, -0.3133, -0.4160,\n", " -0.6681, -0.8990, -0.5464, -0.4518, -0.4402, -0.5246, -0.4561, -0.6747,\n", " -0.1833, -0.4466, -0.4671, -0.5509, -0.6235, -0.2100, -0.3368, -0.3083,\n", " -0.5129, -0.2880, -0.4075, -0.2784, 0.0631, -0.4355, -0.4237, -0.2578,\n", " -0.1380, -0.5085, 0.1004, -0.1426, -0.2537, -0.1756, -0.2135, -0.1898,\n", " -0.2947, -0.3934, -0.3412, -0.3343, -0.1450, -0.3178, -0.2156, -0.3232,\n", " -0.3691, -0.2711, 0.1086, -0.2257, -0.0752, -0.0339, -0.0636, 0.0626,\n", " -0.1460, 0.0792, 0.1529, 0.4743, 0.0343, -0.0158, -0.1255, -0.4698,\n", " -0.0489, 0.2622, 0.0619, -0.2243, -0.1318, 0.0214, 0.2690, 0.0497,\n", " 0.3451, -0.1116, 0.0173, 0.0708, 0.4135, 0.3188, 0.4808, -0.0340,\n", " 0.4786, 0.4896, 0.1077, 0.3500, 0.1309, 0.1398, 0.1943, 0.1651,\n", " 0.3227, 0.5541, 0.2688, 0.1892, 0.2509, 0.2078, -0.0140, 0.2443,\n", " 0.3204, 0.5485, 0.4234, 0.3135, 0.4633, 0.0029, 0.2174, 0.6879,\n", " 0.5089, 0.2479, 0.8608, 0.4307, 0.6205, 0.3482, 0.6469, 0.4475,\n", " 0.6595, 0.3450, 0.3781, 0.4451, 0.1883, 0.6707, 0.8667, 0.5218,\n", " 0.4004, 0.5271, 0.6446, 0.7222, 0.5722, 0.7676, 0.6824, 0.1981,\n", " 0.8089, 0.6296, 0.6748, 0.7515, 0.5103, 0.9052, 0.8405, 0.9092,\n", " 0.6918, 0.6477, 0.5402, 0.6477, 0.4210, 0.6973, 0.6019, 0.5364,\n", " 0.8134, 0.5607, 0.7096, 0.5894, 0.3866, 1.0600, 0.7347, 0.8129,\n", " 1.2088, 0.8825, 0.7179, 1.0115, 0.7013, 1.0128, 0.9747, 1.2759,\n", " 0.7655, 1.0094, 0.7805, 0.6091, 1.2033, 0.9678, 0.8219, 0.8157,\n", " 0.9188, 0.7436, 0.8910, 0.7291, 0.9559, 0.9389, 1.2030, 1.0495,\n", " 1.1811, 0.8884, 0.8390, 0.9894, 0.9238, 0.7628, 0.5421, 1.5147,\n", " 0.6971, 0.6740, 0.8342, 0.6554, 0.7455, 0.6916, 1.2706, 1.1277,\n", " 0.9248, 0.9976, 1.2404, 0.6919, 1.3449, 1.1243, 1.0492, 0.9266,\n", " 1.1194, 1.0304, 1.1323, 1.2372, 0.8300, 1.1916, 1.0923, 0.8313,\n", " 0.8572, 1.1128, 1.0047, 1.1544, 0.9745, 1.0503, 0.9171, 0.8073,\n", " 1.2056, 1.0976, 0.9910, 1.1834, 1.1389, 0.9142, 0.9367, 1.0121,\n", " 0.7704, 1.0558, 0.7306, 0.8117, 0.7061, 1.2315, 0.9015, 0.9339,\n", " 0.5016, 0.9227, 1.2568, 0.9444, 1.1198, 0.9431, 1.0997, 1.3078,\n", " 0.8336, 1.2692, 0.8424, 0.8702, 1.4820, 1.3248, 0.9324, 0.6538,\n", " 1.2011, 1.0170, 0.7863, 1.0178, 0.6519, 0.5970, 0.9052, 0.6846,\n", " 0.7737, 0.9104, 0.8439, 1.0066, 1.0787, 0.9661, 0.9923, 0.7922,\n", " 0.8316, 0.9553, 0.9952, 0.8680, 1.1226, 0.8213, 0.9151, 0.7748,\n", " 0.9953, 0.7773, 0.7916, 0.7321, 0.9130, 1.1433, 0.7060, 0.8066,\n", " 0.8709, 0.7426, 0.8718, 1.0973, 0.7097, 0.9438, 0.8164, 0.8013,\n", " 0.6236, 0.7180, 0.9188, 0.8016, 0.9741, 0.6271, 0.5747, 0.8007,\n", " 0.7754, 0.4877, 0.4746, 0.8654, 0.4743, 0.9015, 0.8082, 0.5449,\n", " 0.9299, 0.2003, 0.5466, 0.4355, 0.7900, 0.4343, 0.7224, 0.8585,\n", " 0.5714, 0.5306, 0.6594, 0.0640, 0.3203, 0.5463, 0.5048, 0.1935,\n", " 0.2883, 0.6778, 0.5014, 0.5235, 0.5718, 0.4587, 0.2808, 0.4073,\n", " 0.8632, 0.8862, 0.5757, 0.3372, 0.2566, 0.7858, 0.3713, 0.1589,\n", " 0.3243, 0.4270, 0.0565, 0.2885, 0.3257, 0.2196, 0.3159, 0.2361,\n", " 0.1087, 0.2224, 0.2633, 0.5037, 0.1980, 0.1530, 0.2780, -0.1399,\n", " 0.5331, 0.3530, 0.3342, 0.2098, -0.0165, 0.1318, 0.4510, -0.1959,\n", " 0.0966, 0.0789, 0.3381, -0.1917, 0.1518, 0.3640, 0.0956, 0.2535,\n", " -0.3988, -0.3479, 0.3864, -0.2639, -0.2368, 0.0258, 0.2441, 0.0687,\n", " 0.0457, 0.2286, -0.0947, -0.1189, 0.1360, -0.0990, -0.2447, 0.2135,\n", " -0.1830, -0.4583, -0.1795, -0.1361, -0.0553, -0.2864, -0.2307, -0.4651,\n", " -0.1889, -0.3185, -0.5318, -0.3012, 0.0062, 0.1046, -0.2321, -0.2945,\n", " -0.0242, -0.0586, -0.2307, -0.2479, -0.0382, -0.1509, -0.5055, -0.3759,\n", " 0.2139, -0.2129, -0.3605, -0.5222, -0.6530, -0.6716, -0.4330, -0.2577,\n", " -0.2672, -0.1297, -0.9203, -0.5832, -0.2640, -0.4996, -0.2625, -0.4407,\n", " -0.8864, -0.2508, -0.4827, -0.3131, -0.2570, -0.7116, -0.5357, -0.7074]),\n", " tensor([[-0.0948, 0.2143, -0.2523, -0.1235],\n", " [ 0.2143, -0.2523, -0.1235, -0.1826],\n", " [-0.2523, -0.1235, -0.1826, 0.1189],\n", " ...,\n", " [-0.2508, -0.4827, -0.3131, -0.2570],\n", " [-0.4827, -0.3131, -0.2570, -0.7116],\n", " [-0.3131, -0.2570, -0.7116, -0.5357]]),\n", " tensor([[-0.1826],\n", " [ 0.1189],\n", " [-0.1963],\n", " [ 0.2347],\n", " [ 0.1456],\n", " [-0.1118],\n", " [ 0.3787],\n", " [ 0.3861],\n", " [ 0.2881],\n", " [ 0.1958],\n", " [ 0.0402],\n", " [ 0.0816],\n", " [ 0.4793],\n", " [ 0.0351],\n", " [ 0.2378],\n", " [ 0.1459],\n", " [ 0.1108],\n", " [ 0.2544],\n", " [-0.0127],\n", " [ 0.0733],\n", " [ 0.3156],\n", " [ 0.0257],\n", " [ 0.3207],\n", " [ 0.3259],\n", " [ 0.3693],\n", " [ 0.0584],\n", " [ 0.1730],\n", " [ 0.3100],\n", " [ 0.2328],\n", " [ 0.0525],\n", " [ 0.4465],\n", " [ 0.1293],\n", " [ 0.4330],\n", " [ 0.3193],\n", " [ 0.4704],\n", " [ 0.5238],\n", " [ 0.5323],\n", " [ 0.4887],\n", " [ 0.0831],\n", " [ 0.5924],\n", " [ 0.6972],\n", " [ 0.3490],\n", " [ 0.7476],\n", " [ 0.6039],\n", " [ 0.9995],\n", " [ 0.1455],\n", " [ 0.1417],\n", " [ 0.5968],\n", " [ 0.6673],\n", " [ 0.3425],\n", " [ 0.7685],\n", " [ 0.4904],\n", " [ 0.2203],\n", " [ 0.2109],\n", " [ 0.4600],\n", " [ 0.5055],\n", " [ 0.3558],\n", " [ 0.7020],\n", " [ 0.7435],\n", " [ 0.4713],\n", " [ 0.4318],\n", " [ 0.5861],\n", " [ 0.3592],\n", " [ 0.7750],\n", " [ 0.6640],\n", " [ 0.7908],\n", " [ 0.2776],\n", " [ 0.5868],\n", " [ 0.6283],\n", " [ 0.3461],\n", " [ 0.6308],\n", " [ 0.7547],\n", " [ 0.5564],\n", " [ 0.7181],\n", " [ 0.7852],\n", " [ 0.7823],\n", " [ 0.7238],\n", " [ 0.9294],\n", " [ 0.9023],\n", " [ 0.8100],\n", " [ 0.5561],\n", " [ 0.7124],\n", " [ 1.1566],\n", " [ 0.7628],\n", " [ 0.9630],\n", " [ 0.4425],\n", " [ 1.0628],\n", " [ 0.7014],\n", " [ 0.4439],\n", " [ 0.7286],\n", " [ 0.8099],\n", " [ 0.5786],\n", " [ 1.0638],\n", " [ 0.9519],\n", " [ 0.8388],\n", " [ 1.2088],\n", " [ 0.9172],\n", " [ 0.7014],\n", " [ 0.5667],\n", " [ 0.6040],\n", " [ 0.5549],\n", " [ 0.7959],\n", " [ 0.9167],\n", " [ 0.9074],\n", " [ 0.6108],\n", " [ 0.8999],\n", " [ 0.9197],\n", " [ 0.8539],\n", " [ 0.6566],\n", " [ 0.9941],\n", " [ 0.6902],\n", " [ 0.8782],\n", " [ 1.4898],\n", " [ 0.9888],\n", " [ 1.1911],\n", " [ 0.5683],\n", " [ 0.8868],\n", " [ 0.7122],\n", " [ 0.8960],\n", " [ 1.1454],\n", " [ 1.2660],\n", " [ 1.0001],\n", " [ 0.6582],\n", " [ 0.9706],\n", " [ 1.0110],\n", " [ 1.0355],\n", " [ 0.9761],\n", " [ 0.9439],\n", " [ 1.0824],\n", " [ 1.4095],\n", " [ 1.2544],\n", " [ 0.6541],\n", " [ 1.0486],\n", " [ 1.0638],\n", " [ 0.9000],\n", " [ 0.9835],\n", " [ 1.2558],\n", " [ 1.1702],\n", " [ 0.8466],\n", " [ 0.7696],\n", " [ 1.2446],\n", " [ 1.1460],\n", " [ 0.9258],\n", " [ 0.8150],\n", " [ 1.1086],\n", " [ 0.9475],\n", " [ 0.9675],\n", " [ 0.7330],\n", " [ 1.1263],\n", " [ 1.1718],\n", " [ 0.9413],\n", " [ 1.0272],\n", " [ 0.7733],\n", " [ 0.9831],\n", " [ 0.8759],\n", " [ 0.7970],\n", " [ 0.6360],\n", " [ 1.1815],\n", " [ 0.9689],\n", " [ 0.6976],\n", " [ 0.9265],\n", " [ 0.8338],\n", " [ 0.7960],\n", " [ 0.7705],\n", " [ 1.2601],\n", " [ 1.2775],\n", " [ 0.7706],\n", " [ 1.0216],\n", " [ 1.1916],\n", " [ 0.8603],\n", " [ 0.9864],\n", " [ 1.0777],\n", " [ 0.8930],\n", " [ 1.0063],\n", " [ 0.8376],\n", " [ 0.9923],\n", " [ 0.8081],\n", " [ 0.8020],\n", " [ 1.1461],\n", " [ 1.1018],\n", " [ 0.8931],\n", " [ 1.0005],\n", " [ 0.8635],\n", " [ 0.7197],\n", " [ 1.2577],\n", " [ 1.0584],\n", " [ 1.4032],\n", " [ 0.8911],\n", " [ 1.1415],\n", " [ 0.8241],\n", " [ 0.7946],\n", " [ 1.0221],\n", " [ 0.8792],\n", " [ 0.7211],\n", " [ 1.1821],\n", " [ 0.8079],\n", " [ 0.8926],\n", " [ 1.0765],\n", " [ 0.9949],\n", " [ 0.9159],\n", " [ 0.7329],\n", " [ 0.9950],\n", " [ 0.7491],\n", " [ 0.8750],\n", " [ 1.1863],\n", " [ 1.0095],\n", " [ 0.8046],\n", " [ 0.6274],\n", " [ 0.8936],\n", " [ 0.7595],\n", " [ 0.8423],\n", " [ 0.8655],\n", " [ 0.6918],\n", " [ 0.7347],\n", " [ 1.1179],\n", " [ 0.5931],\n", " [ 0.8745],\n", " [ 0.4858],\n", " [ 0.9338],\n", " [ 1.1382],\n", " [ 0.6084],\n", " [ 0.9479],\n", " [ 0.8726],\n", " [ 0.7202],\n", " [ 0.9596],\n", " [ 0.4386],\n", " [ 1.2525],\n", " [ 0.5120],\n", " [ 0.7222],\n", " [ 0.6566],\n", " [ 0.8965],\n", " [ 0.7545],\n", " [ 1.1104],\n", " [ 0.6634],\n", " [ 0.5654],\n", " [ 1.0095],\n", " [ 0.6558],\n", " [ 0.7260],\n", " [ 0.8515],\n", " [ 0.3430],\n", " [ 0.7703],\n", " [ 0.3753],\n", " [ 0.4490],\n", " [ 0.4373],\n", " [ 0.8283],\n", " [ 0.5455],\n", " [ 0.7584],\n", " [ 0.8197],\n", " [ 0.4781],\n", " [ 0.3350],\n", " [ 0.6714],\n", " [ 0.3969],\n", " [ 0.7131],\n", " [ 0.5609],\n", " [ 0.4327],\n", " [ 0.4293],\n", " [ 0.3552],\n", " [ 0.5445],\n", " [ 0.5609],\n", " [ 0.4110],\n", " [ 0.8525],\n", " [ 0.3402],\n", " [ 0.4064],\n", " [ 0.5172],\n", " [ 0.5845],\n", " [ 0.6185],\n", " [ 0.4719],\n", " [ 0.9092],\n", " [ 0.6964],\n", " [ 0.7267],\n", " [ 0.6934],\n", " [ 0.4337],\n", " [ 0.2031],\n", " [ 0.0898],\n", " [ 0.4377],\n", " [ 0.4203],\n", " [ 0.2855],\n", " [ 0.4673],\n", " [ 0.6029],\n", " [ 0.4368],\n", " [ 0.1521],\n", " [-0.0606],\n", " [ 0.2532],\n", " [ 0.4365],\n", " [ 0.2989],\n", " [ 0.0743],\n", " [ 0.2734],\n", " [ 0.1060],\n", " [ 0.5543],\n", " [-0.1211],\n", " [ 0.0968],\n", " [ 0.4911],\n", " [ 0.5107],\n", " [ 0.4583],\n", " [ 0.2777],\n", " [-0.0513],\n", " [ 0.1437],\n", " [ 0.0548],\n", " [ 0.0933],\n", " [ 0.1172],\n", " [ 0.0718],\n", " [ 0.4027],\n", " [ 0.1805],\n", " [ 0.0869],\n", " [-0.3066],\n", " [ 0.5615],\n", " [-0.2721],\n", " [-0.2765],\n", " [ 0.0850],\n", " [-0.1473],\n", " [-0.1622],\n", " [-0.1335],\n", " [ 0.1328],\n", " [ 0.0703],\n", " [-0.6712],\n", " [-0.1121],\n", " [-0.1208],\n", " [ 0.0092],\n", " [-0.0805],\n", " [-0.2017],\n", " [ 0.2339],\n", " [-0.3533],\n", " [-0.4598],\n", " [ 0.0620],\n", " [-0.5254],\n", " [ 0.0197],\n", " [-0.0593],\n", " [-0.1914],\n", " [-0.4259],\n", " [-0.0115],\n", " [-0.5406],\n", " [ 0.0137],\n", " [-0.4240],\n", " [-0.2822],\n", " [-0.0796],\n", " [-0.3495],\n", " [-0.4475],\n", " [ 0.2453],\n", " [-0.3729],\n", " [-0.4086],\n", " [-0.2618],\n", " [-0.4539],\n", " [-0.6140],\n", " [-0.2483],\n", " [-0.4165],\n", " [-0.3736],\n", " [-0.0737],\n", " [-0.0212],\n", " [-0.3644],\n", " [-0.0472],\n", " [-0.4087],\n", " [-0.6794],\n", " [-0.5921],\n", " [-0.5632],\n", " [-0.5971],\n", " [-0.1935],\n", " [-0.9543],\n", " [-0.7976],\n", " [-0.3485],\n", " [-0.9538],\n", " [-0.6171],\n", " [-0.7755],\n", " [-0.4651],\n", " [-0.8194],\n", " [-0.3005],\n", " [-0.5191],\n", " [-0.5902],\n", " [-0.2464],\n", " [-0.6908],\n", " [-0.5054],\n", " [-0.5528],\n", " [-1.1089],\n", " [-0.7206],\n", " [-0.8067],\n", " [-0.6780],\n", " [-0.2981],\n", " [-0.6683],\n", " [-0.4324],\n", " [-0.8497],\n", " [-0.5928],\n", " [-0.7203],\n", " [-0.3751],\n", " [-0.7423],\n", " [-0.4109],\n", " [-0.7345],\n", " [-0.6653],\n", " [-0.5752],\n", " [-0.5198],\n", " [-0.7046],\n", " [-1.1754],\n", " [-0.9447],\n", " [-0.7304],\n", " [-0.6510],\n", " [-0.5954],\n", " [-0.7592],\n", " [-0.5285],\n", " [-0.4249],\n", " [-0.7993],\n", " [-1.3758],\n", " [-0.6218],\n", " [-1.0691],\n", " [-0.5775],\n", " [-0.8174],\n", " [-0.7021],\n", " [-0.7784],\n", " [-0.7553],\n", " [-1.2137],\n", " [-0.7302],\n", " [-0.7253],\n", " [-0.6819],\n", " [-1.3077],\n", " [-1.3472],\n", " [-0.7104],\n", " [-0.8387],\n", " [-0.5973],\n", " [-0.8619],\n", " [-1.1138],\n", " [-1.1314],\n", " [-0.9765],\n", " [-1.2121],\n", " [-0.8168],\n", " [-0.7763],\n", " [-1.2988],\n", " [-0.9282],\n", " [-1.1715],\n", " [-0.7216],\n", " [-0.7182],\n", " [-0.2972],\n", " [-0.7471],\n", " [-1.0089],\n", " [-1.1431],\n", " [-1.0396],\n", " [-1.0381],\n", " [-0.5979],\n", " [-0.7363],\n", " [-0.7808],\n", " [-0.9106],\n", " [-1.1468],\n", " [-1.1357],\n", " [-0.6406],\n", " [-0.9603],\n", " [-1.2653],\n", " [-1.5958],\n", " [-1.0592],\n", " [-0.9698],\n", " [-0.8252],\n", " [-1.2515],\n", " [-1.0474],\n", " [-1.1103],\n", " [-1.0035],\n", " [-0.6669],\n", " [-0.9120],\n", " [-0.9146],\n", " [-1.1079],\n", " [-0.8379],\n", " [-0.9123],\n", " [-0.5831],\n", " [-1.6515],\n", " [-0.9385],\n", " [-1.0699],\n", " [-1.1498],\n", " [-0.7861],\n", " [-0.8942],\n", " [-1.0452],\n", " [-1.0064],\n", " [-0.9116],\n", " [-1.1150],\n", " [-0.7801],\n", " [-1.0283],\n", " [-1.0296],\n", " [-1.0927],\n", " [-0.7945],\n", " [-1.0705],\n", " [-1.3215],\n", " [-1.2510],\n", " [-0.9158],\n", " [-0.9377],\n", " [-0.7314],\n", " [-0.9773],\n", " [-1.1910],\n", " [-1.0539],\n", " [-1.1439],\n", " [-1.0784],\n", " [-0.8543],\n", " [-1.1323],\n", " [-1.3193],\n", " [-0.8014],\n", " [-0.7318],\n", " [-0.5805],\n", " [-0.8239],\n", " [-1.1228],\n", " [-1.0473],\n", " [-0.8206],\n", " [-0.6544],\n", " [-1.2654],\n", " [-1.0757],\n", " [-0.5389],\n", " [-0.9908],\n", " [-0.7894],\n", " [-0.7463],\n", " [-1.0391],\n", " [-0.8023],\n", " [-0.8568],\n", " [-1.2414],\n", " [-0.9595],\n", " [-1.1151],\n", " [-0.9689],\n", " [-1.1145],\n", " [-0.6853],\n", " [-0.7547],\n", " [-1.1000],\n", " [-0.9054],\n", " [-1.2262],\n", " [-1.1359],\n", " [-1.0174],\n", " [-0.3782],\n", " [-0.8056],\n", " [-1.1828],\n", " [-0.8426],\n", " [-0.9958],\n", " [-0.9495],\n", " [-1.2745],\n", " [-0.7039],\n", " [-0.5893],\n", " [-0.5648],\n", " [-1.0538],\n", " [-0.6724],\n", " [-0.6340],\n", " [-0.5070],\n", " [-1.0956],\n", " [-1.0957],\n", " [-0.6823],\n", " [-0.5258],\n", " [-0.5777],\n", " [-0.9268],\n", " [-0.5280],\n", " [-0.5989],\n", " [-0.8364],\n", " [-0.7439],\n", " [-0.7619],\n", " [-1.0159],\n", " [-1.0627],\n", " [-0.9416],\n", " [-0.6270],\n", " [-0.4307],\n", " [-0.8575],\n", " [-1.0748],\n", " [-0.5529],\n", " [-0.9339],\n", " [-0.7416],\n", " [-0.6674],\n", " [-0.3178],\n", " [-0.6815],\n", " [-0.7499],\n", " [-0.6359],\n", " [-0.8157],\n", " [-0.5582],\n", " [-0.5083],\n", " [-0.4527],\n", " [-0.8350],\n", " [-0.6317],\n", " [-0.4338],\n", " [-0.4875],\n", " [-0.4046],\n", " [-0.3166],\n", " [-0.3413],\n", " [-0.4722],\n", " [-0.7010],\n", " [-1.2025],\n", " [-0.2133],\n", " [-0.3133],\n", " [-0.4160],\n", " [-0.6681],\n", " [-0.8990],\n", " [-0.5464],\n", " [-0.4518],\n", " [-0.4402],\n", " [-0.5246],\n", " [-0.4561],\n", " [-0.6747],\n", " [-0.1833],\n", " [-0.4466],\n", " [-0.4671],\n", " [-0.5509],\n", " [-0.6235],\n", " [-0.2100],\n", " [-0.3368],\n", " [-0.3083],\n", " [-0.5129],\n", " [-0.2880],\n", " [-0.4075],\n", " [-0.2784],\n", " [ 0.0631],\n", " [-0.4355],\n", " [-0.4237],\n", " [-0.2578],\n", " [-0.1380],\n", " [-0.5085],\n", " [ 0.1004],\n", " [-0.1426],\n", " [-0.2537],\n", " [-0.1756],\n", " [-0.2135],\n", " [-0.1898],\n", " [-0.2947],\n", " [-0.3934],\n", " [-0.3412],\n", " [-0.3343],\n", " [-0.1450],\n", " [-0.3178],\n", " [-0.2156],\n", " [-0.3232],\n", " [-0.3691],\n", " [-0.2711],\n", " [ 0.1086],\n", " [-0.2257],\n", " [-0.0752],\n", " [-0.0339],\n", " [-0.0636],\n", " [ 0.0626],\n", " [-0.1460],\n", " [ 0.0792],\n", " [ 0.1529],\n", " [ 0.4743],\n", " [ 0.0343],\n", " [-0.0158],\n", " [-0.1255],\n", " [-0.4698],\n", " [-0.0489],\n", " [ 0.2622],\n", " [ 0.0619],\n", " [-0.2243],\n", " [-0.1318],\n", " [ 0.0214],\n", " [ 0.2690],\n", " [ 0.0497],\n", " [ 0.3451],\n", " [-0.1116],\n", " [ 0.0173],\n", " [ 0.0708],\n", " [ 0.4135],\n", " [ 0.3188],\n", " [ 0.4808],\n", " [-0.0340],\n", " [ 0.4786],\n", " [ 0.4896],\n", " [ 0.1077],\n", " [ 0.3500],\n", " [ 0.1309],\n", " [ 0.1398],\n", " [ 0.1943],\n", " [ 0.1651],\n", " [ 0.3227],\n", " [ 0.5541],\n", " [ 0.2688],\n", " [ 0.1892],\n", " [ 0.2509],\n", " [ 0.2078],\n", " [-0.0140],\n", " [ 0.2443],\n", " [ 0.3204],\n", " [ 0.5485],\n", " [ 0.4234],\n", " [ 0.3135],\n", " [ 0.4633],\n", " [ 0.0029],\n", " [ 0.2174],\n", " [ 0.6879],\n", " [ 0.5089],\n", " [ 0.2479],\n", " [ 0.8608],\n", " [ 0.4307],\n", " [ 0.6205],\n", " [ 0.3482],\n", " [ 0.6469],\n", " [ 0.4475],\n", " [ 0.6595],\n", " [ 0.3450],\n", " [ 0.3781],\n", " [ 0.4451],\n", " [ 0.1883],\n", " [ 0.6707],\n", " [ 0.8667],\n", " [ 0.5218],\n", " [ 0.4004],\n", " [ 0.5271],\n", " [ 0.6446],\n", " [ 0.7222],\n", " [ 0.5722],\n", " [ 0.7676],\n", " [ 0.6824],\n", " [ 0.1981],\n", " [ 0.8089],\n", " [ 0.6296],\n", " [ 0.6748],\n", " [ 0.7515],\n", " [ 0.5103],\n", " [ 0.9052],\n", " [ 0.8405],\n", " [ 0.9092],\n", " [ 0.6918],\n", " [ 0.6477],\n", " [ 0.5402],\n", " [ 0.6477],\n", " [ 0.4210],\n", " [ 0.6973],\n", " [ 0.6019],\n", " [ 0.5364],\n", " [ 0.8134],\n", " [ 0.5607],\n", " [ 0.7096],\n", " [ 0.5894],\n", " [ 0.3866],\n", " [ 1.0600],\n", " [ 0.7347],\n", " [ 0.8129],\n", " [ 1.2088],\n", " [ 0.8825],\n", " [ 0.7179],\n", " [ 1.0115],\n", " [ 0.7013],\n", " [ 1.0128],\n", " [ 0.9747],\n", " [ 1.2759],\n", " [ 0.7655],\n", " [ 1.0094],\n", " [ 0.7805],\n", " [ 0.6091],\n", " [ 1.2033],\n", " [ 0.9678],\n", " [ 0.8219],\n", " [ 0.8157],\n", " [ 0.9188],\n", " [ 0.7436],\n", " [ 0.8910],\n", " [ 0.7291],\n", " [ 0.9559],\n", " [ 0.9389],\n", " [ 1.2030],\n", " [ 1.0495],\n", " [ 1.1811],\n", " [ 0.8884],\n", " [ 0.8390],\n", " [ 0.9894],\n", " [ 0.9238],\n", " [ 0.7628],\n", " [ 0.5421],\n", " [ 1.5147],\n", " [ 0.6971],\n", " [ 0.6740],\n", " [ 0.8342],\n", " [ 0.6554],\n", " [ 0.7455],\n", " [ 0.6916],\n", " [ 1.2706],\n", " [ 1.1277],\n", " [ 0.9248],\n", " [ 0.9976],\n", " [ 1.2404],\n", " [ 0.6919],\n", " [ 1.3449],\n", " [ 1.1243],\n", " [ 1.0492],\n", " [ 0.9266],\n", " [ 1.1194],\n", " [ 1.0304],\n", " [ 1.1323],\n", " [ 1.2372],\n", " [ 0.8300],\n", " [ 1.1916],\n", " [ 1.0923],\n", " [ 0.8313],\n", " [ 0.8572],\n", " [ 1.1128],\n", " [ 1.0047],\n", " [ 1.1544],\n", " [ 0.9745],\n", " [ 1.0503],\n", " [ 0.9171],\n", " [ 0.8073],\n", " [ 1.2056],\n", " [ 1.0976],\n", " [ 0.9910],\n", " [ 1.1834],\n", " [ 1.1389],\n", " [ 0.9142],\n", " [ 0.9367],\n", " [ 1.0121],\n", " [ 0.7704],\n", " [ 1.0558],\n", " [ 0.7306],\n", " [ 0.8117],\n", " [ 0.7061],\n", " [ 1.2315],\n", " [ 0.9015],\n", " [ 0.9339],\n", " [ 0.5016],\n", " [ 0.9227],\n", " [ 1.2568],\n", " [ 0.9444],\n", " [ 1.1198],\n", " [ 0.9431],\n", " [ 1.0997],\n", " [ 1.3078],\n", " [ 0.8336],\n", " [ 1.2692],\n", " [ 0.8424],\n", " [ 0.8702],\n", " [ 1.4820],\n", " [ 1.3248],\n", " [ 0.9324],\n", " [ 0.6538],\n", " [ 1.2011],\n", " [ 1.0170],\n", " [ 0.7863],\n", " [ 1.0178],\n", " [ 0.6519],\n", " [ 0.5970],\n", " [ 0.9052],\n", " [ 0.6846],\n", " [ 0.7737],\n", " [ 0.9104],\n", " [ 0.8439],\n", " [ 1.0066],\n", " [ 1.0787],\n", " [ 0.9661],\n", " [ 0.9923],\n", " [ 0.7922],\n", " [ 0.8316],\n", " [ 0.9553],\n", " [ 0.9952],\n", " [ 0.8680],\n", " [ 1.1226],\n", " [ 0.8213],\n", " [ 0.9151],\n", " [ 0.7748],\n", " [ 0.9953],\n", " [ 0.7773],\n", " [ 0.7916],\n", " [ 0.7321],\n", " [ 0.9130],\n", " [ 1.1433],\n", " [ 0.7060],\n", " [ 0.8066],\n", " [ 0.8709],\n", " [ 0.7426],\n", " [ 0.8718],\n", " [ 1.0973],\n", " [ 0.7097],\n", " [ 0.9438],\n", " [ 0.8164],\n", " [ 0.8013],\n", " [ 0.6236],\n", " [ 0.7180],\n", " [ 0.9188],\n", " [ 0.8016],\n", " [ 0.9741],\n", " [ 0.6271],\n", " [ 0.5747],\n", " [ 0.8007],\n", " [ 0.7754],\n", " [ 0.4877],\n", " [ 0.4746],\n", " [ 0.8654],\n", " [ 0.4743],\n", " [ 0.9015],\n", " [ 0.8082],\n", " [ 0.5449],\n", " [ 0.9299],\n", " [ 0.2003],\n", " [ 0.5466],\n", " [ 0.4355],\n", " [ 0.7900],\n", " [ 0.4343],\n", " [ 0.7224],\n", " [ 0.8585],\n", " [ 0.5714],\n", " [ 0.5306],\n", " [ 0.6594],\n", " [ 0.0640],\n", " [ 0.3203],\n", " [ 0.5463],\n", " [ 0.5048],\n", " [ 0.1935],\n", " [ 0.2883],\n", " [ 0.6778],\n", " [ 0.5014],\n", " [ 0.5235],\n", " [ 0.5718],\n", " [ 0.4587],\n", " [ 0.2808],\n", " [ 0.4073],\n", " [ 0.8632],\n", " [ 0.8862],\n", " [ 0.5757],\n", " [ 0.3372],\n", " [ 0.2566],\n", " [ 0.7858],\n", " [ 0.3713],\n", " [ 0.1589],\n", " [ 0.3243],\n", " [ 0.4270],\n", " [ 0.0565],\n", " [ 0.2885],\n", " [ 0.3257],\n", " [ 0.2196],\n", " [ 0.3159],\n", " [ 0.2361],\n", " [ 0.1087],\n", " [ 0.2224],\n", " [ 0.2633],\n", " [ 0.5037],\n", " [ 0.1980],\n", " [ 0.1530],\n", " [ 0.2780],\n", " [-0.1399],\n", " [ 0.5331],\n", " [ 0.3530],\n", " [ 0.3342],\n", " [ 0.2098],\n", " [-0.0165],\n", " [ 0.1318],\n", " [ 0.4510],\n", " [-0.1959],\n", " [ 0.0966],\n", " [ 0.0789],\n", " [ 0.3381],\n", " [-0.1917],\n", " [ 0.1518],\n", " [ 0.3640],\n", " [ 0.0956],\n", " [ 0.2535],\n", " [-0.3988],\n", " [-0.3479],\n", " [ 0.3864],\n", " [-0.2639],\n", " [-0.2368],\n", " [ 0.0258],\n", " [ 0.2441],\n", " [ 0.0687],\n", " [ 0.0457],\n", " [ 0.2286],\n", " [-0.0947],\n", " [-0.1189],\n", " [ 0.1360],\n", " [-0.0990],\n", " [-0.2447],\n", " [ 0.2135],\n", " [-0.1830],\n", " [-0.4583],\n", " [-0.1795],\n", " [-0.1361],\n", " [-0.0553],\n", " [-0.2864],\n", " [-0.2307],\n", " [-0.4651],\n", " [-0.1889],\n", " [-0.3185],\n", " [-0.5318],\n", " [-0.3012],\n", " [ 0.0062],\n", " [ 0.1046],\n", " [-0.2321],\n", " [-0.2945],\n", " [-0.0242],\n", " [-0.0586],\n", " [-0.2307],\n", " [-0.2479],\n", " [-0.0382],\n", " [-0.1509],\n", " [-0.5055],\n", " [-0.3759],\n", " [ 0.2139],\n", " [-0.2129],\n", " [-0.3605],\n", " [-0.5222],\n", " [-0.6530],\n", " [-0.6716],\n", " [-0.4330],\n", " [-0.2577],\n", " [-0.2672],\n", " [-0.1297],\n", " [-0.9203],\n", " [-0.5832],\n", " [-0.2640],\n", " [-0.4996],\n", " [-0.2625],\n", " [-0.4407],\n", " [-0.8864],\n", " [-0.2508],\n", " [-0.4827],\n", " [-0.3131],\n", " [-0.2570],\n", " [-0.7116],\n", " [-0.5357],\n", " [-0.7074]]))" ] }, "execution_count": 85, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 85 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:24.385478317Z", "start_time": "2026-04-22T07:03:24.319263622Z" } }, "cell_type": "code", "source": [ "batch_size, n_train = 16, 600\n", "# 只有前n_train个样本用于训练\n", "train_iter = d2l.load_array((features[:n_train], labels[:n_train]),\n", "batch_size, is_train=True)\n" ], "id": "239a596b20d40dec", "outputs": [], "execution_count": 86 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:24.452682331Z", "start_time": "2026-04-22T07:03:24.386760825Z" } }, "cell_type": "code", "source": [ "def init_weights(m):\n", " if type(m) == nn.Linear:\n", " nn.init.xavier_uniform_(m.weight)" ], "id": "54d30bd0ee41cb8", "outputs": [], "execution_count": 87 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:24.505029692Z", "start_time": "2026-04-22T07:03:24.454087829Z" } }, "cell_type": "code", "source": [ "def get_net():\n", " net = nn.Sequential(nn.Linear(4, 10),\n", " nn.ReLU(),\n", " nn.Linear(10, 1))\n", " net.apply(init_weights)\n", " return net\n", "loss = nn.MSELoss(reduction='none')" ], "id": "5d095792e3b3681", "outputs": [], "execution_count": 88 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:24.610623781Z", "start_time": "2026-04-22T07:03:24.507553628Z" } }, "cell_type": "code", "source": [ "def train(net, train_iter, loss, epochs, lr):\n", " trainer = torch.optim.Adam(net.parameters(), lr)\n", " for epoch in range(epochs):\n", " for X, y in train_iter:\n", " trainer.zero_grad()\n", " l = loss(net(X), y)\n", " l.sum().backward()\n", " trainer.step()\n", " print(f'epoch {epoch + 1}, '\n", " f'loss: {d2l.evaluate_loss(net, train_iter, loss):f}')\n", "net = get_net()\n", "train(net, train_iter, loss, 5, 0.01)" ], "id": "5c1dba484e805335", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch 1, loss: 0.082748\n", "epoch 2, loss: 0.066498\n", "epoch 3, loss: 0.061828\n", "epoch 4, loss: 0.059193\n", "epoch 5, loss: 0.058498\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/home/yukun/.conda/envs/nn/lib/python3.11/site-packages/d2l/torch.py:3179: UserWarning: Converting a tensor with requires_grad=True to a scalar may lead to unexpected behavior.\n", "Consider using tensor.detach() first. (Triggered internally at /pytorch/torch/csrc/autograd/generated/python_variable_methods.cpp:836.)\n", " self.data = [a + float(b) for a, b in zip(self.data, args)]\n" ] } ], "execution_count": 89 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:24.713050170Z", "start_time": "2026-04-22T07:03:24.616031026Z" } }, "cell_type": "code", "source": [ "onestep_preds = net(features)\n", "d2l.plot([time, time[tau:]],\n", "[x.detach().numpy(), onestep_preds.detach().numpy()], 'time',\n", "'x', legend=['data', '1-step preds'], xlim=[1, 1000],\n", "figsize=(6, 3))" ], "id": "a6efb4a978cd9375", "outputs": [ { "data": { "text/plain": [ "
" ], "image/svg+xml": "\n\n\n \n \n \n \n 2026-04-22T15:03:24.684021\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" }, "metadata": {}, "output_type": "display_data", "jetTransient": { "display_id": null } } ], "execution_count": 90 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:24.848161521Z", "start_time": "2026-04-22T07:03:24.729135740Z" } }, "cell_type": "code", "source": [ "multistep_preds = torch.zeros(T)\n", "multistep_preds[: n_train + tau] = x[: n_train + tau]\n", "for i in range(n_train + tau, T):\n", " multistep_preds[i] = net(\n", " multistep_preds[i - tau:i].reshape((1, -1)))\n", "d2l.plot([time, time[tau:], time[n_train + tau:]],\n", " [x.detach().numpy(), onestep_preds.detach().numpy(),\n", " multistep_preds[n_train + tau:].detach().numpy()], 'time',\n", " 'x', legend=['data', '1-step preds', 'multistep preds'],\n", " xlim=[1, 1000], figsize=(6, 3))" ], "id": "12c3a1c3912da4dd", "outputs": [ { "data": { "text/plain": [ "
" ], "image/svg+xml": "\n\n\n \n \n \n \n 2026-04-22T15:03:24.805438\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" }, "metadata": {}, "output_type": "display_data", "jetTransient": { "display_id": null } } ], "execution_count": 91 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:24.898795229Z", "start_time": "2026-04-22T07:03:24.850233508Z" } }, "cell_type": "code", "source": [ "import collections\n", "import re" ], "id": "aab66c10a4c143d2", "outputs": [], "execution_count": 92 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:24.953966099Z", "start_time": "2026-04-22T07:03:24.901011365Z" } }, "cell_type": "code", "source": [ "d2l.DATA_HUB['time_machine'] = (d2l.DATA_URL + 'timemachine.txt',\n", "'090b5e7e70c295757f55df93cb0a180b9691891a')\n", "def read_time_machine(): #@save\n", " \"\"\"将时间机器数据集加载到文本行的列表中\"\"\"\n", " with open(d2l.download('time_machine'), 'r') as f:\n", " lines = f.readlines()\n", " return [re.sub('[^A-Za-z]+', ' ', line).strip().lower() for line in lines]\n", "lines = read_time_machine()\n", "print(f'# 文本总行数: {len(lines)}')\n", "print(lines[0])\n", "print(lines[10])" ], "id": "1aff117af810525e", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "# 文本总行数: 3221\n", "the time machine by h g wells\n", "twinkled and his usually pale face was flushed and animated the\n" ] } ], "execution_count": 93 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:25.021941150Z", "start_time": "2026-04-22T07:03:24.965670989Z" } }, "cell_type": "code", "source": [ "def tokenize(lines, token='word'): #@save\n", " \"\"\"将文本行拆分为单词或字符词元\"\"\"\n", " if token == 'word':\n", " return [line.split() for line in lines]\n", " elif token == 'char':\n", " return [list(line) for line in lines]\n", " else:\n", " print('错误:未知词元类型:' + token)\n", "tokens = tokenize(lines)\n", "for i in range(11):\n", " print(tokens[i])" ], "id": "eb4fe9745fbaa5e2", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['the', 'time', 'machine', 'by', 'h', 'g', 'wells']\n", "[]\n", "[]\n", "[]\n", "[]\n", "['i']\n", "[]\n", "[]\n", "['the', 'time', 'traveller', 'for', 'so', 'it', 'will', 'be', 'convenient', 'to', 'speak', 'of', 'him']\n", "['was', 'expounding', 'a', 'recondite', 'matter', 'to', 'us', 'his', 'grey', 'eyes', 'shone', 'and']\n", "['twinkled', 'and', 'his', 'usually', 'pale', 'face', 'was', 'flushed', 'and', 'animated', 'the']\n" ] } ], "execution_count": 94 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:25.086591433Z", "start_time": "2026-04-22T07:03:25.032865323Z" } }, "cell_type": "code", "source": [ "def count_corpus(tokens): #@save\n", " \"\"\"统计词元的频率\"\"\"\n", " # 这里的tokens是1D列表或2D列表\n", " if len(tokens) == 0 or isinstance(tokens[0], list):\n", " # 将词元列表展平成一个列表\n", " tokens = [token for line in tokens for token in line]\n", " return collections.Counter(tokens)\n", "class Vocab: #@save\n", " \"\"\"文本词表\"\"\"\n", " def __init__(self, tokens=None, min_freq=0, reserved_tokens=None):\n", " if tokens is None:\n", " tokens = []\n", " if reserved_tokens is None:\n", " reserved_tokens = []\n", " # 按出现频率排序\n", " counter = count_corpus(tokens)\n", " self._token_freqs = sorted(counter.items(), key=lambda x: x[1],\n", " reverse=True)\n", " # 未知词元的索引为0\n", " self.idx_to_token = [''] + reserved_tokens\n", " self.token_to_idx = {token: idx\n", " for idx, token in enumerate(self.idx_to_token)}\n", " for token, freq in self._token_freqs:\n", " if freq < min_freq:\n", " break\n", " if token not in self.token_to_idx:\n", " self.idx_to_token.append(token)\n", " self.token_to_idx[token] = len(self.idx_to_token) - 1\n", " def __len__(self):\n", " return len(self.idx_to_token)\n", " def __getitem__(self, tokens):\n", " if not isinstance(tokens, (list, tuple)):\n", " return self.token_to_idx.get(tokens, self.unk)\n", " return [self.__getitem__(token) for token in tokens]\n", " def to_tokens(self, indices):\n", " if not isinstance(indices, (list, tuple)):\n", " return self.idx_to_token[indices]\n", " return [self.idx_to_token[index] for index in indices]\n", " @property\n", " def unk(self): # 未知词元的索引为0\n", " return 0\n", " @property\n", " def token_freqs(self):\n", " return self._token_freqs" ], "id": "bee8e5d7b798c6c", "outputs": [], "execution_count": 95 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:25.242997320Z", "start_time": "2026-04-22T07:03:25.089155138Z" } }, "cell_type": "code", "source": [ "vocab = Vocab(tokens)\n", "print(list(vocab.token_to_idx.items())[:10])" ], "id": "ff4e8ac2044850b7", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[('', 0), ('the', 1), ('i', 2), ('and', 3), ('of', 4), ('a', 5), ('to', 6), ('was', 7), ('in', 8), ('that', 9)]\n" ] } ], "execution_count": 96 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:25.308904788Z", "start_time": "2026-04-22T07:03:25.255767350Z" } }, "cell_type": "code", "source": [ "for i in [0, 100]:\n", " print('文本:', tokens[i])\n", " print('索引:', vocab[tokens[i]])" ], "id": "a4e569dfbd251608", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "文本: ['the', 'time', 'machine', 'by', 'h', 'g', 'wells']\n", "索引: [1, 19, 50, 40, 2183, 2184, 400]\n", "文本: ['were', 'three', 'dimensional', 'representations', 'of', 'his', 'four', 'dimensioned']\n", "索引: [20, 175, 1452, 2250, 4, 25, 262, 2251]\n" ] } ], "execution_count": 97 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:25.387400952Z", "start_time": "2026-04-22T07:03:25.322962901Z" } }, "cell_type": "code", "source": [ "def load_corpus_time_machine(max_tokens=-1): #@save\n", " \"\"\"返回时光机器数据集的词元索引列表和词表\"\"\"\n", " lines = read_time_machine()\n", " tokens = tokenize(lines, 'char')\n", " vocab = Vocab(tokens)\n", " # 因为时光机器数据集中的每个文本行不一定是一个句子或一个段落,\n", " # 所以将所有文本行展平到一个列表中\n", " corpus = [vocab[token] for line in tokens for token in line]\n", " if max_tokens > 0:\n", " corpus = corpus[:max_tokens]\n", " return corpus, vocab\n", "corpus, vocab = load_corpus_time_machine()\n", "\n", "len(corpus), len(vocab)\n" ], "id": "1b5c1776ae47af5c", "outputs": [ { "data": { "text/plain": [ "(170580, 28)" ] }, "execution_count": 98, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 98 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:25.453063309Z", "start_time": "2026-04-22T07:03:25.388764888Z" } }, "cell_type": "code", "source": [ "tokens = d2l.tokenize(read_time_machine())\n", "# 因为每个文本行不一定是一个句子或一个段落,因此我们把所有文本行拼接到一起\n", "corpus = [token for line in tokens for token in line]\n", "vocab = d2l.Vocab(corpus)\n", "vocab.token_freqs[:10]" ], "id": "99deb85c025e5cdd", "outputs": [ { "data": { "text/plain": [ "[('the', 2261),\n", " ('i', 1267),\n", " ('and', 1245),\n", " ('of', 1155),\n", " ('a', 816),\n", " ('to', 695),\n", " ('was', 552),\n", " ('in', 541),\n", " ('that', 443),\n", " ('my', 440)]" ] }, "execution_count": 99, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 99 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:25.731046331Z", "start_time": "2026-04-22T07:03:25.454316415Z" } }, "cell_type": "code", "source": [ "freqs = [freq for token, freq in vocab.token_freqs]\n", "d2l.plot(freqs, xlabel='token: x', ylabel='frequency: n(x)',\n", "xscale='log', yscale='log')" ], "id": "5c0d846673e16c33", "outputs": [ { "data": { "text/plain": [ "
" ], "image/svg+xml": "\n\n\n \n \n \n \n 2026-04-22T15:03:25.667027\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" }, "metadata": {}, "output_type": "display_data", "jetTransient": { "display_id": null } } ], "execution_count": 100 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:25.792389729Z", "start_time": "2026-04-22T07:03:25.734951149Z" } }, "cell_type": "code", "source": [ "bigram_tokens = [pair for pair in zip(corpus[:-1], corpus[1:])]\n", "bigram_vocab = Vocab(bigram_tokens)\n", "bigram_vocab.token_freqs[:10]" ], "id": "2826e3ab0863ee64", "outputs": [ { "data": { "text/plain": [ "[(('of', 'the'), 309),\n", " (('in', 'the'), 169),\n", " (('i', 'had'), 130),\n", " (('i', 'was'), 112),\n", " (('and', 'the'), 109),\n", " (('the', 'time'), 102),\n", " (('it', 'was'), 99),\n", " (('to', 'the'), 85),\n", " (('as', 'i'), 78),\n", " (('of', 'a'), 73)]" ] }, "execution_count": 101, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 101 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:25.888424198Z", "start_time": "2026-04-22T07:03:25.808960001Z" } }, "cell_type": "code", "source": [ "trigram_tokens = [triple for triple in zip(\n", "corpus[:-2], corpus[1:-1], corpus[2:])]\n", "trigram_vocab = Vocab(trigram_tokens)\n", "trigram_vocab.token_freqs[:10]" ], "id": "7c8cd0544bf872bb", "outputs": [ { "data": { "text/plain": [ "[(('the', 'time', 'traveller'), 59),\n", " (('the', 'time', 'machine'), 30),\n", " (('the', 'medical', 'man'), 24),\n", " (('it', 'seemed', 'to'), 16),\n", " (('it', 'was', 'a'), 15),\n", " (('here', 'and', 'there'), 15),\n", " (('seemed', 'to', 'me'), 14),\n", " (('i', 'did', 'not'), 14),\n", " (('i', 'saw', 'the'), 13),\n", " (('i', 'began', 'to'), 13)]" ] }, "execution_count": 102, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 102 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:26.133953109Z", "start_time": "2026-04-22T07:03:25.889784531Z" } }, "cell_type": "code", "source": [ "bigram_freqs = [freq for token, freq in bigram_vocab.token_freqs]\n", "trigram_freqs = [freq for token, freq in trigram_vocab.token_freqs]\n", "d2l.plot([freqs, bigram_freqs, trigram_freqs], xlabel='token: x',\n", "ylabel='frequency: n(x)', xscale='log', yscale='log',\n", "legend=['unigram', 'bigram', 'trigram'])" ], "id": "dc3d97dda738613d", "outputs": [ { "data": { "text/plain": [ "
" ], "image/svg+xml": "\n\n\n \n \n \n \n 2026-04-22T15:03:26.075629\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" }, "metadata": {}, "output_type": "display_data", "jetTransient": { "display_id": null } } ], "execution_count": 103 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:26.196704442Z", "start_time": "2026-04-22T07:03:26.148119411Z" } }, "cell_type": "code", "source": [ "import random\n", "def seq_data_iter_random(corpus, batch_size, num_steps): #@save\n", " \"\"\"使用随机抽样生成一个小批量子序列\"\"\"\n", " # 从随机偏移量开始对序列进行分区,随机范围包括num_steps-1\n", " corpus = corpus[random.randint(0, num_steps - 1):]\n", " # 减去1,是因为我们需要考虑标签\n", " num_subseqs = (len(corpus) - 1) // num_steps\n", " # 长度为num_steps的子序列的起始索引\n", " initial_indices = list(range(0, num_subseqs * num_steps, num_steps))\n", " # 在随机抽样的迭代过程中,\n", " # 来自两个相邻的、随机的、小批量中的子序列不一定在原始序列上相邻\n", " random.shuffle(initial_indices)\n", " def data(pos):\n", " # 返回从pos位置开始的长度为num_steps的序列\n", " return corpus[pos: pos + num_steps]\n", " num_batches = num_subseqs // batch_size\n", " for i in range(0, batch_size * num_batches, batch_size):\n", " # 在这里,initial_indices包含子序列的随机起始索引\n", " initial_indices_per_batch = initial_indices[i: i + batch_size]\n", " X = [data(j) for j in initial_indices_per_batch]\n", " Y = [data(j + 1) for j in initial_indices_per_batch]\n", " yield torch.tensor(X), torch.tensor(Y)" ], "id": "fd015793938b83ab", "outputs": [], "execution_count": 104 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:26.251762886Z", "start_time": "2026-04-22T07:03:26.199184612Z" } }, "cell_type": "code", "source": [ "my_seq = list(range(35))\n", "for X, Y in seq_data_iter_random(my_seq, batch_size=2, num_steps=5):\n", " print('X: ', X, '\\nY:', Y)" ], "id": "8961f4934b8c7cc", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "X: tensor([[ 3, 4, 5, 6, 7],\n", " [28, 29, 30, 31, 32]]) \n", "Y: tensor([[ 4, 5, 6, 7, 8],\n", " [29, 30, 31, 32, 33]])\n", "X: tensor([[23, 24, 25, 26, 27],\n", " [ 8, 9, 10, 11, 12]]) \n", "Y: tensor([[24, 25, 26, 27, 28],\n", " [ 9, 10, 11, 12, 13]])\n", "X: tensor([[18, 19, 20, 21, 22],\n", " [13, 14, 15, 16, 17]]) \n", "Y: tensor([[19, 20, 21, 22, 23],\n", " [14, 15, 16, 17, 18]])\n" ] } ], "execution_count": 105 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:26.315745320Z", "start_time": "2026-04-22T07:03:26.265138594Z" } }, "cell_type": "code", "source": [ "def seq_data_iter_sequential(corpus, batch_size, num_steps): #@save\n", " \"\"\"使用顺序分区生成一个小批量子序列\"\"\"\n", " # 从随机偏移量开始划分序列\n", " offset = random.randint(0, num_steps)\n", " num_tokens = ((len(corpus) - offset - 1) // batch_size) * batch_size\n", " Xs = torch.tensor(corpus[offset: offset + num_tokens])\n", " Ys = torch.tensor(corpus[offset + 1: offset + 1 + num_tokens])\n", " Xs, Ys = Xs.reshape(batch_size, -1), Ys.reshape(batch_size, -1)\n", " num_batches = Xs.shape[1] // num_steps\n", " for i in range(0, num_steps * num_batches, num_steps):\n", " X = Xs[:, i: i + num_steps]\n", " Y = Ys[:, i: i + num_steps]\n", " yield X, Y" ], "id": "621b66c0614b22da", "outputs": [], "execution_count": 106 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:26.366004915Z", "start_time": "2026-04-22T07:03:26.318858875Z" } }, "cell_type": "code", "source": [ "class SeqDataLoader:\n", " \"\"\"加载序列数据的迭代器\"\"\"\n", " def __init__(self, batch_size, num_steps, use_random_iter, max_tokens):\n", " if use_random_iter:\n", " self.data_iter_fn = seq_data_iter_random\n", " else:\n", " self.data_iter_fn = seq_data_iter_sequential\n", " self.corpus, self.vocab = load_corpus_time_machine(max_tokens)\n", " self.batch_size, self.num_steps = batch_size, num_steps\n", " def __iter__(self):\n", " return self.data_iter_fn(self.corpus, self.batch_size, self.num_steps)\n", "def load_data_time_machine(batch_size, num_steps, #@save\n", " use_random_iter=False, max_tokens=10000):\n", " \"\"\"返回时光机器数据集的迭代器和词表\"\"\"\n", " data_iter = SeqDataLoader(\n", " batch_size, num_steps, use_random_iter, max_tokens)\n", " return data_iter, data_iter.vocab" ], "id": "f09fe2507a925fe9", "outputs": [], "execution_count": 107 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:26.415821442Z", "start_time": "2026-04-22T07:03:26.368476221Z" } }, "cell_type": "code", "source": [ "batch_size, num_steps = 32, 35\n", "train_iter, vocab = load_data_time_machine(batch_size, num_steps)" ], "id": "69272a664d3b9ae1", "outputs": [], "execution_count": 108 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:26.469093543Z", "start_time": "2026-04-22T07:03:26.417213406Z" } }, "cell_type": "code", "source": "F.one_hot(torch.tensor([0,2]),len(vocab))", "id": "35806d36e5ec3ca7", "outputs": [ { "data": { "text/plain": [ "tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0],\n", " [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0]])" ] }, "execution_count": 109, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 109 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:26.612564439Z", "start_time": "2026-04-22T07:03:26.520392652Z" } }, "cell_type": "code", "source": [ "X = torch.arange(10).reshape((2, 5))\n", "F.one_hot(X.T, 28).shape" ], "id": "6a4695284b898013", "outputs": [ { "data": { "text/plain": [ "torch.Size([5, 2, 28])" ] }, "execution_count": 110, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 110 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:26.703519372Z", "start_time": "2026-04-22T07:03:26.641662463Z" } }, "cell_type": "code", "source": [ "def get_params(vocab_size, num_hiddens, device):\n", " num_inputs = num_outputs = vocab_size\n", " def normal(shape):\n", " return torch.randn(size=shape, device=device) * 0.01\n", " # 隐藏层参数\n", " W_xh = normal((num_inputs, num_hiddens))\n", " W_hh = normal((num_hiddens, num_hiddens))\n", " b_h = torch.zeros(num_hiddens, device=device)\n", " # 输出层参数\n", " W_hq = normal((num_hiddens, num_outputs))\n", " b_q = torch.zeros(num_outputs, device=device)\n", " # 附加梯度\n", " params = [W_xh, W_hh, b_h, W_hq, b_q]\n", " '''\n", " W_xh x->H_t\n", " W_hh H_t-1->H_t\n", " W_hq H_t->opt\n", " '''\n", " for param in params:\n", " param.requires_grad_(True)\n", " return params\n", "def init_rnn_state(batch_size, num_hiddens, device):\n", " return (torch.zeros((batch_size, num_hiddens), device=device), )\n", "def rnn(inputs,state,params):\n", " W_xh,W_hh,b_h,W_hq,b_q = params\n", " H, = state\n", " outputs = []\n", " for X in inputs: # X (batchs,vocab)\n", " H = torch.tanh(torch.mm(X,W_xh)+torch.mm(H,W_hh)+b_h)\n", " Y = torch.mm(H,W_hq)+b_q\n", " outputs.append(Y)\n", " return torch.cat(outputs,dim=0),(H,)\n", "class RNNModelScratch: #@save\n", " \"\"\"从零开始实现的循环神经网络模型\"\"\"\n", " def __init__(self, vocab_size, num_hiddens, device,\n", " get_params, init_state, forward_fn):\n", " self.vocab_size, self.num_hiddens = vocab_size, num_hiddens\n", " self.params = get_params(vocab_size, num_hiddens, device)\n", " self.init_state, self.forward_fn = init_state, forward_fn\n", " def __call__(self, X, state):\n", " X = F.one_hot(X.T, self.vocab_size).type(torch.float32)\n", " return self.forward_fn(X, state, self.params)\n", " def begin_state(self, batch_size, device):\n", " return self.init_state(batch_size, self.num_hiddens, device)" ], "id": "405f7c6af8bdd939", "outputs": [], "execution_count": 111 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:26.793241936Z", "start_time": "2026-04-22T07:03:26.704490837Z" } }, "cell_type": "code", "source": [ "num_hiddens = 512\n", "net = RNNModelScratch(len(vocab), num_hiddens, d2l.try_gpu(), get_params,\n", "init_rnn_state, rnn)\n", "state = net.begin_state(X.shape[0], d2l.try_gpu())\n", "Y, new_state = net(X.to(d2l.try_gpu()), state)\n", "Y.shape, len(new_state), new_state[0].shape" ], "id": "7bdd0b37b1458ec2", "outputs": [ { "data": { "text/plain": [ "(torch.Size([10, 28]), 1, torch.Size([2, 512]))" ] }, "execution_count": 112, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 112 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:26.859603241Z", "start_time": "2026-04-22T07:03:26.795131832Z" } }, "cell_type": "code", "source": [ "def predict_ch8(prefix, num_preds, net, vocab, device): #@save\n", " \"\"\"在prefix后面生成新字符\"\"\"\n", " state = net.begin_state(batch_size=1, device=device)\n", " outputs = [vocab[prefix[0]]]\n", " get_input = lambda: torch.tensor([outputs[-1]], device=device).reshape((1, 1))\n", " for y in prefix[1:]: # 预热期\n", " _, state = net(get_input(), state)\n", " outputs.append(vocab[y])\n", " for _ in range(num_preds): # 预测num_preds步\n", " y, state = net(get_input(), state)\n", " outputs.append(int(y.argmax(dim=1).reshape(1)))\n", " return ''.join([vocab.idx_to_token[i] for i in outputs])\n", "predict_ch8('time traveller ', 10, net, vocab, d2l.try_gpu())" ], "id": "f38dbe2b11e02bc2", "outputs": [ { "data": { "text/plain": [ "'time traveller vxmpussss'" ] }, "execution_count": 113, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 113 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:26.909151706Z", "start_time": "2026-04-22T07:03:26.861214463Z" } }, "cell_type": "code", "source": [ "def grad_clipping(net, theta): #@save\n", " \"\"\"裁剪梯度\"\"\"\n", " if isinstance(net, nn.Module):\n", " params = [p for p in net.parameters() if p.requires_grad]\n", " else:\n", " params = net.params\n", " norm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))\n", " if norm > theta:\n", " for param in params:\n", " param.grad[:] *= theta / norm" ], "id": "6c19717736ffbc68", "outputs": [], "execution_count": 114 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:26.964267753Z", "start_time": "2026-04-22T07:03:26.913802768Z" } }, "cell_type": "code", "source": [ "import math\n", "def train_epoch_ch8(net, train_iter, loss, updater, device, use_random_iter):\n", " \"\"\"训练网络一个迭代周期(定义见第8章)\"\"\"\n", " state, timer = None, d2l.Timer()\n", " metric = d2l.Accumulator(2) # 训练损失之和,词元数量\n", " for X, Y in train_iter:\n", " if state is None or use_random_iter:\n", " # 在第一次迭代或使用随机抽样时初始化state\n", " state = net.begin_state(batch_size=X.shape[0], device=device)\n", " else:\n", " if isinstance(net, nn.Module) and not isinstance(state, tuple):\n", " # state对于nn.GRU是个张量\n", " state.detach_()\n", " else:\n", " # state对于nn.LSTM或对于我们从零开始实现的模型是个张量\n", " for s in state:\n", " s.detach_()\n", " y = Y.T.reshape(-1)\n", " X, y = X.to(device), y.to(device)\n", " y_hat, state = net(X, state)\n", " l = loss(y_hat, y.long()).mean()\n", " if isinstance(updater, torch.optim.Optimizer):\n", " updater.zero_grad()\n", " l.backward()\n", " grad_clipping(net, 1)\n", " updater.step()\n", " else:\n", " l.backward()\n", " grad_clipping(net, 1)\n", " # 因为已经调用了mean函数\n", " updater(batch_size=1)\n", " metric.add(l * y.numel(), y.numel())\n", " return math.exp(metric[0]/metric[1]),metric[1]/timer.stop()\n" ], "id": "37ef611dedcb714b", "outputs": [], "execution_count": 115 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:27.028276769Z", "start_time": "2026-04-22T07:03:26.966889257Z" } }, "cell_type": "code", "source": [ "def train_ch8(net, train_iter, vocab, lr, num_epochs, device,\n", "use_random_iter=False):\n", " \"\"\"训练模型(定义见第8章)\"\"\"\n", " loss = nn.CrossEntropyLoss()\n", " animator = d2l.Animator(xlabel='epoch', ylabel='perplexity',\n", " legend=['train'], xlim=[10, num_epochs])\n", " # 初始化\n", " if isinstance(net, nn.Module):\n", " updater = torch.optim.SGD(net.parameters(), lr)\n", " else:\n", " updater = lambda batch_size: d2l.sgd(net.params, lr, batch_size)\n", " predict = lambda prefix: predict_ch8(prefix, 50, net, vocab, device)\n", " # 训练和预测\n", " for epoch in range(num_epochs):\n", " ppl, speed = train_epoch_ch8(\n", " net, train_iter, loss, updater, device, use_random_iter)\n", " if (epoch + 1) % 10 == 0:\n", " print(predict('time traveller'))\n", " animator.add(epoch + 1, [ppl])\n", " print(f'困惑度 {ppl:.1f}, {speed:.1f} 词元/秒 {str(device)}')\n", " print(predict('time traveller'))\n", " print(predict('traveller'))" ], "id": "c96b60b55664378a", "outputs": [], "execution_count": 116 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:27.077729796Z", "start_time": "2026-04-22T07:03:27.030591268Z" } }, "cell_type": "code", "source": [ "num_epochs, lr = 500, 1\n", "#train_ch8(net, train_iter, vocab, lr, num_epochs, d2l.try_gpu())" ], "id": "ab4a2fbf4dfd21ef", "outputs": [], "execution_count": 117 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:27.135654266Z", "start_time": "2026-04-22T07:03:27.079991786Z" } }, "cell_type": "code", "source": [ "batch_size, num_steps = 32, 35\n", "train_iter, vocab = load_data_time_machine(batch_size, num_steps)" ], "id": "74d672745751714", "outputs": [], "execution_count": 118 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:27.595345257Z", "start_time": "2026-04-22T07:03:27.137667402Z" } }, "cell_type": "code", "source": [ "num_hiddens = 256\n", "rnn_layer = nn.RNN(len(vocab), num_hiddens)\n", "state = torch.zeros((1,batch_size,num_hiddens))\n", "X = torch.rand(size=(num_steps, batch_size, len(vocab)))\n", "Y, state_new = rnn_layer(X, state)\n", "Y.shape, state_new.shape" ], "id": "9694c029c9e657e8", "outputs": [ { "data": { "text/plain": [ "(torch.Size([35, 32, 256]), torch.Size([1, 32, 256]))" ] }, "execution_count": 119, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 119 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:27.653466458Z", "start_time": "2026-04-22T07:03:27.598468205Z" } }, "cell_type": "code", "source": [ "class RNNModel(nn.Module):\n", " \"\"\"循环神经网络模型\"\"\"\n", " def __init__(self, rnn_layer, vocab_size, **kwargs):\n", " super(RNNModel, self).__init__(**kwargs)\n", " self.rnn = rnn_layer\n", " self.vocab_size = vocab_size\n", " self.num_hiddens = self.rnn.hidden_size\n", " # 如果RNN是双向的(之后将介绍),num_directions应该是2,否则应该是1\n", " if not self.rnn.bidirectional:\n", " self.num_directions = 1\n", " self.linear = nn.Linear(self.num_hiddens, self.vocab_size)\n", " else:\n", " self.num_directions = 2\n", " self.linear = nn.Linear(self.num_hiddens * 2, self.vocab_size)\n", " def forward(self, inputs, state):\n", " X = F.one_hot(inputs.T.long(), self.vocab_size)\n", " X = X.to(torch.float32)\n", " Y, state = self.rnn(X, state)\n", " # 全连接层首先将Y的形状改为(时间步数*批量大小,隐藏单元数)\n", " # 它的输出形状是(时间步数*批量大小,词表大小)。\n", " output = self.linear(Y.reshape((-1, Y.shape[-1])))\n", " return output, state\n", " def begin_state(self, device, batch_size=1):\n", " if not isinstance(self.rnn, nn.LSTM):\n", " # nn.GRU以张量作为隐状态\n", " return torch.zeros((self.num_directions * self.rnn.num_layers,\n", " batch_size, self.num_hiddens),\n", " device=device)\n", " else:\n", " # nn.LSTM以元组作为隐状态\n", " return (torch.zeros((\n", " self.num_directions * self.rnn.num_layers,\n", " batch_size, self.num_hiddens), device=device),\n", " torch.zeros((\n", " self.num_directions * self.rnn.num_layers,\n", " batch_size, self.num_hiddens), device=device))" ], "id": "858873034be01538", "outputs": [], "execution_count": 120 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:27.712481585Z", "start_time": "2026-04-22T07:03:27.663190923Z" } }, "cell_type": "code", "source": [ "device = d2l.try_gpu()\n", "net = RNNModel(rnn_layer, vocab_size=len(vocab))\n", "net = net.to(device)\n", "#predict_ch8('time traveller', 10, net, vocab, device)" ], "id": "d59c1599998c8fd4", "outputs": [], "execution_count": 121 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:27.761031426Z", "start_time": "2026-04-22T07:03:27.714738716Z" } }, "cell_type": "code", "source": [ "vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()\n", "num_inputs = vocab_size\n", "gru_layer = nn.GRU(num_inputs, num_hiddens)\n", "model = RNNModel(gru_layer, len(vocab))\n", "model = model.to(device)\n", "#train_ch8(model, train_iter, vocab, lr, num_epochs, device)" ], "id": "adda23bc3664ec6b", "outputs": [], "execution_count": 122 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:27.809310056Z", "start_time": "2026-04-22T07:03:27.763581864Z" } }, "cell_type": "code", "source": [ "num_inputs = vocab_size\n", "lstm_layer = nn.LSTM(num_inputs, num_hiddens)\n", "model = RNNModel(lstm_layer, len(vocab))\n", "model = model.to(device)\n", "#train_ch8(model, train_iter, vocab, lr, num_epochs, device)" ], "id": "b4e30d643d6f755d", "outputs": [], "execution_count": 123 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:27.856847550Z", "start_time": "2026-04-22T07:03:27.811694587Z" } }, "cell_type": "code", "source": [ "d2l.DATA_HUB['fra-eng'] = (d2l.DATA_URL + 'fra-eng.zip',\n", " '94646ad1522d915e7b0f9296181140edcf86a4f5')" ], "id": "50554e839be36011", "outputs": [], "execution_count": 124 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:27.906829883Z", "start_time": "2026-04-22T07:03:27.859224691Z" } }, "cell_type": "code", "source": [ "import os\n", "def read_data_nmt():\n", " \"\"\"载入“英语-法语”数据集\"\"\"\n", " data_dir = d2l.download_extract('fra-eng')\n", " print(data_dir)\n", " with open(os.path.join(data_dir, 'fra.txt'), 'r',\n", " encoding='utf-8') as f:\n", " return f.read()" ], "id": "9cd4287ed84db220", "outputs": [], "execution_count": 125 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:27.980494182Z", "start_time": "2026-04-22T07:03:27.909008466Z" } }, "cell_type": "code", "source": [ "raw_text = read_data_nmt()\n", "print(raw_text[:75])" ], "id": "7c4452b3b6a32f91", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "../data/fra-eng\n", "Go.\tVa !\n", "Hi.\tSalut !\n", "Run!\tCours !\n", "Run!\tCourez !\n", "Who?\tQui ?\n", "Wow!\tÇa alors !\n", "\n" ] } ], "execution_count": 126 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:29.415816650Z", "start_time": "2026-04-22T07:03:27.981645004Z" } }, "cell_type": "code", "source": [ "def preprocess_nmt(text):\n", " def no_space(char,prev_char):\n", " return char in set(',.!?') and prev_char != ' '\n", " text = text.replace('\\u202f',' ').replace('\\xa0',' ').lower()\n", " out = [' ' + char if i >0 and no_space(char,text[i-1]) else char for i,char in enumerate(text)]\n", " return ''.join(out)\n", "text = preprocess_nmt(raw_text)\n", "print(text[:80])" ], "id": "1c729da265572287", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "go .\tva !\n", "hi .\tsalut !\n", "run !\tcours !\n", "run !\tcourez !\n", "who ?\tqui ?\n", "wow !\tça alors !\n" ] } ], "execution_count": 127 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:30.160936469Z", "start_time": "2026-04-22T07:03:30.102689660Z" } }, "cell_type": "code", "source": [ "def tokenize_nmt(text,num_examples=None):\n", " source,target = [],[]\n", " for i,line in enumerate(text.split('\\n')):\n", " if num_examples and i > num_examples:\n", " break\n", " parts = line.split('\\t')\n", " if len(parts) == 2:\n", " source.append(parts[0].split(' '))\n", " target.append(parts[1].split(' '))\n", " return source,target\n", "\n" ], "id": "ca16ef22cbe2c02a", "outputs": [], "execution_count": 128 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:30.821851145Z", "start_time": "2026-04-22T07:03:30.165447701Z" } }, "cell_type": "code", "source": [ "source, target = tokenize_nmt(text)\n", "source[:5], target[:5]" ], "id": "5ece5cb4b78168d0", "outputs": [ { "data": { "text/plain": [ "([['go', '.'], ['hi', '.'], ['run', '!'], ['run', '!'], ['who', '?']],\n", " [['va', '!'], ['salut', '!'], ['cours', '!'], ['courez', '!'], ['qui', '?']])" ] }, "execution_count": 129, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 129 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:31.001500702Z", "start_time": "2026-04-22T07:03:30.869254698Z" } }, "cell_type": "code", "source": [ "def show_list_len_pair_hist(legend,xlabel,ylabel,xlist,ylist):\n", " d2l.set_figsize()\n", " _,_,patches = d2l.plt.hist([[len(l) for l in xlist],[len(l) for l in ylist]])\n", " d2l.plt.xlabel(xlabel)\n", " d2l.plt.ylabel(ylabel)\n", " for patch in patches[1].patches:\n", " patch.set_hatch('/')\n", " d2l.plt.legend(legend)\n", "\n", "show_list_len_pair_hist(['source','target'],'# tokens per sequence','count',source,target)" ], "id": "518249f852ec54c4", "outputs": [ { "data": { "text/plain": [ "
" ], "image/svg+xml": "\n\n\n \n \n \n \n 2026-04-22T15:03:30.973116\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" }, "metadata": {}, "output_type": "display_data", "jetTransient": { "display_id": null } } ], "execution_count": 130 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:31.126122101Z", "start_time": "2026-04-22T07:03:31.016562308Z" } }, "cell_type": "code", "source": [ "src_vocab=Vocab(source,min_freq=2,reserved_tokens=['','',''])\n", "len(src_vocab)" ], "id": "c2dc82617d5a41a4", "outputs": [ { "data": { "text/plain": [ "10012" ] }, "execution_count": 131, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 131 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:31.181454267Z", "start_time": "2026-04-22T07:03:31.127995134Z" } }, "cell_type": "code", "source": [ "def truncate_pad(line,num_steps,padding_token):\n", " if len(line) > num_steps:\n", " return line[:num_steps]\n", " return line + [padding_token] * (num_steps - len(line))\n", "truncate_pad(src_vocab[source[0]], 10, src_vocab[''])" ], "id": "93ae326a3258ecc", "outputs": [ { "data": { "text/plain": [ "[47, 4, 1, 1, 1, 1, 1, 1, 1, 1]" ] }, "execution_count": 132, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 132 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:31.244013357Z", "start_time": "2026-04-22T07:03:31.194631134Z" } }, "cell_type": "code", "source": [ "def build_array_nmt(lines, vocab, num_steps):\n", " \"\"\"将机器翻译的文本序列转换成小批量\"\"\"\n", " lines = [vocab[l] for l in lines]\n", " lines = [l + [vocab['']] for l in lines]\n", " array = torch.tensor([truncate_pad(\n", " l, num_steps, vocab['']) for l in lines])\n", " valid_len = (array != vocab['']).type(torch.int32).sum(1)\n", " return array, valid_len" ], "id": "acd4344e678cf487", "outputs": [], "execution_count": 133 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:31.295907069Z", "start_time": "2026-04-22T07:03:31.246433021Z" } }, "cell_type": "code", "source": [ "def load_data_nmt(batch_size, num_steps, num_examples=600):\n", " \"\"\"返回翻译数据集的迭代器和词表\"\"\"\n", " text = preprocess_nmt(read_data_nmt())\n", " source, target = tokenize_nmt(text, num_examples)\n", " src_vocab = d2l.Vocab(source, min_freq=2,\n", " reserved_tokens=['', '', ''])\n", " tgt_vocab = d2l.Vocab(target, min_freq=2,\n", " reserved_tokens=['', '', ''])\n", " src_array, src_valid_len = build_array_nmt(source, src_vocab, num_steps)\n", " tgt_array, tgt_valid_len = build_array_nmt(target, tgt_vocab, num_steps)\n", " data_arrays = (src_array, src_valid_len, tgt_array, tgt_valid_len)\n", " data_iter = d2l.load_array(data_arrays, batch_size)\n", " return data_iter, src_vocab, tgt_vocab" ], "id": "62586b0175993a4f", "outputs": [], "execution_count": 134 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:32.568811153Z", "start_time": "2026-04-22T07:03:31.298284240Z" } }, "cell_type": "code", "source": [ "train_iter, src_vocab, tgt_vocab = load_data_nmt(batch_size=2, num_steps=8)\n", "for X, X_valid_len, Y, Y_valid_len in train_iter:\n", " print('X:', X.type(torch.int32))\n", " print('X的有效长度:', X_valid_len)\n", " print('Y:', Y.type(torch.int32))\n", " print('Y的有效长度:', Y_valid_len)\n", " break" ], "id": "87a2f147db41a91d", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "../data/fra-eng\n", "X: tensor([[ 83, 163, 2, 4, 5, 5, 5, 5],\n", " [ 29, 69, 2, 4, 5, 5, 5, 5]], dtype=torch.int32)\n", "X的有效长度: tensor([4, 4])\n", "Y: tensor([[100, 171, 6, 2, 4, 5, 5, 5],\n", " [191, 6, 2, 4, 5, 5, 5, 5]], dtype=torch.int32)\n", "Y的有效长度: tensor([5, 4])\n" ] } ], "execution_count": 135 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:32.679822180Z", "start_time": "2026-04-22T07:03:32.616779320Z" } }, "cell_type": "code", "source": [ "class Encoder(nn.Module):\n", " def __init__(self,**kargs):\n", " super(Encoder,self).__init__(**kargs)\n", " def forward(self,X,*args):\n", " raise NotImplementedError(\"必须实现这个方法\")\n", "class Decoder(nn.Module):\n", " \"\"\"编码器-解码器架构的基本解码器接口\"\"\"\n", " def __init__(self, **kwargs):\n", " super(Decoder, self).__init__(**kwargs)\n", " def init_state(self, enc_outputs, *args):\n", " raise NotImplementedError\n", " def forward(self, X, state):\n", " raise NotImplementedError\n", "class EncoderDecoder(nn.Module):\n", " \"\"\"编码器-解码器架构的基类\"\"\"\n", " def __init__(self, encoder, decoder, **kwargs):\n", " super(EncoderDecoder, self).__init__(**kwargs)\n", " self.encoder = encoder\n", " self.decoder = decoder\n", " def forward(self, enc_X, dec_X, *args):\n", " enc_outputs = self.encoder(enc_X, *args)\n", " dec_state = self.decoder.init_state(enc_outputs, *args)\n", " return self.decoder(dec_X, dec_state)\n", "class Seq2SeqEncoder(Encoder):\n", " def __init__(self,vocab_size,embed_size,num_hiddens,num_layers,dropout=0,**kwargs):\n", " super(Seq2SeqEncoder, self).__init__(**kwargs)\n", " self.embedding = nn.Embedding(vocab_size, embed_size)\n", " self.rnn = nn.GRU(embed_size,num_hiddens,num_layers,dropout=dropout)\n", " def forward(self,X,*args):\n", " X = self.embedding(X)\n", " X = X.permute(1,0,2) #(batch,steps,embed_size) -> (steps,batch,embed_size)\n", " output,state = self.rnn(X)\n", " return output,state\n", " # shape of output (steps,batch,num_hiddens)\n", " # shape of state (num_layers,batch,num_hiddens)\n", "encoder = Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16,\n", " num_layers=2)\n", "X = torch.zeros((4,7),dtype=torch.long) #one-hot is integer\n", "output,state = encoder(X)\n", "output.shape\n" ], "id": "d0d01aef4857ee9c", "outputs": [ { "data": { "text/plain": [ "torch.Size([7, 4, 16])" ] }, "execution_count": 136, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 136 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:32.745403173Z", "start_time": "2026-04-22T07:03:32.681269832Z" } }, "cell_type": "code", "source": "state.shape", "id": "bba15a040c10cb01", "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 4, 16])" ] }, "execution_count": 137, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 137 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:32.800607701Z", "start_time": "2026-04-22T07:03:32.749059569Z" } }, "cell_type": "code", "source": [ "class Seq2SeqDecoder(Decoder):\n", " def __init__(self,vocab_size,embed_size,num_hiddens,num_layers,dropout=0,**kwargs):\n", " super(Seq2SeqDecoder, self).__init__(**kwargs)\n", " self.embedding = nn.Embedding(vocab_size, embed_size)\n", " self.rnn = nn.GRU(embed_size+num_hiddens,num_hiddens,num_layers,dropout=dropout)\n", " self.dense = nn.Linear(num_hiddens,vocab_size)\n", " def init_state(self,enc_outputs,*args):\n", " return enc_outputs[1]\n", " def forward(self,X,state):\n", " X = self.embedding(X).permute(1,0,2)\n", " context = state[-1].repeat(X.shape[0],1,1)\n", " X_and_context = torch.cat((X,context),2)\n", " output,state = self.rnn(X_and_context,state)\n", " output = self.dense(output).permute(1,0,2)\n", " return output,state" ], "id": "b659bfd2fdcabebe", "outputs": [], "execution_count": 138 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:32.863556712Z", "start_time": "2026-04-22T07:03:32.801449588Z" } }, "cell_type": "code", "source": [ "decoder = Seq2SeqDecoder(vocab_size=10, embed_size=8, num_hiddens=16,\n", "num_layers=2)\n", "decoder.eval()\n", "state = decoder.init_state(encoder(X))\n", "output, state = decoder(X, state)\n", "output.shape, state.shape" ], "id": "e9c451e560ce3769", "outputs": [ { "data": { "text/plain": [ "(torch.Size([4, 7, 10]), torch.Size([2, 4, 16]))" ] }, "execution_count": 139, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 139 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:32.926789294Z", "start_time": "2026-04-22T07:03:32.864974421Z" } }, "cell_type": "code", "source": [ "def sequence_mask(X, valid_len, value=0):\n", " \"\"\"在序列中屏蔽不相关的项\"\"\"\n", " maxlen = X.size(1)\n", " mask = torch.arange((maxlen), dtype=torch.float32,\n", " device=X.device)[None, :] < valid_len[:, None]\n", " X[~mask] = value\n", " return X\n", "X = torch.tensor([[1, 2, 3], [4, 5, 6]])\n", "sequence_mask(X, torch.tensor([1, 2]))" ], "id": "9ee2db877c089d48", "outputs": [ { "data": { "text/plain": [ "tensor([[1, 0, 0],\n", " [4, 5, 0]])" ] }, "execution_count": 140, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 140 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:33.055040504Z", "start_time": "2026-04-22T07:03:32.985613772Z" } }, "cell_type": "code", "source": [ "X = torch.ones(2, 3, 4)\n", "sequence_mask(X, torch.tensor([1, 2]), value=-1)" ], "id": "caca450bdfe7f650", "outputs": [ { "data": { "text/plain": [ "tensor([[[ 1., 1., 1., 1.],\n", " [-1., -1., -1., -1.],\n", " [-1., -1., -1., -1.]],\n", "\n", " [[ 1., 1., 1., 1.],\n", " [ 1., 1., 1., 1.],\n", " [-1., -1., -1., -1.]]])" ] }, "execution_count": 141, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 141 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:03:33.183816304Z", "start_time": "2026-04-22T07:03:33.114715644Z" } }, "cell_type": "code", "source": [ "class MaskedSoftmaxCELoss(nn.CrossEntropyLoss):\n", " \"\"\"带遮蔽的softmax交叉熵损失函数\"\"\"\n", " # pred的形状:(batch_size,num_steps,vocab_size)\n", " # label的形状:(batch_size,num_steps)\n", " # valid_len的形状:(batch_size,)\n", " def forward(self, pred, label, valid_len):\n", " weights = torch.ones_like(label)\n", " weights = sequence_mask(weights, valid_len)\n", " self.reduction='none'\n", " unweighted_loss = super(MaskedSoftmaxCELoss, self).forward(\n", " pred.permute(0, 2, 1), label)\n", " weighted_loss = (unweighted_loss * weights).mean(dim=1)\n", " return weighted_loss\n", "loss = MaskedSoftmaxCELoss()\n", "loss(torch.ones(3, 4, 10), torch.ones((3, 4), dtype=torch.long),\n", "torch.tensor([4, 2, 0]))" ], "id": "46fc96f0246f32b7", "outputs": [ { "data": { "text/plain": [ "tensor([2.3026, 1.1513, 0.0000])" ] }, "execution_count": 142, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 142 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:11:21.961865073Z", "start_time": "2026-04-22T07:11:21.895381436Z" } }, "cell_type": "code", "source": [ "def train_seq2seq(net, data_iter, lr, num_epochs, tgt_vocab, device):\n", " \"\"\"训练序列到序列模型\"\"\"\n", " def xavier_init_weights(m):\n", " if type(m) == nn.Linear:\n", " nn.init.xavier_uniform_(m.weight)\n", " if type(m) == nn.GRU:\n", " for param in m._flat_weights_names:\n", " if \"weight\" in param:\n", " nn.init.xavier_uniform_(m._parameters[param])\n", " net.apply(xavier_init_weights)\n", " net.to(device)\n", " optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n", " loss = MaskedSoftmaxCELoss()\n", " net.train()\n", " animator = d2l.Animator(xlabel='epoch', ylabel='loss',\n", " xlim=[10, num_epochs])\n", " for epoch in range(num_epochs):\n", " timer = d2l.Timer()\n", " metric = d2l.Accumulator(2) # 训练损失总和,词元数量\n", " for batch in data_iter:\n", " optimizer.zero_grad()\n", " X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in batch]\n", " bos = torch.tensor([tgt_vocab['']] * Y.shape[0],\n", " device=device).reshape(-1, 1)\n", " dec_input = torch.cat([bos, Y[:, :-1]], 1) # 强制教学\n", " Y_hat, _ = net(X, dec_input, X_valid_len)\n", " l = loss(Y_hat, Y, Y_valid_len)\n", " l.sum().backward() # 损失函数的标量进行“反向传播”\n", " d2l.grad_clipping(net, 1)\n", " num_tokens = Y_valid_len.sum()\n", " optimizer.step()\n", " with torch.no_grad():\n", " metric.add(l.sum(), num_tokens)\n", " if (epoch + 1) % 10 == 0:\n", " animator.add(epoch + 1, (metric[0] / metric[1],))" ], "id": "69c315b5875fc288", "outputs": [], "execution_count": 153 }, { "metadata": { "ExecuteTime": { "end_time": "2026-04-22T07:13:26.866073447Z", "start_time": "2026-04-22T07:11:51.132170814Z" } }, "cell_type": "code", "source": [ "embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1\n", "batch_size, num_steps = 64, 10\n", "lr, num_epochs, device = 0.005, 300, d2l.try_gpu()\n", "train_iter, src_vocab, tgt_vocab = load_data_nmt(batch_size, num_steps)\n", "encoder = Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers,\n", "dropout)\n", "decoder = Seq2SeqDecoder(len(tgt_vocab), embed_size, num_hiddens, num_layers,\n", "dropout)\n", "net = EncoderDecoder(encoder, decoder)\n", "train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)" ], "id": "58e54d7b6b77205d", "outputs": [ { "data": { "text/plain": [ "
" ], "image/svg+xml": "\n\n\n \n \n \n \n 2026-04-22T15:13:26.821924\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" }, "metadata": {}, "output_type": "display_data", "jetTransient": { "display_id": null } } ], "execution_count": 155 }, { "metadata": {}, "cell_type": "code", "outputs": [], "execution_count": null, "source": "", "id": "15f5b277bf8d51ed" } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 5 }