nn/chapter5.ipynb

4667 lines
455 KiB
Text
Raw Normal View History

{
"cells": [
{
"cell_type": "code",
"id": "initial_id",
"metadata": {
"collapsed": true,
"ExecuteTime": {
2026-03-25 15:07:28 +00:00
"end_time": "2026-03-25T12:53:19.596857715Z",
"start_time": "2026-03-25T12:53:16.588300896Z"
}
},
"source": [
2026-03-25 15:07:28 +00:00
"\n",
"import torch\n",
"import d2l\n",
"import numpy\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F"
],
"outputs": [],
"execution_count": 1
},
{
"metadata": {
"ExecuteTime": {
2026-03-25 15:07:28 +00:00
"end_time": "2026-03-25T12:53:19.882978594Z",
"start_time": "2026-03-25T12:53:19.641379604Z"
}
},
"cell_type": "code",
"source": [
"net = nn.Sequential(nn.Linear(20, 256), nn.ReLU(), nn.Linear(256, 10))\n",
"X = torch.rand(2, 20)\n",
"net(X)"
],
"id": "dcd5590e7795eec1",
"outputs": [
{
"data": {
"text/plain": [
2026-03-25 15:07:28 +00:00
"tensor([[-0.0824, 0.0285, 0.1192, 0.0922, 0.0465, 0.2007, -0.0262, 0.1639,\n",
" -0.0899, 0.1057],\n",
" [-0.0524, 0.0180, 0.0952, 0.0921, -0.0702, 0.2043, 0.0393, 0.0629,\n",
" -0.1250, 0.0537]], grad_fn=<AddmmBackward0>)"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 2
},
{
"metadata": {
"ExecuteTime": {
2026-03-25 15:07:28 +00:00
"end_time": "2026-03-25T12:53:22.230024033Z",
"start_time": "2026-03-25T12:53:21.253445153Z"
}
},
"cell_type": "code",
"source": [
"class MLP(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.hidden=nn.Linear(20,256)\n",
" self.out=nn.Linear(256,10)\n",
" def forward(self,X):\n",
" return self.out(F.relu(self.hidden(X)))\n"
],
"id": "4ae330604b643cb4",
"outputs": [],
"execution_count": 3
},
{
"metadata": {
"ExecuteTime": {
2026-03-25 15:07:28 +00:00
"end_time": "2026-03-25T12:53:22.637036131Z",
"start_time": "2026-03-25T12:53:22.314296739Z"
}
},
"cell_type": "code",
"source": [
"net=MLP()\n",
"net(X)"
],
"id": "cca55c6c0c7da12f",
"outputs": [
{
"data": {
"text/plain": [
2026-03-25 15:07:28 +00:00
"tensor([[-0.1096, 0.0395, 0.1076, 0.0112, 0.1523, 0.0678, -0.4146, 0.1690,\n",
" 0.0085, -0.0510],\n",
" [-0.0863, 0.0353, 0.0677, -0.0226, 0.1161, 0.0591, -0.3184, 0.1216,\n",
" -0.0316, -0.1315]], grad_fn=<AddmmBackward0>)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 4
},
{
"metadata": {
"ExecuteTime": {
2026-03-25 15:07:28 +00:00
"end_time": "2026-03-25T12:53:23.093212653Z",
"start_time": "2026-03-25T12:53:22.726300762Z"
}
},
"cell_type": "code",
"source": [
"class FixedHiddenMLP(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" # 不计算梯度的随机权重参数。因此其在训练期间保持不变\n",
" self.rand_weight = torch.rand((20, 20), requires_grad=False)\n",
" self.linear = nn.Linear(20, 20)\n",
" def forward(self, X):\n",
" X = self.linear(X)\n",
" # 使用创建的常量参数以及relu和mm函数\n",
" X = F.relu(torch.mm(X, self.rand_weight) + 1)\n",
" # 复用全连接层。这相当于两个全连接层共享参数\n",
" X = self.linear(X)\n",
" # 控制流\n",
" while X.abs().sum() > 1:\n",
" X /= 2\n",
" return X.sum()"
],
"id": "4518d62611d5e749",
"outputs": [],
"execution_count": 5
},
{
"metadata": {
"ExecuteTime": {
2026-03-25 15:07:28 +00:00
"end_time": "2026-03-25T12:53:23.496786702Z",
"start_time": "2026-03-25T12:53:23.216055780Z"
}
},
"cell_type": "code",
"source": [
"net = FixedHiddenMLP()\n",
"net(X)"
],
"id": "fae0187ece4ed5c6",
"outputs": [
{
"data": {
"text/plain": [
2026-03-25 15:07:28 +00:00
"tensor(0.2039, grad_fn=<SumBackward0>)"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 6
},
{
"metadata": {
"ExecuteTime": {
2026-03-25 15:07:28 +00:00
"end_time": "2026-03-25T12:53:23.718703932Z",
"start_time": "2026-03-25T12:53:23.576457566Z"
}
},
"cell_type": "code",
"source": [
"class NestMLP(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.net = nn.Sequential(nn.Linear(20, 64), nn.ReLU(),\n",
" nn.Linear(64, 32), nn.ReLU())\n",
" self.linear = nn.Linear(32, 16)\n",
" def forward(self, X):\n",
" return self.linear(self.net(X))\n",
" chimera = nn.Sequential(NestMLP(), nn.Linear(16, 20), FixedHiddenMLP())\n",
" chimera(X)"
],
"id": "407ef13a86453aae",
"outputs": [],
"execution_count": 7
},
{
"metadata": {
"ExecuteTime": {
2026-03-25 15:07:28 +00:00
"end_time": "2026-03-25T12:53:24.088583749Z",
"start_time": "2026-03-25T12:53:23.724853929Z"
}
},
"cell_type": "code",
"source": [
"net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 1))\n",
"X = torch.rand(size=(2, 4))\n",
"net(X)"
],
"id": "9f3526f263c7a249",
"outputs": [
{
"data": {
"text/plain": [
2026-03-25 15:07:28 +00:00
"tensor([[0.3055],\n",
" [0.0396]], grad_fn=<AddmmBackward0>)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 8
},
{
"metadata": {
"ExecuteTime": {
2026-03-25 15:07:28 +00:00
"end_time": "2026-03-25T12:53:24.533854996Z",
"start_time": "2026-03-25T12:53:24.227646164Z"
}
},
"cell_type": "code",
"source": "print(net[2].state_dict())",
"id": "8c73f8daa02ba28b",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2026-03-25 15:07:28 +00:00
"OrderedDict([('weight', tensor([[-0.0619, -0.2581, -0.0887, 0.1497, 0.3016, 0.0745, 0.3351, -0.2275]])), ('bias', tensor([0.1878]))])\n"
]
}
],
"execution_count": 9
},
{
"metadata": {
"ExecuteTime": {
2026-03-25 15:07:28 +00:00
"end_time": "2026-03-25T12:53:24.907381106Z",
"start_time": "2026-03-25T12:53:24.595749565Z"
}
},
"cell_type": "code",
"source": "net[2].state_dict()",
"id": "b6fee6b64fb96e3c",
"outputs": [
{
"data": {
"text/plain": [
"OrderedDict([('weight',\n",
2026-03-25 15:07:28 +00:00
" tensor([[-0.0619, -0.2581, -0.0887, 0.1497, 0.3016, 0.0745, 0.3351, -0.2275]])),\n",
" ('bias', tensor([0.1878]))])"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 10
},
{
"metadata": {
"ExecuteTime": {
2026-03-25 15:07:28 +00:00
"end_time": "2026-03-25T12:53:25.145444931Z",
"start_time": "2026-03-25T12:53:24.912612304Z"
}
},
"cell_type": "code",
"source": "print(type(net[2].bias))",
"id": "b38e8dc384e038c5",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'torch.nn.parameter.Parameter'>\n"
]
}
],
"execution_count": 11
},
{
"metadata": {
"ExecuteTime": {
2026-03-25 15:07:28 +00:00
"end_time": "2026-03-25T12:53:25.261894811Z",
"start_time": "2026-03-25T12:53:25.163843129Z"
}
},
"cell_type": "code",
"source": [
"print(net[2].bias)\n",
"print(net[2].bias.data)\n"
],
"id": "73f12ca3669d9ede",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Parameter containing:\n",
2026-03-25 15:07:28 +00:00
"tensor([0.1878], requires_grad=True)\n",
"tensor([0.1878])\n"
]
}
],
"execution_count": 12
},
{
"metadata": {
"ExecuteTime": {
2026-03-25 15:07:28 +00:00
"end_time": "2026-03-25T12:53:25.341935137Z",
"start_time": "2026-03-25T12:53:25.264357977Z"
}
},
"cell_type": "code",
"source": "net[2].weight.grad==None",
"id": "db0fe33018c16fac",
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 13
},
{
"metadata": {
"ExecuteTime": {
2026-03-25 15:07:28 +00:00
"end_time": "2026-03-25T12:53:25.433915822Z",
"start_time": "2026-03-25T12:53:25.357825225Z"
}
},
"cell_type": "code",
"source": [
"print(*[(name, param.shape) for name, param in net[0].named_parameters()])\n",
"print(*[(name, param.shape) for name, param in net.named_parameters()])"
],
"id": "75847a1c608ee5c7",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"('weight', torch.Size([8, 4])) ('bias', torch.Size([8]))\n",
"('0.weight', torch.Size([8, 4])) ('0.bias', torch.Size([8])) ('2.weight', torch.Size([1, 8])) ('2.bias', torch.Size([1]))\n"
]
}
],
"execution_count": 14
},
{
"metadata": {
"ExecuteTime": {
2026-03-25 15:07:28 +00:00
"end_time": "2026-03-25T12:53:25.543917851Z",
"start_time": "2026-03-25T12:53:25.460879914Z"
}
},
"cell_type": "code",
"source": "net.state_dict()['2.bias'].data",
"id": "cc74913e8742da7d",
"outputs": [
{
"data": {
"text/plain": [
2026-03-25 15:07:28 +00:00
"tensor([0.1878])"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 15
},
{
"metadata": {
"ExecuteTime": {
2026-03-25 15:07:28 +00:00
"end_time": "2026-03-25T12:53:25.617010198Z",
"start_time": "2026-03-25T12:53:25.559343957Z"
}
},
"cell_type": "code",
"source": [
"def block1():\n",
" return nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 4),nn.ReLU())\n",
"def block2():\n",
" net = nn.Sequential()\n",
" for i in range(4):\n",
" net.add_module(f'block{i}', block1())\n",
" return net"
],
"id": "53c39c5e61fa7bf5",
"outputs": [],
"execution_count": 16
},
{
"metadata": {
"ExecuteTime": {
2026-03-25 15:07:28 +00:00
"end_time": "2026-03-25T12:53:25.713767703Z",
"start_time": "2026-03-25T12:53:25.621699911Z"
}
},
"cell_type": "code",
"source": [
"rgnet = nn.Sequential(block2(),nn.Linear(4,1))\n",
"rgnet(X)"
],
"id": "d3ac7759b619aca",
"outputs": [
{
"data": {
"text/plain": [
2026-03-25 15:07:28 +00:00
"tensor([[-0.3406],\n",
" [-0.3406]], grad_fn=<AddmmBackward0>)"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 17
},
{
"metadata": {
"ExecuteTime": {
2026-03-25 15:07:28 +00:00
"end_time": "2026-03-25T12:53:26.096212878Z",
"start_time": "2026-03-25T12:53:25.758161035Z"
}
},
"cell_type": "code",
"source": "print(rgnet)",
"id": "8fc60f64b07781e6",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sequential(\n",
" (0): Sequential(\n",
" (block0): Sequential(\n",
" (0): Linear(in_features=4, out_features=8, bias=True)\n",
" (1): ReLU()\n",
" (2): Linear(in_features=8, out_features=4, bias=True)\n",
" (3): ReLU()\n",
" )\n",
" (block1): Sequential(\n",
" (0): Linear(in_features=4, out_features=8, bias=True)\n",
" (1): ReLU()\n",
" (2): Linear(in_features=8, out_features=4, bias=True)\n",
" (3): ReLU()\n",
" )\n",
" (block2): Sequential(\n",
" (0): Linear(in_features=4, out_features=8, bias=True)\n",
" (1): ReLU()\n",
" (2): Linear(in_features=8, out_features=4, bias=True)\n",
" (3): ReLU()\n",
" )\n",
" (block3): Sequential(\n",
" (0): Linear(in_features=4, out_features=8, bias=True)\n",
" (1): ReLU()\n",
" (2): Linear(in_features=8, out_features=4, bias=True)\n",
" (3): ReLU()\n",
" )\n",
" )\n",
" (1): Linear(in_features=4, out_features=1, bias=True)\n",
")\n"
]
}
],
"execution_count": 18
},
{
"metadata": {
"ExecuteTime": {
2026-03-25 15:07:28 +00:00
"end_time": "2026-03-25T12:53:26.459785465Z",
"start_time": "2026-03-25T12:53:26.247775930Z"
}
},
"cell_type": "code",
"source": "rgnet[0][1][0].bias.data",
"id": "e590aaafca787b50",
"outputs": [
{
"data": {
"text/plain": [
2026-03-25 15:07:28 +00:00
"tensor([ 0.3709, -0.2778, -0.1532, -0.4749, 0.4300, -0.0282, -0.0499, 0.3819])"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 19
},
{
"metadata": {
"ExecuteTime": {
2026-03-25 15:07:28 +00:00
"end_time": "2026-03-25T12:53:26.559898609Z",
"start_time": "2026-03-25T12:53:26.465566578Z"
}
},
"cell_type": "code",
"source": [
"def init_normal(m):\n",
" if type(m) == nn.Linear:\n",
" nn.init.normal_(m.weight, mean=0, std=0.01)\n",
" nn.init.zeros_(m.bias)\n",
"net.apply(init_normal)\n",
"net[0].weight.data[0], net[0].bias.data[0]"
],
"id": "925ca33221d0a87e",
"outputs": [
{
"data": {
"text/plain": [
2026-03-25 15:07:28 +00:00
"(tensor([-0.0090, 0.0195, 0.0008, 0.0062]), tensor(0.))"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 20
},
{
"metadata": {
"ExecuteTime": {
2026-03-25 15:07:28 +00:00
"end_time": "2026-03-25T12:53:26.624188061Z",
"start_time": "2026-03-25T12:53:26.561739279Z"
}
},
"cell_type": "code",
"source": [
"def init_xavier(m):\n",
" if type(m) == nn.Linear:\n",
" nn.init.xavier_uniform_(m.weight)\n",
"def init_42(m):\n",
" if type(m) == nn.Linear:\n",
" nn.init.constant_(m.weight, 42)\n",
"\n",
"net[0].apply(init_xavier)\n",
"net[2].apply(init_42)\n",
"print(net[0].weight.data[0])\n",
"print(net[2].weight.data)"
],
"id": "81e2de84a8c4ef32",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2026-03-25 15:07:28 +00:00
"tensor([-0.0184, 0.4366, -0.5272, 0.1226])\n",
"tensor([[42., 42., 42., 42., 42., 42., 42., 42.]])\n"
]
}
],
"execution_count": 21
},
{
"metadata": {
"ExecuteTime": {
2026-03-25 15:07:28 +00:00
"end_time": "2026-03-25T12:53:26.720471236Z",
"start_time": "2026-03-25T12:53:26.641527865Z"
}
},
"cell_type": "code",
"source": [
"x = torch.arange(4)\n",
"torch.save(x, 'x-file')"
],
"id": "f05bb378bb60ab9e",
"outputs": [],
"execution_count": 22
},
{
"metadata": {
"ExecuteTime": {
2026-03-25 15:07:28 +00:00
"end_time": "2026-03-25T12:53:26.986419396Z",
"start_time": "2026-03-25T12:53:26.727284472Z"
}
},
"cell_type": "code",
"source": [
"x2 = torch.load('x-file')\n",
"x2"
],
"id": "a74ecaaac0d826c6",
"outputs": [
{
"data": {
"text/plain": [
"tensor([0, 1, 2, 3])"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 23
},
{
"metadata": {
"ExecuteTime": {
2026-03-25 15:07:28 +00:00
"end_time": "2026-03-25T12:53:27.038073696Z",
"start_time": "2026-03-25T12:53:26.998499395Z"
}
},
"cell_type": "code",
"source": [
"class MLP(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.hidden = nn.Linear(20, 256)\n",
" self.output = nn.Linear(256, 10)\n",
" def forward(self, x):\n",
" return self.output(F.relu(self.hidden(x)))\n",
"\n",
"net = MLP()\n",
"X = torch.randn(size=(2, 20))\n",
"Y = net(X)"
],
"id": "b42598f0c4a8e801",
"outputs": [],
"execution_count": 24
},
{
"metadata": {
"ExecuteTime": {
2026-03-25 15:07:28 +00:00
"end_time": "2026-03-25T12:53:27.097281119Z",
"start_time": "2026-03-25T12:53:27.040823019Z"
}
},
"cell_type": "code",
"source": "torch.save(net.state_dict(), 'mlp.params')",
"id": "aaa22eef549caa6f",
"outputs": [],
"execution_count": 25
},
{
"metadata": {
"ExecuteTime": {
2026-03-25 15:07:28 +00:00
"end_time": "2026-03-25T12:53:27.229604563Z",
"start_time": "2026-03-25T12:53:27.100431141Z"
}
},
"cell_type": "code",
"source": [
"clone = MLP()\n",
"clone.load_state_dict(torch.load('mlp.params'))\n",
"clone.eval()"
],
"id": "b92f920229abeeae",
"outputs": [
{
"data": {
"text/plain": [
"MLP(\n",
" (hidden): Linear(in_features=20, out_features=256, bias=True)\n",
" (output): Linear(in_features=256, out_features=10, bias=True)\n",
")"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 26
},
{
"metadata": {
"ExecuteTime": {
2026-03-25 15:07:28 +00:00
"end_time": "2026-03-25T12:53:27.413849570Z",
"start_time": "2026-03-25T12:53:27.245765495Z"
}
},
"cell_type": "code",
"source": [
"Y_clone = clone(X)\n",
"Y_clone == Y"
],
"id": "646c9eb6d7cc81c2",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[True, True, True, True, True, True, True, True, True, True],\n",
" [True, True, True, True, True, True, True, True, True, True]])"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 27
},
{
"metadata": {
"ExecuteTime": {
2026-03-25 15:07:28 +00:00
"end_time": "2026-03-25T12:53:27.819912529Z",
"start_time": "2026-03-25T12:53:27.491059110Z"
}
},
"cell_type": "code",
"source": [
"def corr2d(X,K):\n",
" h,w=K.shape\n",
" Y=torch.ones((X.shape[0]-h+1,X.shape[1]-w+1))\n",
" for i in range(Y.shape[0]):\n",
" for j in range(Y.shape[1]):\n",
" Y[i,j]=(X[i:i+h,j:j+w]*K).sum()\n",
" return Y\n"
],
"id": "d45f9adfe47fce20",
"outputs": [],
"execution_count": 28
},
{
"metadata": {
"ExecuteTime": {
2026-03-25 15:07:28 +00:00
"end_time": "2026-03-25T12:53:28.107369770Z",
"start_time": "2026-03-25T12:53:27.908345229Z"
}
},
"cell_type": "code",
"source": [
"X = torch.tensor([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]])\n",
"K = torch.tensor([[0.0, 1.0], [2.0, 3.0]])\n",
"corr2d(X,K)"
],
"id": "db7279e13647c315",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[19., 25.],\n",
" [37., 43.]])"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 29
},
{
"metadata": {
"ExecuteTime": {
2026-03-25 15:07:28 +00:00
"end_time": "2026-03-25T12:53:28.483070674Z",
"start_time": "2026-03-25T12:53:28.179399310Z"
}
},
"cell_type": "code",
"source": [
"class Conv2D(nn.Module):\n",
" def __init__(self, kernel_size):\n",
" super().__init__()\n",
" self.weight = nn.Parameter(torch.rand(kernel_size))\n",
" self.bias = nn.Parameter(torch.zeros(1))\n",
" def forward(self, x):\n",
" return corr2d(x, self.weight) + self.bias\n"
],
"id": "d60be1bd12a1f37e",
"outputs": [],
"execution_count": 30
},
{
"metadata": {
"ExecuteTime": {
2026-03-25 15:07:28 +00:00
"end_time": "2026-03-25T12:53:28.940265905Z",
"start_time": "2026-03-25T12:53:28.769072795Z"
}
},
"cell_type": "code",
"source": [
"X = torch.ones((6, 8))\n",
"X[:, 2:6] = 0\n",
"X"
],
"id": "5083789b7a728442",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[1., 1., 0., 0., 0., 0., 1., 1.],\n",
" [1., 1., 0., 0., 0., 0., 1., 1.],\n",
" [1., 1., 0., 0., 0., 0., 1., 1.],\n",
" [1., 1., 0., 0., 0., 0., 1., 1.],\n",
" [1., 1., 0., 0., 0., 0., 1., 1.],\n",
" [1., 1., 0., 0., 0., 0., 1., 1.]])"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 31
},
{
"metadata": {
"ExecuteTime": {
2026-03-25 15:07:28 +00:00
"end_time": "2026-03-25T12:53:29.204080403Z",
"start_time": "2026-03-25T12:53:29.030462580Z"
}
},
"cell_type": "code",
"source": [
"K = torch.tensor([[1.0, -1.0]])\n",
"Y = corr2d(X, K)\n",
"Y"
],
"id": "ee8d6bedbde886ad",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 0., 1., 0., 0., 0., -1., 0.],\n",
" [ 0., 1., 0., 0., 0., -1., 0.],\n",
" [ 0., 1., 0., 0., 0., -1., 0.],\n",
" [ 0., 1., 0., 0., 0., -1., 0.],\n",
" [ 0., 1., 0., 0., 0., -1., 0.],\n",
" [ 0., 1., 0., 0., 0., -1., 0.]])"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 32
},
{
"metadata": {
"ExecuteTime": {
2026-03-25 15:07:28 +00:00
"end_time": "2026-03-25T12:53:29.646525680Z",
"start_time": "2026-03-25T12:53:29.347929714Z"
}
},
"cell_type": "code",
"source": "corr2d(X.t(), K)",
"id": "a8278c3837fa9a1c",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0.]])"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 33
},
{
"metadata": {
"ExecuteTime": {
2026-03-25 15:07:28 +00:00
"end_time": "2026-03-25T12:53:30.137467722Z",
"start_time": "2026-03-25T12:53:29.924865950Z"
}
},
"cell_type": "code",
"source": "conv2d = nn.Conv2d(1,1, kernel_size=(1, 2), bias=False)",
"id": "ec61cdb61a8cabff",
"outputs": [],
"execution_count": 34
},
{
"metadata": {
"ExecuteTime": {
2026-03-25 15:07:28 +00:00
"end_time": "2026-03-25T12:53:30.333573187Z",
"start_time": "2026-03-25T12:53:30.191438502Z"
}
},
"cell_type": "code",
"source": [
"X = X.reshape((1, 1, 6, 8))\n",
"Y = Y.reshape((1, 1, 6, 7))\n",
"lr = 3e-2"
],
"id": "d2fc19d84c79a10",
"outputs": [],
"execution_count": 35
},
{
"metadata": {
"ExecuteTime": {
2026-03-25 15:07:28 +00:00
"end_time": "2026-03-25T12:53:31.737006650Z",
"start_time": "2026-03-25T12:53:30.344801829Z"
}
},
"cell_type": "code",
"source": [
"for i in range(100):\n",
" Y_hat = conv2d(X)\n",
" l = (Y_hat - Y) ** 2\n",
" conv2d.zero_grad()\n",
" l.sum().backward()\n",
" # 迭代卷积核\n",
" conv2d.weight.data[:] -= lr * conv2d.weight.grad\n",
" if (i + 1) % 20 == 0:\n",
" print(f'epoch {i+1}, loss {l.sum():.3f}')"
],
"id": "51fbb2e6398a9bd5",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2026-03-25 15:07:28 +00:00
"epoch 20, loss 0.003\n",
"epoch 40, loss 0.000\n",
"epoch 60, loss 0.000\n",
"epoch 80, loss 0.000\n",
"epoch 100, loss 0.000\n"
]
}
],
"execution_count": 36
},
{
"metadata": {
"ExecuteTime": {
2026-03-25 15:07:28 +00:00
"end_time": "2026-03-25T12:53:31.954198352Z",
"start_time": "2026-03-25T12:53:31.789083268Z"
}
},
"cell_type": "code",
"source": "conv2d.weight.data.reshape((1, 2))\n",
"id": "bf53a423f429dfe4",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 1.0000, -1.0000]])"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 37
},
{
"metadata": {
"ExecuteTime": {
2026-03-25 15:07:28 +00:00
"end_time": "2026-03-25T12:53:32.286815333Z",
"start_time": "2026-03-25T12:53:32.015795016Z"
}
},
"cell_type": "code",
"source": [
"\n",
"# 为了方便起见,我们定义了一个计算卷积层的函数。\n",
"# 此函数初始化卷积层权重,并对输入和输出提高和缩减相应的维数\n",
"def comp_conv2d(conv2d, X):\n",
"# 这里的11表示批量大小和通道数都是1\n",
" X = X.reshape((1, 1) + X.shape)\n",
" Y = conv2d(X)\n",
" # 省略前两个维度:批量大小和通道\n",
" return Y.reshape(Y.shape[2:])\n",
"# 请注意这里每边都填充了1行或1列因此总共添加了2行或2列\n",
"conv2d = nn.Conv2d(1, 1, kernel_size=3, padding=1)"
],
"id": "77b61d8c9a2363cc",
"outputs": [],
"execution_count": 38
},
{
"metadata": {
"ExecuteTime": {
2026-03-25 15:07:28 +00:00
"end_time": "2026-03-25T12:53:32.508857009Z",
"start_time": "2026-03-25T12:53:32.348271053Z"
}
},
"cell_type": "code",
"source": [
"X = torch.rand(size=(8, 8))\n",
"comp_conv2d(conv2d, X).shape"
],
"id": "beda6ffa67ec2677",
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([8, 8])"
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 39
},
{
"metadata": {
"ExecuteTime": {
2026-03-25 15:07:28 +00:00
"end_time": "2026-03-25T12:53:32.590919691Z",
"start_time": "2026-03-25T12:53:32.513906871Z"
}
},
"cell_type": "code",
"source": [
"conv2d = nn.Conv2d(1, 1, kernel_size=(5, 3), padding=(2, 1))\n",
"comp_conv2d(conv2d, X).shape"
],
"id": "8c51095daea1432d",
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([8, 8])"
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 40
},
{
"metadata": {
"ExecuteTime": {
2026-03-25 15:07:28 +00:00
"end_time": "2026-03-25T12:53:32.717183431Z",
"start_time": "2026-03-25T12:53:32.611335875Z"
}
},
"cell_type": "code",
"source": [
"conv2d = nn.Conv2d(1, 1, kernel_size=3, padding=1, stride=2)\n",
"comp_conv2d(conv2d, X).shape"
],
"id": "581bf1b15162cbf6",
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([4, 4])"
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 41
},
{
"metadata": {
"ExecuteTime": {
2026-03-25 15:07:28 +00:00
"end_time": "2026-03-25T12:53:32.787967768Z",
"start_time": "2026-03-25T12:53:32.720890025Z"
}
},
"cell_type": "code",
"source": [
"conv2d = nn.Conv2d(1, 1, kernel_size=(3, 5), padding=(0, 1), stride=(3, 4))\n",
"comp_conv2d(conv2d, X).shape"
],
"id": "6f7a2411247baff0",
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([2, 2])"
]
},
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 42
},
{
"metadata": {
"ExecuteTime": {
2026-03-25 15:07:28 +00:00
"end_time": "2026-03-25T12:53:32.925224132Z",
"start_time": "2026-03-25T12:53:32.820587683Z"
}
},
"cell_type": "code",
"source": [
"def corr2d_multi_in(X,K):\n",
" return sum(corr2d(x,k) for x,k in zip(X,K))\n",
"X = torch.tensor([[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]],\n",
"[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]])\n",
"K = torch.tensor([[[0.0, 1.0], [2.0, 3.0]], [[1.0, 2.0], [3.0, 4.0]]])\n",
"corr2d_multi_in(X, K)"
],
"id": "7ac0f17f97b2daa8",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 56., 72.],\n",
" [104., 120.]])"
]
},
2026-03-25 15:07:28 +00:00
"execution_count": 43,
"metadata": {},
"output_type": "execute_result"
}
],
2026-03-25 15:07:28 +00:00
"execution_count": 43
},
{
"metadata": {
"ExecuteTime": {
2026-03-25 15:07:28 +00:00
"end_time": "2026-03-25T12:53:33.185821813Z",
"start_time": "2026-03-25T12:53:32.981013074Z"
}
},
"cell_type": "code",
"source": [
"def corr2d_multi_in_out(X,K) ->torch.Tensor :\n",
" return torch.stack([corr2d_multi_in(X,k) for k in K],0)\n"
],
"id": "d409110d0d6b4b49",
"outputs": [],
2026-03-25 15:07:28 +00:00
"execution_count": 44
},
{
"metadata": {
"ExecuteTime": {
2026-03-25 15:07:28 +00:00
"end_time": "2026-03-25T12:53:33.377776100Z",
"start_time": "2026-03-25T12:53:33.230963272Z"
}
},
"cell_type": "code",
"source": [
"K = torch.stack((K, K + 1, K + 2), 0)\n",
"K.shape"
],
"id": "4114cd871a627075",
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([3, 2, 2, 2])"
]
},
2026-03-25 15:07:28 +00:00
"execution_count": 45,
"metadata": {},
"output_type": "execute_result"
}
],
2026-03-25 15:07:28 +00:00
"execution_count": 45
},
{
"metadata": {
"ExecuteTime": {
2026-03-25 15:07:28 +00:00
"end_time": "2026-03-25T12:53:33.404945478Z",
"start_time": "2026-03-25T12:53:33.381591903Z"
}
},
"cell_type": "code",
"source": "corr2d_multi_in_out(X, K)",
"id": "ce52f41dc9585f8c",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[ 56., 72.],\n",
" [104., 120.]],\n",
"\n",
" [[ 76., 100.],\n",
" [148., 172.]],\n",
"\n",
" [[ 96., 128.],\n",
" [192., 224.]]])"
]
},
2026-03-25 15:07:28 +00:00
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
],
2026-03-25 15:07:28 +00:00
"execution_count": 46
},
{
"metadata": {
"ExecuteTime": {
2026-03-25 15:07:28 +00:00
"end_time": "2026-03-25T12:53:33.482281628Z",
"start_time": "2026-03-25T12:53:33.427984889Z"
}
},
"cell_type": "code",
"source": [
"def corr2d_multi_in_out_1x1(X, K):\n",
" h_i,h,w=X.shape\n",
" h_o=K.shape[0]\n",
" X=X.reshape((h_i,h*w))\n",
" print(X.shape)\n",
" K=K.reshape((h_o,h_i))\n",
" print(K.shape)\n",
" Y=torch.matmul(K,X)\n",
" return Y.reshape((h_o,h,w))"
],
"id": "362d8c692b3c1d75",
"outputs": [],
2026-03-25 15:07:28 +00:00
"execution_count": 47
},
{
"metadata": {
"ExecuteTime": {
2026-03-25 15:07:28 +00:00
"end_time": "2026-03-25T12:53:33.539310405Z",
"start_time": "2026-03-25T12:53:33.485273899Z"
}
},
"cell_type": "code",
"source": [
"X = torch.normal(0, 1, (3, 3, 3))\n",
"K = torch.normal(0, 1, (2, 3, 1, 1))"
],
"id": "28e761f677df8b16",
"outputs": [],
2026-03-25 15:07:28 +00:00
"execution_count": 48
},
{
"metadata": {
"ExecuteTime": {
2026-03-25 15:07:28 +00:00
"end_time": "2026-03-25T12:53:33.644382591Z",
"start_time": "2026-03-25T12:53:33.542526101Z"
}
},
"cell_type": "code",
"source": "Y1 = corr2d_multi_in_out_1x1(X, K)",
"id": "8eb276fed751a6b9",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([3, 9])\n",
"torch.Size([2, 3])\n"
]
}
],
2026-03-25 15:07:28 +00:00
"execution_count": 49
},
{
"metadata": {
"ExecuteTime": {
2026-03-25 15:07:28 +00:00
"end_time": "2026-03-25T12:53:33.715840072Z",
"start_time": "2026-03-25T12:53:33.685724533Z"
}
},
"cell_type": "code",
"source": [
"Y2 = corr2d_multi_in_out(X, K)\n",
"assert float(torch.abs(Y1 - Y2).sum()) < 1e-6"
],
"id": "be28e27d30f36e2c",
"outputs": [],
2026-03-25 15:07:28 +00:00
"execution_count": 50
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:53:33.833762326Z",
"start_time": "2026-03-25T12:53:33.752003546Z"
}
},
"cell_type": "code",
"source": [
"def pool2d(X,pool_size,mode='max'):\n",
" p_h,p_w =pool_size\n",
" Y = torch.zeros((X.shape[0]-p_h+1,X.shape[1]-p_w+1))\n",
" for i in range(Y.shape[0]):\n",
" for j in range(Y.shape[1]):\n",
" match mode:\n",
" case 'max':\n",
" Y[i,j]=X[i:i+p_h,j:j+p_w].max()\n",
" case 'avg':\n",
" Y[i,j]=X[i:i+p_h,j:j+p_w].mean()\n",
"\n",
" return Y"
],
"id": "3c3f71349a2e54c0",
"outputs": [],
"execution_count": 51
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:53:33.941263303Z",
"start_time": "2026-03-25T12:53:33.837650145Z"
}
},
"cell_type": "code",
"source": [
"X = torch.tensor([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]])\n",
"pool2d(X, (2, 2))"
],
"id": "a67207c861cf0cfd",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[4., 5.],\n",
" [7., 8.]])"
]
},
"execution_count": 52,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 52
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:53:34.291316821Z",
"start_time": "2026-03-25T12:53:34.048006688Z"
}
},
"cell_type": "code",
"source": "pool2d(X, (2, 2), 'avg')",
"id": "e387b48df3831b85",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[2., 3.],\n",
" [5., 6.]])"
]
},
"execution_count": 53,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 53
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:53:34.937892470Z",
"start_time": "2026-03-25T12:53:34.524932637Z"
}
},
"cell_type": "code",
"source": [
"X = torch.arange(16, dtype=torch.float32).reshape((1, 1, 4, 4))\n",
"X"
],
"id": "41b618b3a48522b4",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[[ 0., 1., 2., 3.],\n",
" [ 4., 5., 6., 7.],\n",
" [ 8., 9., 10., 11.],\n",
" [12., 13., 14., 15.]]]])"
]
},
"execution_count": 54,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 54
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:53:35.364929618Z",
"start_time": "2026-03-25T12:53:35.250760743Z"
}
},
"cell_type": "code",
"source": [
"pool2d=nn.MaxPool2d(3)\n",
"pool2d(X)"
],
"id": "c77484a8d1267259",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[[10.]]]])"
]
},
"execution_count": 55,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 55
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:53:35.405212051Z",
"start_time": "2026-03-25T12:53:35.381419347Z"
}
},
"cell_type": "code",
"source": [
"pool2d = nn.MaxPool2d(3, padding=1, stride=2)\n",
"pool2d(X)"
],
"id": "847a2bacfb6f2bd7",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[[ 5., 7.],\n",
" [13., 15.]]]])"
]
},
"execution_count": 56,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 56
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:53:35.511354814Z",
"start_time": "2026-03-25T12:53:35.428268478Z"
}
},
"cell_type": "code",
"source": [
"pool2d = nn.MaxPool2d((2, 3), stride=(2, 3), padding=(0, 1))\n",
"pool2d(X)"
],
"id": "5efad1e0b616fff7",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[[ 5., 7.],\n",
" [13., 15.]]]])"
]
},
"execution_count": 57,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 57
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:53:35.571965177Z",
"start_time": "2026-03-25T12:53:35.523510246Z"
}
},
"cell_type": "code",
"source": [
"X = torch.cat((X, X + 1), 1)\n",
"X"
],
"id": "386d4b3eb8069328",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[[ 0., 1., 2., 3.],\n",
" [ 4., 5., 6., 7.],\n",
" [ 8., 9., 10., 11.],\n",
" [12., 13., 14., 15.]],\n",
"\n",
" [[ 1., 2., 3., 4.],\n",
" [ 5., 6., 7., 8.],\n",
" [ 9., 10., 11., 12.],\n",
" [13., 14., 15., 16.]]]])"
]
},
"execution_count": 58,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 58
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:53:35.636903197Z",
"start_time": "2026-03-25T12:53:35.575376191Z"
}
},
"cell_type": "code",
"source": [
"pool2d = nn.MaxPool2d(3, padding=1, stride=2)\n",
"pool2d(X)"
],
"id": "ba5f57a8ca2a3b06",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[[ 5., 7.],\n",
" [13., 15.]],\n",
"\n",
" [[ 6., 8.],\n",
" [14., 16.]]]])"
]
},
"execution_count": 59,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 59
},
2026-03-25 15:07:28 +00:00
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:53:35.808804883Z",
"start_time": "2026-03-25T12:53:35.667844728Z"
}
},
"cell_type": "code",
"source": [
"net = nn.Sequential(\n",
" nn.Conv2d(1,6,kernel_size=5,padding=2), #1*1*28*28 -> 1*6*28*28\n",
" nn.Sigmoid(),\n",
" nn.AvgPool2d(kernel_size=2, stride=2), #1*6*28*28 -> 1*6*14*14\n",
" nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(), #1*6*14*14 -> 1*16*10*10\n",
" nn.AvgPool2d(kernel_size=2, stride=2), #1*16*10*10 -> 1*16*5*5\n",
" nn.Flatten(),\n",
" nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),\n",
" nn.Linear(120, 84), nn.Sigmoid(),\n",
" nn.Linear(84, 10)\n",
")\n",
"X = torch.rand(size=(1,1,28,28),dtype=torch.float32)\n",
"for layer in net:\n",
" X=layer(X)\n",
" print(layer.__class__.__name__,'output shape: \\t',X.shape)"
],
"id": "1eabc29f9c838842",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Conv2d output shape: \t torch.Size([1, 6, 28, 28])\n",
"Sigmoid output shape: \t torch.Size([1, 6, 28, 28])\n",
"AvgPool2d output shape: \t torch.Size([1, 6, 14, 14])\n",
"Conv2d output shape: \t torch.Size([1, 16, 10, 10])\n",
"Sigmoid output shape: \t torch.Size([1, 16, 10, 10])\n",
"AvgPool2d output shape: \t torch.Size([1, 16, 5, 5])\n",
"Flatten output shape: \t torch.Size([1, 400])\n",
"Linear output shape: \t torch.Size([1, 120])\n",
"Sigmoid output shape: \t torch.Size([1, 120])\n",
"Linear output shape: \t torch.Size([1, 84])\n",
"Sigmoid output shape: \t torch.Size([1, 84])\n",
"Linear output shape: \t torch.Size([1, 10])\n"
]
}
],
"execution_count": 60
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:54:36.172500455Z",
"start_time": "2026-03-25T12:54:33.846313406Z"
}
},
"cell_type": "code",
"source": [
"import d2l.torch as d2l\n",
"batch_size = 256\n",
"train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)"
],
"id": "e372f75817ad4a0f",
"outputs": [],
"execution_count": 62
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:54:36.218171897Z",
"start_time": "2026-03-25T12:54:36.201338102Z"
}
},
"cell_type": "code",
"source": [
"lr, num_epochs = 0.9, 10\n",
"#d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())"
],
"id": "9aaeb948f3353955",
"outputs": [],
"execution_count": 63
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:54:36.375790030Z",
"start_time": "2026-03-25T12:54:36.223406778Z"
}
},
"cell_type": "code",
"source": [
"class Inception(nn.Module):\n",
" def __init__(self,in_channels,c1,c2,c3,c4,**kwargs):\n",
" super(Inception,self).__init__(**kwargs)\n",
" self.p1_1 = nn.Conv2d(in_channels,c1,kernel_size=1)\n",
" self.p2_1 = nn.Conv2d(in_channels,c2[0],kernel_size=1)\n",
" self.p2_2 = nn.Conv2d(c2[0],c2[1],kernel_size=3,padding=1)\n",
" self.p3_1 = nn.Conv2d(in_channels,c3[0],kernel_size=1)\n",
" self.p3_2 = nn.Conv2d(c3[0],c3[1],kernel_size=5,padding=2)\n",
" self.p4_1 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)\n",
" self.p4_2 = nn.Conv2d(in_channels, c4, kernel_size=1)\n",
" def forward(self,x):\n",
" p1 = F.relu(self.p1_1(x))\n",
" p2 = F.relu(self.p2_2(F.relu(self.p2_1(x))))\n",
" p3 = F.relu(self.p3_2(F.relu(self.p3_1(x))))\n",
" p4 = F.relu(self.p4_2(self.p4_1(x)))\n",
" return torch.cat((p1,p2,p3,p4),dim=1)"
],
"id": "6d3bb3f70f297dba",
"outputs": [],
"execution_count": 64
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:54:37.582867237Z",
"start_time": "2026-03-25T12:54:36.386034018Z"
}
},
"cell_type": "code",
"source": [
"b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),\n",
" nn.ReLU(),\n",
" nn.MaxPool2d(kernel_size=3, stride=2, padding=1))\n",
"b2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1),\n",
" nn.ReLU(),\n",
" nn.Conv2d(64, 192, kernel_size=3, padding=1),\n",
" nn.ReLU(),\n",
" nn.MaxPool2d(kernel_size=3, stride=2, padding=1))\n",
"b3 = nn.Sequential(Inception(192, 64, (96, 128), (16, 32), 32),\n",
" Inception(256, 128, (128, 192), (32, 96), 64),\n",
" nn.MaxPool2d(kernel_size=3, stride=2, padding=1))\n",
"b4 = nn.Sequential(Inception(480, 192, (96, 208), (16, 48), 64),\n",
" Inception(512, 160, (112, 224), (24, 64), 64),\n",
" Inception(512, 128, (128, 256), (24, 64), 64),\n",
" Inception(512, 112, (144, 288), (32, 64), 64),\n",
" Inception(528, 256, (160, 320), (32, 128), 128),\n",
" nn.MaxPool2d(kernel_size=3, stride=2, padding=1))\n",
"b5 = nn.Sequential(Inception(832, 256, (160, 320), (32, 128), 128),\n",
" Inception(832, 384, (192, 384), (48, 128), 128),\n",
" nn.AdaptiveAvgPool2d((1,1)),\n",
" nn.Flatten())\n",
"net = nn.Sequential(b1, b2, b3, b4, b5, nn.Linear(1024, 10))\n",
"X = torch.rand(size=(1, 1, 96, 96))\n",
"for layer in net:\n",
" X = layer(X)\n",
" print(layer.__class__.__name__,'output shape:\\t', X.shape)"
],
"id": "6ef7022bcb288d65",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sequential output shape:\t torch.Size([1, 64, 24, 24])\n",
"Sequential output shape:\t torch.Size([1, 192, 12, 12])\n",
"Sequential output shape:\t torch.Size([1, 480, 6, 6])\n",
"Sequential output shape:\t torch.Size([1, 832, 3, 3])\n",
"Sequential output shape:\t torch.Size([1, 1024])\n",
"Linear output shape:\t torch.Size([1, 10])\n"
]
}
],
"execution_count": 65
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:54:39.007400793Z",
"start_time": "2026-03-25T12:54:37.697851296Z"
}
},
"cell_type": "code",
"source": [
"import torchinfo\n",
"torchinfo.summary(net,(1,1,96,96))"
],
"id": "acc019ce7afa4470",
"outputs": [
{
"data": {
"text/plain": [
"==========================================================================================\n",
"Layer (type:depth-idx) Output Shape Param #\n",
"==========================================================================================\n",
"Sequential [1, 10] --\n",
"├─Sequential: 1-1 [1, 64, 24, 24] --\n",
"│ └─Conv2d: 2-1 [1, 64, 48, 48] 3,200\n",
"│ └─ReLU: 2-2 [1, 64, 48, 48] --\n",
"│ └─MaxPool2d: 2-3 [1, 64, 24, 24] --\n",
"├─Sequential: 1-2 [1, 192, 12, 12] --\n",
"│ └─Conv2d: 2-4 [1, 64, 24, 24] 4,160\n",
"│ └─ReLU: 2-5 [1, 64, 24, 24] --\n",
"│ └─Conv2d: 2-6 [1, 192, 24, 24] 110,784\n",
"│ └─ReLU: 2-7 [1, 192, 24, 24] --\n",
"│ └─MaxPool2d: 2-8 [1, 192, 12, 12] --\n",
"├─Sequential: 1-3 [1, 480, 6, 6] --\n",
"│ └─Inception: 2-9 [1, 256, 12, 12] --\n",
"│ │ └─Conv2d: 3-1 [1, 64, 12, 12] 12,352\n",
"│ │ └─Conv2d: 3-2 [1, 96, 12, 12] 18,528\n",
"│ │ └─Conv2d: 3-3 [1, 128, 12, 12] 110,720\n",
"│ │ └─Conv2d: 3-4 [1, 16, 12, 12] 3,088\n",
"│ │ └─Conv2d: 3-5 [1, 32, 12, 12] 12,832\n",
"│ │ └─MaxPool2d: 3-6 [1, 192, 12, 12] --\n",
"│ │ └─Conv2d: 3-7 [1, 32, 12, 12] 6,176\n",
"│ └─Inception: 2-10 [1, 480, 12, 12] --\n",
"│ │ └─Conv2d: 3-8 [1, 128, 12, 12] 32,896\n",
"│ │ └─Conv2d: 3-9 [1, 128, 12, 12] 32,896\n",
"│ │ └─Conv2d: 3-10 [1, 192, 12, 12] 221,376\n",
"│ │ └─Conv2d: 3-11 [1, 32, 12, 12] 8,224\n",
"│ │ └─Conv2d: 3-12 [1, 96, 12, 12] 76,896\n",
"│ │ └─MaxPool2d: 3-13 [1, 256, 12, 12] --\n",
"│ │ └─Conv2d: 3-14 [1, 64, 12, 12] 16,448\n",
"│ └─MaxPool2d: 2-11 [1, 480, 6, 6] --\n",
"├─Sequential: 1-4 [1, 832, 3, 3] --\n",
"│ └─Inception: 2-12 [1, 512, 6, 6] --\n",
"│ │ └─Conv2d: 3-15 [1, 192, 6, 6] 92,352\n",
"│ │ └─Conv2d: 3-16 [1, 96, 6, 6] 46,176\n",
"│ │ └─Conv2d: 3-17 [1, 208, 6, 6] 179,920\n",
"│ │ └─Conv2d: 3-18 [1, 16, 6, 6] 7,696\n",
"│ │ └─Conv2d: 3-19 [1, 48, 6, 6] 19,248\n",
"│ │ └─MaxPool2d: 3-20 [1, 480, 6, 6] --\n",
"│ │ └─Conv2d: 3-21 [1, 64, 6, 6] 30,784\n",
"│ └─Inception: 2-13 [1, 512, 6, 6] --\n",
"│ │ └─Conv2d: 3-22 [1, 160, 6, 6] 82,080\n",
"│ │ └─Conv2d: 3-23 [1, 112, 6, 6] 57,456\n",
"│ │ └─Conv2d: 3-24 [1, 224, 6, 6] 226,016\n",
"│ │ └─Conv2d: 3-25 [1, 24, 6, 6] 12,312\n",
"│ │ └─Conv2d: 3-26 [1, 64, 6, 6] 38,464\n",
"│ │ └─MaxPool2d: 3-27 [1, 512, 6, 6] --\n",
"│ │ └─Conv2d: 3-28 [1, 64, 6, 6] 32,832\n",
"│ └─Inception: 2-14 [1, 512, 6, 6] --\n",
"│ │ └─Conv2d: 3-29 [1, 128, 6, 6] 65,664\n",
"│ │ └─Conv2d: 3-30 [1, 128, 6, 6] 65,664\n",
"│ │ └─Conv2d: 3-31 [1, 256, 6, 6] 295,168\n",
"│ │ └─Conv2d: 3-32 [1, 24, 6, 6] 12,312\n",
"│ │ └─Conv2d: 3-33 [1, 64, 6, 6] 38,464\n",
"│ │ └─MaxPool2d: 3-34 [1, 512, 6, 6] --\n",
"│ │ └─Conv2d: 3-35 [1, 64, 6, 6] 32,832\n",
"│ └─Inception: 2-15 [1, 528, 6, 6] --\n",
"│ │ └─Conv2d: 3-36 [1, 112, 6, 6] 57,456\n",
"│ │ └─Conv2d: 3-37 [1, 144, 6, 6] 73,872\n",
"│ │ └─Conv2d: 3-38 [1, 288, 6, 6] 373,536\n",
"│ │ └─Conv2d: 3-39 [1, 32, 6, 6] 16,416\n",
"│ │ └─Conv2d: 3-40 [1, 64, 6, 6] 51,264\n",
"│ │ └─MaxPool2d: 3-41 [1, 512, 6, 6] --\n",
"│ │ └─Conv2d: 3-42 [1, 64, 6, 6] 32,832\n",
"│ └─Inception: 2-16 [1, 832, 6, 6] --\n",
"│ │ └─Conv2d: 3-43 [1, 256, 6, 6] 135,424\n",
"│ │ └─Conv2d: 3-44 [1, 160, 6, 6] 84,640\n",
"│ │ └─Conv2d: 3-45 [1, 320, 6, 6] 461,120\n",
"│ │ └─Conv2d: 3-46 [1, 32, 6, 6] 16,928\n",
"│ │ └─Conv2d: 3-47 [1, 128, 6, 6] 102,528\n",
"│ │ └─MaxPool2d: 3-48 [1, 528, 6, 6] --\n",
"│ │ └─Conv2d: 3-49 [1, 128, 6, 6] 67,712\n",
"│ └─MaxPool2d: 2-17 [1, 832, 3, 3] --\n",
"├─Sequential: 1-5 [1, 1024] --\n",
"│ └─Inception: 2-18 [1, 832, 3, 3] --\n",
"│ │ └─Conv2d: 3-50 [1, 256, 3, 3] 213,248\n",
"│ │ └─Conv2d: 3-51 [1, 160, 3, 3] 133,280\n",
"│ │ └─Conv2d: 3-52 [1, 320, 3, 3] 461,120\n",
"│ │ └─Conv2d: 3-53 [1, 32, 3, 3] 26,656\n",
"│ │ └─Conv2d: 3-54 [1, 128, 3, 3] 102,528\n",
"│ │ └─MaxPool2d: 3-55 [1, 832, 3, 3] --\n",
"│ │ └─Conv2d: 3-56 [1, 128, 3, 3] 106,624\n",
"│ └─Inception: 2-19 [1, 1024, 3, 3] --\n",
"│ │ └─Conv2d: 3-57 [1, 384, 3, 3] 319,872\n",
"│ │ └─Conv2d: 3-58 [1, 192, 3, 3] 159,936\n",
"│ │ └─Conv2d: 3-59 [1, 384, 3, 3] 663,936\n",
"│ │ └─Conv2d: 3-60 [1, 48, 3, 3] 39,984\n",
"│ │ └─Conv2d: 3-61 [1, 128, 3, 3] 153,728\n",
"│ │ └─MaxPool2d: 3-62 [1, 832, 3, 3] --\n",
"│ │ └─Conv2d: 3-63 [1, 128, 3, 3] 106,624\n",
"│ └─AdaptiveAvgPool2d: 2-20 [1, 1024, 1, 1] --\n",
"│ └─Flatten: 2-21 [1, 1024] --\n",
"├─Linear: 1-6 [1, 10] 10,250\n",
"==========================================================================================\n",
"Total params: 5,977,530\n",
"Trainable params: 5,977,530\n",
"Non-trainable params: 0\n",
"Total mult-adds (Units.MEGABYTES): 276.66\n",
"==========================================================================================\n",
"Input size (MB): 0.04\n",
"Forward/backward pass size (MB): 4.74\n",
"Params size (MB): 23.91\n",
"Estimated Total Size (MB): 28.69\n",
"=========================================================================================="
]
},
"execution_count": 66,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 66
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:54:40.559617829Z",
"start_time": "2026-03-25T12:54:39.231004686Z"
}
},
"cell_type": "code",
"source": [
"lr, num_epochs, batch_size = 0.1, 10, 128\n",
"train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)\n",
"#d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())"
],
"id": "3760a5e5813405f7",
"outputs": [],
"execution_count": 67
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:54:41.392761415Z",
"start_time": "2026-03-25T12:54:40.984224963Z"
}
},
"cell_type": "code",
"source": [
"class Residual(nn.Module):\n",
" def __init__(self,input_channels,num_channels,use_1x1conv=False,strides=1):\n",
" super().__init__()\n",
" self.conv1 = nn.Conv2d(input_channels,num_channels,kernel_size=3,padding=1,stride=strides)\n",
" self.conv2 = nn.Conv2d(num_channels,num_channels,kernel_size=3,padding=1)\n",
" if use_1x1conv:\n",
" self.conv3 = nn.Conv2d(input_channels,num_channels,kernel_size=1,stride=strides)\n",
" else:\n",
" self.conv3= None\n",
" self.bn1=nn.BatchNorm2d(num_channels)\n",
" self.bn2=nn.BatchNorm2d(num_channels)\n",
" def forward(self,X):\n",
" Y=F.relu(self.bn1(self.conv1(X)))\n",
" Y=self.bn2(self.conv2(Y))\n",
" if self.conv3:\n",
" X = self.conv3(X)\n",
" Y+=X\n",
" return F.relu(Y)\n"
],
"id": "9300979845ba6916",
"outputs": [],
"execution_count": 68
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:54:41.501972047Z",
"start_time": "2026-03-25T12:54:41.455568607Z"
}
},
"cell_type": "code",
"source": [
"blk = Residual(3,3)\n",
"X = torch.rand(4, 3, 6, 6)"
],
"id": "1248323517ff3228",
"outputs": [],
"execution_count": 69
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:54:41.680548549Z",
"start_time": "2026-03-25T12:54:41.504484009Z"
}
},
"cell_type": "code",
"source": [
"blk = Residual(3,6, use_1x1conv=True, strides=2)\n",
"blk(X).shape"
],
"id": "82cdbd71a157b51c",
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([4, 6, 3, 3])"
]
},
"execution_count": 70,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 70
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:54:41.713889764Z",
"start_time": "2026-03-25T12:54:41.697804378Z"
}
},
"cell_type": "code",
"source": [
"b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),\n",
"nn.BatchNorm2d(64), nn.ReLU(),\n",
"nn.MaxPool2d(kernel_size=3, stride=2, padding=1))"
],
"id": "727da1d2d363ac62",
"outputs": [],
"execution_count": 71
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:54:41.769559409Z",
"start_time": "2026-03-25T12:54:41.715465374Z"
}
},
"cell_type": "code",
"source": [
"def resnet_block(input_channels, num_channels, num_residuals,\n",
" first_block=False):\n",
" blk = []\n",
" for i in range(num_residuals):\n",
" if i == 0 and not first_block:\n",
" blk.append(Residual(input_channels, num_channels,\n",
" use_1x1conv=True, strides=2))\n",
" else:\n",
" blk.append(Residual(num_channels, num_channels))\n",
" return blk"
],
"id": "124134971f8441c0",
"outputs": [],
"execution_count": 72
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:54:41.849526581Z",
"start_time": "2026-03-25T12:54:41.773164092Z"
}
},
"cell_type": "code",
"source": [
"b2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))\n",
"b3 = nn.Sequential(*resnet_block(64, 128, 2))\n",
"b4 = nn.Sequential(*resnet_block(128, 256, 2))\n",
"b5 = nn.Sequential(*resnet_block(256, 512, 2))"
],
"id": "ca1f1c69fba3e913",
"outputs": [],
"execution_count": 73
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:54:41.945086842Z",
"start_time": "2026-03-25T12:54:41.850992491Z"
}
},
"cell_type": "code",
"source": [
"net = nn.Sequential(b1, b2, b3, b4, b5,\n",
"nn.AdaptiveAvgPool2d((1,1)),\n",
"nn.Flatten(), nn.Linear(512, 10))"
],
"id": "f21db27de5dbdec1",
"outputs": [],
"execution_count": 74
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:54:43.021161834Z",
"start_time": "2026-03-25T12:54:41.947258256Z"
}
},
"cell_type": "code",
"source": [
"X = torch.rand(size=(1, 1, 224, 224))\n",
"for layer in net:\n",
" X = layer(X)\n",
" print(layer.__class__.__name__,'output shape:\\t', X.shape)"
],
"id": "6f8851a2bfd18c4e",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sequential output shape:\t torch.Size([1, 64, 56, 56])\n",
"Sequential output shape:\t torch.Size([1, 64, 56, 56])\n",
"Sequential output shape:\t torch.Size([1, 128, 28, 28])\n",
"Sequential output shape:\t torch.Size([1, 256, 14, 14])\n",
"Sequential output shape:\t torch.Size([1, 512, 7, 7])\n",
"AdaptiveAvgPool2d output shape:\t torch.Size([1, 512, 1, 1])\n",
"Flatten output shape:\t torch.Size([1, 512])\n",
"Linear output shape:\t torch.Size([1, 10])\n"
]
}
],
"execution_count": 75
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:54:43.115617383Z",
"start_time": "2026-03-25T12:54:43.046016884Z"
}
},
"cell_type": "code",
"source": [
"lr, num_epochs, batch_size = 0.05, 10, 256\n",
"train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)\n",
"#d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())"
],
"id": "e095d74b29dffef6",
"outputs": [],
"execution_count": 76
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:54:43.222301703Z",
"start_time": "2026-03-25T12:54:43.119038826Z"
}
},
"cell_type": "code",
"source": [
"import torch\n",
"import d2l.torch as d2l\n",
"import numpy\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"print(torch.version.__version__)"
],
"id": "3fd6d22221f87bea",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2.10.0+cu128\n"
]
}
],
"execution_count": 77
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:54:43.326030676Z",
"start_time": "2026-03-25T12:54:43.296999050Z"
}
},
"cell_type": "code",
"source": [
"A=torch.Tensor([[1,2,0,0],[0,2,0,0],[0,0,2,1],[0,0,0,3]])\n",
"C=torch.Tensor([[1,0,0,0],[0,1,0,0],[0,0,-2,3],[0,0,0,-3]])"
],
"id": "254f5d3d659dbe0f",
"outputs": [],
"execution_count": 78
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:54:43.395079338Z",
"start_time": "2026-03-25T12:54:43.332112143Z"
}
},
"cell_type": "code",
"source": "B=torch.Tensor([[2,0,0,0],[-2,1,0,0],[0,0,-3,0],[0,0,0,-3]])",
"id": "a13d9c27c2fdbfad",
"outputs": [],
"execution_count": 79
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:54:43.526691390Z",
"start_time": "2026-03-25T12:54:43.405681993Z"
}
},
"cell_type": "code",
"source": "torch.mm(A,C)",
"id": "e513a37beaa85f8f",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 1., 2., 0., 0.],\n",
" [ 0., 2., 0., 0.],\n",
" [ 0., 0., -4., 3.],\n",
" [ 0., 0., 0., -9.]])"
]
},
"execution_count": 80,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 80
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:54:44.183397512Z",
"start_time": "2026-03-25T12:54:43.839736167Z"
}
},
"cell_type": "code",
"source": "torch.det(torch.mm(torch.mm(A,C),B))",
"id": "9a85eceac652875f",
"outputs": [
{
"data": {
"text/plain": [
"tensor(1296.)"
]
},
"execution_count": 81,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 81
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:54:44.470157150Z",
"start_time": "2026-03-25T12:54:44.297718949Z"
}
},
"cell_type": "code",
"source": "1296**5\n",
"id": "6dc27d79722da58f",
"outputs": [
{
"data": {
"text/plain": [
"3656158440062976"
]
},
"execution_count": 82,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 82
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:54:44.641680344Z",
"start_time": "2026-03-25T12:54:44.488342487Z"
}
},
"cell_type": "code",
"source": "torch.mm(C,B)",
"id": "ec5a170d775f4705",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 2., 0., 0., 0.],\n",
" [-2., 1., 0., 0.],\n",
" [ 0., 0., 6., -9.],\n",
" [ 0., 0., 0., 9.]])"
]
},
"execution_count": 83,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 83
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:54:45.007541962Z",
"start_time": "2026-03-25T12:54:44.731090643Z"
}
},
"cell_type": "code",
"source": [
"T = 1000 # 总共产生1000个点\n",
"time = torch.arange(1, T + 1, dtype=torch.float32)\n",
"x = torch.sin(0.01 * time) + torch.normal(0, 0.2, (T,))\n",
"d2l.plot(time, [x], 'time', 'x', xlim=[1, 1000], figsize=(6, 3))"
],
"id": "c3884b10464c6baa",
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 600x300 with 1 Axes>"
],
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"406.885938pt\" height=\"211.07625pt\" viewBox=\"0 0 406.885938 211.07625\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2026-03-25T20:54:44.863294</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.7.2, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 211.07625 \nL 406.885938 211.07625 \nL 406.885938 0 \nL 0 0 \nz\n\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 52.160938 173.52 \nL 386.960938 173.52 \nL 386.960938 7.2 \nL 52.160938 7.2 \nz\n\"/>\n </g>\n <g id=\"matplotlib.axis_1\">\n <g id=\"xtick_1\">\n <g id=\"line2d_1\">\n <path d=\"M 118.852829 173.52 \nL 118.852829 7.2 \n\" clip-path=\"url(#p5cf86932f7)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_2\">\n <defs>\n <path id=\"m49149f9012\" d=\"M 0 0 \nL 0 3.5 \n\" style=\"stroke: #ffffff; stroke-width: 0.8\"/>\n </defs>\n <g>\n <use xlink:href=\"#m49149f9012\" x=\"118.852829\" y=\"173.52\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_1\">\n <!-- 200 -->\n <g style=\"fill: #ffffff\" transform=\"translate(109.309079 188.118438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-32\" d=\"M 1228 531 \nL 3431 531 \nL 3431 0 \nL 469 0 \nL 469 531 \nQ 828 903 1448 1529 \nQ 2069 2156 2228 2338 \nQ 2531 2678 2651 2914 \nQ 2772 3150 2772 3378 \nQ 2772 3750 2511 3984 \nQ 2250 4219 1831 4219 \nQ 1534 4219 1204 4116 \nQ 875 4013 500 3803 \nL 500 4441 \nQ 881 4594 1212 4672 \nQ 1544 4750 1819 4750 \nQ 2544 4750 2975 4387 \nQ 3406 4025 3406 3419 \nQ 3406 3131 3298 2873 \nQ 3191 2616 2906 2266 \nQ 2828 2175 2409 1742 \nQ 1991 1309 1228 531 \nz\n\" transform=\"scale(0.015625)\"/>\n <path id=\"DejaVuSans-30\" d=\"M 2034 4250 \nQ 1547 4250 1301 3770 \nQ 1056 3291 1056 2328 \nQ 1056 1369 1301 889 \nQ 1547 409 2034 409 \nQ 2525 409 2770 889 \nQ 3016 1369 3016 2328 \nQ 3016 3291 2770 3770 \nQ 2525 4250 2034 4250 \nz\nM 2034 4750 \nQ 2819 4750 3233 4129 \nQ 3647 3509 3647 2328 \nQ 3647 1150 3233 529 \nQ 2819 -91 2034 -91 \nQ 1250 -91 836 529 \nQ 422 1150 422 2328 \nQ 422 3509 836 4129 \nQ 1250 4750 2034 4750 \nz\n\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-32\"/>\n <use xlink:href=\"#DejaVuSans-30\" x=\"63.623047\"/>\n <use xlink:href=\"#DejaVuSans-30\" x=\"127.246094\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_2\">\n <g id=\"line2d_3\">\n <path d=\"M 185.879856 173.52 \nL 185.879856 7.2 \n\" clip-path=\"url(#p5cf86932f7)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_4\">\n <g>\n <use xlink:href=\"#m49149f9012\" x=\"185.879856\" y=\"173.52\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_2\">\n <!-- 400 -->\n <g style=\"fill: #ffffff\" transform=\"translate(176.336106 188.118438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-34\" d=\"M 2419 4116 \nL 825 1625 \nL 2419 1625 \nL 2419 4116 \nz\nM 2253 4666 \nL 3047 4666 \nL 3047 1625 \nL 3713 1625 \nL 3713 1100 \nL 3047 1100 \nL 3047 0 \nL 2419 0 \nL 2419 11
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
}
],
"execution_count": 84
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:54:45.287386804Z",
"start_time": "2026-03-25T12:54:45.068746693Z"
}
},
"cell_type": "code",
"source": [
"tau = 4\n",
"features = torch.zeros((T - tau, tau))\n",
"for i in range(tau):\n",
" features[:, i] = x[i: T - tau + i]\n",
"labels = x[tau:].reshape((-1, 1))\n",
"x,features,labels\n"
],
"id": "5d450c6f6b724a14",
"outputs": [
{
"data": {
"text/plain": [
"(tensor([-0.1257, 0.4977, 0.1275, 0.0113, 0.1759, 0.1263, 0.0984, 0.0670,\n",
" 0.3374, -0.3129, 0.3756, 0.0234, -0.0841, 0.4951, 0.3441, -0.0585,\n",
" -0.2159, 0.0357, 0.0667, -0.0126, 0.6966, -0.0548, 0.0864, 0.5669,\n",
" 0.2040, 0.2158, 0.1378, 0.2790, 0.4541, 0.3656, 0.3050, 0.3321,\n",
" 0.3818, 0.3404, 0.3803, 0.3527, 0.5237, 0.7250, 0.3400, 0.3136,\n",
" 0.6944, 0.3985, 0.9682, 0.5841, 0.5376, 0.2229, 0.6266, 0.1417,\n",
" 0.2132, 0.6786, 0.3201, 0.5340, 0.7747, 0.7968, 0.7266, 0.7018,\n",
" 0.8106, 0.6221, 0.2093, 0.3683, 0.5998, 0.5546, 0.6686, 0.4981,\n",
" 0.6079, 0.3726, 0.9469, 0.6261, 0.4213, 0.5943, 1.2487, 0.5027,\n",
" 0.6524, 0.6218, 0.4721, 0.7688, 0.8629, 0.5897, 0.3414, 1.0822,\n",
" 0.9223, 0.8020, 0.6607, 0.4673, 0.7155, 0.6349, 0.4676, 0.9303,\n",
" 0.6977, 0.7986, 0.5661, 0.9401, 0.8111, 1.0929, 0.5887, 0.8674,\n",
" 0.8081, 0.8682, 0.7049, 1.0303, 0.5297, 0.8990, 0.6131, 1.1693,\n",
" 1.0146, 1.1179, 0.8550, 0.6801, 0.9054, 0.9622, 0.8227, 0.6969,\n",
" 0.8629, 0.9992, 0.9735, 0.9114, 0.5090, 0.9698, 1.1530, 1.2176,\n",
" 1.1019, 1.0681, 0.6768, 1.0307, 0.9873, 1.1988, 1.1947, 0.8704,\n",
" 0.8378, 0.7581, 1.2643, 1.2095, 0.7556, 1.0024, 0.8649, 1.1953,\n",
" 0.8106, 1.2512, 1.1907, 0.8453, 1.0807, 0.7710, 0.9226, 0.8100,\n",
" 1.0641, 0.9683, 0.7675, 1.2630, 0.9153, 1.0170, 1.3423, 0.8989,\n",
" 1.2243, 1.3355, 0.9849, 0.6055, 0.4062, 0.8255, 1.1904, 0.7565,\n",
" 1.0362, 0.8106, 0.8765, 1.1825, 1.0300, 1.1883, 0.7432, 0.7962,\n",
" 0.7900, 0.9459, 1.0081, 1.1498, 1.0555, 1.4386, 0.9888, 0.7890,\n",
" 0.9454, 0.9568, 0.9832, 0.7835, 0.8084, 0.7282, 1.1450, 1.2708,\n",
" 1.1315, 0.6742, 0.6001, 0.6483, 0.8992, 1.0016, 1.0392, 0.5630,\n",
" 1.3330, 0.9323, 0.6719, 0.9954, 1.0855, 1.0105, 0.6578, 1.0974,\n",
" 0.9163, 1.0161, 1.0866, 0.8661, 0.5516, 1.0398, 1.0476, 0.8525,\n",
" 0.8723, 1.0883, 0.5629, 0.3963, 0.7161, 1.2104, 1.0025, 1.0816,\n",
" 0.7881, 0.7980, 0.6719, 0.5641, 0.7839, 0.7183, 0.6777, 1.1626,\n",
" 0.6991, 0.7296, 0.9149, 0.4818, 0.3593, 0.8057, 0.9782, 0.6981,\n",
" 0.8359, 0.5616, 0.8751, 0.4524, 0.9480, 0.4057, 0.6413, 0.6728,\n",
" 0.8040, 1.1152, 0.6752, 0.7030, 0.5862, 0.7373, 0.6680, 0.6739,\n",
" 0.7372, 1.0807, 0.8491, 0.4628, 0.5695, 0.4675, 0.8295, 0.7881,\n",
" 0.6622, 0.3701, 0.3987, 0.6082, 0.4924, 0.6136, 0.4755, 0.7166,\n",
" 0.4721, 0.2420, 0.2503, 0.5961, 0.5344, 0.6053, 0.5369, 0.2291,\n",
" 0.3503, 0.2833, 0.1630, 0.0821, 0.1769, 0.5129, 0.2650, 0.1519,\n",
" 0.2660, 0.1505, 0.2407, 0.1766, 0.2215, 0.3759, 0.0643, 0.2909,\n",
" 0.0220, 0.5878, 0.1559, 0.2339, 0.3533, -0.1447, 0.5657, 0.0656,\n",
" -0.1913, 0.1975, -0.0296, 0.3531, 0.0032, 0.1607, 0.2249, 0.0783,\n",
" 0.1663, -0.0781, -0.0607, 0.3047, 0.2461, -0.0380, 0.0481, -0.0040,\n",
" 0.0110, -0.0221, 0.1001, 0.0754, 0.2153, -0.1584, 0.0033, -0.2072,\n",
" 0.1622, -0.1114, -0.0954, -0.2582, -0.0575, -0.0883, 0.3422, -0.1808,\n",
" -0.2768, -0.1964, 0.1526, -0.1362, 0.0674, -0.5093, -0.0344, -0.3681,\n",
" -0.2217, -0.1733, -0.0589, -0.1194, -0.0979, -0.2122, -0.5427, -0.5028,\n",
" 0.0059, -0.2044, -0.2778, -0.3447, -0.0537, -0.4030, -0.7130, -0.5167,\n",
" -0.4477, -0.4382, 0.0076, -0.1804, -0.1491, 0.1210, -0.4279, -0.6204,\n",
" -0.7309, -0.1835, -0.9354, -0.6655, -0.7265, -0.5585, -0.8215, -0.3998,\n",
" -0.6667, -0.4026, -0.3606, -0.2286, -0.5571, -0.8246, -0.2567, -0.8022,\n",
" -0.3873, -0.6781, -0.8021, -0.7463, -0.6887, -0.5723, -0.6661, -0.4324,\n",
" -0.6482, -0.5130, -0.6848, -0.5460, -0.8493, -0.1809, -0.5165, -0.4671,\n",
" -0.8529, -0.9896, -0.8904, -0.4498, -1.0809, -0.9123, -0.7125, -0.4627,\n",
" -0.5643, -0.7416, -0.8990, -0.8161, -0.5500, -0.9439, -0.8327, -0.7132,\n",
" -0.8250, -0.9772, -0.8947, -0.4970, -0.4945, -0.4604, -0.7029, -0.7518,\n",
" -0.7635, -0.8060, -0.8300, -1.1194, -1.2429, -0.7834, -0.3628, -1.1099,\n",
" -0.8337, -1.0767, -0.7193, -0.6253, -0.9703, -0.5913, -1.0695, -0.9610,\n",
" -0.7796, -0.8729, -1.1516, -0.8974, -1.1277, -0.8297, -0.6336, -1.5144,\n",
" -1.0980, -1.0812, -0.5136, -0.6882, -0.9138, -0.9021, -1.0671, -1.1456,\n",
" -0.9467, -0.6042, -0.8922, -0.9499, -0.6512, -1.0729, -1.1589, -1.1675,\n",
" -0.9637, -0.7511, -0.8479, -0.8410, -1.1934, -0.8869, -0.9340, -1.0252,\n",
" -0.8195, -1.3040, -0.6508, -1.0083, -1.1282, -0.9536, -1.0764, -1.2750,\n",
" -1.0073, -1.0259, -0.8144, -1.2082, -0.9558, -0.9895, -1.0417, -1.0077,\n",
" -0.7460, -0.7199, -1.1118, -0.7411, -1.2156, -0.8967, -0.8194, -1.1041,\n",
" -0.9286, -0.9155, -0.7483, -0.9874, -1.0476, -0.9132, -0.7950, -0.8823,\n",
" -0.8565, -1.0017, -0.9736, -0.8743, -0.9509, -1.3399, -0.8861, -1.0557,\n",
" -0.8494, -0.6369, -1.0813, -0.7510, -0.8624, -1.1163, -0.9114, -0.7323,\n",
" -0.9083, -0.8352, -0.6851, -0.9174, -0.9412, -1.3040, -0.6257, -0.7814,\n",
" -0.7670, -1.0620, -0.9168, -1.0231, -0.5532, -0.7955, -0.9293, -0.7984,\n",
" -0.9475, -0.8074, -1.0046, -0.7866, -0.8110, -0.8169, -0.7929, -0.9577,\n",
" -0.7490, -0.6953, -0.7600, -0.6348, -0.5752, -0.6600, -1.1377, -1.0344,\n",
" -0.6518, -0.7506, -0.9227, -0.7814, -0.9301, -0.4463, -0.8153, -0.7221,\n",
" -0.6543, -1.0062, -0.4462, -0.5389, -0.3644, -0.3854, -0.5175, -0.3598,\n",
" -0.7745, -0.8278, -0.6843, -0.5519, -0.6849, -0.6662, -0.8282, -0.5927,\n",
" -0.8346, -0.5149, -0.0033, -0.7285, -0.8659, -0.4320, -0.5433, -0.5551,\n",
" -0.4936, -0.3990, -0.2697, -0.5388, -0.5527, -0.5663, -0.4017, -0.2667,\n",
" -0.3446, -0.3117, -0.3110, -0.8562, -0.2726, -0.5014, -0.4719, -0.5338,\n",
" -0.7666, -0.1854, -0.5822, -0.4734, -0.2585, -0.2755, -0.4047, -0.0902,\n",
" -0.0984, -0.3434, -0.0755, -0.5209, -0.2434, -0.3536, -0.0617, 0.1276,\n",
" -0.0150, -0.5196, -0.2691, -0.8314, 0.1469, -0.0438, -0.4816, 0.1779,\n",
" -0.1709, -0.2126, -0.2875, -0.4329, -0.0967, -0.5540, -0.2296, -0.0021,\n",
" -0.1871, 0.0261, -0.0573, 0.3196, 0.1587, 0.1620, -0.3062, 0.1800,\n",
" -0.0216, -0.0861, 0.3876, 0.2574, 0.2573, 0.3694, 0.1312, 0.6010,\n",
" 0.0274, 0.0227, -0.1395, 0.0214, 0.3586, 0.0331, 0.2754, 0.4699,\n",
" 0.3533, -0.0946, 0.1566, 0.2768, 0.6166, 0.3522, 0.2357, 0.2673,\n",
" 0.2506, 0.4461, 0.6163, 0.1398, 0.3288, 0.4211, 0.3313, 0.1029,\n",
" 0.4284, 0.1385, 0.1132, 0.0989, 0.3567, 0.2329, 0.4514, 0.7074,\n",
" 0.3183, 0.2934, 0.4533, 0.2790, 0.4807, 0.8162, 0.6992, 0.1948,\n",
" 0.5107, 0.8306, 0.2990, 0.2718, 0.7156, 0.8072, 0.6706, 0.5840,\n",
" 0.8009, 0.5367, 0.8542, 0.4551, 0.6621, 0.6004, 0.6589, 0.4726,\n",
" 0.5991, 0.8084, 0.5788, 0.7125, 0.6552, 0.9191, 0.3361, 0.8335,\n",
" 0.2599, 0.6830, 0.6857, 0.4505, 0.7303, 0.5562, 0.3135, 0.7432,\n",
" 0.8188, 0.7189, 0.6228, 0.8273, 0.6486, 0.9803, 0.6484, 0.7697,\n",
" 1.1531, 0.9866, 1.3931, 0.9747, 1.2460, 1.0597, 0.7014, 0.9013,\n",
" 0.9571, 0.7041, 1.0944, 1.1762, 1.1356, 1.0760, 1.0171, 0.8546,\n",
" 0.9204, 0.9524, 1.3716, 0.7630, 0.9069, 1.0180, 1.0366, 1.0358,\n",
" 0.8609, 0.8634, 0.8047, 0.7477, 0.9808, 1.0275, 1.2071, 0.5799,\n",
" 0.8834, 0.8784, 1.1447, 1.0891, 0.5811, 0.9703, 1.2833, 0.9937,\n",
" 1.1356, 0.8306, 0.9129, 1.0194, 1.4320, 1.2589, 0.9175, 0.8849,\n",
" 1.1727, 0.9605, 0.7599, 0.8099, 1.0688, 0.7013, 1.0260, 0.7066,\n",
" 0.8967, 1.0578, 0.8639, 1.0968, 0.9553, 1.0410, 0.7809, 0.8928,\n",
" 0.9644, 0.8980, 0.9744, 0.6657, 1.0549, 0.9716, 1.0272, 0.9510,\n",
" 1.0992, 0.8345, 1.0305, 1.0269, 0.9503, 1.0622, 0.9953, 1.3019,\n",
" 1.0447, 0.9759, 0.9953, 1.0697, 0.9619, 1.0681, 1.0844, 0.6814,\n",
" 0.7774, 1.1827, 1.1599, 0.7436, 0.8570, 0.7392, 1.2210, 0.8350,\n",
" 0.7613, 0.7885, 1.0991, 0.6867, 0.5461, 1.1209, 1.1265, 0.9876,\n",
" 0.8403, 0.9892, 0.7838, 0.5770, 0.7996, 1.1023, 1.1888, 0.8290,\n",
" 0.9919, 0.7272, 0.6149, 0.8744, 0.7331, 0.9389, 0.8888, 0.4813,\n",
" 1.1600, 0.6871, 0.7780, 0.9699, 0.3082, 0.8391, 0.5978, 0.5697,\n",
" 0.9227, 0.4502, 0.5293, 0.7309, 0.7579, 0.5995, 0.5698, 0.5490,\n",
" 0.7483, 0.9721, 0.9419, 0.5393, 0.9869, 0.9892, 0.5714, 0.7620,\n",
" 0.6800, 0.8412, 0.6070, 0.1774, 0.6198, 0.7153, 0.7985, 0.5209,\n",
" 1.1309, 0.6716, 0.7221, 0.5309, 0.6143, 0.9212, 0.6585, 0.5518,\n",
" 0.7676, 0.7002, 0.5711, 0.5491, 0.7280, 1.2188, 0.3206, 0.5493,\n",
" 0.7454, 0.5868, 0.6143, 0.8513, 0.1876, 0.5672, 0.4292, 0.5437,\n",
" 0.4909, 0.7139, 0.5861, 0.3725, 0.5194, 0.4843, 0.0279, 0.3152,\n",
" 0.4333, 0.5915, 0.2709, 0.4861, 0.1708, -0.0844, 0.1523, -0.2092,\n",
" 0.2965, -0.1280, 0.4479, 0.4392, 0.1969, 0.1989, -0.0969, 0.2829,\n",
" 0.1741, -0.1890, -0.0512, 0.4777, 0.0458, 0.0724, 0.1996, 0.2772,\n",
" -0.0650, 0.4351, 0.2693, -0.0298, -0.1171, 0.3714, 0.0992, 0.0090,\n",
" 0.0618, 0.1225, 0.1389, 0.1166, 0.0821, 0.0435, -0.1259, -0.1045,\n",
" 0.1779, -0.2051, -0.2457, -0.1619, -0.0991, 0.1651, 0.1712, -0.1440,\n",
" -0.0499, -0.0943, 0.1058, -0.3224, -0.2115, -0.1307, -0.2432, -0.1935,\n",
" -0.1462, -0.3798, -0.3857, -0.3871, 0.1132, -0.5729, 0.1458, -0.5250,\n",
" -0.1113, -0.1085, -0.3974, -0.2798, -0.2995, -0.0517, -0.1601, -0.5213,\n",
" -0.3897, -0.5143, -0.4268, -0.4268, -0.1593, -0.3720, -0.2030, -0.5328,\n",
" -0.8009, -0.5220, -0.5291, -0.3730, -0.4571, -0.3859, -0.3053, -0.3744,\n",
" -0.7439, -0.7338, -0.2856, -0.3440, -0.6041, -0.7940, -0.6112, -0.1943]),\n",
" tensor([[-0.1257, 0.4977, 0.1275, 0.0113],\n",
" [ 0.4977, 0.1275, 0.0113, 0.1759],\n",
" [ 0.1275, 0.0113, 0.1759, 0.1263],\n",
" ...,\n",
" [-0.7338, -0.2856, -0.3440, -0.6041],\n",
" [-0.2856, -0.3440, -0.6041, -0.7940],\n",
" [-0.3440, -0.6041, -0.7940, -0.6112]]),\n",
" tensor([[ 0.1759],\n",
" [ 0.1263],\n",
" [ 0.0984],\n",
" [ 0.0670],\n",
" [ 0.3374],\n",
" [-0.3129],\n",
" [ 0.3756],\n",
" [ 0.0234],\n",
" [-0.0841],\n",
" [ 0.4951],\n",
" [ 0.3441],\n",
" [-0.0585],\n",
" [-0.2159],\n",
" [ 0.0357],\n",
" [ 0.0667],\n",
" [-0.0126],\n",
" [ 0.6966],\n",
" [-0.0548],\n",
" [ 0.0864],\n",
" [ 0.5669],\n",
" [ 0.2040],\n",
" [ 0.2158],\n",
" [ 0.1378],\n",
" [ 0.2790],\n",
" [ 0.4541],\n",
" [ 0.3656],\n",
" [ 0.3050],\n",
" [ 0.3321],\n",
" [ 0.3818],\n",
" [ 0.3404],\n",
" [ 0.3803],\n",
" [ 0.3527],\n",
" [ 0.5237],\n",
" [ 0.7250],\n",
" [ 0.3400],\n",
" [ 0.3136],\n",
" [ 0.6944],\n",
" [ 0.3985],\n",
" [ 0.9682],\n",
" [ 0.5841],\n",
" [ 0.5376],\n",
" [ 0.2229],\n",
" [ 0.6266],\n",
" [ 0.1417],\n",
" [ 0.2132],\n",
" [ 0.6786],\n",
" [ 0.3201],\n",
" [ 0.5340],\n",
" [ 0.7747],\n",
" [ 0.7968],\n",
" [ 0.7266],\n",
" [ 0.7018],\n",
" [ 0.8106],\n",
" [ 0.6221],\n",
" [ 0.2093],\n",
" [ 0.3683],\n",
" [ 0.5998],\n",
" [ 0.5546],\n",
" [ 0.6686],\n",
" [ 0.4981],\n",
" [ 0.6079],\n",
" [ 0.3726],\n",
" [ 0.9469],\n",
" [ 0.6261],\n",
" [ 0.4213],\n",
" [ 0.5943],\n",
" [ 1.2487],\n",
" [ 0.5027],\n",
" [ 0.6524],\n",
" [ 0.6218],\n",
" [ 0.4721],\n",
" [ 0.7688],\n",
" [ 0.8629],\n",
" [ 0.5897],\n",
" [ 0.3414],\n",
" [ 1.0822],\n",
" [ 0.9223],\n",
" [ 0.8020],\n",
" [ 0.6607],\n",
" [ 0.4673],\n",
" [ 0.7155],\n",
" [ 0.6349],\n",
" [ 0.4676],\n",
" [ 0.9303],\n",
" [ 0.6977],\n",
" [ 0.7986],\n",
" [ 0.5661],\n",
" [ 0.9401],\n",
" [ 0.8111],\n",
" [ 1.0929],\n",
" [ 0.5887],\n",
" [ 0.8674],\n",
" [ 0.8081],\n",
" [ 0.8682],\n",
" [ 0.7049],\n",
" [ 1.0303],\n",
" [ 0.5297],\n",
" [ 0.8990],\n",
" [ 0.6131],\n",
" [ 1.1693],\n",
" [ 1.0146],\n",
" [ 1.1179],\n",
" [ 0.8550],\n",
" [ 0.6801],\n",
" [ 0.9054],\n",
" [ 0.9622],\n",
" [ 0.8227],\n",
" [ 0.6969],\n",
" [ 0.8629],\n",
" [ 0.9992],\n",
" [ 0.9735],\n",
" [ 0.9114],\n",
" [ 0.5090],\n",
" [ 0.9698],\n",
" [ 1.1530],\n",
" [ 1.2176],\n",
" [ 1.1019],\n",
" [ 1.0681],\n",
" [ 0.6768],\n",
" [ 1.0307],\n",
" [ 0.9873],\n",
" [ 1.1988],\n",
" [ 1.1947],\n",
" [ 0.8704],\n",
" [ 0.8378],\n",
" [ 0.7581],\n",
" [ 1.2643],\n",
" [ 1.2095],\n",
" [ 0.7556],\n",
" [ 1.0024],\n",
" [ 0.8649],\n",
" [ 1.1953],\n",
" [ 0.8106],\n",
" [ 1.2512],\n",
" [ 1.1907],\n",
" [ 0.8453],\n",
" [ 1.0807],\n",
" [ 0.7710],\n",
" [ 0.9226],\n",
" [ 0.8100],\n",
" [ 1.0641],\n",
" [ 0.9683],\n",
" [ 0.7675],\n",
" [ 1.2630],\n",
" [ 0.9153],\n",
" [ 1.0170],\n",
" [ 1.3423],\n",
" [ 0.8989],\n",
" [ 1.2243],\n",
" [ 1.3355],\n",
" [ 0.9849],\n",
" [ 0.6055],\n",
" [ 0.4062],\n",
" [ 0.8255],\n",
" [ 1.1904],\n",
" [ 0.7565],\n",
" [ 1.0362],\n",
" [ 0.8106],\n",
" [ 0.8765],\n",
" [ 1.1825],\n",
" [ 1.0300],\n",
" [ 1.1883],\n",
" [ 0.7432],\n",
" [ 0.7962],\n",
" [ 0.7900],\n",
" [ 0.9459],\n",
" [ 1.0081],\n",
" [ 1.1498],\n",
" [ 1.0555],\n",
" [ 1.4386],\n",
" [ 0.9888],\n",
" [ 0.7890],\n",
" [ 0.9454],\n",
" [ 0.9568],\n",
" [ 0.9832],\n",
" [ 0.7835],\n",
" [ 0.8084],\n",
" [ 0.7282],\n",
" [ 1.1450],\n",
" [ 1.2708],\n",
" [ 1.1315],\n",
" [ 0.6742],\n",
" [ 0.6001],\n",
" [ 0.6483],\n",
" [ 0.8992],\n",
" [ 1.0016],\n",
" [ 1.0392],\n",
" [ 0.5630],\n",
" [ 1.3330],\n",
" [ 0.9323],\n",
" [ 0.6719],\n",
" [ 0.9954],\n",
" [ 1.0855],\n",
" [ 1.0105],\n",
" [ 0.6578],\n",
" [ 1.0974],\n",
" [ 0.9163],\n",
" [ 1.0161],\n",
" [ 1.0866],\n",
" [ 0.8661],\n",
" [ 0.5516],\n",
" [ 1.0398],\n",
" [ 1.0476],\n",
" [ 0.8525],\n",
" [ 0.8723],\n",
" [ 1.0883],\n",
" [ 0.5629],\n",
" [ 0.3963],\n",
" [ 0.7161],\n",
" [ 1.2104],\n",
" [ 1.0025],\n",
" [ 1.0816],\n",
" [ 0.7881],\n",
" [ 0.7980],\n",
" [ 0.6719],\n",
" [ 0.5641],\n",
" [ 0.7839],\n",
" [ 0.7183],\n",
" [ 0.6777],\n",
" [ 1.1626],\n",
" [ 0.6991],\n",
" [ 0.7296],\n",
" [ 0.9149],\n",
" [ 0.4818],\n",
" [ 0.3593],\n",
" [ 0.8057],\n",
" [ 0.9782],\n",
" [ 0.6981],\n",
" [ 0.8359],\n",
" [ 0.5616],\n",
" [ 0.8751],\n",
" [ 0.4524],\n",
" [ 0.9480],\n",
" [ 0.4057],\n",
" [ 0.6413],\n",
" [ 0.6728],\n",
" [ 0.8040],\n",
" [ 1.1152],\n",
" [ 0.6752],\n",
" [ 0.7030],\n",
" [ 0.5862],\n",
" [ 0.7373],\n",
" [ 0.6680],\n",
" [ 0.6739],\n",
" [ 0.7372],\n",
" [ 1.0807],\n",
" [ 0.8491],\n",
" [ 0.4628],\n",
" [ 0.5695],\n",
" [ 0.4675],\n",
" [ 0.8295],\n",
" [ 0.7881],\n",
" [ 0.6622],\n",
" [ 0.3701],\n",
" [ 0.3987],\n",
" [ 0.6082],\n",
" [ 0.4924],\n",
" [ 0.6136],\n",
" [ 0.4755],\n",
" [ 0.7166],\n",
" [ 0.4721],\n",
" [ 0.2420],\n",
" [ 0.2503],\n",
" [ 0.5961],\n",
" [ 0.5344],\n",
" [ 0.6053],\n",
" [ 0.5369],\n",
" [ 0.2291],\n",
" [ 0.3503],\n",
" [ 0.2833],\n",
" [ 0.1630],\n",
" [ 0.0821],\n",
" [ 0.1769],\n",
" [ 0.5129],\n",
" [ 0.2650],\n",
" [ 0.1519],\n",
" [ 0.2660],\n",
" [ 0.1505],\n",
" [ 0.2407],\n",
" [ 0.1766],\n",
" [ 0.2215],\n",
" [ 0.3759],\n",
" [ 0.0643],\n",
" [ 0.2909],\n",
" [ 0.0220],\n",
" [ 0.5878],\n",
" [ 0.1559],\n",
" [ 0.2339],\n",
" [ 0.3533],\n",
" [-0.1447],\n",
" [ 0.5657],\n",
" [ 0.0656],\n",
" [-0.1913],\n",
" [ 0.1975],\n",
" [-0.0296],\n",
" [ 0.3531],\n",
" [ 0.0032],\n",
" [ 0.1607],\n",
" [ 0.2249],\n",
" [ 0.0783],\n",
" [ 0.1663],\n",
" [-0.0781],\n",
" [-0.0607],\n",
" [ 0.3047],\n",
" [ 0.2461],\n",
" [-0.0380],\n",
" [ 0.0481],\n",
" [-0.0040],\n",
" [ 0.0110],\n",
" [-0.0221],\n",
" [ 0.1001],\n",
" [ 0.0754],\n",
" [ 0.2153],\n",
" [-0.1584],\n",
" [ 0.0033],\n",
" [-0.2072],\n",
" [ 0.1622],\n",
" [-0.1114],\n",
" [-0.0954],\n",
" [-0.2582],\n",
" [-0.0575],\n",
" [-0.0883],\n",
" [ 0.3422],\n",
" [-0.1808],\n",
" [-0.2768],\n",
" [-0.1964],\n",
" [ 0.1526],\n",
" [-0.1362],\n",
" [ 0.0674],\n",
" [-0.5093],\n",
" [-0.0344],\n",
" [-0.3681],\n",
" [-0.2217],\n",
" [-0.1733],\n",
" [-0.0589],\n",
" [-0.1194],\n",
" [-0.0979],\n",
" [-0.2122],\n",
" [-0.5427],\n",
" [-0.5028],\n",
" [ 0.0059],\n",
" [-0.2044],\n",
" [-0.2778],\n",
" [-0.3447],\n",
" [-0.0537],\n",
" [-0.4030],\n",
" [-0.7130],\n",
" [-0.5167],\n",
" [-0.4477],\n",
" [-0.4382],\n",
" [ 0.0076],\n",
" [-0.1804],\n",
" [-0.1491],\n",
" [ 0.1210],\n",
" [-0.4279],\n",
" [-0.6204],\n",
" [-0.7309],\n",
" [-0.1835],\n",
" [-0.9354],\n",
" [-0.6655],\n",
" [-0.7265],\n",
" [-0.5585],\n",
" [-0.8215],\n",
" [-0.3998],\n",
" [-0.6667],\n",
" [-0.4026],\n",
" [-0.3606],\n",
" [-0.2286],\n",
" [-0.5571],\n",
" [-0.8246],\n",
" [-0.2567],\n",
" [-0.8022],\n",
" [-0.3873],\n",
" [-0.6781],\n",
" [-0.8021],\n",
" [-0.7463],\n",
" [-0.6887],\n",
" [-0.5723],\n",
" [-0.6661],\n",
" [-0.4324],\n",
" [-0.6482],\n",
" [-0.5130],\n",
" [-0.6848],\n",
" [-0.5460],\n",
" [-0.8493],\n",
" [-0.1809],\n",
" [-0.5165],\n",
" [-0.4671],\n",
" [-0.8529],\n",
" [-0.9896],\n",
" [-0.8904],\n",
" [-0.4498],\n",
" [-1.0809],\n",
" [-0.9123],\n",
" [-0.7125],\n",
" [-0.4627],\n",
" [-0.5643],\n",
" [-0.7416],\n",
" [-0.8990],\n",
" [-0.8161],\n",
" [-0.5500],\n",
" [-0.9439],\n",
" [-0.8327],\n",
" [-0.7132],\n",
" [-0.8250],\n",
" [-0.9772],\n",
" [-0.8947],\n",
" [-0.4970],\n",
" [-0.4945],\n",
" [-0.4604],\n",
" [-0.7029],\n",
" [-0.7518],\n",
" [-0.7635],\n",
" [-0.8060],\n",
" [-0.8300],\n",
" [-1.1194],\n",
" [-1.2429],\n",
" [-0.7834],\n",
" [-0.3628],\n",
" [-1.1099],\n",
" [-0.8337],\n",
" [-1.0767],\n",
" [-0.7193],\n",
" [-0.6253],\n",
" [-0.9703],\n",
" [-0.5913],\n",
" [-1.0695],\n",
" [-0.9610],\n",
" [-0.7796],\n",
" [-0.8729],\n",
" [-1.1516],\n",
" [-0.8974],\n",
" [-1.1277],\n",
" [-0.8297],\n",
" [-0.6336],\n",
" [-1.5144],\n",
" [-1.0980],\n",
" [-1.0812],\n",
" [-0.5136],\n",
" [-0.6882],\n",
" [-0.9138],\n",
" [-0.9021],\n",
" [-1.0671],\n",
" [-1.1456],\n",
" [-0.9467],\n",
" [-0.6042],\n",
" [-0.8922],\n",
" [-0.9499],\n",
" [-0.6512],\n",
" [-1.0729],\n",
" [-1.1589],\n",
" [-1.1675],\n",
" [-0.9637],\n",
" [-0.7511],\n",
" [-0.8479],\n",
" [-0.8410],\n",
" [-1.1934],\n",
" [-0.8869],\n",
" [-0.9340],\n",
" [-1.0252],\n",
" [-0.8195],\n",
" [-1.3040],\n",
" [-0.6508],\n",
" [-1.0083],\n",
" [-1.1282],\n",
" [-0.9536],\n",
" [-1.0764],\n",
" [-1.2750],\n",
" [-1.0073],\n",
" [-1.0259],\n",
" [-0.8144],\n",
" [-1.2082],\n",
" [-0.9558],\n",
" [-0.9895],\n",
" [-1.0417],\n",
" [-1.0077],\n",
" [-0.7460],\n",
" [-0.7199],\n",
" [-1.1118],\n",
" [-0.7411],\n",
" [-1.2156],\n",
" [-0.8967],\n",
" [-0.8194],\n",
" [-1.1041],\n",
" [-0.9286],\n",
" [-0.9155],\n",
" [-0.7483],\n",
" [-0.9874],\n",
" [-1.0476],\n",
" [-0.9132],\n",
" [-0.7950],\n",
" [-0.8823],\n",
" [-0.8565],\n",
" [-1.0017],\n",
" [-0.9736],\n",
" [-0.8743],\n",
" [-0.9509],\n",
" [-1.3399],\n",
" [-0.8861],\n",
" [-1.0557],\n",
" [-0.8494],\n",
" [-0.6369],\n",
" [-1.0813],\n",
" [-0.7510],\n",
" [-0.8624],\n",
" [-1.1163],\n",
" [-0.9114],\n",
" [-0.7323],\n",
" [-0.9083],\n",
" [-0.8352],\n",
" [-0.6851],\n",
" [-0.9174],\n",
" [-0.9412],\n",
" [-1.3040],\n",
" [-0.6257],\n",
" [-0.7814],\n",
" [-0.7670],\n",
" [-1.0620],\n",
" [-0.9168],\n",
" [-1.0231],\n",
" [-0.5532],\n",
" [-0.7955],\n",
" [-0.9293],\n",
" [-0.7984],\n",
" [-0.9475],\n",
" [-0.8074],\n",
" [-1.0046],\n",
" [-0.7866],\n",
" [-0.8110],\n",
" [-0.8169],\n",
" [-0.7929],\n",
" [-0.9577],\n",
" [-0.7490],\n",
" [-0.6953],\n",
" [-0.7600],\n",
" [-0.6348],\n",
" [-0.5752],\n",
" [-0.6600],\n",
" [-1.1377],\n",
" [-1.0344],\n",
" [-0.6518],\n",
" [-0.7506],\n",
" [-0.9227],\n",
" [-0.7814],\n",
" [-0.9301],\n",
" [-0.4463],\n",
" [-0.8153],\n",
" [-0.7221],\n",
" [-0.6543],\n",
" [-1.0062],\n",
" [-0.4462],\n",
" [-0.5389],\n",
" [-0.3644],\n",
" [-0.3854],\n",
" [-0.5175],\n",
" [-0.3598],\n",
" [-0.7745],\n",
" [-0.8278],\n",
" [-0.6843],\n",
" [-0.5519],\n",
" [-0.6849],\n",
" [-0.6662],\n",
" [-0.8282],\n",
" [-0.5927],\n",
" [-0.8346],\n",
" [-0.5149],\n",
" [-0.0033],\n",
" [-0.7285],\n",
" [-0.8659],\n",
" [-0.4320],\n",
" [-0.5433],\n",
" [-0.5551],\n",
" [-0.4936],\n",
" [-0.3990],\n",
" [-0.2697],\n",
" [-0.5388],\n",
" [-0.5527],\n",
" [-0.5663],\n",
" [-0.4017],\n",
" [-0.2667],\n",
" [-0.3446],\n",
" [-0.3117],\n",
" [-0.3110],\n",
" [-0.8562],\n",
" [-0.2726],\n",
" [-0.5014],\n",
" [-0.4719],\n",
" [-0.5338],\n",
" [-0.7666],\n",
" [-0.1854],\n",
" [-0.5822],\n",
" [-0.4734],\n",
" [-0.2585],\n",
" [-0.2755],\n",
" [-0.4047],\n",
" [-0.0902],\n",
" [-0.0984],\n",
" [-0.3434],\n",
" [-0.0755],\n",
" [-0.5209],\n",
" [-0.2434],\n",
" [-0.3536],\n",
" [-0.0617],\n",
" [ 0.1276],\n",
" [-0.0150],\n",
" [-0.5196],\n",
" [-0.2691],\n",
" [-0.8314],\n",
" [ 0.1469],\n",
" [-0.0438],\n",
" [-0.4816],\n",
" [ 0.1779],\n",
" [-0.1709],\n",
" [-0.2126],\n",
" [-0.2875],\n",
" [-0.4329],\n",
" [-0.0967],\n",
" [-0.5540],\n",
" [-0.2296],\n",
" [-0.0021],\n",
" [-0.1871],\n",
" [ 0.0261],\n",
" [-0.0573],\n",
" [ 0.3196],\n",
" [ 0.1587],\n",
" [ 0.1620],\n",
" [-0.3062],\n",
" [ 0.1800],\n",
" [-0.0216],\n",
" [-0.0861],\n",
" [ 0.3876],\n",
" [ 0.2574],\n",
" [ 0.2573],\n",
" [ 0.3694],\n",
" [ 0.1312],\n",
" [ 0.6010],\n",
" [ 0.0274],\n",
" [ 0.0227],\n",
" [-0.1395],\n",
" [ 0.0214],\n",
" [ 0.3586],\n",
" [ 0.0331],\n",
" [ 0.2754],\n",
" [ 0.4699],\n",
" [ 0.3533],\n",
" [-0.0946],\n",
" [ 0.1566],\n",
" [ 0.2768],\n",
" [ 0.6166],\n",
" [ 0.3522],\n",
" [ 0.2357],\n",
" [ 0.2673],\n",
" [ 0.2506],\n",
" [ 0.4461],\n",
" [ 0.6163],\n",
" [ 0.1398],\n",
" [ 0.3288],\n",
" [ 0.4211],\n",
" [ 0.3313],\n",
" [ 0.1029],\n",
" [ 0.4284],\n",
" [ 0.1385],\n",
" [ 0.1132],\n",
" [ 0.0989],\n",
" [ 0.3567],\n",
" [ 0.2329],\n",
" [ 0.4514],\n",
" [ 0.7074],\n",
" [ 0.3183],\n",
" [ 0.2934],\n",
" [ 0.4533],\n",
" [ 0.2790],\n",
" [ 0.4807],\n",
" [ 0.8162],\n",
" [ 0.6992],\n",
" [ 0.1948],\n",
" [ 0.5107],\n",
" [ 0.8306],\n",
" [ 0.2990],\n",
" [ 0.2718],\n",
" [ 0.7156],\n",
" [ 0.8072],\n",
" [ 0.6706],\n",
" [ 0.5840],\n",
" [ 0.8009],\n",
" [ 0.5367],\n",
" [ 0.8542],\n",
" [ 0.4551],\n",
" [ 0.6621],\n",
" [ 0.6004],\n",
" [ 0.6589],\n",
" [ 0.4726],\n",
" [ 0.5991],\n",
" [ 0.8084],\n",
" [ 0.5788],\n",
" [ 0.7125],\n",
" [ 0.6552],\n",
" [ 0.9191],\n",
" [ 0.3361],\n",
" [ 0.8335],\n",
" [ 0.2599],\n",
" [ 0.6830],\n",
" [ 0.6857],\n",
" [ 0.4505],\n",
" [ 0.7303],\n",
" [ 0.5562],\n",
" [ 0.3135],\n",
" [ 0.7432],\n",
" [ 0.8188],\n",
" [ 0.7189],\n",
" [ 0.6228],\n",
" [ 0.8273],\n",
" [ 0.6486],\n",
" [ 0.9803],\n",
" [ 0.6484],\n",
" [ 0.7697],\n",
" [ 1.1531],\n",
" [ 0.9866],\n",
" [ 1.3931],\n",
" [ 0.9747],\n",
" [ 1.2460],\n",
" [ 1.0597],\n",
" [ 0.7014],\n",
" [ 0.9013],\n",
" [ 0.9571],\n",
" [ 0.7041],\n",
" [ 1.0944],\n",
" [ 1.1762],\n",
" [ 1.1356],\n",
" [ 1.0760],\n",
" [ 1.0171],\n",
" [ 0.8546],\n",
" [ 0.9204],\n",
" [ 0.9524],\n",
" [ 1.3716],\n",
" [ 0.7630],\n",
" [ 0.9069],\n",
" [ 1.0180],\n",
" [ 1.0366],\n",
" [ 1.0358],\n",
" [ 0.8609],\n",
" [ 0.8634],\n",
" [ 0.8047],\n",
" [ 0.7477],\n",
" [ 0.9808],\n",
" [ 1.0275],\n",
" [ 1.2071],\n",
" [ 0.5799],\n",
" [ 0.8834],\n",
" [ 0.8784],\n",
" [ 1.1447],\n",
" [ 1.0891],\n",
" [ 0.5811],\n",
" [ 0.9703],\n",
" [ 1.2833],\n",
" [ 0.9937],\n",
" [ 1.1356],\n",
" [ 0.8306],\n",
" [ 0.9129],\n",
" [ 1.0194],\n",
" [ 1.4320],\n",
" [ 1.2589],\n",
" [ 0.9175],\n",
" [ 0.8849],\n",
" [ 1.1727],\n",
" [ 0.9605],\n",
" [ 0.7599],\n",
" [ 0.8099],\n",
" [ 1.0688],\n",
" [ 0.7013],\n",
" [ 1.0260],\n",
" [ 0.7066],\n",
" [ 0.8967],\n",
" [ 1.0578],\n",
" [ 0.8639],\n",
" [ 1.0968],\n",
" [ 0.9553],\n",
" [ 1.0410],\n",
" [ 0.7809],\n",
" [ 0.8928],\n",
" [ 0.9644],\n",
" [ 0.8980],\n",
" [ 0.9744],\n",
" [ 0.6657],\n",
" [ 1.0549],\n",
" [ 0.9716],\n",
" [ 1.0272],\n",
" [ 0.9510],\n",
" [ 1.0992],\n",
" [ 0.8345],\n",
" [ 1.0305],\n",
" [ 1.0269],\n",
" [ 0.9503],\n",
" [ 1.0622],\n",
" [ 0.9953],\n",
" [ 1.3019],\n",
" [ 1.0447],\n",
" [ 0.9759],\n",
" [ 0.9953],\n",
" [ 1.0697],\n",
" [ 0.9619],\n",
" [ 1.0681],\n",
" [ 1.0844],\n",
" [ 0.6814],\n",
" [ 0.7774],\n",
" [ 1.1827],\n",
" [ 1.1599],\n",
" [ 0.7436],\n",
" [ 0.8570],\n",
" [ 0.7392],\n",
" [ 1.2210],\n",
" [ 0.8350],\n",
" [ 0.7613],\n",
" [ 0.7885],\n",
" [ 1.0991],\n",
" [ 0.6867],\n",
" [ 0.5461],\n",
" [ 1.1209],\n",
" [ 1.1265],\n",
" [ 0.9876],\n",
" [ 0.8403],\n",
" [ 0.9892],\n",
" [ 0.7838],\n",
" [ 0.5770],\n",
" [ 0.7996],\n",
" [ 1.1023],\n",
" [ 1.1888],\n",
" [ 0.8290],\n",
" [ 0.9919],\n",
" [ 0.7272],\n",
" [ 0.6149],\n",
" [ 0.8744],\n",
" [ 0.7331],\n",
" [ 0.9389],\n",
" [ 0.8888],\n",
" [ 0.4813],\n",
" [ 1.1600],\n",
" [ 0.6871],\n",
" [ 0.7780],\n",
" [ 0.9699],\n",
" [ 0.3082],\n",
" [ 0.8391],\n",
" [ 0.5978],\n",
" [ 0.5697],\n",
" [ 0.9227],\n",
" [ 0.4502],\n",
" [ 0.5293],\n",
" [ 0.7309],\n",
" [ 0.7579],\n",
" [ 0.5995],\n",
" [ 0.5698],\n",
" [ 0.5490],\n",
" [ 0.7483],\n",
" [ 0.9721],\n",
" [ 0.9419],\n",
" [ 0.5393],\n",
" [ 0.9869],\n",
" [ 0.9892],\n",
" [ 0.5714],\n",
" [ 0.7620],\n",
" [ 0.6800],\n",
" [ 0.8412],\n",
" [ 0.6070],\n",
" [ 0.1774],\n",
" [ 0.6198],\n",
" [ 0.7153],\n",
" [ 0.7985],\n",
" [ 0.5209],\n",
" [ 1.1309],\n",
" [ 0.6716],\n",
" [ 0.7221],\n",
" [ 0.5309],\n",
" [ 0.6143],\n",
" [ 0.9212],\n",
" [ 0.6585],\n",
" [ 0.5518],\n",
" [ 0.7676],\n",
" [ 0.7002],\n",
" [ 0.5711],\n",
" [ 0.5491],\n",
" [ 0.7280],\n",
" [ 1.2188],\n",
" [ 0.3206],\n",
" [ 0.5493],\n",
" [ 0.7454],\n",
" [ 0.5868],\n",
" [ 0.6143],\n",
" [ 0.8513],\n",
" [ 0.1876],\n",
" [ 0.5672],\n",
" [ 0.4292],\n",
" [ 0.5437],\n",
" [ 0.4909],\n",
" [ 0.7139],\n",
" [ 0.5861],\n",
" [ 0.3725],\n",
" [ 0.5194],\n",
" [ 0.4843],\n",
" [ 0.0279],\n",
" [ 0.3152],\n",
" [ 0.4333],\n",
" [ 0.5915],\n",
" [ 0.2709],\n",
" [ 0.4861],\n",
" [ 0.1708],\n",
" [-0.0844],\n",
" [ 0.1523],\n",
" [-0.2092],\n",
" [ 0.2965],\n",
" [-0.1280],\n",
" [ 0.4479],\n",
" [ 0.4392],\n",
" [ 0.1969],\n",
" [ 0.1989],\n",
" [-0.0969],\n",
" [ 0.2829],\n",
" [ 0.1741],\n",
" [-0.1890],\n",
" [-0.0512],\n",
" [ 0.4777],\n",
" [ 0.0458],\n",
" [ 0.0724],\n",
" [ 0.1996],\n",
" [ 0.2772],\n",
" [-0.0650],\n",
" [ 0.4351],\n",
" [ 0.2693],\n",
" [-0.0298],\n",
" [-0.1171],\n",
" [ 0.3714],\n",
" [ 0.0992],\n",
" [ 0.0090],\n",
" [ 0.0618],\n",
" [ 0.1225],\n",
" [ 0.1389],\n",
" [ 0.1166],\n",
" [ 0.0821],\n",
" [ 0.0435],\n",
" [-0.1259],\n",
" [-0.1045],\n",
" [ 0.1779],\n",
" [-0.2051],\n",
" [-0.2457],\n",
" [-0.1619],\n",
" [-0.0991],\n",
" [ 0.1651],\n",
" [ 0.1712],\n",
" [-0.1440],\n",
" [-0.0499],\n",
" [-0.0943],\n",
" [ 0.1058],\n",
" [-0.3224],\n",
" [-0.2115],\n",
" [-0.1307],\n",
" [-0.2432],\n",
" [-0.1935],\n",
" [-0.1462],\n",
" [-0.3798],\n",
" [-0.3857],\n",
" [-0.3871],\n",
" [ 0.1132],\n",
" [-0.5729],\n",
" [ 0.1458],\n",
" [-0.5250],\n",
" [-0.1113],\n",
" [-0.1085],\n",
" [-0.3974],\n",
" [-0.2798],\n",
" [-0.2995],\n",
" [-0.0517],\n",
" [-0.1601],\n",
" [-0.5213],\n",
" [-0.3897],\n",
" [-0.5143],\n",
" [-0.4268],\n",
" [-0.4268],\n",
" [-0.1593],\n",
" [-0.3720],\n",
" [-0.2030],\n",
" [-0.5328],\n",
" [-0.8009],\n",
" [-0.5220],\n",
" [-0.5291],\n",
" [-0.3730],\n",
" [-0.4571],\n",
" [-0.3859],\n",
" [-0.3053],\n",
" [-0.3744],\n",
" [-0.7439],\n",
" [-0.7338],\n",
" [-0.2856],\n",
" [-0.3440],\n",
" [-0.6041],\n",
" [-0.7940],\n",
" [-0.6112],\n",
" [-0.1943]]))"
]
},
"execution_count": 85,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 85
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:54:45.359640179Z",
"start_time": "2026-03-25T12:54:45.318414005Z"
}
},
"cell_type": "code",
"source": [
"batch_size, n_train = 16, 600\n",
"# 只有前n_train个样本用于训练\n",
"train_iter = d2l.load_array((features[:n_train], labels[:n_train]),\n",
"batch_size, is_train=True)\n"
],
"id": "239a596b20d40dec",
"outputs": [],
"execution_count": 86
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:54:45.428803377Z",
"start_time": "2026-03-25T12:54:45.361568655Z"
}
},
"cell_type": "code",
"source": [
"def init_weights(m):\n",
" if type(m) == nn.Linear:\n",
" nn.init.xavier_uniform_(m.weight)"
],
"id": "54d30bd0ee41cb8",
"outputs": [],
"execution_count": 87
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:54:45.522687266Z",
"start_time": "2026-03-25T12:54:45.432893670Z"
}
},
"cell_type": "code",
"source": [
"def get_net():\n",
" net = nn.Sequential(nn.Linear(4, 10),\n",
" nn.ReLU(),\n",
" nn.Linear(10, 1))\n",
" net.apply(init_weights)\n",
" return net\n",
"loss = nn.MSELoss(reduction='none')"
],
"id": "5d095792e3b3681",
"outputs": [],
"execution_count": 88
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:54:45.821846960Z",
"start_time": "2026-03-25T12:54:45.527750802Z"
}
},
"cell_type": "code",
"source": [
"def train(net, train_iter, loss, epochs, lr):\n",
" trainer = torch.optim.Adam(net.parameters(), lr)\n",
" for epoch in range(epochs):\n",
" for X, y in train_iter:\n",
" trainer.zero_grad()\n",
" l = loss(net(X), y)\n",
" l.sum().backward()\n",
" trainer.step()\n",
" print(f'epoch {epoch + 1}, '\n",
" f'loss: {d2l.evaluate_loss(net, train_iter, loss):f}')\n",
"net = get_net()\n",
"train(net, train_iter, loss, 5, 0.01)"
],
"id": "5c1dba484e805335",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 1, loss: 0.069361\n",
"epoch 2, loss: 0.057280\n",
"epoch 3, loss: 0.054714\n",
"epoch 4, loss: 0.054167\n",
"epoch 5, loss: 0.050941\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/yukun/.conda/envs/nn/lib/python3.11/site-packages/d2l/torch.py:3179: UserWarning: Converting a tensor with requires_grad=True to a scalar may lead to unexpected behavior.\n",
"Consider using tensor.detach() first. (Triggered internally at /pytorch/torch/csrc/autograd/generated/python_variable_methods.cpp:836.)\n",
" self.data = [a + float(b) for a, b in zip(self.data, args)]\n"
]
}
],
"execution_count": 89
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:54:46.022554993Z",
"start_time": "2026-03-25T12:54:45.826680707Z"
}
},
"cell_type": "code",
"source": [
"onestep_preds = net(features)\n",
"d2l.plot([time, time[tau:]],\n",
"[x.detach().numpy(), onestep_preds.detach().numpy()], 'time',\n",
"'x', legend=['data', '1-step preds'], xlim=[1, 1000],\n",
"figsize=(6, 3))"
],
"id": "a6efb4a978cd9375",
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 600x300 with 1 Axes>"
],
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"406.885938pt\" height=\"211.07625pt\" viewBox=\"0 0 406.885938 211.07625\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2026-03-25T20:54:45.966103</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.7.2, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 211.07625 \nL 406.885938 211.07625 \nL 406.885938 0 \nL 0 0 \nz\n\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 52.160938 173.52 \nL 386.960938 173.52 \nL 386.960938 7.2 \nL 52.160938 7.2 \nz\n\"/>\n </g>\n <g id=\"matplotlib.axis_1\">\n <g id=\"xtick_1\">\n <g id=\"line2d_1\">\n <path d=\"M 118.852829 173.52 \nL 118.852829 7.2 \n\" clip-path=\"url(#p37412f0ba3)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_2\">\n <defs>\n <path id=\"m7724595e6b\" d=\"M 0 0 \nL 0 3.5 \n\" style=\"stroke: #ffffff; stroke-width: 0.8\"/>\n </defs>\n <g>\n <use xlink:href=\"#m7724595e6b\" x=\"118.852829\" y=\"173.52\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_1\">\n <!-- 200 -->\n <g style=\"fill: #ffffff\" transform=\"translate(109.309079 188.118438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-32\" d=\"M 1228 531 \nL 3431 531 \nL 3431 0 \nL 469 0 \nL 469 531 \nQ 828 903 1448 1529 \nQ 2069 2156 2228 2338 \nQ 2531 2678 2651 2914 \nQ 2772 3150 2772 3378 \nQ 2772 3750 2511 3984 \nQ 2250 4219 1831 4219 \nQ 1534 4219 1204 4116 \nQ 875 4013 500 3803 \nL 500 4441 \nQ 881 4594 1212 4672 \nQ 1544 4750 1819 4750 \nQ 2544 4750 2975 4387 \nQ 3406 4025 3406 3419 \nQ 3406 3131 3298 2873 \nQ 3191 2616 2906 2266 \nQ 2828 2175 2409 1742 \nQ 1991 1309 1228 531 \nz\n\" transform=\"scale(0.015625)\"/>\n <path id=\"DejaVuSans-30\" d=\"M 2034 4250 \nQ 1547 4250 1301 3770 \nQ 1056 3291 1056 2328 \nQ 1056 1369 1301 889 \nQ 1547 409 2034 409 \nQ 2525 409 2770 889 \nQ 3016 1369 3016 2328 \nQ 3016 3291 2770 3770 \nQ 2525 4250 2034 4250 \nz\nM 2034 4750 \nQ 2819 4750 3233 4129 \nQ 3647 3509 3647 2328 \nQ 3647 1150 3233 529 \nQ 2819 -91 2034 -91 \nQ 1250 -91 836 529 \nQ 422 1150 422 2328 \nQ 422 3509 836 4129 \nQ 1250 4750 2034 4750 \nz\n\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-32\"/>\n <use xlink:href=\"#DejaVuSans-30\" x=\"63.623047\"/>\n <use xlink:href=\"#DejaVuSans-30\" x=\"127.246094\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_2\">\n <g id=\"line2d_3\">\n <path d=\"M 185.879856 173.52 \nL 185.879856 7.2 \n\" clip-path=\"url(#p37412f0ba3)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_4\">\n <g>\n <use xlink:href=\"#m7724595e6b\" x=\"185.879856\" y=\"173.52\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_2\">\n <!-- 400 -->\n <g style=\"fill: #ffffff\" transform=\"translate(176.336106 188.118438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-34\" d=\"M 2419 4116 \nL 825 1625 \nL 2419 1625 \nL 2419 4116 \nz\nM 2253 4666 \nL 3047 4666 \nL 3047 1625 \nL 3713 1625 \nL 3713 1100 \nL 3047 1100 \nL 3047 0 \nL 2419 0 \nL 2419 11
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
}
],
"execution_count": 90
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:54:46.309387018Z",
"start_time": "2026-03-25T12:54:46.060491885Z"
}
},
"cell_type": "code",
"source": [
"multistep_preds = torch.zeros(T)\n",
"multistep_preds[: n_train + tau] = x[: n_train + tau]\n",
"for i in range(n_train + tau, T):\n",
" multistep_preds[i] = net(\n",
" multistep_preds[i - tau:i].reshape((1, -1)))\n",
"d2l.plot([time, time[tau:], time[n_train + tau:]],\n",
" [x.detach().numpy(), onestep_preds.detach().numpy(),\n",
" multistep_preds[n_train + tau:].detach().numpy()], 'time',\n",
" 'x', legend=['data', '1-step preds', 'multistep preds'],\n",
" xlim=[1, 1000], figsize=(6, 3))"
],
"id": "12c3a1c3912da4dd",
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 600x300 with 1 Axes>"
],
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"406.885938pt\" height=\"211.07625pt\" viewBox=\"0 0 406.885938 211.07625\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2026-03-25T20:54:46.220729</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.7.2, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 211.07625 \nL 406.885938 211.07625 \nL 406.885938 0 \nL 0 0 \nz\n\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 52.160938 173.52 \nL 386.960938 173.52 \nL 386.960938 7.2 \nL 52.160938 7.2 \nz\n\"/>\n </g>\n <g id=\"matplotlib.axis_1\">\n <g id=\"xtick_1\">\n <g id=\"line2d_1\">\n <path d=\"M 118.852829 173.52 \nL 118.852829 7.2 \n\" clip-path=\"url(#pd9c266d709)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_2\">\n <defs>\n <path id=\"mbf15bb4269\" d=\"M 0 0 \nL 0 3.5 \n\" style=\"stroke: #ffffff; stroke-width: 0.8\"/>\n </defs>\n <g>\n <use xlink:href=\"#mbf15bb4269\" x=\"118.852829\" y=\"173.52\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_1\">\n <!-- 200 -->\n <g style=\"fill: #ffffff\" transform=\"translate(109.309079 188.118438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-32\" d=\"M 1228 531 \nL 3431 531 \nL 3431 0 \nL 469 0 \nL 469 531 \nQ 828 903 1448 1529 \nQ 2069 2156 2228 2338 \nQ 2531 2678 2651 2914 \nQ 2772 3150 2772 3378 \nQ 2772 3750 2511 3984 \nQ 2250 4219 1831 4219 \nQ 1534 4219 1204 4116 \nQ 875 4013 500 3803 \nL 500 4441 \nQ 881 4594 1212 4672 \nQ 1544 4750 1819 4750 \nQ 2544 4750 2975 4387 \nQ 3406 4025 3406 3419 \nQ 3406 3131 3298 2873 \nQ 3191 2616 2906 2266 \nQ 2828 2175 2409 1742 \nQ 1991 1309 1228 531 \nz\n\" transform=\"scale(0.015625)\"/>\n <path id=\"DejaVuSans-30\" d=\"M 2034 4250 \nQ 1547 4250 1301 3770 \nQ 1056 3291 1056 2328 \nQ 1056 1369 1301 889 \nQ 1547 409 2034 409 \nQ 2525 409 2770 889 \nQ 3016 1369 3016 2328 \nQ 3016 3291 2770 3770 \nQ 2525 4250 2034 4250 \nz\nM 2034 4750 \nQ 2819 4750 3233 4129 \nQ 3647 3509 3647 2328 \nQ 3647 1150 3233 529 \nQ 2819 -91 2034 -91 \nQ 1250 -91 836 529 \nQ 422 1150 422 2328 \nQ 422 3509 836 4129 \nQ 1250 4750 2034 4750 \nz\n\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-32\"/>\n <use xlink:href=\"#DejaVuSans-30\" x=\"63.623047\"/>\n <use xlink:href=\"#DejaVuSans-30\" x=\"127.246094\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_2\">\n <g id=\"line2d_3\">\n <path d=\"M 185.879856 173.52 \nL 185.879856 7.2 \n\" clip-path=\"url(#pd9c266d709)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_4\">\n <g>\n <use xlink:href=\"#mbf15bb4269\" x=\"185.879856\" y=\"173.52\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_2\">\n <!-- 400 -->\n <g style=\"fill: #ffffff\" transform=\"translate(176.336106 188.118438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-34\" d=\"M 2419 4116 \nL 825 1625 \nL 2419 1625 \nL 2419 4116 \nz\nM 2253 4666 \nL 3047 4666 \nL 3047 1625 \nL 3713 1625 \nL 3713 1100 \nL 3047 1100 \nL 3047 0 \nL 2419 0 \nL 2419 11
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
}
],
"execution_count": 91
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:54:46.389774559Z",
"start_time": "2026-03-25T12:54:46.333788807Z"
}
},
"cell_type": "code",
"source": [
"import collections\n",
"import re"
],
"id": "aab66c10a4c143d2",
"outputs": [],
"execution_count": 92
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:54:46.511151557Z",
"start_time": "2026-03-25T12:54:46.396367087Z"
}
},
"cell_type": "code",
"source": [
"d2l.DATA_HUB['time_machine'] = (d2l.DATA_URL + 'timemachine.txt',\n",
"'090b5e7e70c295757f55df93cb0a180b9691891a')\n",
"def read_time_machine(): #@save\n",
" \"\"\"将时间机器数据集加载到文本行的列表中\"\"\"\n",
" with open(d2l.download('time_machine'), 'r') as f:\n",
" lines = f.readlines()\n",
" return [re.sub('[^A-Za-z]+', ' ', line).strip().lower() for line in lines]\n",
"lines = read_time_machine()\n",
"print(f'# 文本总行数: {len(lines)}')\n",
"print(lines[0])\n",
"print(lines[10])"
],
"id": "1aff117af810525e",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"# 文本总行数: 3221\n",
"the time machine by h g wells\n",
"twinkled and his usually pale face was flushed and animated the\n"
]
}
],
"execution_count": 93
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:54:46.589883911Z",
"start_time": "2026-03-25T12:54:46.524677647Z"
}
},
"cell_type": "code",
"source": [
"def tokenize(lines, token='word'): #@save\n",
" \"\"\"将文本行拆分为单词或字符词元\"\"\"\n",
" if token == 'word':\n",
" return [line.split() for line in lines]\n",
" elif token == 'char':\n",
" return [list(line) for line in lines]\n",
" else:\n",
" print('错误:未知词元类型:' + token)\n",
"tokens = tokenize(lines)\n",
"for i in range(11):\n",
" print(tokens[i])"
],
"id": "eb4fe9745fbaa5e2",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['the', 'time', 'machine', 'by', 'h', 'g', 'wells']\n",
"[]\n",
"[]\n",
"[]\n",
"[]\n",
"['i']\n",
"[]\n",
"[]\n",
"['the', 'time', 'traveller', 'for', 'so', 'it', 'will', 'be', 'convenient', 'to', 'speak', 'of', 'him']\n",
"['was', 'expounding', 'a', 'recondite', 'matter', 'to', 'us', 'his', 'grey', 'eyes', 'shone', 'and']\n",
"['twinkled', 'and', 'his', 'usually', 'pale', 'face', 'was', 'flushed', 'and', 'animated', 'the']\n"
]
}
],
"execution_count": 94
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:54:46.684454112Z",
"start_time": "2026-03-25T12:54:46.620352826Z"
}
},
"cell_type": "code",
"source": [
"def count_corpus(tokens): #@save\n",
" \"\"\"统计词元的频率\"\"\"\n",
" # 这里的tokens是1D列表或2D列表\n",
" if len(tokens) == 0 or isinstance(tokens[0], list):\n",
" # 将词元列表展平成一个列表\n",
" tokens = [token for line in tokens for token in line]\n",
" return collections.Counter(tokens)\n",
"class Vocab: #@save\n",
" \"\"\"文本词表\"\"\"\n",
" def __init__(self, tokens=None, min_freq=0, reserved_tokens=None):\n",
" if tokens is None:\n",
" tokens = []\n",
" if reserved_tokens is None:\n",
" reserved_tokens = []\n",
" # 按出现频率排序\n",
" counter = count_corpus(tokens)\n",
" self._token_freqs = sorted(counter.items(), key=lambda x: x[1],\n",
" reverse=True)\n",
" # 未知词元的索引为0\n",
" self.idx_to_token = ['<unk>'] + reserved_tokens\n",
" self.token_to_idx = {token: idx\n",
" for idx, token in enumerate(self.idx_to_token)}\n",
" for token, freq in self._token_freqs:\n",
" if freq < min_freq:\n",
" break\n",
" if token not in self.token_to_idx:\n",
" self.idx_to_token.append(token)\n",
" self.token_to_idx[token] = len(self.idx_to_token) - 1\n",
" def __len__(self):\n",
" return len(self.idx_to_token)\n",
" def __getitem__(self, tokens):\n",
" if not isinstance(tokens, (list, tuple)):\n",
" return self.token_to_idx.get(tokens, self.unk)\n",
" return [self.__getitem__(token) for token in tokens]\n",
" def to_tokens(self, indices):\n",
" if not isinstance(indices, (list, tuple)):\n",
" return self.idx_to_token[indices]\n",
" return [self.idx_to_token[index] for index in indices]\n",
" @property\n",
" def unk(self): # 未知词元的索引为0\n",
" return 0\n",
" @property\n",
" def token_freqs(self):\n",
" return self._token_freqs"
],
"id": "bee8e5d7b798c6c",
"outputs": [],
"execution_count": 95
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:54:46.844584420Z",
"start_time": "2026-03-25T12:54:46.709821817Z"
}
},
"cell_type": "code",
"source": [
"vocab = Vocab(tokens)\n",
"print(list(vocab.token_to_idx.items())[:10])"
],
"id": "ff4e8ac2044850b7",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[('<unk>', 0), ('the', 1), ('i', 2), ('and', 3), ('of', 4), ('a', 5), ('to', 6), ('was', 7), ('in', 8), ('that', 9)]\n"
]
}
],
"execution_count": 96
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:54:46.914064266Z",
"start_time": "2026-03-25T12:54:46.846539992Z"
}
},
"cell_type": "code",
"source": [
"for i in [0, 100]:\n",
" print('文本:', tokens[i])\n",
" print('索引:', vocab[tokens[i]])"
],
"id": "a4e569dfbd251608",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"文本: ['the', 'time', 'machine', 'by', 'h', 'g', 'wells']\n",
"索引: [1, 19, 50, 40, 2183, 2184, 400]\n",
"文本: ['were', 'three', 'dimensional', 'representations', 'of', 'his', 'four', 'dimensioned']\n",
"索引: [20, 175, 1452, 2250, 4, 25, 262, 2251]\n"
]
}
],
"execution_count": 97
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:54:47.087077440Z",
"start_time": "2026-03-25T12:54:46.945943096Z"
}
},
"cell_type": "code",
"source": [
"def load_corpus_time_machine(max_tokens=-1): #@save\n",
" \"\"\"返回时光机器数据集的词元索引列表和词表\"\"\"\n",
" lines = read_time_machine()\n",
" tokens = tokenize(lines, 'char')\n",
" vocab = Vocab(tokens)\n",
" # 因为时光机器数据集中的每个文本行不一定是一个句子或一个段落,\n",
" # 所以将所有文本行展平到一个列表中\n",
" corpus = [vocab[token] for line in tokens for token in line]\n",
" if max_tokens > 0:\n",
" corpus = corpus[:max_tokens]\n",
" return corpus, vocab\n",
"corpus, vocab = load_corpus_time_machine()\n",
"\n",
"len(corpus), len(vocab)\n"
],
"id": "1b5c1776ae47af5c",
"outputs": [
{
"data": {
"text/plain": [
"(170580, 28)"
]
},
"execution_count": 98,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 98
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:54:47.245364838Z",
"start_time": "2026-03-25T12:54:47.111205224Z"
}
},
"cell_type": "code",
"source": [
"tokens = d2l.tokenize(read_time_machine())\n",
"# 因为每个文本行不一定是一个句子或一个段落,因此我们把所有文本行拼接到一起\n",
"corpus = [token for line in tokens for token in line]\n",
"vocab = d2l.Vocab(corpus)\n",
"vocab.token_freqs[:10]"
],
"id": "99deb85c025e5cdd",
"outputs": [
{
"data": {
"text/plain": [
"[('the', 2261),\n",
" ('i', 1267),\n",
" ('and', 1245),\n",
" ('of', 1155),\n",
" ('a', 816),\n",
" ('to', 695),\n",
" ('was', 552),\n",
" ('in', 541),\n",
" ('that', 443),\n",
" ('my', 440)]"
]
},
"execution_count": 99,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 99
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:54:47.857887916Z",
"start_time": "2026-03-25T12:54:47.247469243Z"
}
},
"cell_type": "code",
"source": [
"freqs = [freq for token, freq in vocab.token_freqs]\n",
"d2l.plot(freqs, xlabel='token: x', ylabel='frequency: n(x)',\n",
"xscale='log', yscale='log')"
],
"id": "5c0d846673e16c33",
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 350x250 with 1 Axes>"
],
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"247.978125pt\" height=\"183.35625pt\" viewBox=\"0 0 247.978125 183.35625\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2026-03-25T20:54:47.734963</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.7.2, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 183.35625 \nL 247.978125 183.35625 \nL 247.978125 0 \nL 0 0 \nz\n\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 45.478125 145.8 \nL 240.778125 145.8 \nL 240.778125 7.2 \nL 45.478125 7.2 \nz\n\"/>\n </g>\n <g id=\"matplotlib.axis_1\">\n <g id=\"xtick_1\">\n <g id=\"line2d_1\">\n <path d=\"M 54.355398 145.8 \nL 54.355398 7.2 \n\" clip-path=\"url(#p34995c6764)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_2\">\n <defs>\n <path id=\"m5bbba6b94a\" d=\"M 0 0 \nL 0 3.5 \n\" style=\"stroke: #ffffff; stroke-width: 0.8\"/>\n </defs>\n <g>\n <use xlink:href=\"#m5bbba6b94a\" x=\"54.355398\" y=\"145.8\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_1\">\n <!-- $\\mathdefault{10^{0}}$ -->\n <g style=\"fill: #ffffff\" transform=\"translate(45.555398 160.398438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-31\" d=\"M 794 531 \nL 1825 531 \nL 1825 4091 \nL 703 3866 \nL 703 4441 \nL 1819 4666 \nL 2450 4666 \nL 2450 531 \nL 3481 531 \nL 3481 0 \nL 794 0 \nL 794 531 \nz\n\" transform=\"scale(0.015625)\"/>\n <path id=\"DejaVuSans-30\" d=\"M 2034 4250 \nQ 1547 4250 1301 3770 \nQ 1056 3291 1056 2328 \nQ 1056 1369 1301 889 \nQ 1547 409 2034 409 \nQ 2525 409 2770 889 \nQ 3016 1369 3016 2328 \nQ 3016 3291 2770 3770 \nQ 2525 4250 2034 4250 \nz\nM 2034 4750 \nQ 2819 4750 3233 4129 \nQ 3647 3509 3647 2328 \nQ 3647 1150 3233 529 \nQ 2819 -91 2034 -91 \nQ 1250 -91 836 529 \nQ 422 1150 422 2328 \nQ 422 3509 836 4129 \nQ 1250 4750 2034 4750 \nz\n\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-31\" transform=\"translate(0 0.765625)\"/>\n <use xlink:href=\"#DejaVuSans-30\" transform=\"translate(63.623047 0.765625)\"/>\n <use xlink:href=\"#DejaVuSans-30\" transform=\"translate(128.203125 39.046875) scale(0.7)\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_2\">\n <g id=\"line2d_3\">\n <path d=\"M 102.85613 145.8 \nL 102.85613 7.2 \n\" clip-path=\"url(#p34995c6764)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_4\">\n <g>\n <use xlink:href=\"#m5bbba6b94a\" x=\"102.85613\" y=\"145.8\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_2\">\n <!-- $\\mathdefault{10^{1}}$ -->\n <g style=\"fill: #ffffff\" transform=\"translate(94.05613 160.398438) scale(0.1 -0.1)\">\n <use xlink:href=\"#DejaVuSans-31\" transform=\"translate(0 0.684375)\"/>\n <use xlink:href=\"#DejaVuSans-30\" transform=\"translate(63.623047 0.684375)\"/>\n <use xlink:href=\"#DejaVuSans-31\" transform=\"translate(128.203125 38.965625) scale(0.7)\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_3\">\n <g id=\"line2d_5\">\n <path d=\"M 151.356861 145.8 \nL 15
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
}
],
"execution_count": 100
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:54:48.002182407Z",
"start_time": "2026-03-25T12:54:47.912592290Z"
}
},
"cell_type": "code",
"source": [
"bigram_tokens = [pair for pair in zip(corpus[:-1], corpus[1:])]\n",
"bigram_vocab = Vocab(bigram_tokens)\n",
"bigram_vocab.token_freqs[:10]"
],
"id": "2826e3ab0863ee64",
"outputs": [
{
"data": {
"text/plain": [
"[(('of', 'the'), 309),\n",
" (('in', 'the'), 169),\n",
" (('i', 'had'), 130),\n",
" (('i', 'was'), 112),\n",
" (('and', 'the'), 109),\n",
" (('the', 'time'), 102),\n",
" (('it', 'was'), 99),\n",
" (('to', 'the'), 85),\n",
" (('as', 'i'), 78),\n",
" (('of', 'a'), 73)]"
]
},
"execution_count": 101,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 101
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:54:48.095243055Z",
"start_time": "2026-03-25T12:54:48.005003927Z"
}
},
"cell_type": "code",
"source": [
"trigram_tokens = [triple for triple in zip(\n",
"corpus[:-2], corpus[1:-1], corpus[2:])]\n",
"trigram_vocab = Vocab(trigram_tokens)\n",
"trigram_vocab.token_freqs[:10]"
],
"id": "7c8cd0544bf872bb",
"outputs": [
{
"data": {
"text/plain": [
"[(('the', 'time', 'traveller'), 59),\n",
" (('the', 'time', 'machine'), 30),\n",
" (('the', 'medical', 'man'), 24),\n",
" (('it', 'seemed', 'to'), 16),\n",
" (('it', 'was', 'a'), 15),\n",
" (('here', 'and', 'there'), 15),\n",
" (('seemed', 'to', 'me'), 14),\n",
" (('i', 'did', 'not'), 14),\n",
" (('i', 'saw', 'the'), 13),\n",
" (('i', 'began', 'to'), 13)]"
]
},
"execution_count": 102,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 102
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:54:48.748052730Z",
"start_time": "2026-03-25T12:54:48.097673111Z"
}
},
"cell_type": "code",
"source": [
"bigram_freqs = [freq for token, freq in bigram_vocab.token_freqs]\n",
"trigram_freqs = [freq for token, freq in trigram_vocab.token_freqs]\n",
"d2l.plot([freqs, bigram_freqs, trigram_freqs], xlabel='token: x',\n",
"ylabel='frequency: n(x)', xscale='log', yscale='log',\n",
"legend=['unigram', 'bigram', 'trigram'])"
],
"id": "dc3d97dda738613d",
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 350x250 with 1 Axes>"
],
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"247.978125pt\" height=\"183.35625pt\" viewBox=\"0 0 247.978125 183.35625\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2026-03-25T20:54:48.552018</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.7.2, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 183.35625 \nL 247.978125 183.35625 \nL 247.978125 0 \nL 0 0 \nz\n\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 45.478125 145.8 \nL 240.778125 145.8 \nL 240.778125 7.2 \nL 45.478125 7.2 \nz\n\"/>\n </g>\n <g id=\"matplotlib.axis_1\">\n <g id=\"xtick_1\">\n <g id=\"line2d_1\">\n <path d=\"M 54.355398 145.8 \nL 54.355398 7.2 \n\" clip-path=\"url(#pe624ffde45)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_2\">\n <defs>\n <path id=\"m5244db54f0\" d=\"M 0 0 \nL 0 3.5 \n\" style=\"stroke: #ffffff; stroke-width: 0.8\"/>\n </defs>\n <g>\n <use xlink:href=\"#m5244db54f0\" x=\"54.355398\" y=\"145.8\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_1\">\n <!-- $\\mathdefault{10^{0}}$ -->\n <g style=\"fill: #ffffff\" transform=\"translate(45.555398 160.398438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-31\" d=\"M 794 531 \nL 1825 531 \nL 1825 4091 \nL 703 3866 \nL 703 4441 \nL 1819 4666 \nL 2450 4666 \nL 2450 531 \nL 3481 531 \nL 3481 0 \nL 794 0 \nL 794 531 \nz\n\" transform=\"scale(0.015625)\"/>\n <path id=\"DejaVuSans-30\" d=\"M 2034 4250 \nQ 1547 4250 1301 3770 \nQ 1056 3291 1056 2328 \nQ 1056 1369 1301 889 \nQ 1547 409 2034 409 \nQ 2525 409 2770 889 \nQ 3016 1369 3016 2328 \nQ 3016 3291 2770 3770 \nQ 2525 4250 2034 4250 \nz\nM 2034 4750 \nQ 2819 4750 3233 4129 \nQ 3647 3509 3647 2328 \nQ 3647 1150 3233 529 \nQ 2819 -91 2034 -91 \nQ 1250 -91 836 529 \nQ 422 1150 422 2328 \nQ 422 3509 836 4129 \nQ 1250 4750 2034 4750 \nz\n\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-31\" transform=\"translate(0 0.765625)\"/>\n <use xlink:href=\"#DejaVuSans-30\" transform=\"translate(63.623047 0.765625)\"/>\n <use xlink:href=\"#DejaVuSans-30\" transform=\"translate(128.203125 39.046875) scale(0.7)\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_2\">\n <g id=\"line2d_3\">\n <path d=\"M 94.026857 145.8 \nL 94.026857 7.2 \n\" clip-path=\"url(#pe624ffde45)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_4\">\n <g>\n <use xlink:href=\"#m5244db54f0\" x=\"94.026857\" y=\"145.8\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_2\">\n <!-- $\\mathdefault{10^{1}}$ -->\n <g style=\"fill: #ffffff\" transform=\"translate(85.226857 160.398438) scale(0.1 -0.1)\">\n <use xlink:href=\"#DejaVuSans-31\" transform=\"translate(0 0.684375)\"/>\n <use xlink:href=\"#DejaVuSans-30\" transform=\"translate(63.623047 0.684375)\"/>\n <use xlink:href=\"#DejaVuSans-31\" transform=\"translate(128.203125 38.965625) scale(0.7)\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_3\">\n <g id=\"line2d_5\">\n <path d=\"M 133.698316 145.8 \nL 1
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
}
],
"execution_count": 103
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:54:48.898844389Z",
"start_time": "2026-03-25T12:54:48.827838521Z"
}
},
"cell_type": "code",
"source": [
"import random\n",
"def seq_data_iter_random(corpus, batch_size, num_steps): #@save\n",
" \"\"\"使用随机抽样生成一个小批量子序列\"\"\"\n",
" # 从随机偏移量开始对序列进行分区随机范围包括num_steps-1\n",
" corpus = corpus[random.randint(0, num_steps - 1):]\n",
" # 减去1是因为我们需要考虑标签\n",
" num_subseqs = (len(corpus) - 1) // num_steps\n",
" # 长度为num_steps的子序列的起始索引\n",
" initial_indices = list(range(0, num_subseqs * num_steps, num_steps))\n",
" # 在随机抽样的迭代过程中,\n",
" # 来自两个相邻的、随机的、小批量中的子序列不一定在原始序列上相邻\n",
" random.shuffle(initial_indices)\n",
" def data(pos):\n",
" # 返回从pos位置开始的长度为num_steps的序列\n",
" return corpus[pos: pos + num_steps]\n",
" num_batches = num_subseqs // batch_size\n",
" for i in range(0, batch_size * num_batches, batch_size):\n",
" # 在这里initial_indices包含子序列的随机起始索引\n",
" initial_indices_per_batch = initial_indices[i: i + batch_size]\n",
" X = [data(j) for j in initial_indices_per_batch]\n",
" Y = [data(j + 1) for j in initial_indices_per_batch]\n",
" yield torch.tensor(X), torch.tensor(Y)"
],
"id": "fd015793938b83ab",
"outputs": [],
"execution_count": 104
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:54:49.053830957Z",
"start_time": "2026-03-25T12:54:48.902512374Z"
}
},
"cell_type": "code",
"source": [
"my_seq = list(range(35))\n",
"for X, Y in seq_data_iter_random(my_seq, batch_size=2, num_steps=5):\n",
" print('X: ', X, '\\nY:', Y)"
],
"id": "8961f4934b8c7cc",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"X: tensor([[27, 28, 29, 30, 31],\n",
" [ 2, 3, 4, 5, 6]]) \n",
"Y: tensor([[28, 29, 30, 31, 32],\n",
" [ 3, 4, 5, 6, 7]])\n",
"X: tensor([[17, 18, 19, 20, 21],\n",
" [22, 23, 24, 25, 26]]) \n",
"Y: tensor([[18, 19, 20, 21, 22],\n",
" [23, 24, 25, 26, 27]])\n",
"X: tensor([[12, 13, 14, 15, 16],\n",
" [ 7, 8, 9, 10, 11]]) \n",
"Y: tensor([[13, 14, 15, 16, 17],\n",
" [ 8, 9, 10, 11, 12]])\n"
]
}
],
"execution_count": 105
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T12:54:49.085900056Z",
"start_time": "2026-03-25T12:54:49.065678273Z"
}
},
"cell_type": "code",
"source": [
"def seq_data_iter_sequential(corpus, batch_size, num_steps): #@save\n",
" \"\"\"使用顺序分区生成一个小批量子序列\"\"\"\n",
" # 从随机偏移量开始划分序列\n",
" offset = random.randint(0, num_steps)\n",
" num_tokens = ((len(corpus) - offset - 1) // batch_size) * batch_size\n",
" Xs = torch.tensor(corpus[offset: offset + num_tokens])\n",
" Ys = torch.tensor(corpus[offset + 1: offset + 1 + num_tokens])\n",
" Xs, Ys = Xs.reshape(batch_size, -1), Ys.reshape(batch_size, -1)\n",
" num_batches = Xs.shape[1] // num_steps\n",
" for i in range(0, num_steps * num_batches, num_steps):\n",
" X = Xs[:, i: i + num_steps]\n",
" Y = Ys[:, i: i + num_steps]\n",
" yield X, Y"
],
"id": "621b66c0614b22da",
"outputs": [],
"execution_count": 106
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T13:13:22.060317841Z",
"start_time": "2026-03-25T13:13:22.007573330Z"
}
},
"cell_type": "code",
"source": [
"class SeqDataLoader:\n",
" \"\"\"加载序列数据的迭代器\"\"\"\n",
" def __init__(self, batch_size, num_steps, use_random_iter, max_tokens):\n",
" if use_random_iter:\n",
" self.data_iter_fn = seq_data_iter_random\n",
" else:\n",
" self.data_iter_fn = seq_data_iter_sequential\n",
" self.corpus, self.vocab = load_corpus_time_machine(max_tokens)\n",
" self.batch_size, self.num_steps = batch_size, num_steps\n",
" def __iter__(self):\n",
" return self.data_iter_fn(self.corpus, self.batch_size, self.num_steps)\n",
"def load_data_time_machine(batch_size, num_steps, #@save\n",
" use_random_iter=False, max_tokens=10000):\n",
" \"\"\"返回时光机器数据集的迭代器和词表\"\"\"\n",
" data_iter = SeqDataLoader(\n",
" batch_size, num_steps, use_random_iter, max_tokens)\n",
" return data_iter, data_iter.vocab"
],
"id": "f09fe2507a925fe9",
"outputs": [],
"execution_count": 111
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T13:13:25.998507190Z",
"start_time": "2026-03-25T13:13:25.902583940Z"
}
},
"cell_type": "code",
"source": [
"batch_size, num_steps = 32, 35\n",
"train_iter, vocab = load_data_time_machine(batch_size, num_steps)"
],
"id": "69272a664d3b9ae1",
"outputs": [],
"execution_count": 112
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T13:14:32.469873183Z",
"start_time": "2026-03-25T13:14:32.366446697Z"
}
},
"cell_type": "code",
"source": "F.one_hot(torch.tensor([0,2]),len(vocab))",
"id": "35806d36e5ec3ca7",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0],\n",
" [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0]])"
]
},
"execution_count": 115,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 115
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T13:16:48.087534183Z",
"start_time": "2026-03-25T13:16:47.919685613Z"
}
},
"cell_type": "code",
"source": [
"X = torch.arange(10).reshape((2, 5))\n",
"F.one_hot(X.T, 28).shape"
],
"id": "6a4695284b898013",
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([5, 2, 28])"
]
},
"execution_count": 119,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 119
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T13:37:28.354220059Z",
"start_time": "2026-03-25T13:37:28.262974202Z"
}
},
"cell_type": "code",
"source": [
"def get_params(vocab_size, num_hiddens, device):\n",
" num_inputs = num_outputs = vocab_size\n",
" def normal(shape):\n",
" return torch.randn(size=shape, device=device) * 0.01\n",
" # 隐藏层参数\n",
" W_xh = normal((num_inputs, num_hiddens))\n",
" W_hh = normal((num_hiddens, num_hiddens))\n",
" b_h = torch.zeros(num_hiddens, device=device)\n",
" # 输出层参数\n",
" W_hq = normal((num_hiddens, num_outputs))\n",
" b_q = torch.zeros(num_outputs, device=device)\n",
" # 附加梯度\n",
" params = [W_xh, W_hh, b_h, W_hq, b_q]\n",
" '''\n",
" W_xh x->H_t\n",
" W_hh H_t-1->H_t\n",
" W_hq H_t->opt\n",
" '''\n",
" for param in params:\n",
" param.requires_grad_(True)\n",
" return params\n",
"def init_rnn_state(batch_size, num_hiddens, device):\n",
" return (torch.zeros((batch_size, num_hiddens), device=device), )\n",
"def rnn(inputs,state,params):\n",
" W_xh,W_hh,b_h,W_hq,b_q = params\n",
" H, = state\n",
" outputs = []\n",
" for X in inputs: # X (batchs,vocab)\n",
" H = torch.tanh(torch.mm(X,W_xh)+torch.mm(H,W_hh)+b_h)\n",
" Y = torch.mm(H,W_hq)+b_q\n",
" outputs.append(Y)\n",
" return torch.cat(outputs,dim=0),(H,)\n",
"class RNNModelScratch: #@save\n",
" \"\"\"从零开始实现的循环神经网络模型\"\"\"\n",
" def __init__(self, vocab_size, num_hiddens, device,\n",
" get_params, init_state, forward_fn):\n",
" self.vocab_size, self.num_hiddens = vocab_size, num_hiddens\n",
" self.params = get_params(vocab_size, num_hiddens, device)\n",
" self.init_state, self.forward_fn = init_state, forward_fn\n",
" def __call__(self, X, state):\n",
" X = F.one_hot(X.T, self.vocab_size).type(torch.float32)\n",
" return self.forward_fn(X, state, self.params)\n",
" def begin_state(self, batch_size, device):\n",
" return self.init_state(batch_size, self.num_hiddens, device)"
],
"id": "405f7c6af8bdd939",
"outputs": [],
"execution_count": 120
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T13:50:48.023098387Z",
"start_time": "2026-03-25T13:50:47.712386114Z"
}
},
"cell_type": "code",
"source": [
"num_hiddens = 512\n",
"net = RNNModelScratch(len(vocab), num_hiddens, d2l.try_gpu(), get_params,\n",
"init_rnn_state, rnn)\n",
"state = net.begin_state(X.shape[0], d2l.try_gpu())\n",
"Y, new_state = net(X.to(d2l.try_gpu()), state)\n",
"Y.shape, len(new_state), new_state[0].shape"
],
"id": "7bdd0b37b1458ec2",
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([10, 28]), 1, torch.Size([2, 512]))"
]
},
"execution_count": 123,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 123
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T14:06:25.042155295Z",
"start_time": "2026-03-25T14:06:24.593900625Z"
}
},
"cell_type": "code",
"source": [
"def predict_ch8(prefix, num_preds, net, vocab, device): #@save\n",
" \"\"\"在prefix后面生成新字符\"\"\"\n",
" state = net.begin_state(batch_size=1, device=device)\n",
" outputs = [vocab[prefix[0]]]\n",
" get_input = lambda: torch.tensor([outputs[-1]], device=device).reshape((1, 1))\n",
" for y in prefix[1:]: # 预热期\n",
" _, state = net(get_input(), state)\n",
" outputs.append(vocab[y])\n",
" for _ in range(num_preds): # 预测num_preds步\n",
" y, state = net(get_input(), state)\n",
" outputs.append(int(y.argmax(dim=1).reshape(1)))\n",
" return ''.join([vocab.idx_to_token[i] for i in outputs])\n",
"predict_ch8('time traveller ', 10, net, vocab, d2l.try_gpu())"
],
"id": "f38dbe2b11e02bc2",
"outputs": [
{
"data": {
"text/plain": [
"'time traveller slgm sl sl'"
]
},
"execution_count": 124,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 124
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T14:14:43.788897281Z",
"start_time": "2026-03-25T14:14:43.657537187Z"
}
},
"cell_type": "code",
"source": [
"def grad_clipping(net, theta): #@save\n",
" \"\"\"裁剪梯度\"\"\"\n",
" if isinstance(net, nn.Module):\n",
" params = [p for p in net.parameters() if p.requires_grad]\n",
" else:\n",
" params = net.params\n",
" norm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))\n",
" if norm > theta:\n",
" for param in params:\n",
" param.grad[:] *= theta / norm"
],
"id": "6c19717736ffbc68",
"outputs": [],
"execution_count": 125
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T14:21:41.754553299Z",
"start_time": "2026-03-25T14:21:41.701446937Z"
}
},
"cell_type": "code",
"source": [
"import math\n",
"def train_epoch_ch8(net, train_iter, loss, updater, device, use_random_iter):\n",
" \"\"\"训练网络一个迭代周期定义见第8章\"\"\"\n",
" state, timer = None, d2l.Timer()\n",
" metric = d2l.Accumulator(2) # 训练损失之和,词元数量\n",
" for X, Y in train_iter:\n",
" if state is None or use_random_iter:\n",
" # 在第一次迭代或使用随机抽样时初始化state\n",
" state = net.begin_state(batch_size=X.shape[0], device=device)\n",
" else:\n",
" if isinstance(net, nn.Module) and not isinstance(state, tuple):\n",
" # state对于nn.GRU是个张量\n",
" state.detach_()\n",
" else:\n",
" # state对于nn.LSTM或对于我们从零开始实现的模型是个张量\n",
" for s in state:\n",
" s.detach_()\n",
" y = Y.T.reshape(-1)\n",
" X, y = X.to(device), y.to(device)\n",
" y_hat, state = net(X, state)\n",
" l = loss(y_hat, y.long()).mean()\n",
" if isinstance(updater, torch.optim.Optimizer):\n",
" updater.zero_grad()\n",
" l.backward()\n",
" grad_clipping(net, 1)\n",
" updater.step()\n",
" else:\n",
" l.backward()\n",
" grad_clipping(net, 1)\n",
" # 因为已经调用了mean函数\n",
" updater(batch_size=1)\n",
" metric.add(l * y.numel(), y.numel())\n",
" return math.exp(metric[0]/metric[1]),metric[1]/timer.stop()\n"
],
"id": "37ef611dedcb714b",
"outputs": [],
"execution_count": 132
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T15:00:52.669519274Z",
"start_time": "2026-03-25T15:00:52.618407065Z"
}
},
"cell_type": "code",
"source": [
"def train_ch8(net, train_iter, vocab, lr, num_epochs, device,\n",
"use_random_iter=False):\n",
" \"\"\"训练模型定义见第8章\"\"\"\n",
" loss = nn.CrossEntropyLoss()\n",
" animator = d2l.Animator(xlabel='epoch', ylabel='perplexity',\n",
" legend=['train'], xlim=[10, num_epochs])\n",
" # 初始化\n",
" if isinstance(net, nn.Module):\n",
" updater = torch.optim.SGD(net.parameters(), lr)\n",
" else:\n",
" updater = lambda batch_size: d2l.sgd(net.params, lr, batch_size)\n",
" predict = lambda prefix: predict_ch8(prefix, 50, net, vocab, device)\n",
" # 训练和预测\n",
" for epoch in range(num_epochs):\n",
" ppl, speed = train_epoch_ch8(\n",
" net, train_iter, loss, updater, device, use_random_iter)\n",
" if (epoch + 1) % 10 == 0:\n",
" print(predict('time traveller'))\n",
" animator.add(epoch + 1, [ppl])\n",
" print(f'困惑度 {ppl:.1f}, {speed:.1f} 词元/秒 {str(device)}')\n",
" print(predict('time traveller'))\n",
" print(predict('traveller'))"
],
"id": "c96b60b55664378a",
"outputs": [],
"execution_count": 146
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T15:02:40.728313868Z",
"start_time": "2026-03-25T15:00:53.106582989Z"
}
},
"cell_type": "code",
"source": [
"num_epochs, lr = 500, 1\n",
"train_ch8(net, train_iter, vocab, lr, num_epochs, d2l.try_gpu())"
],
"id": "ab4a2fbf4dfd21ef",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"困惑度 1.3, 128113.6 词元/秒 cpu\n",
"time travelleris thene by in psmed the k waile to se pas of sour\n",
"traveller tore asmethe which we canle wey thard abthof spar\n"
]
},
{
"data": {
"text/plain": [
"<Figure size 350x250 with 1 Axes>"
],
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"246.284375pt\" height=\"183.35625pt\" viewBox=\"0 0 246.284375 183.35625\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2026-03-25T23:02:40.690662</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.7.2, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 183.35625 \nL 246.284375 183.35625 \nL 246.284375 0 \nL 0 0 \nz\n\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 34.240625 145.8 \nL 229.540625 145.8 \nL 229.540625 7.2 \nL 34.240625 7.2 \nz\n\"/>\n </g>\n <g id=\"matplotlib.axis_1\">\n <g id=\"xtick_1\">\n <g id=\"line2d_1\">\n <path d=\"M 70.112054 145.8 \nL 70.112054 7.2 \n\" clip-path=\"url(#p1a9dfa12b4)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_2\">\n <defs>\n <path id=\"mb00c30088f\" d=\"M 0 0 \nL 0 3.5 \n\" style=\"stroke: #ffffff; stroke-width: 0.8\"/>\n </defs>\n <g>\n <use xlink:href=\"#mb00c30088f\" x=\"70.112054\" y=\"145.8\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_1\">\n <!-- 100 -->\n <g style=\"fill: #ffffff\" transform=\"translate(60.568304 160.398438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-31\" d=\"M 794 531 \nL 1825 531 \nL 1825 4091 \nL 703 3866 \nL 703 4441 \nL 1819 4666 \nL 2450 4666 \nL 2450 531 \nL 3481 531 \nL 3481 0 \nL 794 0 \nL 794 531 \nz\n\" transform=\"scale(0.015625)\"/>\n <path id=\"DejaVuSans-30\" d=\"M 2034 4250 \nQ 1547 4250 1301 3770 \nQ 1056 3291 1056 2328 \nQ 1056 1369 1301 889 \nQ 1547 409 2034 409 \nQ 2525 409 2770 889 \nQ 3016 1369 3016 2328 \nQ 3016 3291 2770 3770 \nQ 2525 4250 2034 4250 \nz\nM 2034 4750 \nQ 2819 4750 3233 4129 \nQ 3647 3509 3647 2328 \nQ 3647 1150 3233 529 \nQ 2819 -91 2034 -91 \nQ 1250 -91 836 529 \nQ 422 1150 422 2328 \nQ 422 3509 836 4129 \nQ 1250 4750 2034 4750 \nz\n\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-31\"/>\n <use xlink:href=\"#DejaVuSans-30\" x=\"63.623047\"/>\n <use xlink:href=\"#DejaVuSans-30\" x=\"127.246094\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_2\">\n <g id=\"line2d_3\">\n <path d=\"M 109.969196 145.8 \nL 109.969196 7.2 \n\" clip-path=\"url(#p1a9dfa12b4)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_4\">\n <g>\n <use xlink:href=\"#mb00c30088f\" x=\"109.969196\" y=\"145.8\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_2\">\n <!-- 200 -->\n <g style=\"fill: #ffffff\" transform=\"translate(100.425446 160.398438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-32\" d=\"M 1228 531 \nL 3431 531 \nL 3431 0 \nL 469 0 \nL 469 531 \nQ 828 903 1448 1529 \nQ 2069 2156 2228 2338 \nQ 2531 2678 2651 2914 \nQ 2772 3150 2772 3378 \nQ 2772 3750 2511 3984 \nQ 2250 4219 1831 4219 \nQ 1534 4219 1204 4116 \nQ 875 4013 500 3803 \nL 500 4441 \nQ 881 4594 1212 4672 \nQ 1544 4750 1819 4750 \nQ 2544 4750 2975 4387 \nQ 3406 4025 3406 3419 \nQ 3406 3131 3298 2873 \nQ 3191 2616 2906 2266 \nQ 2828 2175 2409 1742 \nQ 1991 1309 1228 531 \nz\n\" transform=\"scale(0.01562
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
}
],
"execution_count": 147
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T15:02:40.837383063Z",
"start_time": "2026-03-25T15:02:40.788068670Z"
}
},
"cell_type": "code",
"source": [
"batch_size, num_steps = 32, 35\n",
"train_iter, vocab = load_data_time_machine(batch_size, num_steps)"
],
"id": "74d672745751714",
"outputs": [],
"execution_count": 148
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T15:02:40.907873399Z",
"start_time": "2026-03-25T15:02:40.840147997Z"
}
},
"cell_type": "code",
"source": [
"num_hiddens = 256\n",
"rnn_layer = nn.RNN(len(vocab), num_hiddens)\n",
"state = torch.zeros((1,batch_size,num_hiddens))\n",
"X = torch.rand(size=(num_steps, batch_size, len(vocab)))\n",
"Y, state_new = rnn_layer(X, state)\n",
"Y.shape, state_new.shape"
],
"id": "9694c029c9e657e8",
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([35, 32, 256]), torch.Size([1, 32, 256]))"
]
},
"execution_count": 149,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 149
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T15:02:40.960240439Z",
"start_time": "2026-03-25T15:02:40.909528103Z"
}
},
"cell_type": "code",
"source": [
"class RNNModel(nn.Module):\n",
" \"\"\"循环神经网络模型\"\"\"\n",
" def __init__(self, rnn_layer, vocab_size, **kwargs):\n",
" super(RNNModel, self).__init__(**kwargs)\n",
" self.rnn = rnn_layer\n",
" self.vocab_size = vocab_size\n",
" self.num_hiddens = self.rnn.hidden_size\n",
" # 如果RNN是双向的之后将介绍num_directions应该是2否则应该是1\n",
" if not self.rnn.bidirectional:\n",
" self.num_directions = 1\n",
" self.linear = nn.Linear(self.num_hiddens, self.vocab_size)\n",
" else:\n",
" self.num_directions = 2\n",
" self.linear = nn.Linear(self.num_hiddens * 2, self.vocab_size)\n",
" def forward(self, inputs, state):\n",
" X = F.one_hot(inputs.T.long(), self.vocab_size)\n",
" X = X.to(torch.float32)\n",
" Y, state = self.rnn(X, state)\n",
" # 全连接层首先将Y的形状改为(时间步数*批量大小,隐藏单元数)\n",
" # 它的输出形状是(时间步数*批量大小,词表大小)。\n",
" output = self.linear(Y.reshape((-1, Y.shape[-1])))\n",
" return output, state\n",
" def begin_state(self, device, batch_size=1):\n",
" if not isinstance(self.rnn, nn.LSTM):\n",
" # nn.GRU以张量作为隐状态\n",
" return torch.zeros((self.num_directions * self.rnn.num_layers,\n",
" batch_size, self.num_hiddens),\n",
" device=device)\n",
" else:\n",
" # nn.LSTM以元组作为隐状态\n",
" return (torch.zeros((\n",
" self.num_directions * self.rnn.num_layers,\n",
" batch_size, self.num_hiddens), device=device),\n",
" torch.zeros((\n",
" self.num_directions * self.rnn.num_layers,\n",
" batch_size, self.num_hiddens), device=device))"
],
"id": "858873034be01538",
"outputs": [],
"execution_count": 150
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T15:02:41.082450468Z",
"start_time": "2026-03-25T15:02:40.962991493Z"
}
},
"cell_type": "code",
"source": [
"device = d2l.try_gpu()\n",
"net = RNNModel(rnn_layer, vocab_size=len(vocab))\n",
"net = net.to(device)\n",
"predict_ch8('time traveller', 10, net, vocab, device)"
],
"id": "d59c1599998c8fd4",
"outputs": [
{
"data": {
"text/plain": [
"'time travellerunuuuuuuuu'"
]
},
"execution_count": 151,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 151
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-25T15:04:22.185320871Z",
"start_time": "2026-03-25T15:02:41.084179495Z"
}
},
"cell_type": "code",
"source": [
"num_epochs, lr = 500, 1\n",
"train_ch8(net, train_iter, vocab, lr, num_epochs, device)"
],
"id": "460f80bcf15ffd50",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"困惑度 1.3, 146307.4 词元/秒 cpu\n",
"time travellerit would be revery erance for any hemptanef re has\n",
"travellery uplagstoot somethacongacout in anly fale tard ap\n"
]
},
{
"data": {
"text/plain": [
"<Figure size 350x250 with 1 Axes>"
],
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"252.646875pt\" height=\"183.35625pt\" viewBox=\"0 0 252.646875 183.35625\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2026-03-25T23:04:22.133483</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.7.2, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 183.35625 \nL 252.646875 183.35625 \nL 252.646875 0 \nL 0 0 \nz\n\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 40.603125 145.8 \nL 235.903125 145.8 \nL 235.903125 7.2 \nL 40.603125 7.2 \nz\n\"/>\n </g>\n <g id=\"matplotlib.axis_1\">\n <g id=\"xtick_1\">\n <g id=\"line2d_1\">\n <path d=\"M 76.474554 145.8 \nL 76.474554 7.2 \n\" clip-path=\"url(#p204033faea)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_2\">\n <defs>\n <path id=\"m277b52c217\" d=\"M 0 0 \nL 0 3.5 \n\" style=\"stroke: #ffffff; stroke-width: 0.8\"/>\n </defs>\n <g>\n <use xlink:href=\"#m277b52c217\" x=\"76.474554\" y=\"145.8\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_1\">\n <!-- 100 -->\n <g style=\"fill: #ffffff\" transform=\"translate(66.930804 160.398438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-31\" d=\"M 794 531 \nL 1825 531 \nL 1825 4091 \nL 703 3866 \nL 703 4441 \nL 1819 4666 \nL 2450 4666 \nL 2450 531 \nL 3481 531 \nL 3481 0 \nL 794 0 \nL 794 531 \nz\n\" transform=\"scale(0.015625)\"/>\n <path id=\"DejaVuSans-30\" d=\"M 2034 4250 \nQ 1547 4250 1301 3770 \nQ 1056 3291 1056 2328 \nQ 1056 1369 1301 889 \nQ 1547 409 2034 409 \nQ 2525 409 2770 889 \nQ 3016 1369 3016 2328 \nQ 3016 3291 2770 3770 \nQ 2525 4250 2034 4250 \nz\nM 2034 4750 \nQ 2819 4750 3233 4129 \nQ 3647 3509 3647 2328 \nQ 3647 1150 3233 529 \nQ 2819 -91 2034 -91 \nQ 1250 -91 836 529 \nQ 422 1150 422 2328 \nQ 422 3509 836 4129 \nQ 1250 4750 2034 4750 \nz\n\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-31\"/>\n <use xlink:href=\"#DejaVuSans-30\" x=\"63.623047\"/>\n <use xlink:href=\"#DejaVuSans-30\" x=\"127.246094\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_2\">\n <g id=\"line2d_3\">\n <path d=\"M 116.331696 145.8 \nL 116.331696 7.2 \n\" clip-path=\"url(#p204033faea)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_4\">\n <g>\n <use xlink:href=\"#m277b52c217\" x=\"116.331696\" y=\"145.8\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_2\">\n <!-- 200 -->\n <g style=\"fill: #ffffff\" transform=\"translate(106.787946 160.398438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-32\" d=\"M 1228 531 \nL 3431 531 \nL 3431 0 \nL 469 0 \nL 469 531 \nQ 828 903 1448 1529 \nQ 2069 2156 2228 2338 \nQ 2531 2678 2651 2914 \nQ 2772 3150 2772 3378 \nQ 2772 3750 2511 3984 \nQ 2250 4219 1831 4219 \nQ 1534 4219 1204 4116 \nQ 875 4013 500 3803 \nL 500 4441 \nQ 881 4594 1212 4672 \nQ 1544 4750 1819 4750 \nQ 2544 4750 2975 4387 \nQ 3406 4025 3406 3419 \nQ 3406 3131 3298 2873 \nQ 3191 2616 2906 2266 \nQ 2828 2175 2409 1742 \nQ 1991 1309 1228 531 \nz\n\" transform=\"scale(0.01562
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
}
],
"execution_count": 152
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": "",
2026-03-25 15:07:28 +00:00
"id": "adda23bc3664ec6b"
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}