2026-03-22 08:28:55 +00:00
{
"cells": [
{
"cell_type": "code",
"id": "initial_id",
"metadata": {
"collapsed": true,
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:02.177207285Z",
"start_time": "2026-04-22T07:02:59.204677901Z"
2026-03-22 08:28:55 +00:00
}
},
"source": [
2026-03-25 15:07:28 +00:00
"\n",
2026-03-22 08:28:55 +00:00
"import torch\n",
"import d2l\n",
"import numpy\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F"
],
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 2
2026-03-22 08:28:55 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:02.454043741Z",
"start_time": "2026-04-22T07:03:02.230904947Z"
2026-03-22 08:28:55 +00:00
}
},
"cell_type": "code",
"source": [
"net = nn.Sequential(nn.Linear(20, 256), nn.ReLU(), nn.Linear(256, 10))\n",
"X = torch.rand(2, 20)\n",
"net(X)"
],
"id": "dcd5590e7795eec1",
"outputs": [
{
"data": {
"text/plain": [
2026-04-22 07:23:35 +00:00
"tensor([[ 0.0041, -0.3465, -0.2096, 0.2304, -0.1043, 0.0066, 0.1817, 0.0355,\n",
" 0.2685, -0.0461],\n",
" [-0.0932, -0.1621, -0.1244, 0.2398, -0.0759, 0.0680, 0.1511, 0.0224,\n",
" 0.2522, -0.0228]], grad_fn=<AddmmBackward0>)"
2026-03-22 08:28:55 +00:00
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 3,
2026-03-22 08:28:55 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 3
2026-03-22 08:28:55 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:04.497911379Z",
"start_time": "2026-04-22T07:03:03.603349572Z"
2026-03-22 08:28:55 +00:00
}
},
"cell_type": "code",
"source": [
"class MLP(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.hidden=nn.Linear(20,256)\n",
" self.out=nn.Linear(256,10)\n",
" def forward(self,X):\n",
" return self.out(F.relu(self.hidden(X)))\n"
],
"id": "4ae330604b643cb4",
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 4
2026-03-22 08:28:55 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:05.186050578Z",
"start_time": "2026-04-22T07:03:04.689781242Z"
2026-03-22 08:28:55 +00:00
}
},
"cell_type": "code",
"source": [
"net=MLP()\n",
"net(X)"
],
"id": "cca55c6c0c7da12f",
"outputs": [
{
"data": {
"text/plain": [
2026-04-22 07:23:35 +00:00
"tensor([[-0.2165, 0.1394, 0.0867, 0.0692, 0.2914, -0.1427, 0.2218, -0.0533,\n",
" -0.2137, 0.0044],\n",
" [-0.2020, 0.0648, 0.0514, 0.0500, 0.2555, -0.1679, 0.1621, -0.1462,\n",
" -0.2527, 0.0386]], grad_fn=<AddmmBackward0>)"
2026-03-22 08:28:55 +00:00
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 5,
2026-03-22 08:28:55 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 5
2026-03-22 08:28:55 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:05.664056879Z",
"start_time": "2026-04-22T07:03:05.331438994Z"
2026-03-22 08:28:55 +00:00
}
},
"cell_type": "code",
"source": [
"class FixedHiddenMLP(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" # 不计算梯度的随机权重参数。因此其在训练期间保持不变\n",
" self.rand_weight = torch.rand((20, 20), requires_grad=False)\n",
" self.linear = nn.Linear(20, 20)\n",
" def forward(self, X):\n",
" X = self.linear(X)\n",
" # 使用创建的常量参数以及relu和mm函数\n",
" X = F.relu(torch.mm(X, self.rand_weight) + 1)\n",
" # 复用全连接层。这相当于两个全连接层共享参数\n",
" X = self.linear(X)\n",
" # 控制流\n",
" while X.abs().sum() > 1:\n",
" X /= 2\n",
" return X.sum()"
],
"id": "4518d62611d5e749",
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 6
2026-03-22 08:28:55 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:06.104535508Z",
"start_time": "2026-04-22T07:03:05.824555955Z"
2026-03-22 08:28:55 +00:00
}
},
"cell_type": "code",
"source": [
"net = FixedHiddenMLP()\n",
"net(X)"
],
"id": "fae0187ece4ed5c6",
"outputs": [
{
"data": {
"text/plain": [
2026-04-22 07:23:35 +00:00
"tensor(-0.0023, grad_fn=<SumBackward0>)"
2026-03-22 08:28:55 +00:00
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 7,
2026-03-22 08:28:55 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 7
2026-03-22 08:28:55 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:06.273938290Z",
"start_time": "2026-04-22T07:03:06.117091179Z"
2026-03-22 08:28:55 +00:00
}
},
"cell_type": "code",
"source": [
"class NestMLP(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.net = nn.Sequential(nn.Linear(20, 64), nn.ReLU(),\n",
" nn.Linear(64, 32), nn.ReLU())\n",
" self.linear = nn.Linear(32, 16)\n",
" def forward(self, X):\n",
" return self.linear(self.net(X))\n",
" chimera = nn.Sequential(NestMLP(), nn.Linear(16, 20), FixedHiddenMLP())\n",
" chimera(X)"
],
"id": "407ef13a86453aae",
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 8
2026-03-22 08:28:55 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:06.462449517Z",
"start_time": "2026-04-22T07:03:06.323939028Z"
2026-03-22 08:28:55 +00:00
}
},
"cell_type": "code",
"source": [
"net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 1))\n",
"X = torch.rand(size=(2, 4))\n",
"net(X)"
],
"id": "9f3526f263c7a249",
"outputs": [
{
"data": {
"text/plain": [
2026-04-22 07:23:35 +00:00
"tensor([[-0.1265],\n",
" [-0.0471]], grad_fn=<AddmmBackward0>)"
2026-03-22 08:28:55 +00:00
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 9,
2026-03-22 08:28:55 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 9
2026-03-22 08:28:55 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:06.843325610Z",
"start_time": "2026-04-22T07:03:06.539581889Z"
2026-03-22 08:28:55 +00:00
}
},
"cell_type": "code",
"source": "print(net[2].state_dict())",
"id": "8c73f8daa02ba28b",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2026-04-22 07:23:35 +00:00
"OrderedDict([('weight', tensor([[ 0.0136, -0.1015, 0.1191, 0.2722, 0.3456, -0.0650, -0.0437, -0.2806]])), ('bias', tensor([-0.0945]))])\n"
2026-03-22 08:28:55 +00:00
]
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 10
2026-03-22 08:28:55 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:07.189943309Z",
"start_time": "2026-04-22T07:03:06.962295444Z"
2026-03-22 08:28:55 +00:00
}
},
"cell_type": "code",
"source": "net[2].state_dict()",
"id": "b6fee6b64fb96e3c",
"outputs": [
{
"data": {
"text/plain": [
"OrderedDict([('weight',\n",
2026-04-22 07:23:35 +00:00
" tensor([[ 0.0136, -0.1015, 0.1191, 0.2722, 0.3456, -0.0650, -0.0437, -0.2806]])),\n",
" ('bias', tensor([-0.0945]))])"
2026-03-22 08:28:55 +00:00
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 11,
2026-03-22 08:28:55 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 11
2026-03-22 08:28:55 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:07.395792068Z",
"start_time": "2026-04-22T07:03:07.243437434Z"
2026-03-22 08:28:55 +00:00
}
},
"cell_type": "code",
"source": "print(type(net[2].bias))",
"id": "b38e8dc384e038c5",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'torch.nn.parameter.Parameter'>\n"
]
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 12
2026-03-22 08:28:55 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:07.629769183Z",
"start_time": "2026-04-22T07:03:07.457413574Z"
2026-03-22 08:28:55 +00:00
}
},
"cell_type": "code",
"source": [
"print(net[2].bias)\n",
"print(net[2].bias.data)\n"
],
"id": "73f12ca3669d9ede",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Parameter containing:\n",
2026-04-22 07:23:35 +00:00
"tensor([-0.0945], requires_grad=True)\n",
"tensor([-0.0945])\n"
2026-03-22 08:28:55 +00:00
]
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 13
2026-03-22 08:28:55 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:07.873696310Z",
"start_time": "2026-04-22T07:03:07.679040535Z"
2026-03-22 08:28:55 +00:00
}
},
"cell_type": "code",
"source": "net[2].weight.grad==None",
"id": "db0fe33018c16fac",
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 14,
2026-03-22 08:28:55 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 14
2026-03-22 08:28:55 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:07.984798139Z",
"start_time": "2026-04-22T07:03:07.896070141Z"
2026-03-22 08:28:55 +00:00
}
},
"cell_type": "code",
"source": [
"print(*[(name, param.shape) for name, param in net[0].named_parameters()])\n",
"print(*[(name, param.shape) for name, param in net.named_parameters()])"
],
"id": "75847a1c608ee5c7",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"('weight', torch.Size([8, 4])) ('bias', torch.Size([8]))\n",
"('0.weight', torch.Size([8, 4])) ('0.bias', torch.Size([8])) ('2.weight', torch.Size([1, 8])) ('2.bias', torch.Size([1]))\n"
]
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 15
2026-03-22 08:28:55 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:08.010084298Z",
"start_time": "2026-04-22T07:03:07.991112964Z"
2026-03-22 08:28:55 +00:00
}
},
"cell_type": "code",
"source": "net.state_dict()['2.bias'].data",
"id": "cc74913e8742da7d",
"outputs": [
{
"data": {
"text/plain": [
2026-04-22 07:23:35 +00:00
"tensor([-0.0945])"
2026-03-22 08:28:55 +00:00
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 16,
2026-03-22 08:28:55 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 16
2026-03-22 08:28:55 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:08.088926956Z",
"start_time": "2026-04-22T07:03:08.042795645Z"
2026-03-22 08:28:55 +00:00
}
},
"cell_type": "code",
"source": [
"def block1():\n",
" return nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 4),nn.ReLU())\n",
"def block2():\n",
" net = nn.Sequential()\n",
" for i in range(4):\n",
" net.add_module(f'block{i}', block1())\n",
" return net"
],
"id": "53c39c5e61fa7bf5",
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 17
2026-03-22 08:28:55 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:08.330228511Z",
"start_time": "2026-04-22T07:03:08.096053767Z"
2026-03-22 08:28:55 +00:00
}
},
"cell_type": "code",
"source": [
"rgnet = nn.Sequential(block2(),nn.Linear(4,1))\n",
"rgnet(X)"
],
"id": "d3ac7759b619aca",
"outputs": [
{
"data": {
"text/plain": [
2026-04-22 07:23:35 +00:00
"tensor([[0.0117],\n",
" [0.0117]], grad_fn=<AddmmBackward0>)"
2026-03-22 08:28:55 +00:00
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 18,
2026-03-22 08:28:55 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 18
2026-03-22 08:28:55 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:08.645186191Z",
"start_time": "2026-04-22T07:03:08.455908607Z"
2026-03-22 08:28:55 +00:00
}
},
"cell_type": "code",
"source": "print(rgnet)",
"id": "8fc60f64b07781e6",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sequential(\n",
" (0): Sequential(\n",
" (block0): Sequential(\n",
" (0): Linear(in_features=4, out_features=8, bias=True)\n",
" (1): ReLU()\n",
" (2): Linear(in_features=8, out_features=4, bias=True)\n",
" (3): ReLU()\n",
" )\n",
" (block1): Sequential(\n",
" (0): Linear(in_features=4, out_features=8, bias=True)\n",
" (1): ReLU()\n",
" (2): Linear(in_features=8, out_features=4, bias=True)\n",
" (3): ReLU()\n",
" )\n",
" (block2): Sequential(\n",
" (0): Linear(in_features=4, out_features=8, bias=True)\n",
" (1): ReLU()\n",
" (2): Linear(in_features=8, out_features=4, bias=True)\n",
" (3): ReLU()\n",
" )\n",
" (block3): Sequential(\n",
" (0): Linear(in_features=4, out_features=8, bias=True)\n",
" (1): ReLU()\n",
" (2): Linear(in_features=8, out_features=4, bias=True)\n",
" (3): ReLU()\n",
" )\n",
" )\n",
" (1): Linear(in_features=4, out_features=1, bias=True)\n",
")\n"
]
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 19
2026-03-22 08:28:55 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:08.903696348Z",
"start_time": "2026-04-22T07:03:08.733628048Z"
2026-03-22 08:28:55 +00:00
}
},
"cell_type": "code",
"source": "rgnet[0][1][0].bias.data",
"id": "e590aaafca787b50",
"outputs": [
{
"data": {
"text/plain": [
2026-04-22 07:23:35 +00:00
"tensor([ 0.2396, -0.2293, -0.3365, 0.0070, -0.0166, -0.2328, -0.1627, 0.3407])"
2026-03-22 08:28:55 +00:00
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 20,
2026-03-22 08:28:55 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 20
2026-03-22 08:28:55 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:09.181470689Z",
"start_time": "2026-04-22T07:03:08.920938456Z"
2026-03-22 08:28:55 +00:00
}
},
"cell_type": "code",
"source": [
"def init_normal(m):\n",
" if type(m) == nn.Linear:\n",
" nn.init.normal_(m.weight, mean=0, std=0.01)\n",
" nn.init.zeros_(m.bias)\n",
"net.apply(init_normal)\n",
"net[0].weight.data[0], net[0].bias.data[0]"
],
"id": "925ca33221d0a87e",
"outputs": [
{
"data": {
"text/plain": [
2026-04-22 07:23:35 +00:00
"(tensor([ 0.0166, 0.0092, 0.0013, -0.0031]), tensor(0.))"
2026-03-22 08:28:55 +00:00
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 21,
2026-03-22 08:28:55 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 21
2026-03-22 08:28:55 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:09.303120289Z",
"start_time": "2026-04-22T07:03:09.184866866Z"
2026-03-22 08:28:55 +00:00
}
},
"cell_type": "code",
"source": [
"def init_xavier(m):\n",
" if type(m) == nn.Linear:\n",
" nn.init.xavier_uniform_(m.weight)\n",
"def init_42(m):\n",
" if type(m) == nn.Linear:\n",
" nn.init.constant_(m.weight, 42)\n",
"\n",
"net[0].apply(init_xavier)\n",
"net[2].apply(init_42)\n",
"print(net[0].weight.data[0])\n",
"print(net[2].weight.data)"
],
"id": "81e2de84a8c4ef32",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2026-04-22 07:23:35 +00:00
"tensor([-0.2085, 0.4344, -0.3960, 0.5868])\n",
2026-03-22 08:28:55 +00:00
"tensor([[42., 42., 42., 42., 42., 42., 42., 42.]])\n"
]
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 22
2026-03-22 08:28:55 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:09.411878374Z",
"start_time": "2026-04-22T07:03:09.355106030Z"
2026-03-22 08:28:55 +00:00
}
},
"cell_type": "code",
"source": [
"x = torch.arange(4)\n",
"torch.save(x, 'x-file')"
],
"id": "f05bb378bb60ab9e",
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 23
2026-03-22 08:28:55 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:09.509836133Z",
"start_time": "2026-04-22T07:03:09.427360581Z"
2026-03-22 08:28:55 +00:00
}
},
"cell_type": "code",
"source": [
"x2 = torch.load('x-file')\n",
"x2"
],
"id": "a74ecaaac0d826c6",
"outputs": [
{
"data": {
"text/plain": [
"tensor([0, 1, 2, 3])"
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 24,
2026-03-22 08:28:55 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 24
2026-03-22 08:28:55 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:09.542056671Z",
"start_time": "2026-04-22T07:03:09.518568625Z"
2026-03-22 08:28:55 +00:00
}
},
"cell_type": "code",
"source": [
"class MLP(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.hidden = nn.Linear(20, 256)\n",
" self.output = nn.Linear(256, 10)\n",
" def forward(self, x):\n",
" return self.output(F.relu(self.hidden(x)))\n",
"\n",
"net = MLP()\n",
"X = torch.randn(size=(2, 20))\n",
"Y = net(X)"
],
"id": "b42598f0c4a8e801",
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 25
2026-03-22 08:28:55 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:09.603222999Z",
"start_time": "2026-04-22T07:03:09.548610614Z"
2026-03-22 08:28:55 +00:00
}
},
"cell_type": "code",
"source": "torch.save(net.state_dict(), 'mlp.params')",
"id": "aaa22eef549caa6f",
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 26
2026-03-22 08:28:55 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:09.699407155Z",
"start_time": "2026-04-22T07:03:09.607082306Z"
2026-03-22 08:28:55 +00:00
}
},
"cell_type": "code",
"source": [
"clone = MLP()\n",
"clone.load_state_dict(torch.load('mlp.params'))\n",
"clone.eval()"
],
"id": "b92f920229abeeae",
"outputs": [
{
"data": {
"text/plain": [
"MLP(\n",
" (hidden): Linear(in_features=20, out_features=256, bias=True)\n",
" (output): Linear(in_features=256, out_features=10, bias=True)\n",
")"
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 27,
2026-03-22 08:28:55 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 27
2026-03-22 08:28:55 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:09.854186935Z",
"start_time": "2026-04-22T07:03:09.721875531Z"
2026-03-22 08:28:55 +00:00
}
},
"cell_type": "code",
"source": [
"Y_clone = clone(X)\n",
"Y_clone == Y"
],
"id": "646c9eb6d7cc81c2",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[True, True, True, True, True, True, True, True, True, True],\n",
" [True, True, True, True, True, True, True, True, True, True]])"
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 28,
2026-03-22 08:28:55 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 28
2026-03-22 08:28:55 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:10.143072800Z",
"start_time": "2026-04-22T07:03:09.938713854Z"
2026-03-22 08:28:55 +00:00
}
},
"cell_type": "code",
"source": [
"def corr2d(X,K):\n",
" h,w=K.shape\n",
" Y=torch.ones((X.shape[0]-h+1,X.shape[1]-w+1))\n",
" for i in range(Y.shape[0]):\n",
" for j in range(Y.shape[1]):\n",
" Y[i,j]=(X[i:i+h,j:j+w]*K).sum()\n",
" return Y\n"
],
"id": "d45f9adfe47fce20",
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 29
2026-03-22 08:28:55 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:10.449191223Z",
"start_time": "2026-04-22T07:03:10.209878470Z"
2026-03-22 08:28:55 +00:00
}
},
"cell_type": "code",
"source": [
"X = torch.tensor([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]])\n",
"K = torch.tensor([[0.0, 1.0], [2.0, 3.0]])\n",
"corr2d(X,K)"
],
"id": "db7279e13647c315",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[19., 25.],\n",
" [37., 43.]])"
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 30,
2026-03-22 08:28:55 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 30
2026-03-22 08:28:55 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:10.723922087Z",
"start_time": "2026-04-22T07:03:10.558883210Z"
2026-03-22 08:28:55 +00:00
}
},
"cell_type": "code",
"source": [
"class Conv2D(nn.Module):\n",
" def __init__(self, kernel_size):\n",
" super().__init__()\n",
" self.weight = nn.Parameter(torch.rand(kernel_size))\n",
" self.bias = nn.Parameter(torch.zeros(1))\n",
" def forward(self, x):\n",
" return corr2d(x, self.weight) + self.bias\n"
],
"id": "d60be1bd12a1f37e",
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 31
2026-03-22 08:28:55 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:11.208027335Z",
"start_time": "2026-04-22T07:03:10.944803814Z"
2026-03-22 08:28:55 +00:00
}
},
"cell_type": "code",
"source": [
"X = torch.ones((6, 8))\n",
"X[:, 2:6] = 0\n",
"X"
],
"id": "5083789b7a728442",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[1., 1., 0., 0., 0., 0., 1., 1.],\n",
" [1., 1., 0., 0., 0., 0., 1., 1.],\n",
" [1., 1., 0., 0., 0., 0., 1., 1.],\n",
" [1., 1., 0., 0., 0., 0., 1., 1.],\n",
" [1., 1., 0., 0., 0., 0., 1., 1.],\n",
" [1., 1., 0., 0., 0., 0., 1., 1.]])"
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 32,
2026-03-22 08:28:55 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 32
2026-03-22 08:28:55 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:11.547516752Z",
"start_time": "2026-04-22T07:03:11.280664423Z"
2026-03-22 08:28:55 +00:00
}
},
"cell_type": "code",
"source": [
"K = torch.tensor([[1.0, -1.0]])\n",
"Y = corr2d(X, K)\n",
"Y"
],
"id": "ee8d6bedbde886ad",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 0., 1., 0., 0., 0., -1., 0.],\n",
" [ 0., 1., 0., 0., 0., -1., 0.],\n",
" [ 0., 1., 0., 0., 0., -1., 0.],\n",
" [ 0., 1., 0., 0., 0., -1., 0.],\n",
" [ 0., 1., 0., 0., 0., -1., 0.],\n",
" [ 0., 1., 0., 0., 0., -1., 0.]])"
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 33,
2026-03-22 08:28:55 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 33
2026-03-22 08:28:55 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:11.992663567Z",
"start_time": "2026-04-22T07:03:11.712431704Z"
2026-03-22 08:28:55 +00:00
}
},
"cell_type": "code",
"source": "corr2d(X.t(), K)",
"id": "a8278c3837fa9a1c",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0.]])"
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 34,
2026-03-22 08:28:55 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 34
2026-03-22 08:28:55 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:12.544730677Z",
"start_time": "2026-04-22T07:03:12.187262859Z"
2026-03-22 08:28:55 +00:00
}
},
"cell_type": "code",
"source": "conv2d = nn.Conv2d(1,1, kernel_size=(1, 2), bias=False)",
"id": "ec61cdb61a8cabff",
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 35
2026-03-22 08:28:55 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:12.918278160Z",
"start_time": "2026-04-22T07:03:12.663915511Z"
2026-03-22 08:28:55 +00:00
}
},
"cell_type": "code",
"source": [
"X = X.reshape((1, 1, 6, 8))\n",
"Y = Y.reshape((1, 1, 6, 7))\n",
"lr = 3e-2"
],
"id": "d2fc19d84c79a10",
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 36
2026-03-22 08:28:55 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:14.229826949Z",
"start_time": "2026-04-22T07:03:12.942822259Z"
2026-03-22 08:28:55 +00:00
}
},
"cell_type": "code",
"source": [
"for i in range(100):\n",
" Y_hat = conv2d(X)\n",
" l = (Y_hat - Y) ** 2\n",
" conv2d.zero_grad()\n",
" l.sum().backward()\n",
" # 迭代卷积核\n",
" conv2d.weight.data[:] -= lr * conv2d.weight.grad\n",
" if (i + 1) % 20 == 0:\n",
" print(f'epoch {i+1}, loss {l.sum():.3f}')"
],
"id": "51fbb2e6398a9bd5",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2026-04-01 15:21:14 +00:00
"epoch 20, loss 0.000\n",
2026-03-22 08:28:55 +00:00
"epoch 40, loss 0.000\n",
"epoch 60, loss 0.000\n",
"epoch 80, loss 0.000\n",
"epoch 100, loss 0.000\n"
]
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 37
2026-03-22 08:28:55 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:14.392987505Z",
"start_time": "2026-04-22T07:03:14.281161755Z"
2026-03-22 08:28:55 +00:00
}
},
"cell_type": "code",
"source": "conv2d.weight.data.reshape((1, 2))\n",
"id": "bf53a423f429dfe4",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 1.0000, -1.0000]])"
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 38,
2026-03-22 08:28:55 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 38
2026-03-22 08:28:55 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:14.572225657Z",
"start_time": "2026-04-22T07:03:14.447504117Z"
2026-03-22 08:28:55 +00:00
}
},
"cell_type": "code",
"source": [
"\n",
"# 为了方便起见,我们定义了一个计算卷积层的函数。\n",
"# 此函数初始化卷积层权重,并对输入和输出提高和缩减相应的维数\n",
"def comp_conv2d(conv2d, X):\n",
"# 这里的( 1, 1) 表示批量大小和通道数都是1\n",
" X = X.reshape((1, 1) + X.shape)\n",
" Y = conv2d(X)\n",
" # 省略前两个维度:批量大小和通道\n",
" return Y.reshape(Y.shape[2:])\n",
"# 请注意, 这里每边都填充了1行或1列, 因此总共添加了2行或2列\n",
"conv2d = nn.Conv2d(1, 1, kernel_size=3, padding=1)"
],
"id": "77b61d8c9a2363cc",
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 39
2026-03-22 08:28:55 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:14.842090802Z",
"start_time": "2026-04-22T07:03:14.640382418Z"
2026-03-22 08:28:55 +00:00
}
},
"cell_type": "code",
"source": [
"X = torch.rand(size=(8, 8))\n",
"comp_conv2d(conv2d, X).shape"
],
"id": "beda6ffa67ec2677",
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([8, 8])"
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 40,
2026-03-22 08:28:55 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 40
2026-03-22 08:28:55 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:15.005425705Z",
"start_time": "2026-04-22T07:03:14.845097024Z"
2026-03-22 08:28:55 +00:00
}
},
"cell_type": "code",
"source": [
"conv2d = nn.Conv2d(1, 1, kernel_size=(5, 3), padding=(2, 1))\n",
"comp_conv2d(conv2d, X).shape"
],
"id": "8c51095daea1432d",
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([8, 8])"
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 41,
2026-03-22 08:28:55 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 41
2026-03-22 08:28:55 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:15.154120979Z",
"start_time": "2026-04-22T07:03:15.063863068Z"
2026-03-22 08:28:55 +00:00
}
},
"cell_type": "code",
"source": [
"conv2d = nn.Conv2d(1, 1, kernel_size=3, padding=1, stride=2)\n",
"comp_conv2d(conv2d, X).shape"
],
"id": "581bf1b15162cbf6",
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([4, 4])"
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 42,
2026-03-22 08:28:55 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 42
2026-03-22 08:28:55 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:15.290065680Z",
"start_time": "2026-04-22T07:03:15.156986867Z"
2026-03-22 08:28:55 +00:00
}
},
"cell_type": "code",
"source": [
"conv2d = nn.Conv2d(1, 1, kernel_size=(3, 5), padding=(0, 1), stride=(3, 4))\n",
"comp_conv2d(conv2d, X).shape"
],
"id": "6f7a2411247baff0",
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([2, 2])"
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 43,
2026-03-22 08:28:55 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 43
2026-03-22 08:28:55 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:15.447782218Z",
"start_time": "2026-04-22T07:03:15.341665415Z"
2026-03-22 08:28:55 +00:00
}
},
"cell_type": "code",
"source": [
"def corr2d_multi_in(X,K):\n",
" return sum(corr2d(x,k) for x,k in zip(X,K))\n",
"X = torch.tensor([[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]],\n",
"[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]])\n",
"K = torch.tensor([[[0.0, 1.0], [2.0, 3.0]], [[1.0, 2.0], [3.0, 4.0]]])\n",
"corr2d_multi_in(X, K)"
],
"id": "7ac0f17f97b2daa8",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 56., 72.],\n",
" [104., 120.]])"
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 44,
2026-03-22 08:28:55 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 44
2026-03-22 08:28:55 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:15.631383700Z",
"start_time": "2026-04-22T07:03:15.507675748Z"
2026-03-22 08:28:55 +00:00
}
},
"cell_type": "code",
"source": [
"def corr2d_multi_in_out(X,K) ->torch.Tensor :\n",
" return torch.stack([corr2d_multi_in(X,k) for k in K],0)\n"
],
"id": "d409110d0d6b4b49",
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 45
2026-03-22 08:28:55 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:15.955093562Z",
"start_time": "2026-04-22T07:03:15.716320703Z"
2026-03-22 08:28:55 +00:00
}
},
"cell_type": "code",
"source": [
"K = torch.stack((K, K + 1, K + 2), 0)\n",
"K.shape"
],
"id": "4114cd871a627075",
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([3, 2, 2, 2])"
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 46,
2026-03-22 08:28:55 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 46
2026-03-22 08:28:55 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:16.110185723Z",
"start_time": "2026-04-22T07:03:15.964443180Z"
2026-03-22 08:28:55 +00:00
}
},
"cell_type": "code",
"source": "corr2d_multi_in_out(X, K)",
"id": "ce52f41dc9585f8c",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[ 56., 72.],\n",
" [104., 120.]],\n",
"\n",
" [[ 76., 100.],\n",
" [148., 172.]],\n",
"\n",
" [[ 96., 128.],\n",
" [192., 224.]]])"
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 47,
2026-03-22 08:28:55 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 47
2026-03-22 08:28:55 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:16.181875274Z",
"start_time": "2026-04-22T07:03:16.125158123Z"
2026-03-22 08:28:55 +00:00
}
},
"cell_type": "code",
"source": [
"def corr2d_multi_in_out_1x1(X, K):\n",
" h_i,h,w=X.shape\n",
" h_o=K.shape[0]\n",
" X=X.reshape((h_i,h*w))\n",
" print(X.shape)\n",
" K=K.reshape((h_o,h_i))\n",
" print(K.shape)\n",
" Y=torch.matmul(K,X)\n",
" return Y.reshape((h_o,h,w))"
],
"id": "362d8c692b3c1d75",
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 48
2026-03-22 08:28:55 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:16.240983284Z",
"start_time": "2026-04-22T07:03:16.187079922Z"
2026-03-22 08:28:55 +00:00
}
},
"cell_type": "code",
"source": [
"X = torch.normal(0, 1, (3, 3, 3))\n",
"K = torch.normal(0, 1, (2, 3, 1, 1))"
],
"id": "28e761f677df8b16",
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 49
2026-03-22 08:28:55 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:16.386442288Z",
"start_time": "2026-04-22T07:03:16.245865317Z"
2026-03-22 08:28:55 +00:00
}
},
"cell_type": "code",
"source": "Y1 = corr2d_multi_in_out_1x1(X, K)",
"id": "8eb276fed751a6b9",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([3, 9])\n",
"torch.Size([2, 3])\n"
]
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 50
2026-03-22 08:28:55 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:16.444217309Z",
"start_time": "2026-04-22T07:03:16.422396408Z"
2026-03-22 08:28:55 +00:00
}
},
"cell_type": "code",
"source": [
"Y2 = corr2d_multi_in_out(X, K)\n",
"assert float(torch.abs(Y1 - Y2).sum()) < 1e-6"
],
"id": "be28e27d30f36e2c",
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 51
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:16.504989589Z",
"start_time": "2026-04-22T07:03:16.449374426Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"def pool2d(X,pool_size,mode='max'):\n",
" p_h,p_w =pool_size\n",
" Y = torch.zeros((X.shape[0]-p_h+1,X.shape[1]-p_w+1))\n",
" for i in range(Y.shape[0]):\n",
" for j in range(Y.shape[1]):\n",
" match mode:\n",
" case 'max':\n",
" Y[i,j]=X[i:i+p_h,j:j+p_w].max()\n",
" case 'avg':\n",
" Y[i,j]=X[i:i+p_h,j:j+p_w].mean()\n",
"\n",
" return Y"
],
"id": "3c3f71349a2e54c0",
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 52
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:16.635764547Z",
"start_time": "2026-04-22T07:03:16.510663393Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"X = torch.tensor([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]])\n",
"pool2d(X, (2, 2))"
],
"id": "a67207c861cf0cfd",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[4., 5.],\n",
" [7., 8.]])"
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 53,
2026-03-25 15:07:28 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 53
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:16.938183687Z",
"start_time": "2026-04-22T07:03:16.693509080Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": "pool2d(X, (2, 2), 'avg')",
"id": "e387b48df3831b85",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[2., 3.],\n",
" [5., 6.]])"
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 54,
2026-03-25 15:07:28 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 54
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:17.569350884Z",
"start_time": "2026-04-22T07:03:17.198359992Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"X = torch.arange(16, dtype=torch.float32).reshape((1, 1, 4, 4))\n",
"X"
],
"id": "41b618b3a48522b4",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[[ 0., 1., 2., 3.],\n",
" [ 4., 5., 6., 7.],\n",
" [ 8., 9., 10., 11.],\n",
" [12., 13., 14., 15.]]]])"
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 55,
2026-03-25 15:07:28 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 55
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:17.840996860Z",
"start_time": "2026-04-22T07:03:17.734395761Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"pool2d=nn.MaxPool2d(3)\n",
"pool2d(X)"
],
"id": "c77484a8d1267259",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[[10.]]]])"
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 56,
2026-03-25 15:07:28 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 56
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:17.967526280Z",
"start_time": "2026-04-22T07:03:17.871240998Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"pool2d = nn.MaxPool2d(3, padding=1, stride=2)\n",
"pool2d(X)"
],
"id": "847a2bacfb6f2bd7",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[[ 5., 7.],\n",
" [13., 15.]]]])"
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 57,
2026-03-25 15:07:28 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 57
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:18.057661560Z",
"start_time": "2026-04-22T07:03:17.998953875Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"pool2d = nn.MaxPool2d((2, 3), stride=(2, 3), padding=(0, 1))\n",
"pool2d(X)"
],
"id": "5efad1e0b616fff7",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[[ 5., 7.],\n",
" [13., 15.]]]])"
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 58,
2026-03-25 15:07:28 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 58
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:18.211656993Z",
"start_time": "2026-04-22T07:03:18.081025077Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"X = torch.cat((X, X + 1), 1)\n",
"X"
],
"id": "386d4b3eb8069328",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[[ 0., 1., 2., 3.],\n",
" [ 4., 5., 6., 7.],\n",
" [ 8., 9., 10., 11.],\n",
" [12., 13., 14., 15.]],\n",
"\n",
" [[ 1., 2., 3., 4.],\n",
" [ 5., 6., 7., 8.],\n",
" [ 9., 10., 11., 12.],\n",
" [13., 14., 15., 16.]]]])"
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 59,
2026-03-25 15:07:28 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 59
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:18.254285311Z",
"start_time": "2026-04-22T07:03:18.221776346Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"pool2d = nn.MaxPool2d(3, padding=1, stride=2)\n",
"pool2d(X)"
],
"id": "ba5f57a8ca2a3b06",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[[ 5., 7.],\n",
" [13., 15.]],\n",
"\n",
" [[ 6., 8.],\n",
" [14., 16.]]]])"
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 60,
2026-03-25 15:07:28 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 60
2026-03-22 08:28:55 +00:00
},
2026-03-25 15:07:28 +00:00
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:18.415322634Z",
"start_time": "2026-04-22T07:03:18.283827059Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"net = nn.Sequential(\n",
" nn.Conv2d(1,6,kernel_size=5,padding=2), #1*1*28*28 -> 1*6*28*28\n",
" nn.Sigmoid(),\n",
" nn.AvgPool2d(kernel_size=2, stride=2), #1*6*28*28 -> 1*6*14*14\n",
" nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(), #1*6*14*14 -> 1*16*10*10\n",
" nn.AvgPool2d(kernel_size=2, stride=2), #1*16*10*10 -> 1*16*5*5\n",
" nn.Flatten(),\n",
" nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),\n",
" nn.Linear(120, 84), nn.Sigmoid(),\n",
" nn.Linear(84, 10)\n",
")\n",
"X = torch.rand(size=(1,1,28,28),dtype=torch.float32)\n",
"for layer in net:\n",
" X=layer(X)\n",
" print(layer.__class__.__name__,'output shape: \\t',X.shape)"
],
"id": "1eabc29f9c838842",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Conv2d output shape: \t torch.Size([1, 6, 28, 28])\n",
"Sigmoid output shape: \t torch.Size([1, 6, 28, 28])\n",
"AvgPool2d output shape: \t torch.Size([1, 6, 14, 14])\n",
"Conv2d output shape: \t torch.Size([1, 16, 10, 10])\n",
"Sigmoid output shape: \t torch.Size([1, 16, 10, 10])\n",
"AvgPool2d output shape: \t torch.Size([1, 16, 5, 5])\n",
"Flatten output shape: \t torch.Size([1, 400])\n",
"Linear output shape: \t torch.Size([1, 120])\n",
"Sigmoid output shape: \t torch.Size([1, 120])\n",
"Linear output shape: \t torch.Size([1, 84])\n",
"Sigmoid output shape: \t torch.Size([1, 84])\n",
"Linear output shape: \t torch.Size([1, 10])\n"
]
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 61
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:20.954626561Z",
"start_time": "2026-04-22T07:03:18.435497060Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"import d2l.torch as d2l\n",
"batch_size = 256\n",
"train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)"
],
"id": "e372f75817ad4a0f",
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 62
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:21.060902242Z",
"start_time": "2026-04-22T07:03:21.006364318Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"lr, num_epochs = 0.9, 10\n",
"#d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())"
],
"id": "9aaeb948f3353955",
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 63
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:21.128721611Z",
"start_time": "2026-04-22T07:03:21.065860429Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"class Inception(nn.Module):\n",
" def __init__(self,in_channels,c1,c2,c3,c4,**kwargs):\n",
" super(Inception,self).__init__(**kwargs)\n",
" self.p1_1 = nn.Conv2d(in_channels,c1,kernel_size=1)\n",
" self.p2_1 = nn.Conv2d(in_channels,c2[0],kernel_size=1)\n",
" self.p2_2 = nn.Conv2d(c2[0],c2[1],kernel_size=3,padding=1)\n",
" self.p3_1 = nn.Conv2d(in_channels,c3[0],kernel_size=1)\n",
" self.p3_2 = nn.Conv2d(c3[0],c3[1],kernel_size=5,padding=2)\n",
" self.p4_1 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)\n",
" self.p4_2 = nn.Conv2d(in_channels, c4, kernel_size=1)\n",
" def forward(self,x):\n",
" p1 = F.relu(self.p1_1(x))\n",
" p2 = F.relu(self.p2_2(F.relu(self.p2_1(x))))\n",
" p3 = F.relu(self.p3_2(F.relu(self.p3_1(x))))\n",
" p4 = F.relu(self.p4_2(self.p4_1(x)))\n",
" return torch.cat((p1,p2,p3,p4),dim=1)"
],
"id": "6d3bb3f70f297dba",
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 64
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:21.758676619Z",
"start_time": "2026-04-22T07:03:21.134168809Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),\n",
" nn.ReLU(),\n",
" nn.MaxPool2d(kernel_size=3, stride=2, padding=1))\n",
"b2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1),\n",
" nn.ReLU(),\n",
" nn.Conv2d(64, 192, kernel_size=3, padding=1),\n",
" nn.ReLU(),\n",
" nn.MaxPool2d(kernel_size=3, stride=2, padding=1))\n",
"b3 = nn.Sequential(Inception(192, 64, (96, 128), (16, 32), 32),\n",
" Inception(256, 128, (128, 192), (32, 96), 64),\n",
" nn.MaxPool2d(kernel_size=3, stride=2, padding=1))\n",
"b4 = nn.Sequential(Inception(480, 192, (96, 208), (16, 48), 64),\n",
" Inception(512, 160, (112, 224), (24, 64), 64),\n",
" Inception(512, 128, (128, 256), (24, 64), 64),\n",
" Inception(512, 112, (144, 288), (32, 64), 64),\n",
" Inception(528, 256, (160, 320), (32, 128), 128),\n",
" nn.MaxPool2d(kernel_size=3, stride=2, padding=1))\n",
"b5 = nn.Sequential(Inception(832, 256, (160, 320), (32, 128), 128),\n",
" Inception(832, 384, (192, 384), (48, 128), 128),\n",
" nn.AdaptiveAvgPool2d((1,1)),\n",
" nn.Flatten())\n",
"net = nn.Sequential(b1, b2, b3, b4, b5, nn.Linear(1024, 10))\n",
"X = torch.rand(size=(1, 1, 96, 96))\n",
"for layer in net:\n",
" X = layer(X)\n",
" print(layer.__class__.__name__,'output shape:\\t', X.shape)"
],
"id": "6ef7022bcb288d65",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sequential output shape:\t torch.Size([1, 64, 24, 24])\n",
"Sequential output shape:\t torch.Size([1, 192, 12, 12])\n",
"Sequential output shape:\t torch.Size([1, 480, 6, 6])\n",
"Sequential output shape:\t torch.Size([1, 832, 3, 3])\n",
"Sequential output shape:\t torch.Size([1, 1024])\n",
"Linear output shape:\t torch.Size([1, 10])\n"
]
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 65
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:22.718023030Z",
"start_time": "2026-04-22T07:03:21.806456021Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"import torchinfo\n",
"torchinfo.summary(net,(1,1,96,96))"
],
"id": "acc019ce7afa4470",
"outputs": [
{
"data": {
"text/plain": [
"==========================================================================================\n",
"Layer (type:depth-idx) Output Shape Param #\n",
"==========================================================================================\n",
"Sequential [1, 10] --\n",
"├─Sequential: 1-1 [1, 64, 24, 24] --\n",
"│ └─Conv2d: 2-1 [1, 64, 48, 48] 3,200\n",
"│ └─ReLU: 2-2 [1, 64, 48, 48] --\n",
"│ └─MaxPool2d: 2-3 [1, 64, 24, 24] --\n",
"├─Sequential: 1-2 [1, 192, 12, 12] --\n",
"│ └─Conv2d: 2-4 [1, 64, 24, 24] 4,160\n",
"│ └─ReLU: 2-5 [1, 64, 24, 24] --\n",
"│ └─Conv2d: 2-6 [1, 192, 24, 24] 110,784\n",
"│ └─ReLU: 2-7 [1, 192, 24, 24] --\n",
"│ └─MaxPool2d: 2-8 [1, 192, 12, 12] --\n",
"├─Sequential: 1-3 [1, 480, 6, 6] --\n",
"│ └─Inception: 2-9 [1, 256, 12, 12] --\n",
"│ │ └─Conv2d: 3-1 [1, 64, 12, 12] 12,352\n",
"│ │ └─Conv2d: 3-2 [1, 96, 12, 12] 18,528\n",
"│ │ └─Conv2d: 3-3 [1, 128, 12, 12] 110,720\n",
"│ │ └─Conv2d: 3-4 [1, 16, 12, 12] 3,088\n",
"│ │ └─Conv2d: 3-5 [1, 32, 12, 12] 12,832\n",
"│ │ └─MaxPool2d: 3-6 [1, 192, 12, 12] --\n",
"│ │ └─Conv2d: 3-7 [1, 32, 12, 12] 6,176\n",
"│ └─Inception: 2-10 [1, 480, 12, 12] --\n",
"│ │ └─Conv2d: 3-8 [1, 128, 12, 12] 32,896\n",
"│ │ └─Conv2d: 3-9 [1, 128, 12, 12] 32,896\n",
"│ │ └─Conv2d: 3-10 [1, 192, 12, 12] 221,376\n",
"│ │ └─Conv2d: 3-11 [1, 32, 12, 12] 8,224\n",
"│ │ └─Conv2d: 3-12 [1, 96, 12, 12] 76,896\n",
"│ │ └─MaxPool2d: 3-13 [1, 256, 12, 12] --\n",
"│ │ └─Conv2d: 3-14 [1, 64, 12, 12] 16,448\n",
"│ └─MaxPool2d: 2-11 [1, 480, 6, 6] --\n",
"├─Sequential: 1-4 [1, 832, 3, 3] --\n",
"│ └─Inception: 2-12 [1, 512, 6, 6] --\n",
"│ │ └─Conv2d: 3-15 [1, 192, 6, 6] 92,352\n",
"│ │ └─Conv2d: 3-16 [1, 96, 6, 6] 46,176\n",
"│ │ └─Conv2d: 3-17 [1, 208, 6, 6] 179,920\n",
"│ │ └─Conv2d: 3-18 [1, 16, 6, 6] 7,696\n",
"│ │ └─Conv2d: 3-19 [1, 48, 6, 6] 19,248\n",
"│ │ └─MaxPool2d: 3-20 [1, 480, 6, 6] --\n",
"│ │ └─Conv2d: 3-21 [1, 64, 6, 6] 30,784\n",
"│ └─Inception: 2-13 [1, 512, 6, 6] --\n",
"│ │ └─Conv2d: 3-22 [1, 160, 6, 6] 82,080\n",
"│ │ └─Conv2d: 3-23 [1, 112, 6, 6] 57,456\n",
"│ │ └─Conv2d: 3-24 [1, 224, 6, 6] 226,016\n",
"│ │ └─Conv2d: 3-25 [1, 24, 6, 6] 12,312\n",
"│ │ └─Conv2d: 3-26 [1, 64, 6, 6] 38,464\n",
"│ │ └─MaxPool2d: 3-27 [1, 512, 6, 6] --\n",
"│ │ └─Conv2d: 3-28 [1, 64, 6, 6] 32,832\n",
"│ └─Inception: 2-14 [1, 512, 6, 6] --\n",
"│ │ └─Conv2d: 3-29 [1, 128, 6, 6] 65,664\n",
"│ │ └─Conv2d: 3-30 [1, 128, 6, 6] 65,664\n",
"│ │ └─Conv2d: 3-31 [1, 256, 6, 6] 295,168\n",
"│ │ └─Conv2d: 3-32 [1, 24, 6, 6] 12,312\n",
"│ │ └─Conv2d: 3-33 [1, 64, 6, 6] 38,464\n",
"│ │ └─MaxPool2d: 3-34 [1, 512, 6, 6] --\n",
"│ │ └─Conv2d: 3-35 [1, 64, 6, 6] 32,832\n",
"│ └─Inception: 2-15 [1, 528, 6, 6] --\n",
"│ │ └─Conv2d: 3-36 [1, 112, 6, 6] 57,456\n",
"│ │ └─Conv2d: 3-37 [1, 144, 6, 6] 73,872\n",
"│ │ └─Conv2d: 3-38 [1, 288, 6, 6] 373,536\n",
"│ │ └─Conv2d: 3-39 [1, 32, 6, 6] 16,416\n",
"│ │ └─Conv2d: 3-40 [1, 64, 6, 6] 51,264\n",
"│ │ └─MaxPool2d: 3-41 [1, 512, 6, 6] --\n",
"│ │ └─Conv2d: 3-42 [1, 64, 6, 6] 32,832\n",
"│ └─Inception: 2-16 [1, 832, 6, 6] --\n",
"│ │ └─Conv2d: 3-43 [1, 256, 6, 6] 135,424\n",
"│ │ └─Conv2d: 3-44 [1, 160, 6, 6] 84,640\n",
"│ │ └─Conv2d: 3-45 [1, 320, 6, 6] 461,120\n",
"│ │ └─Conv2d: 3-46 [1, 32, 6, 6] 16,928\n",
"│ │ └─Conv2d: 3-47 [1, 128, 6, 6] 102,528\n",
"│ │ └─MaxPool2d: 3-48 [1, 528, 6, 6] --\n",
"│ │ └─Conv2d: 3-49 [1, 128, 6, 6] 67,712\n",
"│ └─MaxPool2d: 2-17 [1, 832, 3, 3] --\n",
"├─Sequential: 1-5 [1, 1024] --\n",
"│ └─Inception: 2-18 [1, 832, 3, 3] --\n",
"│ │ └─Conv2d: 3-50 [1, 256, 3, 3] 213,248\n",
"│ │ └─Conv2d: 3-51 [1, 160, 3, 3] 133,280\n",
"│ │ └─Conv2d: 3-52 [1, 320, 3, 3] 461,120\n",
"│ │ └─Conv2d: 3-53 [1, 32, 3, 3] 26,656\n",
"│ │ └─Conv2d: 3-54 [1, 128, 3, 3] 102,528\n",
"│ │ └─MaxPool2d: 3-55 [1, 832, 3, 3] --\n",
"│ │ └─Conv2d: 3-56 [1, 128, 3, 3] 106,624\n",
"│ └─Inception: 2-19 [1, 1024, 3, 3] --\n",
"│ │ └─Conv2d: 3-57 [1, 384, 3, 3] 319,872\n",
"│ │ └─Conv2d: 3-58 [1, 192, 3, 3] 159,936\n",
"│ │ └─Conv2d: 3-59 [1, 384, 3, 3] 663,936\n",
"│ │ └─Conv2d: 3-60 [1, 48, 3, 3] 39,984\n",
"│ │ └─Conv2d: 3-61 [1, 128, 3, 3] 153,728\n",
"│ │ └─MaxPool2d: 3-62 [1, 832, 3, 3] --\n",
"│ │ └─Conv2d: 3-63 [1, 128, 3, 3] 106,624\n",
"│ └─AdaptiveAvgPool2d: 2-20 [1, 1024, 1, 1] --\n",
"│ └─Flatten: 2-21 [1, 1024] --\n",
"├─Linear: 1-6 [1, 10] 10,250\n",
"==========================================================================================\n",
"Total params: 5,977,530\n",
"Trainable params: 5,977,530\n",
"Non-trainable params: 0\n",
"Total mult-adds (Units.MEGABYTES): 276.66\n",
"==========================================================================================\n",
"Input size (MB): 0.04\n",
"Forward/backward pass size (MB): 4.74\n",
"Params size (MB): 23.91\n",
"Estimated Total Size (MB): 28.69\n",
"=========================================================================================="
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 66,
2026-03-25 15:07:28 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 66
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:22.823562168Z",
"start_time": "2026-04-22T07:03:22.774952520Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"lr, num_epochs, batch_size = 0.1, 10, 128\n",
"train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)\n",
"#d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())"
],
"id": "3760a5e5813405f7",
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 67
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:22.875034764Z",
"start_time": "2026-04-22T07:03:22.825694238Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"class Residual(nn.Module):\n",
" def __init__(self,input_channels,num_channels,use_1x1conv=False,strides=1):\n",
" super().__init__()\n",
" self.conv1 = nn.Conv2d(input_channels,num_channels,kernel_size=3,padding=1,stride=strides)\n",
" self.conv2 = nn.Conv2d(num_channels,num_channels,kernel_size=3,padding=1)\n",
" if use_1x1conv:\n",
" self.conv3 = nn.Conv2d(input_channels,num_channels,kernel_size=1,stride=strides)\n",
" else:\n",
" self.conv3= None\n",
" self.bn1=nn.BatchNorm2d(num_channels)\n",
" self.bn2=nn.BatchNorm2d(num_channels)\n",
" def forward(self,X):\n",
" Y=F.relu(self.bn1(self.conv1(X)))\n",
" Y=self.bn2(self.conv2(Y))\n",
" if self.conv3:\n",
" X = self.conv3(X)\n",
" Y+=X\n",
" return F.relu(Y)\n"
],
"id": "9300979845ba6916",
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 68
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:22.926048268Z",
"start_time": "2026-04-22T07:03:22.876715194Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"blk = Residual(3,3)\n",
"X = torch.rand(4, 3, 6, 6)"
],
"id": "1248323517ff3228",
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 69
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:22.992544732Z",
"start_time": "2026-04-22T07:03:22.927760279Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"blk = Residual(3,6, use_1x1conv=True, strides=2)\n",
"blk(X).shape"
],
"id": "82cdbd71a157b51c",
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([4, 6, 3, 3])"
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 70,
2026-03-25 15:07:28 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 70
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:23.041967126Z",
"start_time": "2026-04-22T07:03:22.993777596Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),\n",
"nn.BatchNorm2d(64), nn.ReLU(),\n",
"nn.MaxPool2d(kernel_size=3, stride=2, padding=1))"
],
"id": "727da1d2d363ac62",
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 71
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:23.093341944Z",
"start_time": "2026-04-22T07:03:23.044341294Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"def resnet_block(input_channels, num_channels, num_residuals,\n",
" first_block=False):\n",
" blk = []\n",
" for i in range(num_residuals):\n",
" if i == 0 and not first_block:\n",
" blk.append(Residual(input_channels, num_channels,\n",
" use_1x1conv=True, strides=2))\n",
" else:\n",
" blk.append(Residual(num_channels, num_channels))\n",
" return blk"
],
"id": "124134971f8441c0",
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 72
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:23.143917227Z",
"start_time": "2026-04-22T07:03:23.095134233Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"b2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))\n",
"b3 = nn.Sequential(*resnet_block(64, 128, 2))\n",
"b4 = nn.Sequential(*resnet_block(128, 256, 2))\n",
"b5 = nn.Sequential(*resnet_block(256, 512, 2))"
],
"id": "ca1f1c69fba3e913",
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 73
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:23.192826666Z",
"start_time": "2026-04-22T07:03:23.145253995Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"net = nn.Sequential(b1, b2, b3, b4, b5,\n",
"nn.AdaptiveAvgPool2d((1,1)),\n",
"nn.Flatten(), nn.Linear(512, 10))"
],
"id": "f21db27de5dbdec1",
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 74
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:23.412254076Z",
"start_time": "2026-04-22T07:03:23.195173742Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"X = torch.rand(size=(1, 1, 224, 224))\n",
"for layer in net:\n",
" X = layer(X)\n",
" print(layer.__class__.__name__,'output shape:\\t', X.shape)"
],
"id": "6f8851a2bfd18c4e",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sequential output shape:\t torch.Size([1, 64, 56, 56])\n",
"Sequential output shape:\t torch.Size([1, 64, 56, 56])\n",
"Sequential output shape:\t torch.Size([1, 128, 28, 28])\n",
"Sequential output shape:\t torch.Size([1, 256, 14, 14])\n",
"Sequential output shape:\t torch.Size([1, 512, 7, 7])\n",
"AdaptiveAvgPool2d output shape:\t torch.Size([1, 512, 1, 1])\n",
"Flatten output shape:\t torch.Size([1, 512])\n",
"Linear output shape:\t torch.Size([1, 10])\n"
]
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 75
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:23.462272089Z",
"start_time": "2026-04-22T07:03:23.415437344Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"lr, num_epochs, batch_size = 0.05, 10, 256\n",
"train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)\n",
"#d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())"
],
"id": "e095d74b29dffef6",
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 76
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:23.541760019Z",
"start_time": "2026-04-22T07:03:23.464100823Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"import torch\n",
"import d2l.torch as d2l\n",
"import numpy\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"print(torch.version.__version__)"
],
"id": "3fd6d22221f87bea",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2.10.0+cu128\n"
]
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 77
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:23.592662909Z",
"start_time": "2026-04-22T07:03:23.543067093Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"A=torch.Tensor([[1,2,0,0],[0,2,0,0],[0,0,2,1],[0,0,0,3]])\n",
"C=torch.Tensor([[1,0,0,0],[0,1,0,0],[0,0,-2,3],[0,0,0,-3]])"
],
"id": "254f5d3d659dbe0f",
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 78
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:23.643655879Z",
"start_time": "2026-04-22T07:03:23.594970016Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": "B=torch.Tensor([[2,0,0,0],[-2,1,0,0],[0,0,-3,0],[0,0,0,-3]])",
"id": "a13d9c27c2fdbfad",
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 79
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:23.700556244Z",
"start_time": "2026-04-22T07:03:23.645948395Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": "torch.mm(A,C)",
"id": "e513a37beaa85f8f",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 1., 2., 0., 0.],\n",
" [ 0., 2., 0., 0.],\n",
" [ 0., 0., -4., 3.],\n",
" [ 0., 0., 0., -9.]])"
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 80,
2026-03-25 15:07:28 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 80
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:23.827304753Z",
"start_time": "2026-04-22T07:03:23.751870264Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": "torch.det(torch.mm(torch.mm(A,C),B))",
"id": "9a85eceac652875f",
"outputs": [
{
"data": {
"text/plain": [
"tensor(1296.)"
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 81,
2026-03-25 15:07:28 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 81
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:23.931211184Z",
"start_time": "2026-04-22T07:03:23.879693736Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": "1296**5\n",
"id": "6dc27d79722da58f",
"outputs": [
{
"data": {
"text/plain": [
"3656158440062976"
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 82,
2026-03-25 15:07:28 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 82
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:23.997938663Z",
"start_time": "2026-04-22T07:03:23.945902637Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": "torch.mm(C,B)",
"id": "ec5a170d775f4705",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 2., 0., 0., 0.],\n",
" [-2., 1., 0., 0.],\n",
" [ 0., 0., 6., -9.],\n",
" [ 0., 0., 0., 9.]])"
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 83,
2026-03-25 15:07:28 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 83
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:24.170960756Z",
"start_time": "2026-04-22T07:03:24.067057874Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"T = 1000 # 总共产生1000个点\n",
"time = torch.arange(1, T + 1, dtype=torch.float32)\n",
"x = torch.sin(0.01 * time) + torch.normal(0, 0.2, (T,))\n",
"d2l.plot(time, [x], 'time', 'x', xlim=[1, 1000], figsize=(6, 3))"
],
"id": "c3884b10464c6baa",
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 600x300 with 1 Axes>"
],
2026-04-22 07:23:35 +00:00
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"406.885938pt\" height=\"211.07625pt\" viewBox=\"0 0 406.885938 211.07625\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2026-04-22T15:03:24.145178</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.7.2, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 211.07625 \nL 406.885938 211.07625 \nL 406.885938 0 \nL 0 0 \nz\n\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 52.160938 173.52 \nL 386.960938 173.52 \nL 386.960938 7.2 \nL 52.160938 7.2 \nz\n\"/>\n </g>\n <g id=\"matplotlib.axis_1\">\n <g id=\"xtick_1\">\n <g id=\"line2d_1\">\n <path d=\"M 118.852829 173.52 \nL 118.852829 7.2 \n\" clip-path=\"url(#p991db88ff6)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_2\">\n <defs>\n <path id=\"m5e40d2494d\" d=\"M 0 0 \nL 0 3.5 \n\" style=\"stroke: #ffffff; stroke-width: 0.8\"/>\n </defs>\n <g>\n <use xlink:href=\"#m5e40d2494d\" x=\"118.852829\" y=\"173.52\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_1\">\n <!-- 200 -->\n <g style=\"fill: #ffffff\" transform=\"translate(109.309079 188.118438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-32\" d=\"M 1228 531 \nL 3431 531 \nL 3431 0 \nL 469 0 \nL 469 531 \nQ 828 903 1448 1529 \nQ 2069 2156 2228 2338 \nQ 2531 2678 2651 2914 \nQ 2772 3150 2772 3378 \nQ 2772 3750 2511 3984 \nQ 2250 4219 1831 4219 \nQ 1534 4219 1204 4116 \nQ 875 4013 500 3803 \nL 500 4441 \nQ 881 4594 1212 4672 \nQ 1544 4750 1819 4750 \nQ 2544 4750 2975 4387 \nQ 3406 4025 3406 3419 \nQ 3406 3131 3298 2873 \nQ 3191 2616 2906 2266 \nQ 2828 2175 2409 1742 \nQ 1991 1309 1228 531 \nz\n\" transform=\"scale(0.015625)\"/>\n <path id=\"DejaVuSans-30\" d=\"M 2034 4250 \nQ 1547 4250 1301 3770 \nQ 1056 3291 1056 2328 \nQ 1056 1369 1301 889 \nQ 1547 409 2034 409 \nQ 2525 409 2770 889 \nQ 3016 1369 3016 2328 \nQ 3016 3291 2770 3770 \nQ 2525 4250 2034 4250 \nz\nM 2034 4750 \nQ 2819 4750 3233 4129 \nQ 3647 3509 3647 2328 \nQ 3647 1150 3233 529 \nQ 2819 -91 2034 -91 \nQ 1250 -91 836 529 \nQ 422 1150 422 2328 \nQ 422 3509 836 4129 \nQ 1250 4750 2034 4750 \nz\n\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-32\"/>\n <use xlink:href=\"#DejaVuSans-30\" x=\"63.623047\"/>\n <use xlink:href=\"#DejaVuSans-30\" x=\"127.246094\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_2\">\n <g id=\"line2d_3\">\n <path d=\"M 185.879856 173.52 \nL 185.879856 7.2 \n\" clip-path=\"url(#p991db88ff6)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_4\">\n <g>\n <use xlink:href=\"#m5e40d2494d\" x=\"185.879856\" y=\"173.52\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_2\">\n <!-- 400 -->\n <g style=\"fill: #ffffff\" transform=\"translate(176.336106 188.118438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-34\" d=\"M 2419 4116 \nL 825 1625 \nL 2419 1625 \nL 2419 4116 \nz\nM 2253 4666 \nL 3047 4666 \nL 3047 1625 \nL 3713 1625 \nL 3713 1100 \nL 3047 1100 \nL 3047 0 \nL 2419 0 \nL 2419 11
2026-03-25 15:07:28 +00:00
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 84
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:24.309169705Z",
"start_time": "2026-04-22T07:03:24.235365355Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"tau = 4\n",
"features = torch.zeros((T - tau, tau))\n",
"for i in range(tau):\n",
" features[:, i] = x[i: T - tau + i]\n",
"labels = x[tau:].reshape((-1, 1))\n",
"x,features,labels\n"
],
"id": "5d450c6f6b724a14",
"outputs": [
{
"data": {
"text/plain": [
2026-04-22 07:23:35 +00:00
"(tensor([-0.0948, 0.2143, -0.2523, -0.1235, -0.1826, 0.1189, -0.1963, 0.2347,\n",
" 0.1456, -0.1118, 0.3787, 0.3861, 0.2881, 0.1958, 0.0402, 0.0816,\n",
" 0.4793, 0.0351, 0.2378, 0.1459, 0.1108, 0.2544, -0.0127, 0.0733,\n",
" 0.3156, 0.0257, 0.3207, 0.3259, 0.3693, 0.0584, 0.1730, 0.3100,\n",
" 0.2328, 0.0525, 0.4465, 0.1293, 0.4330, 0.3193, 0.4704, 0.5238,\n",
" 0.5323, 0.4887, 0.0831, 0.5924, 0.6972, 0.3490, 0.7476, 0.6039,\n",
" 0.9995, 0.1455, 0.1417, 0.5968, 0.6673, 0.3425, 0.7685, 0.4904,\n",
" 0.2203, 0.2109, 0.4600, 0.5055, 0.3558, 0.7020, 0.7435, 0.4713,\n",
" 0.4318, 0.5861, 0.3592, 0.7750, 0.6640, 0.7908, 0.2776, 0.5868,\n",
" 0.6283, 0.3461, 0.6308, 0.7547, 0.5564, 0.7181, 0.7852, 0.7823,\n",
" 0.7238, 0.9294, 0.9023, 0.8100, 0.5561, 0.7124, 1.1566, 0.7628,\n",
" 0.9630, 0.4425, 1.0628, 0.7014, 0.4439, 0.7286, 0.8099, 0.5786,\n",
" 1.0638, 0.9519, 0.8388, 1.2088, 0.9172, 0.7014, 0.5667, 0.6040,\n",
" 0.5549, 0.7959, 0.9167, 0.9074, 0.6108, 0.8999, 0.9197, 0.8539,\n",
" 0.6566, 0.9941, 0.6902, 0.8782, 1.4898, 0.9888, 1.1911, 0.5683,\n",
" 0.8868, 0.7122, 0.8960, 1.1454, 1.2660, 1.0001, 0.6582, 0.9706,\n",
" 1.0110, 1.0355, 0.9761, 0.9439, 1.0824, 1.4095, 1.2544, 0.6541,\n",
" 1.0486, 1.0638, 0.9000, 0.9835, 1.2558, 1.1702, 0.8466, 0.7696,\n",
" 1.2446, 1.1460, 0.9258, 0.8150, 1.1086, 0.9475, 0.9675, 0.7330,\n",
" 1.1263, 1.1718, 0.9413, 1.0272, 0.7733, 0.9831, 0.8759, 0.7970,\n",
" 0.6360, 1.1815, 0.9689, 0.6976, 0.9265, 0.8338, 0.7960, 0.7705,\n",
" 1.2601, 1.2775, 0.7706, 1.0216, 1.1916, 0.8603, 0.9864, 1.0777,\n",
" 0.8930, 1.0063, 0.8376, 0.9923, 0.8081, 0.8020, 1.1461, 1.1018,\n",
" 0.8931, 1.0005, 0.8635, 0.7197, 1.2577, 1.0584, 1.4032, 0.8911,\n",
" 1.1415, 0.8241, 0.7946, 1.0221, 0.8792, 0.7211, 1.1821, 0.8079,\n",
" 0.8926, 1.0765, 0.9949, 0.9159, 0.7329, 0.9950, 0.7491, 0.8750,\n",
" 1.1863, 1.0095, 0.8046, 0.6274, 0.8936, 0.7595, 0.8423, 0.8655,\n",
" 0.6918, 0.7347, 1.1179, 0.5931, 0.8745, 0.4858, 0.9338, 1.1382,\n",
" 0.6084, 0.9479, 0.8726, 0.7202, 0.9596, 0.4386, 1.2525, 0.5120,\n",
" 0.7222, 0.6566, 0.8965, 0.7545, 1.1104, 0.6634, 0.5654, 1.0095,\n",
" 0.6558, 0.7260, 0.8515, 0.3430, 0.7703, 0.3753, 0.4490, 0.4373,\n",
" 0.8283, 0.5455, 0.7584, 0.8197, 0.4781, 0.3350, 0.6714, 0.3969,\n",
" 0.7131, 0.5609, 0.4327, 0.4293, 0.3552, 0.5445, 0.5609, 0.4110,\n",
" 0.8525, 0.3402, 0.4064, 0.5172, 0.5845, 0.6185, 0.4719, 0.9092,\n",
" 0.6964, 0.7267, 0.6934, 0.4337, 0.2031, 0.0898, 0.4377, 0.4203,\n",
" 0.2855, 0.4673, 0.6029, 0.4368, 0.1521, -0.0606, 0.2532, 0.4365,\n",
" 0.2989, 0.0743, 0.2734, 0.1060, 0.5543, -0.1211, 0.0968, 0.4911,\n",
" 0.5107, 0.4583, 0.2777, -0.0513, 0.1437, 0.0548, 0.0933, 0.1172,\n",
" 0.0718, 0.4027, 0.1805, 0.0869, -0.3066, 0.5615, -0.2721, -0.2765,\n",
" 0.0850, -0.1473, -0.1622, -0.1335, 0.1328, 0.0703, -0.6712, -0.1121,\n",
" -0.1208, 0.0092, -0.0805, -0.2017, 0.2339, -0.3533, -0.4598, 0.0620,\n",
" -0.5254, 0.0197, -0.0593, -0.1914, -0.4259, -0.0115, -0.5406, 0.0137,\n",
" -0.4240, -0.2822, -0.0796, -0.3495, -0.4475, 0.2453, -0.3729, -0.4086,\n",
" -0.2618, -0.4539, -0.6140, -0.2483, -0.4165, -0.3736, -0.0737, -0.0212,\n",
" -0.3644, -0.0472, -0.4087, -0.6794, -0.5921, -0.5632, -0.5971, -0.1935,\n",
" -0.9543, -0.7976, -0.3485, -0.9538, -0.6171, -0.7755, -0.4651, -0.8194,\n",
" -0.3005, -0.5191, -0.5902, -0.2464, -0.6908, -0.5054, -0.5528, -1.1089,\n",
" -0.7206, -0.8067, -0.6780, -0.2981, -0.6683, -0.4324, -0.8497, -0.5928,\n",
" -0.7203, -0.3751, -0.7423, -0.4109, -0.7345, -0.6653, -0.5752, -0.5198,\n",
" -0.7046, -1.1754, -0.9447, -0.7304, -0.6510, -0.5954, -0.7592, -0.5285,\n",
" -0.4249, -0.7993, -1.3758, -0.6218, -1.0691, -0.5775, -0.8174, -0.7021,\n",
" -0.7784, -0.7553, -1.2137, -0.7302, -0.7253, -0.6819, -1.3077, -1.3472,\n",
" -0.7104, -0.8387, -0.5973, -0.8619, -1.1138, -1.1314, -0.9765, -1.2121,\n",
" -0.8168, -0.7763, -1.2988, -0.9282, -1.1715, -0.7216, -0.7182, -0.2972,\n",
" -0.7471, -1.0089, -1.1431, -1.0396, -1.0381, -0.5979, -0.7363, -0.7808,\n",
" -0.9106, -1.1468, -1.1357, -0.6406, -0.9603, -1.2653, -1.5958, -1.0592,\n",
" -0.9698, -0.8252, -1.2515, -1.0474, -1.1103, -1.0035, -0.6669, -0.9120,\n",
" -0.9146, -1.1079, -0.8379, -0.9123, -0.5831, -1.6515, -0.9385, -1.0699,\n",
" -1.1498, -0.7861, -0.8942, -1.0452, -1.0064, -0.9116, -1.1150, -0.7801,\n",
" -1.0283, -1.0296, -1.0927, -0.7945, -1.0705, -1.3215, -1.2510, -0.9158,\n",
" -0.9377, -0.7314, -0.9773, -1.1910, -1.0539, -1.1439, -1.0784, -0.8543,\n",
" -1.1323, -1.3193, -0.8014, -0.7318, -0.5805, -0.8239, -1.1228, -1.0473,\n",
" -0.8206, -0.6544, -1.2654, -1.0757, -0.5389, -0.9908, -0.7894, -0.7463,\n",
" -1.0391, -0.8023, -0.8568, -1.2414, -0.9595, -1.1151, -0.9689, -1.1145,\n",
" -0.6853, -0.7547, -1.1000, -0.9054, -1.2262, -1.1359, -1.0174, -0.3782,\n",
" -0.8056, -1.1828, -0.8426, -0.9958, -0.9495, -1.2745, -0.7039, -0.5893,\n",
" -0.5648, -1.0538, -0.6724, -0.6340, -0.5070, -1.0956, -1.0957, -0.6823,\n",
" -0.5258, -0.5777, -0.9268, -0.5280, -0.5989, -0.8364, -0.7439, -0.7619,\n",
" -1.0159, -1.0627, -0.9416, -0.6270, -0.4307, -0.8575, -1.0748, -0.5529,\n",
" -0.9339, -0.7416, -0.6674, -0.3178, -0.6815, -0.7499, -0.6359, -0.8157,\n",
" -0.5582, -0.5083, -0.4527, -0.8350, -0.6317, -0.4338, -0.4875, -0.4046,\n",
" -0.3166, -0.3413, -0.4722, -0.7010, -1.2025, -0.2133, -0.3133, -0.4160,\n",
" -0.6681, -0.8990, -0.5464, -0.4518, -0.4402, -0.5246, -0.4561, -0.6747,\n",
" -0.1833, -0.4466, -0.4671, -0.5509, -0.6235, -0.2100, -0.3368, -0.3083,\n",
" -0.5129, -0.2880, -0.4075, -0.2784, 0.0631, -0.4355, -0.4237, -0.2578,\n",
" -0.1380, -0.5085, 0.1004, -0.1426, -0.2537, -0.1756, -0.2135, -0.1898,\n",
" -0.2947, -0.3934, -0.3412, -0.3343, -0.1450, -0.3178, -0.2156, -0.3232,\n",
" -0.3691, -0.2711, 0.1086, -0.2257, -0.0752, -0.0339, -0.0636, 0.0626,\n",
" -0.1460, 0.0792, 0.1529, 0.4743, 0.0343, -0.0158, -0.1255, -0.4698,\n",
" -0.0489, 0.2622, 0.0619, -0.2243, -0.1318, 0.0214, 0.2690, 0.0497,\n",
" 0.3451, -0.1116, 0.0173, 0.0708, 0.4135, 0.3188, 0.4808, -0.0340,\n",
" 0.4786, 0.4896, 0.1077, 0.3500, 0.1309, 0.1398, 0.1943, 0.1651,\n",
" 0.3227, 0.5541, 0.2688, 0.1892, 0.2509, 0.2078, -0.0140, 0.2443,\n",
" 0.3204, 0.5485, 0.4234, 0.3135, 0.4633, 0.0029, 0.2174, 0.6879,\n",
" 0.5089, 0.2479, 0.8608, 0.4307, 0.6205, 0.3482, 0.6469, 0.4475,\n",
" 0.6595, 0.3450, 0.3781, 0.4451, 0.1883, 0.6707, 0.8667, 0.5218,\n",
" 0.4004, 0.5271, 0.6446, 0.7222, 0.5722, 0.7676, 0.6824, 0.1981,\n",
" 0.8089, 0.6296, 0.6748, 0.7515, 0.5103, 0.9052, 0.8405, 0.9092,\n",
" 0.6918, 0.6477, 0.5402, 0.6477, 0.4210, 0.6973, 0.6019, 0.5364,\n",
" 0.8134, 0.5607, 0.7096, 0.5894, 0.3866, 1.0600, 0.7347, 0.8129,\n",
" 1.2088, 0.8825, 0.7179, 1.0115, 0.7013, 1.0128, 0.9747, 1.2759,\n",
" 0.7655, 1.0094, 0.7805, 0.6091, 1.2033, 0.9678, 0.8219, 0.8157,\n",
" 0.9188, 0.7436, 0.8910, 0.7291, 0.9559, 0.9389, 1.2030, 1.0495,\n",
" 1.1811, 0.8884, 0.8390, 0.9894, 0.9238, 0.7628, 0.5421, 1.5147,\n",
" 0.6971, 0.6740, 0.8342, 0.6554, 0.7455, 0.6916, 1.2706, 1.1277,\n",
" 0.9248, 0.9976, 1.2404, 0.6919, 1.3449, 1.1243, 1.0492, 0.9266,\n",
" 1.1194, 1.0304, 1.1323, 1.2372, 0.8300, 1.1916, 1.0923, 0.8313,\n",
" 0.8572, 1.1128, 1.0047, 1.1544, 0.9745, 1.0503, 0.9171, 0.8073,\n",
" 1.2056, 1.0976, 0.9910, 1.1834, 1.1389, 0.9142, 0.9367, 1.0121,\n",
" 0.7704, 1.0558, 0.7306, 0.8117, 0.7061, 1.2315, 0.9015, 0.9339,\n",
" 0.5016, 0.9227, 1.2568, 0.9444, 1.1198, 0.9431, 1.0997, 1.3078,\n",
" 0.8336, 1.2692, 0.8424, 0.8702, 1.4820, 1.3248, 0.9324, 0.6538,\n",
" 1.2011, 1.0170, 0.7863, 1.0178, 0.6519, 0.5970, 0.9052, 0.6846,\n",
" 0.7737, 0.9104, 0.8439, 1.0066, 1.0787, 0.9661, 0.9923, 0.7922,\n",
" 0.8316, 0.9553, 0.9952, 0.8680, 1.1226, 0.8213, 0.9151, 0.7748,\n",
" 0.9953, 0.7773, 0.7916, 0.7321, 0.9130, 1.1433, 0.7060, 0.8066,\n",
" 0.8709, 0.7426, 0.8718, 1.0973, 0.7097, 0.9438, 0.8164, 0.8013,\n",
" 0.6236, 0.7180, 0.9188, 0.8016, 0.9741, 0.6271, 0.5747, 0.8007,\n",
" 0.7754, 0.4877, 0.4746, 0.8654, 0.4743, 0.9015, 0.8082, 0.5449,\n",
" 0.9299, 0.2003, 0.5466, 0.4355, 0.7900, 0.4343, 0.7224, 0.8585,\n",
" 0.5714, 0.5306, 0.6594, 0.0640, 0.3203, 0.5463, 0.5048, 0.1935,\n",
" 0.2883, 0.6778, 0.5014, 0.5235, 0.5718, 0.4587, 0.2808, 0.4073,\n",
" 0.8632, 0.8862, 0.5757, 0.3372, 0.2566, 0.7858, 0.3713, 0.1589,\n",
" 0.3243, 0.4270, 0.0565, 0.2885, 0.3257, 0.2196, 0.3159, 0.2361,\n",
" 0.1087, 0.2224, 0.2633, 0.5037, 0.1980, 0.1530, 0.2780, -0.1399,\n",
" 0.5331, 0.3530, 0.3342, 0.2098, -0.0165, 0.1318, 0.4510, -0.1959,\n",
" 0.0966, 0.0789, 0.3381, -0.1917, 0.1518, 0.3640, 0.0956, 0.2535,\n",
" -0.3988, -0.3479, 0.3864, -0.2639, -0.2368, 0.0258, 0.2441, 0.0687,\n",
" 0.0457, 0.2286, -0.0947, -0.1189, 0.1360, -0.0990, -0.2447, 0.2135,\n",
" -0.1830, -0.4583, -0.1795, -0.1361, -0.0553, -0.2864, -0.2307, -0.4651,\n",
" -0.1889, -0.3185, -0.5318, -0.3012, 0.0062, 0.1046, -0.2321, -0.2945,\n",
" -0.0242, -0.0586, -0.2307, -0.2479, -0.0382, -0.1509, -0.5055, -0.3759,\n",
" 0.2139, -0.2129, -0.3605, -0.5222, -0.6530, -0.6716, -0.4330, -0.2577,\n",
" -0.2672, -0.1297, -0.9203, -0.5832, -0.2640, -0.4996, -0.2625, -0.4407,\n",
" -0.8864, -0.2508, -0.4827, -0.3131, -0.2570, -0.7116, -0.5357, -0.7074]),\n",
" tensor([[-0.0948, 0.2143, -0.2523, -0.1235],\n",
" [ 0.2143, -0.2523, -0.1235, -0.1826],\n",
" [-0.2523, -0.1235, -0.1826, 0.1189],\n",
2026-03-25 15:07:28 +00:00
" ...,\n",
2026-04-22 07:23:35 +00:00
" [-0.2508, -0.4827, -0.3131, -0.2570],\n",
" [-0.4827, -0.3131, -0.2570, -0.7116],\n",
" [-0.3131, -0.2570, -0.7116, -0.5357]]),\n",
" tensor([[-0.1826],\n",
" [ 0.1189],\n",
" [-0.1963],\n",
" [ 0.2347],\n",
" [ 0.1456],\n",
" [-0.1118],\n",
" [ 0.3787],\n",
" [ 0.3861],\n",
" [ 0.2881],\n",
" [ 0.1958],\n",
" [ 0.0402],\n",
" [ 0.0816],\n",
" [ 0.4793],\n",
" [ 0.0351],\n",
" [ 0.2378],\n",
" [ 0.1459],\n",
" [ 0.1108],\n",
" [ 0.2544],\n",
" [-0.0127],\n",
" [ 0.0733],\n",
" [ 0.3156],\n",
" [ 0.0257],\n",
" [ 0.3207],\n",
" [ 0.3259],\n",
" [ 0.3693],\n",
" [ 0.0584],\n",
" [ 0.1730],\n",
" [ 0.3100],\n",
" [ 0.2328],\n",
" [ 0.0525],\n",
" [ 0.4465],\n",
" [ 0.1293],\n",
" [ 0.4330],\n",
" [ 0.3193],\n",
" [ 0.4704],\n",
" [ 0.5238],\n",
" [ 0.5323],\n",
" [ 0.4887],\n",
" [ 0.0831],\n",
" [ 0.5924],\n",
" [ 0.6972],\n",
" [ 0.3490],\n",
" [ 0.7476],\n",
" [ 0.6039],\n",
" [ 0.9995],\n",
" [ 0.1455],\n",
" [ 0.1417],\n",
" [ 0.5968],\n",
" [ 0.6673],\n",
" [ 0.3425],\n",
" [ 0.7685],\n",
" [ 0.4904],\n",
" [ 0.2203],\n",
" [ 0.2109],\n",
" [ 0.4600],\n",
" [ 0.5055],\n",
" [ 0.3558],\n",
" [ 0.7020],\n",
" [ 0.7435],\n",
" [ 0.4713],\n",
" [ 0.4318],\n",
" [ 0.5861],\n",
" [ 0.3592],\n",
" [ 0.7750],\n",
" [ 0.6640],\n",
" [ 0.7908],\n",
" [ 0.2776],\n",
" [ 0.5868],\n",
" [ 0.6283],\n",
" [ 0.3461],\n",
" [ 0.6308],\n",
" [ 0.7547],\n",
" [ 0.5564],\n",
" [ 0.7181],\n",
" [ 0.7852],\n",
" [ 0.7823],\n",
" [ 0.7238],\n",
" [ 0.9294],\n",
" [ 0.9023],\n",
" [ 0.8100],\n",
" [ 0.5561],\n",
" [ 0.7124],\n",
" [ 1.1566],\n",
" [ 0.7628],\n",
" [ 0.9630],\n",
" [ 0.4425],\n",
" [ 1.0628],\n",
" [ 0.7014],\n",
" [ 0.4439],\n",
" [ 0.7286],\n",
" [ 0.8099],\n",
" [ 0.5786],\n",
" [ 1.0638],\n",
" [ 0.9519],\n",
" [ 0.8388],\n",
" [ 1.2088],\n",
" [ 0.9172],\n",
" [ 0.7014],\n",
" [ 0.5667],\n",
" [ 0.6040],\n",
" [ 0.5549],\n",
" [ 0.7959],\n",
" [ 0.9167],\n",
" [ 0.9074],\n",
" [ 0.6108],\n",
" [ 0.8999],\n",
" [ 0.9197],\n",
" [ 0.8539],\n",
" [ 0.6566],\n",
" [ 0.9941],\n",
" [ 0.6902],\n",
" [ 0.8782],\n",
" [ 1.4898],\n",
" [ 0.9888],\n",
" [ 1.1911],\n",
" [ 0.5683],\n",
" [ 0.8868],\n",
" [ 0.7122],\n",
" [ 0.8960],\n",
" [ 1.1454],\n",
" [ 1.2660],\n",
" [ 1.0001],\n",
" [ 0.6582],\n",
" [ 0.9706],\n",
" [ 1.0110],\n",
" [ 1.0355],\n",
" [ 0.9761],\n",
" [ 0.9439],\n",
" [ 1.0824],\n",
" [ 1.4095],\n",
" [ 1.2544],\n",
" [ 0.6541],\n",
" [ 1.0486],\n",
" [ 1.0638],\n",
" [ 0.9000],\n",
" [ 0.9835],\n",
" [ 1.2558],\n",
" [ 1.1702],\n",
" [ 0.8466],\n",
" [ 0.7696],\n",
" [ 1.2446],\n",
" [ 1.1460],\n",
" [ 0.9258],\n",
" [ 0.8150],\n",
" [ 1.1086],\n",
" [ 0.9475],\n",
" [ 0.9675],\n",
" [ 0.7330],\n",
" [ 1.1263],\n",
" [ 1.1718],\n",
" [ 0.9413],\n",
" [ 1.0272],\n",
" [ 0.7733],\n",
" [ 0.9831],\n",
" [ 0.8759],\n",
" [ 0.7970],\n",
" [ 0.6360],\n",
" [ 1.1815],\n",
" [ 0.9689],\n",
" [ 0.6976],\n",
" [ 0.9265],\n",
" [ 0.8338],\n",
" [ 0.7960],\n",
" [ 0.7705],\n",
" [ 1.2601],\n",
" [ 1.2775],\n",
" [ 0.7706],\n",
" [ 1.0216],\n",
" [ 1.1916],\n",
" [ 0.8603],\n",
" [ 0.9864],\n",
" [ 1.0777],\n",
" [ 0.8930],\n",
" [ 1.0063],\n",
" [ 0.8376],\n",
" [ 0.9923],\n",
" [ 0.8081],\n",
" [ 0.8020],\n",
" [ 1.1461],\n",
" [ 1.1018],\n",
" [ 0.8931],\n",
" [ 1.0005],\n",
" [ 0.8635],\n",
" [ 0.7197],\n",
" [ 1.2577],\n",
" [ 1.0584],\n",
" [ 1.4032],\n",
" [ 0.8911],\n",
" [ 1.1415],\n",
" [ 0.8241],\n",
" [ 0.7946],\n",
" [ 1.0221],\n",
" [ 0.8792],\n",
" [ 0.7211],\n",
" [ 1.1821],\n",
" [ 0.8079],\n",
" [ 0.8926],\n",
" [ 1.0765],\n",
" [ 0.9949],\n",
" [ 0.9159],\n",
" [ 0.7329],\n",
" [ 0.9950],\n",
" [ 0.7491],\n",
" [ 0.8750],\n",
" [ 1.1863],\n",
" [ 1.0095],\n",
" [ 0.8046],\n",
" [ 0.6274],\n",
" [ 0.8936],\n",
" [ 0.7595],\n",
" [ 0.8423],\n",
" [ 0.8655],\n",
" [ 0.6918],\n",
" [ 0.7347],\n",
" [ 1.1179],\n",
" [ 0.5931],\n",
" [ 0.8745],\n",
" [ 0.4858],\n",
" [ 0.9338],\n",
" [ 1.1382],\n",
" [ 0.6084],\n",
" [ 0.9479],\n",
" [ 0.8726],\n",
" [ 0.7202],\n",
" [ 0.9596],\n",
" [ 0.4386],\n",
" [ 1.2525],\n",
" [ 0.5120],\n",
" [ 0.7222],\n",
" [ 0.6566],\n",
" [ 0.8965],\n",
" [ 0.7545],\n",
" [ 1.1104],\n",
" [ 0.6634],\n",
" [ 0.5654],\n",
" [ 1.0095],\n",
" [ 0.6558],\n",
" [ 0.7260],\n",
" [ 0.8515],\n",
" [ 0.3430],\n",
" [ 0.7703],\n",
" [ 0.3753],\n",
" [ 0.4490],\n",
" [ 0.4373],\n",
" [ 0.8283],\n",
" [ 0.5455],\n",
" [ 0.7584],\n",
" [ 0.8197],\n",
" [ 0.4781],\n",
" [ 0.3350],\n",
" [ 0.6714],\n",
" [ 0.3969],\n",
" [ 0.7131],\n",
" [ 0.5609],\n",
" [ 0.4327],\n",
" [ 0.4293],\n",
" [ 0.3552],\n",
" [ 0.5445],\n",
" [ 0.5609],\n",
" [ 0.4110],\n",
" [ 0.8525],\n",
" [ 0.3402],\n",
" [ 0.4064],\n",
" [ 0.5172],\n",
" [ 0.5845],\n",
" [ 0.6185],\n",
" [ 0.4719],\n",
" [ 0.9092],\n",
" [ 0.6964],\n",
" [ 0.7267],\n",
" [ 0.6934],\n",
" [ 0.4337],\n",
" [ 0.2031],\n",
" [ 0.0898],\n",
" [ 0.4377],\n",
" [ 0.4203],\n",
" [ 0.2855],\n",
" [ 0.4673],\n",
" [ 0.6029],\n",
" [ 0.4368],\n",
" [ 0.1521],\n",
" [-0.0606],\n",
" [ 0.2532],\n",
" [ 0.4365],\n",
" [ 0.2989],\n",
" [ 0.0743],\n",
" [ 0.2734],\n",
" [ 0.1060],\n",
" [ 0.5543],\n",
" [-0.1211],\n",
" [ 0.0968],\n",
" [ 0.4911],\n",
" [ 0.5107],\n",
" [ 0.4583],\n",
" [ 0.2777],\n",
" [-0.0513],\n",
" [ 0.1437],\n",
" [ 0.0548],\n",
" [ 0.0933],\n",
" [ 0.1172],\n",
" [ 0.0718],\n",
" [ 0.4027],\n",
" [ 0.1805],\n",
" [ 0.0869],\n",
" [-0.3066],\n",
" [ 0.5615],\n",
" [-0.2721],\n",
" [-0.2765],\n",
" [ 0.0850],\n",
" [-0.1473],\n",
" [-0.1622],\n",
" [-0.1335],\n",
" [ 0.1328],\n",
" [ 0.0703],\n",
" [-0.6712],\n",
" [-0.1121],\n",
" [-0.1208],\n",
" [ 0.0092],\n",
" [-0.0805],\n",
" [-0.2017],\n",
" [ 0.2339],\n",
" [-0.3533],\n",
" [-0.4598],\n",
" [ 0.0620],\n",
" [-0.5254],\n",
" [ 0.0197],\n",
" [-0.0593],\n",
" [-0.1914],\n",
" [-0.4259],\n",
" [-0.0115],\n",
" [-0.5406],\n",
" [ 0.0137],\n",
" [-0.4240],\n",
" [-0.2822],\n",
" [-0.0796],\n",
" [-0.3495],\n",
" [-0.4475],\n",
" [ 0.2453],\n",
" [-0.3729],\n",
" [-0.4086],\n",
" [-0.2618],\n",
" [-0.4539],\n",
" [-0.6140],\n",
" [-0.2483],\n",
" [-0.4165],\n",
" [-0.3736],\n",
" [-0.0737],\n",
" [-0.0212],\n",
" [-0.3644],\n",
" [-0.0472],\n",
" [-0.4087],\n",
" [-0.6794],\n",
" [-0.5921],\n",
" [-0.5632],\n",
" [-0.5971],\n",
" [-0.1935],\n",
" [-0.9543],\n",
" [-0.7976],\n",
" [-0.3485],\n",
" [-0.9538],\n",
" [-0.6171],\n",
" [-0.7755],\n",
" [-0.4651],\n",
" [-0.8194],\n",
" [-0.3005],\n",
" [-0.5191],\n",
" [-0.5902],\n",
" [-0.2464],\n",
" [-0.6908],\n",
" [-0.5054],\n",
" [-0.5528],\n",
" [-1.1089],\n",
" [-0.7206],\n",
" [-0.8067],\n",
" [-0.6780],\n",
" [-0.2981],\n",
" [-0.6683],\n",
" [-0.4324],\n",
" [-0.8497],\n",
" [-0.5928],\n",
" [-0.7203],\n",
" [-0.3751],\n",
" [-0.7423],\n",
" [-0.4109],\n",
" [-0.7345],\n",
" [-0.6653],\n",
" [-0.5752],\n",
" [-0.5198],\n",
" [-0.7046],\n",
" [-1.1754],\n",
" [-0.9447],\n",
" [-0.7304],\n",
" [-0.6510],\n",
" [-0.5954],\n",
" [-0.7592],\n",
" [-0.5285],\n",
" [-0.4249],\n",
" [-0.7993],\n",
" [-1.3758],\n",
" [-0.6218],\n",
" [-1.0691],\n",
" [-0.5775],\n",
" [-0.8174],\n",
" [-0.7021],\n",
" [-0.7784],\n",
" [-0.7553],\n",
" [-1.2137],\n",
" [-0.7302],\n",
" [-0.7253],\n",
" [-0.6819],\n",
" [-1.3077],\n",
" [-1.3472],\n",
" [-0.7104],\n",
" [-0.8387],\n",
" [-0.5973],\n",
" [-0.8619],\n",
" [-1.1138],\n",
" [-1.1314],\n",
" [-0.9765],\n",
" [-1.2121],\n",
" [-0.8168],\n",
" [-0.7763],\n",
" [-1.2988],\n",
" [-0.9282],\n",
" [-1.1715],\n",
" [-0.7216],\n",
" [-0.7182],\n",
" [-0.2972],\n",
" [-0.7471],\n",
" [-1.0089],\n",
" [-1.1431],\n",
" [-1.0396],\n",
" [-1.0381],\n",
" [-0.5979],\n",
" [-0.7363],\n",
" [-0.7808],\n",
" [-0.9106],\n",
" [-1.1468],\n",
" [-1.1357],\n",
" [-0.6406],\n",
" [-0.9603],\n",
" [-1.2653],\n",
" [-1.5958],\n",
" [-1.0592],\n",
" [-0.9698],\n",
" [-0.8252],\n",
" [-1.2515],\n",
" [-1.0474],\n",
" [-1.1103],\n",
" [-1.0035],\n",
" [-0.6669],\n",
" [-0.9120],\n",
" [-0.9146],\n",
" [-1.1079],\n",
" [-0.8379],\n",
" [-0.9123],\n",
" [-0.5831],\n",
" [-1.6515],\n",
" [-0.9385],\n",
" [-1.0699],\n",
" [-1.1498],\n",
" [-0.7861],\n",
" [-0.8942],\n",
" [-1.0452],\n",
" [-1.0064],\n",
" [-0.9116],\n",
" [-1.1150],\n",
" [-0.7801],\n",
" [-1.0283],\n",
" [-1.0296],\n",
" [-1.0927],\n",
" [-0.7945],\n",
" [-1.0705],\n",
" [-1.3215],\n",
" [-1.2510],\n",
" [-0.9158],\n",
" [-0.9377],\n",
" [-0.7314],\n",
" [-0.9773],\n",
" [-1.1910],\n",
" [-1.0539],\n",
" [-1.1439],\n",
" [-1.0784],\n",
" [-0.8543],\n",
" [-1.1323],\n",
" [-1.3193],\n",
" [-0.8014],\n",
" [-0.7318],\n",
" [-0.5805],\n",
" [-0.8239],\n",
" [-1.1228],\n",
" [-1.0473],\n",
" [-0.8206],\n",
" [-0.6544],\n",
" [-1.2654],\n",
" [-1.0757],\n",
" [-0.5389],\n",
" [-0.9908],\n",
" [-0.7894],\n",
" [-0.7463],\n",
" [-1.0391],\n",
" [-0.8023],\n",
" [-0.8568],\n",
" [-1.2414],\n",
" [-0.9595],\n",
" [-1.1151],\n",
" [-0.9689],\n",
" [-1.1145],\n",
" [-0.6853],\n",
" [-0.7547],\n",
" [-1.1000],\n",
" [-0.9054],\n",
" [-1.2262],\n",
" [-1.1359],\n",
" [-1.0174],\n",
" [-0.3782],\n",
" [-0.8056],\n",
" [-1.1828],\n",
" [-0.8426],\n",
" [-0.9958],\n",
" [-0.9495],\n",
" [-1.2745],\n",
" [-0.7039],\n",
" [-0.5893],\n",
" [-0.5648],\n",
" [-1.0538],\n",
" [-0.6724],\n",
" [-0.6340],\n",
" [-0.5070],\n",
" [-1.0956],\n",
" [-1.0957],\n",
" [-0.6823],\n",
" [-0.5258],\n",
" [-0.5777],\n",
" [-0.9268],\n",
" [-0.5280],\n",
" [-0.5989],\n",
" [-0.8364],\n",
" [-0.7439],\n",
" [-0.7619],\n",
" [-1.0159],\n",
" [-1.0627],\n",
" [-0.9416],\n",
" [-0.6270],\n",
" [-0.4307],\n",
" [-0.8575],\n",
" [-1.0748],\n",
" [-0.5529],\n",
" [-0.9339],\n",
" [-0.7416],\n",
" [-0.6674],\n",
" [-0.3178],\n",
" [-0.6815],\n",
" [-0.7499],\n",
" [-0.6359],\n",
" [-0.8157],\n",
" [-0.5582],\n",
" [-0.5083],\n",
" [-0.4527],\n",
" [-0.8350],\n",
" [-0.6317],\n",
" [-0.4338],\n",
" [-0.4875],\n",
" [-0.4046],\n",
" [-0.3166],\n",
" [-0.3413],\n",
" [-0.4722],\n",
" [-0.7010],\n",
" [-1.2025],\n",
" [-0.2133],\n",
" [-0.3133],\n",
" [-0.4160],\n",
" [-0.6681],\n",
" [-0.8990],\n",
" [-0.5464],\n",
" [-0.4518],\n",
" [-0.4402],\n",
" [-0.5246],\n",
" [-0.4561],\n",
" [-0.6747],\n",
" [-0.1833],\n",
" [-0.4466],\n",
" [-0.4671],\n",
" [-0.5509],\n",
" [-0.6235],\n",
" [-0.2100],\n",
" [-0.3368],\n",
" [-0.3083],\n",
" [-0.5129],\n",
" [-0.2880],\n",
" [-0.4075],\n",
" [-0.2784],\n",
" [ 0.0631],\n",
" [-0.4355],\n",
" [-0.4237],\n",
" [-0.2578],\n",
" [-0.1380],\n",
" [-0.5085],\n",
" [ 0.1004],\n",
" [-0.1426],\n",
" [-0.2537],\n",
" [-0.1756],\n",
" [-0.2135],\n",
" [-0.1898],\n",
" [-0.2947],\n",
" [-0.3934],\n",
" [-0.3412],\n",
" [-0.3343],\n",
" [-0.1450],\n",
" [-0.3178],\n",
" [-0.2156],\n",
" [-0.3232],\n",
" [-0.3691],\n",
" [-0.2711],\n",
" [ 0.1086],\n",
" [-0.2257],\n",
" [-0.0752],\n",
" [-0.0339],\n",
" [-0.0636],\n",
" [ 0.0626],\n",
" [-0.1460],\n",
" [ 0.0792],\n",
" [ 0.1529],\n",
" [ 0.4743],\n",
" [ 0.0343],\n",
" [-0.0158],\n",
" [-0.1255],\n",
" [-0.4698],\n",
" [-0.0489],\n",
" [ 0.2622],\n",
" [ 0.0619],\n",
" [-0.2243],\n",
" [-0.1318],\n",
" [ 0.0214],\n",
" [ 0.2690],\n",
" [ 0.0497],\n",
" [ 0.3451],\n",
" [-0.1116],\n",
" [ 0.0173],\n",
" [ 0.0708],\n",
" [ 0.4135],\n",
" [ 0.3188],\n",
" [ 0.4808],\n",
" [-0.0340],\n",
" [ 0.4786],\n",
" [ 0.4896],\n",
" [ 0.1077],\n",
" [ 0.3500],\n",
" [ 0.1309],\n",
" [ 0.1398],\n",
" [ 0.1943],\n",
" [ 0.1651],\n",
" [ 0.3227],\n",
" [ 0.5541],\n",
" [ 0.2688],\n",
" [ 0.1892],\n",
" [ 0.2509],\n",
" [ 0.2078],\n",
" [-0.0140],\n",
" [ 0.2443],\n",
" [ 0.3204],\n",
" [ 0.5485],\n",
" [ 0.4234],\n",
" [ 0.3135],\n",
" [ 0.4633],\n",
" [ 0.0029],\n",
" [ 0.2174],\n",
" [ 0.6879],\n",
" [ 0.5089],\n",
" [ 0.2479],\n",
" [ 0.8608],\n",
" [ 0.4307],\n",
" [ 0.6205],\n",
" [ 0.3482],\n",
" [ 0.6469],\n",
" [ 0.4475],\n",
" [ 0.6595],\n",
" [ 0.3450],\n",
" [ 0.3781],\n",
" [ 0.4451],\n",
" [ 0.1883],\n",
" [ 0.6707],\n",
" [ 0.8667],\n",
" [ 0.5218],\n",
" [ 0.4004],\n",
" [ 0.5271],\n",
" [ 0.6446],\n",
" [ 0.7222],\n",
" [ 0.5722],\n",
" [ 0.7676],\n",
" [ 0.6824],\n",
" [ 0.1981],\n",
" [ 0.8089],\n",
" [ 0.6296],\n",
" [ 0.6748],\n",
" [ 0.7515],\n",
" [ 0.5103],\n",
" [ 0.9052],\n",
" [ 0.8405],\n",
" [ 0.9092],\n",
" [ 0.6918],\n",
" [ 0.6477],\n",
" [ 0.5402],\n",
" [ 0.6477],\n",
" [ 0.4210],\n",
" [ 0.6973],\n",
" [ 0.6019],\n",
" [ 0.5364],\n",
" [ 0.8134],\n",
" [ 0.5607],\n",
" [ 0.7096],\n",
" [ 0.5894],\n",
" [ 0.3866],\n",
" [ 1.0600],\n",
" [ 0.7347],\n",
" [ 0.8129],\n",
" [ 1.2088],\n",
" [ 0.8825],\n",
" [ 0.7179],\n",
" [ 1.0115],\n",
" [ 0.7013],\n",
" [ 1.0128],\n",
" [ 0.9747],\n",
" [ 1.2759],\n",
" [ 0.7655],\n",
" [ 1.0094],\n",
" [ 0.7805],\n",
" [ 0.6091],\n",
" [ 1.2033],\n",
" [ 0.9678],\n",
" [ 0.8219],\n",
" [ 0.8157],\n",
" [ 0.9188],\n",
" [ 0.7436],\n",
" [ 0.8910],\n",
" [ 0.7291],\n",
" [ 0.9559],\n",
" [ 0.9389],\n",
" [ 1.2030],\n",
" [ 1.0495],\n",
" [ 1.1811],\n",
" [ 0.8884],\n",
" [ 0.8390],\n",
" [ 0.9894],\n",
" [ 0.9238],\n",
" [ 0.7628],\n",
" [ 0.5421],\n",
" [ 1.5147],\n",
" [ 0.6971],\n",
" [ 0.6740],\n",
" [ 0.8342],\n",
" [ 0.6554],\n",
" [ 0.7455],\n",
" [ 0.6916],\n",
" [ 1.2706],\n",
" [ 1.1277],\n",
" [ 0.9248],\n",
" [ 0.9976],\n",
" [ 1.2404],\n",
" [ 0.6919],\n",
" [ 1.3449],\n",
" [ 1.1243],\n",
" [ 1.0492],\n",
" [ 0.9266],\n",
" [ 1.1194],\n",
" [ 1.0304],\n",
" [ 1.1323],\n",
" [ 1.2372],\n",
" [ 0.8300],\n",
" [ 1.1916],\n",
" [ 1.0923],\n",
" [ 0.8313],\n",
" [ 0.8572],\n",
" [ 1.1128],\n",
" [ 1.0047],\n",
" [ 1.1544],\n",
" [ 0.9745],\n",
" [ 1.0503],\n",
" [ 0.9171],\n",
" [ 0.8073],\n",
" [ 1.2056],\n",
" [ 1.0976],\n",
" [ 0.9910],\n",
" [ 1.1834],\n",
" [ 1.1389],\n",
" [ 0.9142],\n",
" [ 0.9367],\n",
" [ 1.0121],\n",
" [ 0.7704],\n",
" [ 1.0558],\n",
" [ 0.7306],\n",
" [ 0.8117],\n",
" [ 0.7061],\n",
" [ 1.2315],\n",
" [ 0.9015],\n",
" [ 0.9339],\n",
" [ 0.5016],\n",
" [ 0.9227],\n",
" [ 1.2568],\n",
" [ 0.9444],\n",
" [ 1.1198],\n",
" [ 0.9431],\n",
" [ 1.0997],\n",
" [ 1.3078],\n",
" [ 0.8336],\n",
" [ 1.2692],\n",
" [ 0.8424],\n",
" [ 0.8702],\n",
" [ 1.4820],\n",
" [ 1.3248],\n",
" [ 0.9324],\n",
" [ 0.6538],\n",
" [ 1.2011],\n",
" [ 1.0170],\n",
" [ 0.7863],\n",
" [ 1.0178],\n",
" [ 0.6519],\n",
" [ 0.5970],\n",
" [ 0.9052],\n",
" [ 0.6846],\n",
" [ 0.7737],\n",
" [ 0.9104],\n",
" [ 0.8439],\n",
" [ 1.0066],\n",
" [ 1.0787],\n",
" [ 0.9661],\n",
" [ 0.9923],\n",
" [ 0.7922],\n",
" [ 0.8316],\n",
" [ 0.9553],\n",
" [ 0.9952],\n",
" [ 0.8680],\n",
" [ 1.1226],\n",
" [ 0.8213],\n",
" [ 0.9151],\n",
" [ 0.7748],\n",
" [ 0.9953],\n",
" [ 0.7773],\n",
" [ 0.7916],\n",
" [ 0.7321],\n",
" [ 0.9130],\n",
" [ 1.1433],\n",
" [ 0.7060],\n",
" [ 0.8066],\n",
" [ 0.8709],\n",
" [ 0.7426],\n",
" [ 0.8718],\n",
" [ 1.0973],\n",
" [ 0.7097],\n",
" [ 0.9438],\n",
" [ 0.8164],\n",
" [ 0.8013],\n",
" [ 0.6236],\n",
" [ 0.7180],\n",
" [ 0.9188],\n",
" [ 0.8016],\n",
" [ 0.9741],\n",
" [ 0.6271],\n",
" [ 0.5747],\n",
" [ 0.8007],\n",
" [ 0.7754],\n",
" [ 0.4877],\n",
" [ 0.4746],\n",
" [ 0.8654],\n",
" [ 0.4743],\n",
" [ 0.9015],\n",
" [ 0.8082],\n",
" [ 0.5449],\n",
" [ 0.9299],\n",
" [ 0.2003],\n",
" [ 0.5466],\n",
" [ 0.4355],\n",
" [ 0.7900],\n",
" [ 0.4343],\n",
" [ 0.7224],\n",
" [ 0.8585],\n",
" [ 0.5714],\n",
" [ 0.5306],\n",
" [ 0.6594],\n",
" [ 0.0640],\n",
" [ 0.3203],\n",
" [ 0.5463],\n",
" [ 0.5048],\n",
" [ 0.1935],\n",
" [ 0.2883],\n",
" [ 0.6778],\n",
" [ 0.5014],\n",
" [ 0.5235],\n",
" [ 0.5718],\n",
" [ 0.4587],\n",
" [ 0.2808],\n",
" [ 0.4073],\n",
" [ 0.8632],\n",
" [ 0.8862],\n",
" [ 0.5757],\n",
" [ 0.3372],\n",
" [ 0.2566],\n",
" [ 0.7858],\n",
" [ 0.3713],\n",
" [ 0.1589],\n",
" [ 0.3243],\n",
" [ 0.4270],\n",
" [ 0.0565],\n",
" [ 0.2885],\n",
" [ 0.3257],\n",
" [ 0.2196],\n",
" [ 0.3159],\n",
" [ 0.2361],\n",
" [ 0.1087],\n",
" [ 0.2224],\n",
" [ 0.2633],\n",
" [ 0.5037],\n",
" [ 0.1980],\n",
" [ 0.1530],\n",
" [ 0.2780],\n",
" [-0.1399],\n",
" [ 0.5331],\n",
" [ 0.3530],\n",
" [ 0.3342],\n",
" [ 0.2098],\n",
" [-0.0165],\n",
" [ 0.1318],\n",
" [ 0.4510],\n",
" [-0.1959],\n",
" [ 0.0966],\n",
" [ 0.0789],\n",
" [ 0.3381],\n",
" [-0.1917],\n",
" [ 0.1518],\n",
" [ 0.3640],\n",
" [ 0.0956],\n",
" [ 0.2535],\n",
" [-0.3988],\n",
" [-0.3479],\n",
" [ 0.3864],\n",
" [-0.2639],\n",
" [-0.2368],\n",
" [ 0.0258],\n",
" [ 0.2441],\n",
" [ 0.0687],\n",
" [ 0.0457],\n",
" [ 0.2286],\n",
" [-0.0947],\n",
" [-0.1189],\n",
" [ 0.1360],\n",
" [-0.0990],\n",
" [-0.2447],\n",
" [ 0.2135],\n",
" [-0.1830],\n",
" [-0.4583],\n",
" [-0.1795],\n",
" [-0.1361],\n",
" [-0.0553],\n",
" [-0.2864],\n",
" [-0.2307],\n",
" [-0.4651],\n",
" [-0.1889],\n",
" [-0.3185],\n",
" [-0.5318],\n",
" [-0.3012],\n",
" [ 0.0062],\n",
" [ 0.1046],\n",
" [-0.2321],\n",
" [-0.2945],\n",
" [-0.0242],\n",
" [-0.0586],\n",
" [-0.2307],\n",
" [-0.2479],\n",
" [-0.0382],\n",
" [-0.1509],\n",
" [-0.5055],\n",
" [-0.3759],\n",
" [ 0.2139],\n",
" [-0.2129],\n",
" [-0.3605],\n",
" [-0.5222],\n",
" [-0.6530],\n",
" [-0.6716],\n",
" [-0.4330],\n",
" [-0.2577],\n",
" [-0.2672],\n",
" [-0.1297],\n",
" [-0.9203],\n",
" [-0.5832],\n",
" [-0.2640],\n",
" [-0.4996],\n",
" [-0.2625],\n",
" [-0.4407],\n",
" [-0.8864],\n",
" [-0.2508],\n",
" [-0.4827],\n",
" [-0.3131],\n",
" [-0.2570],\n",
" [-0.7116],\n",
" [-0.5357],\n",
" [-0.7074]]))"
2026-03-25 15:07:28 +00:00
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 85,
2026-03-25 15:07:28 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 85
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:24.385478317Z",
"start_time": "2026-04-22T07:03:24.319263622Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"batch_size, n_train = 16, 600\n",
"# 只有前n_train个样本用于训练\n",
"train_iter = d2l.load_array((features[:n_train], labels[:n_train]),\n",
"batch_size, is_train=True)\n"
],
"id": "239a596b20d40dec",
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 86
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:24.452682331Z",
"start_time": "2026-04-22T07:03:24.386760825Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"def init_weights(m):\n",
" if type(m) == nn.Linear:\n",
" nn.init.xavier_uniform_(m.weight)"
],
"id": "54d30bd0ee41cb8",
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 87
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:24.505029692Z",
"start_time": "2026-04-22T07:03:24.454087829Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"def get_net():\n",
" net = nn.Sequential(nn.Linear(4, 10),\n",
" nn.ReLU(),\n",
" nn.Linear(10, 1))\n",
" net.apply(init_weights)\n",
" return net\n",
"loss = nn.MSELoss(reduction='none')"
],
"id": "5d095792e3b3681",
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 88
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:24.610623781Z",
"start_time": "2026-04-22T07:03:24.507553628Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"def train(net, train_iter, loss, epochs, lr):\n",
" trainer = torch.optim.Adam(net.parameters(), lr)\n",
" for epoch in range(epochs):\n",
" for X, y in train_iter:\n",
" trainer.zero_grad()\n",
" l = loss(net(X), y)\n",
" l.sum().backward()\n",
" trainer.step()\n",
" print(f'epoch {epoch + 1}, '\n",
" f'loss: {d2l.evaluate_loss(net, train_iter, loss):f}')\n",
"net = get_net()\n",
"train(net, train_iter, loss, 5, 0.01)"
],
"id": "5c1dba484e805335",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2026-04-22 07:23:35 +00:00
"epoch 1, loss: 0.082748\n",
"epoch 2, loss: 0.066498\n",
"epoch 3, loss: 0.061828\n",
"epoch 4, loss: 0.059193\n",
"epoch 5, loss: 0.058498\n"
2026-03-25 15:07:28 +00:00
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/yukun/.conda/envs/nn/lib/python3.11/site-packages/d2l/torch.py:3179: UserWarning: Converting a tensor with requires_grad=True to a scalar may lead to unexpected behavior.\n",
"Consider using tensor.detach() first. (Triggered internally at /pytorch/torch/csrc/autograd/generated/python_variable_methods.cpp:836.)\n",
" self.data = [a + float(b) for a, b in zip(self.data, args)]\n"
]
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 89
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:24.713050170Z",
"start_time": "2026-04-22T07:03:24.616031026Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"onestep_preds = net(features)\n",
"d2l.plot([time, time[tau:]],\n",
"[x.detach().numpy(), onestep_preds.detach().numpy()], 'time',\n",
"'x', legend=['data', '1-step preds'], xlim=[1, 1000],\n",
"figsize=(6, 3))"
],
"id": "a6efb4a978cd9375",
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 600x300 with 1 Axes>"
],
2026-04-22 07:23:35 +00:00
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"406.885938pt\" height=\"211.07625pt\" viewBox=\"0 0 406.885938 211.07625\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2026-04-22T15:03:24.684021</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.7.2, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 211.07625 \nL 406.885938 211.07625 \nL 406.885938 0 \nL 0 0 \nz\n\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 52.160938 173.52 \nL 386.960938 173.52 \nL 386.960938 7.2 \nL 52.160938 7.2 \nz\n\"/>\n </g>\n <g id=\"matplotlib.axis_1\">\n <g id=\"xtick_1\">\n <g id=\"line2d_1\">\n <path d=\"M 118.852829 173.52 \nL 118.852829 7.2 \n\" clip-path=\"url(#pada85592f5)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_2\">\n <defs>\n <path id=\"m22554b858e\" d=\"M 0 0 \nL 0 3.5 \n\" style=\"stroke: #ffffff; stroke-width: 0.8\"/>\n </defs>\n <g>\n <use xlink:href=\"#m22554b858e\" x=\"118.852829\" y=\"173.52\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_1\">\n <!-- 200 -->\n <g style=\"fill: #ffffff\" transform=\"translate(109.309079 188.118438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-32\" d=\"M 1228 531 \nL 3431 531 \nL 3431 0 \nL 469 0 \nL 469 531 \nQ 828 903 1448 1529 \nQ 2069 2156 2228 2338 \nQ 2531 2678 2651 2914 \nQ 2772 3150 2772 3378 \nQ 2772 3750 2511 3984 \nQ 2250 4219 1831 4219 \nQ 1534 4219 1204 4116 \nQ 875 4013 500 3803 \nL 500 4441 \nQ 881 4594 1212 4672 \nQ 1544 4750 1819 4750 \nQ 2544 4750 2975 4387 \nQ 3406 4025 3406 3419 \nQ 3406 3131 3298 2873 \nQ 3191 2616 2906 2266 \nQ 2828 2175 2409 1742 \nQ 1991 1309 1228 531 \nz\n\" transform=\"scale(0.015625)\"/>\n <path id=\"DejaVuSans-30\" d=\"M 2034 4250 \nQ 1547 4250 1301 3770 \nQ 1056 3291 1056 2328 \nQ 1056 1369 1301 889 \nQ 1547 409 2034 409 \nQ 2525 409 2770 889 \nQ 3016 1369 3016 2328 \nQ 3016 3291 2770 3770 \nQ 2525 4250 2034 4250 \nz\nM 2034 4750 \nQ 2819 4750 3233 4129 \nQ 3647 3509 3647 2328 \nQ 3647 1150 3233 529 \nQ 2819 -91 2034 -91 \nQ 1250 -91 836 529 \nQ 422 1150 422 2328 \nQ 422 3509 836 4129 \nQ 1250 4750 2034 4750 \nz\n\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-32\"/>\n <use xlink:href=\"#DejaVuSans-30\" x=\"63.623047\"/>\n <use xlink:href=\"#DejaVuSans-30\" x=\"127.246094\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_2\">\n <g id=\"line2d_3\">\n <path d=\"M 185.879856 173.52 \nL 185.879856 7.2 \n\" clip-path=\"url(#pada85592f5)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_4\">\n <g>\n <use xlink:href=\"#m22554b858e\" x=\"185.879856\" y=\"173.52\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_2\">\n <!-- 400 -->\n <g style=\"fill: #ffffff\" transform=\"translate(176.336106 188.118438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-34\" d=\"M 2419 4116 \nL 825 1625 \nL 2419 1625 \nL 2419 4116 \nz\nM 2253 4666 \nL 3047 4666 \nL 3047 1625 \nL 3713 1625 \nL 3713 1100 \nL 3047 1100 \nL 3047 0 \nL 2419 0 \nL 2419 11
2026-03-25 15:07:28 +00:00
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 90
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:24.848161521Z",
"start_time": "2026-04-22T07:03:24.729135740Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"multistep_preds = torch.zeros(T)\n",
"multistep_preds[: n_train + tau] = x[: n_train + tau]\n",
"for i in range(n_train + tau, T):\n",
" multistep_preds[i] = net(\n",
" multistep_preds[i - tau:i].reshape((1, -1)))\n",
"d2l.plot([time, time[tau:], time[n_train + tau:]],\n",
" [x.detach().numpy(), onestep_preds.detach().numpy(),\n",
" multistep_preds[n_train + tau:].detach().numpy()], 'time',\n",
" 'x', legend=['data', '1-step preds', 'multistep preds'],\n",
" xlim=[1, 1000], figsize=(6, 3))"
],
"id": "12c3a1c3912da4dd",
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 600x300 with 1 Axes>"
],
2026-04-22 07:23:35 +00:00
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"406.885938pt\" height=\"211.07625pt\" viewBox=\"0 0 406.885938 211.07625\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2026-04-22T15:03:24.805438</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.7.2, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 211.07625 \nL 406.885938 211.07625 \nL 406.885938 0 \nL 0 0 \nz\n\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 52.160938 173.52 \nL 386.960938 173.52 \nL 386.960938 7.2 \nL 52.160938 7.2 \nz\n\"/>\n </g>\n <g id=\"matplotlib.axis_1\">\n <g id=\"xtick_1\">\n <g id=\"line2d_1\">\n <path d=\"M 118.852829 173.52 \nL 118.852829 7.2 \n\" clip-path=\"url(#p9136e6403c)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_2\">\n <defs>\n <path id=\"m5e61501fd7\" d=\"M 0 0 \nL 0 3.5 \n\" style=\"stroke: #ffffff; stroke-width: 0.8\"/>\n </defs>\n <g>\n <use xlink:href=\"#m5e61501fd7\" x=\"118.852829\" y=\"173.52\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_1\">\n <!-- 200 -->\n <g style=\"fill: #ffffff\" transform=\"translate(109.309079 188.118438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-32\" d=\"M 1228 531 \nL 3431 531 \nL 3431 0 \nL 469 0 \nL 469 531 \nQ 828 903 1448 1529 \nQ 2069 2156 2228 2338 \nQ 2531 2678 2651 2914 \nQ 2772 3150 2772 3378 \nQ 2772 3750 2511 3984 \nQ 2250 4219 1831 4219 \nQ 1534 4219 1204 4116 \nQ 875 4013 500 3803 \nL 500 4441 \nQ 881 4594 1212 4672 \nQ 1544 4750 1819 4750 \nQ 2544 4750 2975 4387 \nQ 3406 4025 3406 3419 \nQ 3406 3131 3298 2873 \nQ 3191 2616 2906 2266 \nQ 2828 2175 2409 1742 \nQ 1991 1309 1228 531 \nz\n\" transform=\"scale(0.015625)\"/>\n <path id=\"DejaVuSans-30\" d=\"M 2034 4250 \nQ 1547 4250 1301 3770 \nQ 1056 3291 1056 2328 \nQ 1056 1369 1301 889 \nQ 1547 409 2034 409 \nQ 2525 409 2770 889 \nQ 3016 1369 3016 2328 \nQ 3016 3291 2770 3770 \nQ 2525 4250 2034 4250 \nz\nM 2034 4750 \nQ 2819 4750 3233 4129 \nQ 3647 3509 3647 2328 \nQ 3647 1150 3233 529 \nQ 2819 -91 2034 -91 \nQ 1250 -91 836 529 \nQ 422 1150 422 2328 \nQ 422 3509 836 4129 \nQ 1250 4750 2034 4750 \nz\n\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-32\"/>\n <use xlink:href=\"#DejaVuSans-30\" x=\"63.623047\"/>\n <use xlink:href=\"#DejaVuSans-30\" x=\"127.246094\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_2\">\n <g id=\"line2d_3\">\n <path d=\"M 185.879856 173.52 \nL 185.879856 7.2 \n\" clip-path=\"url(#p9136e6403c)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_4\">\n <g>\n <use xlink:href=\"#m5e61501fd7\" x=\"185.879856\" y=\"173.52\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_2\">\n <!-- 400 -->\n <g style=\"fill: #ffffff\" transform=\"translate(176.336106 188.118438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-34\" d=\"M 2419 4116 \nL 825 1625 \nL 2419 1625 \nL 2419 4116 \nz\nM 2253 4666 \nL 3047 4666 \nL 3047 1625 \nL 3713 1625 \nL 3713 1100 \nL 3047 1100 \nL 3047 0 \nL 2419 0 \nL 2419 11
2026-03-25 15:07:28 +00:00
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 91
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:24.898795229Z",
"start_time": "2026-04-22T07:03:24.850233508Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"import collections\n",
"import re"
],
"id": "aab66c10a4c143d2",
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 92
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:24.953966099Z",
"start_time": "2026-04-22T07:03:24.901011365Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"d2l.DATA_HUB['time_machine'] = (d2l.DATA_URL + 'timemachine.txt',\n",
"'090b5e7e70c295757f55df93cb0a180b9691891a')\n",
"def read_time_machine(): #@save\n",
" \"\"\"将时间机器数据集加载到文本行的列表中\"\"\"\n",
" with open(d2l.download('time_machine'), 'r') as f:\n",
" lines = f.readlines()\n",
" return [re.sub('[^A-Za-z]+', ' ', line).strip().lower() for line in lines]\n",
"lines = read_time_machine()\n",
"print(f'# 文本总行数: {len(lines)}')\n",
"print(lines[0])\n",
"print(lines[10])"
],
"id": "1aff117af810525e",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"# 文本总行数: 3221\n",
"the time machine by h g wells\n",
"twinkled and his usually pale face was flushed and animated the\n"
]
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 93
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:25.021941150Z",
"start_time": "2026-04-22T07:03:24.965670989Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"def tokenize(lines, token='word'): #@save\n",
" \"\"\"将文本行拆分为单词或字符词元\"\"\"\n",
" if token == 'word':\n",
" return [line.split() for line in lines]\n",
" elif token == 'char':\n",
" return [list(line) for line in lines]\n",
" else:\n",
" print('错误:未知词元类型:' + token)\n",
"tokens = tokenize(lines)\n",
"for i in range(11):\n",
" print(tokens[i])"
],
"id": "eb4fe9745fbaa5e2",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['the', 'time', 'machine', 'by', 'h', 'g', 'wells']\n",
"[]\n",
"[]\n",
"[]\n",
"[]\n",
"['i']\n",
"[]\n",
"[]\n",
"['the', 'time', 'traveller', 'for', 'so', 'it', 'will', 'be', 'convenient', 'to', 'speak', 'of', 'him']\n",
"['was', 'expounding', 'a', 'recondite', 'matter', 'to', 'us', 'his', 'grey', 'eyes', 'shone', 'and']\n",
"['twinkled', 'and', 'his', 'usually', 'pale', 'face', 'was', 'flushed', 'and', 'animated', 'the']\n"
]
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 94
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:25.086591433Z",
"start_time": "2026-04-22T07:03:25.032865323Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"def count_corpus(tokens): #@save\n",
" \"\"\"统计词元的频率\"\"\"\n",
" # 这里的tokens是1D列表或2D列表\n",
" if len(tokens) == 0 or isinstance(tokens[0], list):\n",
" # 将词元列表展平成一个列表\n",
" tokens = [token for line in tokens for token in line]\n",
" return collections.Counter(tokens)\n",
"class Vocab: #@save\n",
" \"\"\"文本词表\"\"\"\n",
" def __init__(self, tokens=None, min_freq=0, reserved_tokens=None):\n",
" if tokens is None:\n",
" tokens = []\n",
" if reserved_tokens is None:\n",
" reserved_tokens = []\n",
" # 按出现频率排序\n",
" counter = count_corpus(tokens)\n",
" self._token_freqs = sorted(counter.items(), key=lambda x: x[1],\n",
" reverse=True)\n",
" # 未知词元的索引为0\n",
" self.idx_to_token = ['<unk>'] + reserved_tokens\n",
" self.token_to_idx = {token: idx\n",
" for idx, token in enumerate(self.idx_to_token)}\n",
" for token, freq in self._token_freqs:\n",
" if freq < min_freq:\n",
" break\n",
" if token not in self.token_to_idx:\n",
" self.idx_to_token.append(token)\n",
" self.token_to_idx[token] = len(self.idx_to_token) - 1\n",
" def __len__(self):\n",
" return len(self.idx_to_token)\n",
" def __getitem__(self, tokens):\n",
" if not isinstance(tokens, (list, tuple)):\n",
" return self.token_to_idx.get(tokens, self.unk)\n",
" return [self.__getitem__(token) for token in tokens]\n",
" def to_tokens(self, indices):\n",
" if not isinstance(indices, (list, tuple)):\n",
" return self.idx_to_token[indices]\n",
" return [self.idx_to_token[index] for index in indices]\n",
" @property\n",
" def unk(self): # 未知词元的索引为0\n",
" return 0\n",
" @property\n",
" def token_freqs(self):\n",
" return self._token_freqs"
],
"id": "bee8e5d7b798c6c",
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 95
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:25.242997320Z",
"start_time": "2026-04-22T07:03:25.089155138Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"vocab = Vocab(tokens)\n",
"print(list(vocab.token_to_idx.items())[:10])"
],
"id": "ff4e8ac2044850b7",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[('<unk>', 0), ('the', 1), ('i', 2), ('and', 3), ('of', 4), ('a', 5), ('to', 6), ('was', 7), ('in', 8), ('that', 9)]\n"
]
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 96
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:25.308904788Z",
"start_time": "2026-04-22T07:03:25.255767350Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"for i in [0, 100]:\n",
" print('文本:', tokens[i])\n",
" print('索引:', vocab[tokens[i]])"
],
"id": "a4e569dfbd251608",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"文本: ['the', 'time', 'machine', 'by', 'h', 'g', 'wells']\n",
"索引: [1, 19, 50, 40, 2183, 2184, 400]\n",
"文本: ['were', 'three', 'dimensional', 'representations', 'of', 'his', 'four', 'dimensioned']\n",
"索引: [20, 175, 1452, 2250, 4, 25, 262, 2251]\n"
]
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 97
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:25.387400952Z",
"start_time": "2026-04-22T07:03:25.322962901Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"def load_corpus_time_machine(max_tokens=-1): #@save\n",
" \"\"\"返回时光机器数据集的词元索引列表和词表\"\"\"\n",
" lines = read_time_machine()\n",
" tokens = tokenize(lines, 'char')\n",
" vocab = Vocab(tokens)\n",
" # 因为时光机器数据集中的每个文本行不一定是一个句子或一个段落,\n",
" # 所以将所有文本行展平到一个列表中\n",
" corpus = [vocab[token] for line in tokens for token in line]\n",
" if max_tokens > 0:\n",
" corpus = corpus[:max_tokens]\n",
" return corpus, vocab\n",
"corpus, vocab = load_corpus_time_machine()\n",
"\n",
"len(corpus), len(vocab)\n"
],
"id": "1b5c1776ae47af5c",
"outputs": [
{
"data": {
"text/plain": [
"(170580, 28)"
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 98,
2026-03-25 15:07:28 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 98
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:25.453063309Z",
"start_time": "2026-04-22T07:03:25.388764888Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"tokens = d2l.tokenize(read_time_machine())\n",
"# 因为每个文本行不一定是一个句子或一个段落,因此我们把所有文本行拼接到一起\n",
"corpus = [token for line in tokens for token in line]\n",
"vocab = d2l.Vocab(corpus)\n",
"vocab.token_freqs[:10]"
],
"id": "99deb85c025e5cdd",
"outputs": [
{
"data": {
"text/plain": [
"[('the', 2261),\n",
" ('i', 1267),\n",
" ('and', 1245),\n",
" ('of', 1155),\n",
" ('a', 816),\n",
" ('to', 695),\n",
" ('was', 552),\n",
" ('in', 541),\n",
" ('that', 443),\n",
" ('my', 440)]"
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 99,
2026-03-25 15:07:28 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 99
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:25.731046331Z",
"start_time": "2026-04-22T07:03:25.454316415Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"freqs = [freq for token, freq in vocab.token_freqs]\n",
"d2l.plot(freqs, xlabel='token: x', ylabel='frequency: n(x)',\n",
"xscale='log', yscale='log')"
],
"id": "5c0d846673e16c33",
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 350x250 with 1 Axes>"
],
2026-04-22 07:23:35 +00:00
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"247.978125pt\" height=\"183.35625pt\" viewBox=\"0 0 247.978125 183.35625\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2026-04-22T15:03:25.667027</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.7.2, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 183.35625 \nL 247.978125 183.35625 \nL 247.978125 0 \nL 0 0 \nz\n\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 45.478125 145.8 \nL 240.778125 145.8 \nL 240.778125 7.2 \nL 45.478125 7.2 \nz\n\"/>\n </g>\n <g id=\"matplotlib.axis_1\">\n <g id=\"xtick_1\">\n <g id=\"line2d_1\">\n <path d=\"M 54.355398 145.8 \nL 54.355398 7.2 \n\" clip-path=\"url(#pa5617703bf)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_2\">\n <defs>\n <path id=\"ma46b1d16ab\" d=\"M 0 0 \nL 0 3.5 \n\" style=\"stroke: #ffffff; stroke-width: 0.8\"/>\n </defs>\n <g>\n <use xlink:href=\"#ma46b1d16ab\" x=\"54.355398\" y=\"145.8\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_1\">\n <!-- $\\mathdefault{10^{0}}$ -->\n <g style=\"fill: #ffffff\" transform=\"translate(45.555398 160.398438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-31\" d=\"M 794 531 \nL 1825 531 \nL 1825 4091 \nL 703 3866 \nL 703 4441 \nL 1819 4666 \nL 2450 4666 \nL 2450 531 \nL 3481 531 \nL 3481 0 \nL 794 0 \nL 794 531 \nz\n\" transform=\"scale(0.015625)\"/>\n <path id=\"DejaVuSans-30\" d=\"M 2034 4250 \nQ 1547 4250 1301 3770 \nQ 1056 3291 1056 2328 \nQ 1056 1369 1301 889 \nQ 1547 409 2034 409 \nQ 2525 409 2770 889 \nQ 3016 1369 3016 2328 \nQ 3016 3291 2770 3770 \nQ 2525 4250 2034 4250 \nz\nM 2034 4750 \nQ 2819 4750 3233 4129 \nQ 3647 3509 3647 2328 \nQ 3647 1150 3233 529 \nQ 2819 -91 2034 -91 \nQ 1250 -91 836 529 \nQ 422 1150 422 2328 \nQ 422 3509 836 4129 \nQ 1250 4750 2034 4750 \nz\n\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-31\" transform=\"translate(0 0.765625)\"/>\n <use xlink:href=\"#DejaVuSans-30\" transform=\"translate(63.623047 0.765625)\"/>\n <use xlink:href=\"#DejaVuSans-30\" transform=\"translate(128.203125 39.046875) scale(0.7)\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_2\">\n <g id=\"line2d_3\">\n <path d=\"M 102.85613 145.8 \nL 102.85613 7.2 \n\" clip-path=\"url(#pa5617703bf)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_4\">\n <g>\n <use xlink:href=\"#ma46b1d16ab\" x=\"102.85613\" y=\"145.8\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_2\">\n <!-- $\\mathdefault{10^{1}}$ -->\n <g style=\"fill: #ffffff\" transform=\"translate(94.05613 160.398438) scale(0.1 -0.1)\">\n <use xlink:href=\"#DejaVuSans-31\" transform=\"translate(0 0.684375)\"/>\n <use xlink:href=\"#DejaVuSans-30\" transform=\"translate(63.623047 0.684375)\"/>\n <use xlink:href=\"#DejaVuSans-31\" transform=\"translate(128.203125 38.965625) scale(0.7)\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_3\">\n <g id=\"line2d_5\">\n <path d=\"M 151.356861 145.8 \nL 15
2026-03-25 15:07:28 +00:00
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 100
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:25.792389729Z",
"start_time": "2026-04-22T07:03:25.734951149Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"bigram_tokens = [pair for pair in zip(corpus[:-1], corpus[1:])]\n",
"bigram_vocab = Vocab(bigram_tokens)\n",
"bigram_vocab.token_freqs[:10]"
],
"id": "2826e3ab0863ee64",
"outputs": [
{
"data": {
"text/plain": [
"[(('of', 'the'), 309),\n",
" (('in', 'the'), 169),\n",
" (('i', 'had'), 130),\n",
" (('i', 'was'), 112),\n",
" (('and', 'the'), 109),\n",
" (('the', 'time'), 102),\n",
" (('it', 'was'), 99),\n",
" (('to', 'the'), 85),\n",
" (('as', 'i'), 78),\n",
" (('of', 'a'), 73)]"
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 101,
2026-03-25 15:07:28 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 101
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:25.888424198Z",
"start_time": "2026-04-22T07:03:25.808960001Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"trigram_tokens = [triple for triple in zip(\n",
"corpus[:-2], corpus[1:-1], corpus[2:])]\n",
"trigram_vocab = Vocab(trigram_tokens)\n",
"trigram_vocab.token_freqs[:10]"
],
"id": "7c8cd0544bf872bb",
"outputs": [
{
"data": {
"text/plain": [
"[(('the', 'time', 'traveller'), 59),\n",
" (('the', 'time', 'machine'), 30),\n",
" (('the', 'medical', 'man'), 24),\n",
" (('it', 'seemed', 'to'), 16),\n",
" (('it', 'was', 'a'), 15),\n",
" (('here', 'and', 'there'), 15),\n",
" (('seemed', 'to', 'me'), 14),\n",
" (('i', 'did', 'not'), 14),\n",
" (('i', 'saw', 'the'), 13),\n",
" (('i', 'began', 'to'), 13)]"
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 102,
2026-03-25 15:07:28 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 102
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:26.133953109Z",
"start_time": "2026-04-22T07:03:25.889784531Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"bigram_freqs = [freq for token, freq in bigram_vocab.token_freqs]\n",
"trigram_freqs = [freq for token, freq in trigram_vocab.token_freqs]\n",
"d2l.plot([freqs, bigram_freqs, trigram_freqs], xlabel='token: x',\n",
"ylabel='frequency: n(x)', xscale='log', yscale='log',\n",
"legend=['unigram', 'bigram', 'trigram'])"
],
"id": "dc3d97dda738613d",
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 350x250 with 1 Axes>"
],
2026-04-22 07:23:35 +00:00
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"247.978125pt\" height=\"183.35625pt\" viewBox=\"0 0 247.978125 183.35625\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2026-04-22T15:03:26.075629</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.7.2, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 183.35625 \nL 247.978125 183.35625 \nL 247.978125 0 \nL 0 0 \nz\n\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 45.478125 145.8 \nL 240.778125 145.8 \nL 240.778125 7.2 \nL 45.478125 7.2 \nz\n\"/>\n </g>\n <g id=\"matplotlib.axis_1\">\n <g id=\"xtick_1\">\n <g id=\"line2d_1\">\n <path d=\"M 54.355398 145.8 \nL 54.355398 7.2 \n\" clip-path=\"url(#p000c6e7d8a)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_2\">\n <defs>\n <path id=\"mc905ff436c\" d=\"M 0 0 \nL 0 3.5 \n\" style=\"stroke: #ffffff; stroke-width: 0.8\"/>\n </defs>\n <g>\n <use xlink:href=\"#mc905ff436c\" x=\"54.355398\" y=\"145.8\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_1\">\n <!-- $\\mathdefault{10^{0}}$ -->\n <g style=\"fill: #ffffff\" transform=\"translate(45.555398 160.398438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-31\" d=\"M 794 531 \nL 1825 531 \nL 1825 4091 \nL 703 3866 \nL 703 4441 \nL 1819 4666 \nL 2450 4666 \nL 2450 531 \nL 3481 531 \nL 3481 0 \nL 794 0 \nL 794 531 \nz\n\" transform=\"scale(0.015625)\"/>\n <path id=\"DejaVuSans-30\" d=\"M 2034 4250 \nQ 1547 4250 1301 3770 \nQ 1056 3291 1056 2328 \nQ 1056 1369 1301 889 \nQ 1547 409 2034 409 \nQ 2525 409 2770 889 \nQ 3016 1369 3016 2328 \nQ 3016 3291 2770 3770 \nQ 2525 4250 2034 4250 \nz\nM 2034 4750 \nQ 2819 4750 3233 4129 \nQ 3647 3509 3647 2328 \nQ 3647 1150 3233 529 \nQ 2819 -91 2034 -91 \nQ 1250 -91 836 529 \nQ 422 1150 422 2328 \nQ 422 3509 836 4129 \nQ 1250 4750 2034 4750 \nz\n\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-31\" transform=\"translate(0 0.765625)\"/>\n <use xlink:href=\"#DejaVuSans-30\" transform=\"translate(63.623047 0.765625)\"/>\n <use xlink:href=\"#DejaVuSans-30\" transform=\"translate(128.203125 39.046875) scale(0.7)\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_2\">\n <g id=\"line2d_3\">\n <path d=\"M 94.026857 145.8 \nL 94.026857 7.2 \n\" clip-path=\"url(#p000c6e7d8a)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_4\">\n <g>\n <use xlink:href=\"#mc905ff436c\" x=\"94.026857\" y=\"145.8\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_2\">\n <!-- $\\mathdefault{10^{1}}$ -->\n <g style=\"fill: #ffffff\" transform=\"translate(85.226857 160.398438) scale(0.1 -0.1)\">\n <use xlink:href=\"#DejaVuSans-31\" transform=\"translate(0 0.684375)\"/>\n <use xlink:href=\"#DejaVuSans-30\" transform=\"translate(63.623047 0.684375)\"/>\n <use xlink:href=\"#DejaVuSans-31\" transform=\"translate(128.203125 38.965625) scale(0.7)\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_3\">\n <g id=\"line2d_5\">\n <path d=\"M 133.698316 145.8 \nL 1
2026-03-25 15:07:28 +00:00
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 103
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:26.196704442Z",
"start_time": "2026-04-22T07:03:26.148119411Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"import random\n",
"def seq_data_iter_random(corpus, batch_size, num_steps): #@save\n",
" \"\"\"使用随机抽样生成一个小批量子序列\"\"\"\n",
" # 从随机偏移量开始对序列进行分区, 随机范围包括num_steps-1\n",
" corpus = corpus[random.randint(0, num_steps - 1):]\n",
" # 减去1, 是因为我们需要考虑标签\n",
" num_subseqs = (len(corpus) - 1) // num_steps\n",
" # 长度为num_steps的子序列的起始索引\n",
" initial_indices = list(range(0, num_subseqs * num_steps, num_steps))\n",
" # 在随机抽样的迭代过程中,\n",
" # 来自两个相邻的、随机的、小批量中的子序列不一定在原始序列上相邻\n",
" random.shuffle(initial_indices)\n",
" def data(pos):\n",
" # 返回从pos位置开始的长度为num_steps的序列\n",
" return corpus[pos: pos + num_steps]\n",
" num_batches = num_subseqs // batch_size\n",
" for i in range(0, batch_size * num_batches, batch_size):\n",
" # 在这里, initial_indices包含子序列的随机起始索引\n",
" initial_indices_per_batch = initial_indices[i: i + batch_size]\n",
" X = [data(j) for j in initial_indices_per_batch]\n",
" Y = [data(j + 1) for j in initial_indices_per_batch]\n",
" yield torch.tensor(X), torch.tensor(Y)"
],
"id": "fd015793938b83ab",
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 104
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:26.251762886Z",
"start_time": "2026-04-22T07:03:26.199184612Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"my_seq = list(range(35))\n",
"for X, Y in seq_data_iter_random(my_seq, batch_size=2, num_steps=5):\n",
" print('X: ', X, '\\nY:', Y)"
],
"id": "8961f4934b8c7cc",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2026-04-22 07:23:35 +00:00
"X: tensor([[ 3, 4, 5, 6, 7],\n",
" [28, 29, 30, 31, 32]]) \n",
"Y: tensor([[ 4, 5, 6, 7, 8],\n",
" [29, 30, 31, 32, 33]])\n",
"X: tensor([[23, 24, 25, 26, 27],\n",
" [ 8, 9, 10, 11, 12]]) \n",
"Y: tensor([[24, 25, 26, 27, 28],\n",
" [ 9, 10, 11, 12, 13]])\n",
"X: tensor([[18, 19, 20, 21, 22],\n",
" [13, 14, 15, 16, 17]]) \n",
"Y: tensor([[19, 20, 21, 22, 23],\n",
" [14, 15, 16, 17, 18]])\n"
2026-03-25 15:07:28 +00:00
]
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 105
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:26.315745320Z",
"start_time": "2026-04-22T07:03:26.265138594Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"def seq_data_iter_sequential(corpus, batch_size, num_steps): #@save\n",
" \"\"\"使用顺序分区生成一个小批量子序列\"\"\"\n",
" # 从随机偏移量开始划分序列\n",
" offset = random.randint(0, num_steps)\n",
" num_tokens = ((len(corpus) - offset - 1) // batch_size) * batch_size\n",
" Xs = torch.tensor(corpus[offset: offset + num_tokens])\n",
" Ys = torch.tensor(corpus[offset + 1: offset + 1 + num_tokens])\n",
" Xs, Ys = Xs.reshape(batch_size, -1), Ys.reshape(batch_size, -1)\n",
" num_batches = Xs.shape[1] // num_steps\n",
" for i in range(0, num_steps * num_batches, num_steps):\n",
" X = Xs[:, i: i + num_steps]\n",
" Y = Ys[:, i: i + num_steps]\n",
" yield X, Y"
],
"id": "621b66c0614b22da",
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 106
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:26.366004915Z",
"start_time": "2026-04-22T07:03:26.318858875Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"class SeqDataLoader:\n",
" \"\"\"加载序列数据的迭代器\"\"\"\n",
" def __init__(self, batch_size, num_steps, use_random_iter, max_tokens):\n",
" if use_random_iter:\n",
" self.data_iter_fn = seq_data_iter_random\n",
" else:\n",
" self.data_iter_fn = seq_data_iter_sequential\n",
" self.corpus, self.vocab = load_corpus_time_machine(max_tokens)\n",
" self.batch_size, self.num_steps = batch_size, num_steps\n",
" def __iter__(self):\n",
" return self.data_iter_fn(self.corpus, self.batch_size, self.num_steps)\n",
"def load_data_time_machine(batch_size, num_steps, #@save\n",
" use_random_iter=False, max_tokens=10000):\n",
" \"\"\"返回时光机器数据集的迭代器和词表\"\"\"\n",
" data_iter = SeqDataLoader(\n",
" batch_size, num_steps, use_random_iter, max_tokens)\n",
" return data_iter, data_iter.vocab"
],
"id": "f09fe2507a925fe9",
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 107
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:26.415821442Z",
"start_time": "2026-04-22T07:03:26.368476221Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"batch_size, num_steps = 32, 35\n",
"train_iter, vocab = load_data_time_machine(batch_size, num_steps)"
],
"id": "69272a664d3b9ae1",
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 108
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:26.469093543Z",
"start_time": "2026-04-22T07:03:26.417213406Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": "F.one_hot(torch.tensor([0,2]),len(vocab))",
"id": "35806d36e5ec3ca7",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0],\n",
" [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0]])"
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 109,
2026-03-25 15:07:28 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 109
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:26.612564439Z",
"start_time": "2026-04-22T07:03:26.520392652Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"X = torch.arange(10).reshape((2, 5))\n",
"F.one_hot(X.T, 28).shape"
],
"id": "6a4695284b898013",
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([5, 2, 28])"
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 110,
2026-03-25 15:07:28 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 110
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:26.703519372Z",
"start_time": "2026-04-22T07:03:26.641662463Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"def get_params(vocab_size, num_hiddens, device):\n",
" num_inputs = num_outputs = vocab_size\n",
" def normal(shape):\n",
" return torch.randn(size=shape, device=device) * 0.01\n",
" # 隐藏层参数\n",
" W_xh = normal((num_inputs, num_hiddens))\n",
" W_hh = normal((num_hiddens, num_hiddens))\n",
" b_h = torch.zeros(num_hiddens, device=device)\n",
" # 输出层参数\n",
" W_hq = normal((num_hiddens, num_outputs))\n",
" b_q = torch.zeros(num_outputs, device=device)\n",
" # 附加梯度\n",
" params = [W_xh, W_hh, b_h, W_hq, b_q]\n",
" '''\n",
" W_xh x->H_t\n",
" W_hh H_t-1->H_t\n",
" W_hq H_t->opt\n",
" '''\n",
" for param in params:\n",
" param.requires_grad_(True)\n",
" return params\n",
"def init_rnn_state(batch_size, num_hiddens, device):\n",
" return (torch.zeros((batch_size, num_hiddens), device=device), )\n",
"def rnn(inputs,state,params):\n",
" W_xh,W_hh,b_h,W_hq,b_q = params\n",
" H, = state\n",
" outputs = []\n",
" for X in inputs: # X (batchs,vocab)\n",
" H = torch.tanh(torch.mm(X,W_xh)+torch.mm(H,W_hh)+b_h)\n",
" Y = torch.mm(H,W_hq)+b_q\n",
" outputs.append(Y)\n",
" return torch.cat(outputs,dim=0),(H,)\n",
"class RNNModelScratch: #@save\n",
" \"\"\"从零开始实现的循环神经网络模型\"\"\"\n",
" def __init__(self, vocab_size, num_hiddens, device,\n",
" get_params, init_state, forward_fn):\n",
" self.vocab_size, self.num_hiddens = vocab_size, num_hiddens\n",
" self.params = get_params(vocab_size, num_hiddens, device)\n",
" self.init_state, self.forward_fn = init_state, forward_fn\n",
" def __call__(self, X, state):\n",
" X = F.one_hot(X.T, self.vocab_size).type(torch.float32)\n",
" return self.forward_fn(X, state, self.params)\n",
" def begin_state(self, batch_size, device):\n",
" return self.init_state(batch_size, self.num_hiddens, device)"
],
"id": "405f7c6af8bdd939",
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 111
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:26.793241936Z",
"start_time": "2026-04-22T07:03:26.704490837Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"num_hiddens = 512\n",
"net = RNNModelScratch(len(vocab), num_hiddens, d2l.try_gpu(), get_params,\n",
"init_rnn_state, rnn)\n",
"state = net.begin_state(X.shape[0], d2l.try_gpu())\n",
"Y, new_state = net(X.to(d2l.try_gpu()), state)\n",
"Y.shape, len(new_state), new_state[0].shape"
],
"id": "7bdd0b37b1458ec2",
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([10, 28]), 1, torch.Size([2, 512]))"
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 112,
2026-03-25 15:07:28 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 112
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:26.859603241Z",
"start_time": "2026-04-22T07:03:26.795131832Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"def predict_ch8(prefix, num_preds, net, vocab, device): #@save\n",
" \"\"\"在prefix后面生成新字符\"\"\"\n",
" state = net.begin_state(batch_size=1, device=device)\n",
" outputs = [vocab[prefix[0]]]\n",
" get_input = lambda: torch.tensor([outputs[-1]], device=device).reshape((1, 1))\n",
" for y in prefix[1:]: # 预热期\n",
" _, state = net(get_input(), state)\n",
" outputs.append(vocab[y])\n",
" for _ in range(num_preds): # 预测num_preds步\n",
" y, state = net(get_input(), state)\n",
" outputs.append(int(y.argmax(dim=1).reshape(1)))\n",
" return ''.join([vocab.idx_to_token[i] for i in outputs])\n",
"predict_ch8('time traveller ', 10, net, vocab, d2l.try_gpu())"
],
"id": "f38dbe2b11e02bc2",
"outputs": [
{
"data": {
"text/plain": [
2026-04-22 07:23:35 +00:00
"'time traveller v<unk>xmpussss'"
2026-03-25 15:07:28 +00:00
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 113,
2026-03-25 15:07:28 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 113
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:26.909151706Z",
"start_time": "2026-04-22T07:03:26.861214463Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"def grad_clipping(net, theta): #@save\n",
" \"\"\"裁剪梯度\"\"\"\n",
" if isinstance(net, nn.Module):\n",
" params = [p for p in net.parameters() if p.requires_grad]\n",
" else:\n",
" params = net.params\n",
" norm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))\n",
" if norm > theta:\n",
" for param in params:\n",
" param.grad[:] *= theta / norm"
],
"id": "6c19717736ffbc68",
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 114
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:26.964267753Z",
"start_time": "2026-04-22T07:03:26.913802768Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"import math\n",
"def train_epoch_ch8(net, train_iter, loss, updater, device, use_random_iter):\n",
" \"\"\"训练网络一个迭代周期( 定义见第8章) \"\"\"\n",
" state, timer = None, d2l.Timer()\n",
" metric = d2l.Accumulator(2) # 训练损失之和,词元数量\n",
" for X, Y in train_iter:\n",
" if state is None or use_random_iter:\n",
" # 在第一次迭代或使用随机抽样时初始化state\n",
" state = net.begin_state(batch_size=X.shape[0], device=device)\n",
" else:\n",
" if isinstance(net, nn.Module) and not isinstance(state, tuple):\n",
" # state对于nn.GRU是个张量\n",
" state.detach_()\n",
" else:\n",
" # state对于nn.LSTM或对于我们从零开始实现的模型是个张量\n",
" for s in state:\n",
" s.detach_()\n",
" y = Y.T.reshape(-1)\n",
" X, y = X.to(device), y.to(device)\n",
" y_hat, state = net(X, state)\n",
" l = loss(y_hat, y.long()).mean()\n",
" if isinstance(updater, torch.optim.Optimizer):\n",
" updater.zero_grad()\n",
" l.backward()\n",
" grad_clipping(net, 1)\n",
" updater.step()\n",
" else:\n",
" l.backward()\n",
" grad_clipping(net, 1)\n",
" # 因为已经调用了mean函数\n",
" updater(batch_size=1)\n",
" metric.add(l * y.numel(), y.numel())\n",
" return math.exp(metric[0]/metric[1]),metric[1]/timer.stop()\n"
],
"id": "37ef611dedcb714b",
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 115
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:27.028276769Z",
"start_time": "2026-04-22T07:03:26.966889257Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"def train_ch8(net, train_iter, vocab, lr, num_epochs, device,\n",
"use_random_iter=False):\n",
" \"\"\"训练模型( 定义见第8章) \"\"\"\n",
" loss = nn.CrossEntropyLoss()\n",
" animator = d2l.Animator(xlabel='epoch', ylabel='perplexity',\n",
" legend=['train'], xlim=[10, num_epochs])\n",
" # 初始化\n",
" if isinstance(net, nn.Module):\n",
" updater = torch.optim.SGD(net.parameters(), lr)\n",
" else:\n",
" updater = lambda batch_size: d2l.sgd(net.params, lr, batch_size)\n",
" predict = lambda prefix: predict_ch8(prefix, 50, net, vocab, device)\n",
" # 训练和预测\n",
" for epoch in range(num_epochs):\n",
" ppl, speed = train_epoch_ch8(\n",
" net, train_iter, loss, updater, device, use_random_iter)\n",
" if (epoch + 1) % 10 == 0:\n",
" print(predict('time traveller'))\n",
" animator.add(epoch + 1, [ppl])\n",
" print(f'困惑度 {ppl:.1f}, {speed:.1f} 词元/秒 {str(device)}')\n",
" print(predict('time traveller'))\n",
" print(predict('traveller'))"
],
"id": "c96b60b55664378a",
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 116
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:27.077729796Z",
"start_time": "2026-04-22T07:03:27.030591268Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"num_epochs, lr = 500, 1\n",
2026-04-01 15:21:14 +00:00
"#train_ch8(net, train_iter, vocab, lr, num_epochs, d2l.try_gpu())"
2026-03-25 15:07:28 +00:00
],
"id": "ab4a2fbf4dfd21ef",
2026-04-01 15:21:14 +00:00
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 117
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:27.135654266Z",
"start_time": "2026-04-22T07:03:27.079991786Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"batch_size, num_steps = 32, 35\n",
"train_iter, vocab = load_data_time_machine(batch_size, num_steps)"
],
"id": "74d672745751714",
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 118
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:27.595345257Z",
"start_time": "2026-04-22T07:03:27.137667402Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"num_hiddens = 256\n",
"rnn_layer = nn.RNN(len(vocab), num_hiddens)\n",
"state = torch.zeros((1,batch_size,num_hiddens))\n",
"X = torch.rand(size=(num_steps, batch_size, len(vocab)))\n",
"Y, state_new = rnn_layer(X, state)\n",
"Y.shape, state_new.shape"
],
"id": "9694c029c9e657e8",
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([35, 32, 256]), torch.Size([1, 32, 256]))"
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 119,
2026-03-25 15:07:28 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 119
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:27.653466458Z",
"start_time": "2026-04-22T07:03:27.598468205Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"class RNNModel(nn.Module):\n",
" \"\"\"循环神经网络模型\"\"\"\n",
" def __init__(self, rnn_layer, vocab_size, **kwargs):\n",
" super(RNNModel, self).__init__(**kwargs)\n",
" self.rnn = rnn_layer\n",
" self.vocab_size = vocab_size\n",
" self.num_hiddens = self.rnn.hidden_size\n",
" # 如果RNN是双向的( 之后将介绍) , num_directions应该是2, 否则应该是1\n",
" if not self.rnn.bidirectional:\n",
" self.num_directions = 1\n",
" self.linear = nn.Linear(self.num_hiddens, self.vocab_size)\n",
" else:\n",
" self.num_directions = 2\n",
" self.linear = nn.Linear(self.num_hiddens * 2, self.vocab_size)\n",
" def forward(self, inputs, state):\n",
" X = F.one_hot(inputs.T.long(), self.vocab_size)\n",
" X = X.to(torch.float32)\n",
" Y, state = self.rnn(X, state)\n",
" # 全连接层首先将Y的形状改为(时间步数*批量大小,隐藏单元数)\n",
" # 它的输出形状是(时间步数*批量大小,词表大小)。\n",
" output = self.linear(Y.reshape((-1, Y.shape[-1])))\n",
" return output, state\n",
" def begin_state(self, device, batch_size=1):\n",
" if not isinstance(self.rnn, nn.LSTM):\n",
" # nn.GRU以张量作为隐状态\n",
" return torch.zeros((self.num_directions * self.rnn.num_layers,\n",
" batch_size, self.num_hiddens),\n",
" device=device)\n",
" else:\n",
" # nn.LSTM以元组作为隐状态\n",
" return (torch.zeros((\n",
" self.num_directions * self.rnn.num_layers,\n",
" batch_size, self.num_hiddens), device=device),\n",
" torch.zeros((\n",
" self.num_directions * self.rnn.num_layers,\n",
" batch_size, self.num_hiddens), device=device))"
],
"id": "858873034be01538",
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 120
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:27.712481585Z",
"start_time": "2026-04-22T07:03:27.663190923Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
"device = d2l.try_gpu()\n",
"net = RNNModel(rnn_layer, vocab_size=len(vocab))\n",
"net = net.to(device)\n",
2026-04-01 15:21:14 +00:00
"#predict_ch8('time traveller', 10, net, vocab, device)"
2026-03-25 15:07:28 +00:00
],
"id": "d59c1599998c8fd4",
2026-04-01 15:21:14 +00:00
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 121
2026-04-01 15:21:14 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:27.761031426Z",
"start_time": "2026-04-22T07:03:27.714738716Z"
2026-04-01 15:21:14 +00:00
}
},
"cell_type": "code",
"source": [
"vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()\n",
"num_inputs = vocab_size\n",
"gru_layer = nn.GRU(num_inputs, num_hiddens)\n",
"model = RNNModel(gru_layer, len(vocab))\n",
"model = model.to(device)\n",
"#train_ch8(model, train_iter, vocab, lr, num_epochs, device)"
],
"id": "adda23bc3664ec6b",
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 122
2026-04-01 15:21:14 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:27.809310056Z",
"start_time": "2026-04-22T07:03:27.763581864Z"
2026-04-01 15:21:14 +00:00
}
},
"cell_type": "code",
"source": [
"num_inputs = vocab_size\n",
"lstm_layer = nn.LSTM(num_inputs, num_hiddens)\n",
"model = RNNModel(lstm_layer, len(vocab))\n",
"model = model.to(device)\n",
"#train_ch8(model, train_iter, vocab, lr, num_epochs, device)"
],
"id": "b4e30d643d6f755d",
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 123
2026-04-01 15:21:14 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:27.856847550Z",
"start_time": "2026-04-22T07:03:27.811694587Z"
2026-04-01 15:21:14 +00:00
}
},
"cell_type": "code",
"source": [
"d2l.DATA_HUB['fra-eng'] = (d2l.DATA_URL + 'fra-eng.zip',\n",
" '94646ad1522d915e7b0f9296181140edcf86a4f5')"
],
"id": "50554e839be36011",
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 124
2026-04-01 15:21:14 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:27.906829883Z",
"start_time": "2026-04-22T07:03:27.859224691Z"
2026-04-01 15:21:14 +00:00
}
},
"cell_type": "code",
"source": [
"import os\n",
"def read_data_nmt():\n",
" \"\"\"载入“英语-法语”数据集\"\"\"\n",
" data_dir = d2l.download_extract('fra-eng')\n",
2026-04-22 07:23:35 +00:00
" print(data_dir)\n",
2026-04-01 15:21:14 +00:00
" with open(os.path.join(data_dir, 'fra.txt'), 'r',\n",
" encoding='utf-8') as f:\n",
" return f.read()"
],
"id": "9cd4287ed84db220",
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 125
2026-04-01 15:21:14 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:27.980494182Z",
"start_time": "2026-04-22T07:03:27.909008466Z"
2026-04-01 15:21:14 +00:00
}
},
"cell_type": "code",
"source": [
"raw_text = read_data_nmt()\n",
"print(raw_text[:75])"
],
"id": "7c4452b3b6a32f91",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2026-04-22 07:23:35 +00:00
"../data/fra-eng\n",
2026-04-01 15:21:14 +00:00
"Go.\tVa !\n",
"Hi.\tSalut !\n",
"Run!\tCours !\n",
"Run!\tCourez !\n",
"Who?\tQui ?\n",
"Wow!\tÇa alors !\n",
"\n"
]
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 126
2026-04-01 15:21:14 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:29.415816650Z",
"start_time": "2026-04-22T07:03:27.981645004Z"
2026-04-01 15:21:14 +00:00
}
},
"cell_type": "code",
"source": [
"def preprocess_nmt(text):\n",
" def no_space(char,prev_char):\n",
" return char in set(',.!?') and prev_char != ' '\n",
" text = text.replace('\\u202f',' ').replace('\\xa0',' ').lower()\n",
" out = [' ' + char if i >0 and no_space(char,text[i-1]) else char for i,char in enumerate(text)]\n",
" return ''.join(out)\n",
"text = preprocess_nmt(raw_text)\n",
"print(text[:80])"
],
"id": "1c729da265572287",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"go .\tva !\n",
"hi .\tsalut !\n",
"run !\tcours !\n",
"run !\tcourez !\n",
"who ?\tqui ?\n",
"wow !\tça alors !\n"
]
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 127
2026-04-01 15:21:14 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:30.160936469Z",
"start_time": "2026-04-22T07:03:30.102689660Z"
2026-04-01 15:21:14 +00:00
}
},
"cell_type": "code",
"source": [
"def tokenize_nmt(text,num_examples=None):\n",
" source,target = [],[]\n",
" for i,line in enumerate(text.split('\\n')):\n",
" if num_examples and i > num_examples:\n",
" break\n",
" parts = line.split('\\t')\n",
" if len(parts) == 2:\n",
" source.append(parts[0].split(' '))\n",
" target.append(parts[1].split(' '))\n",
" return source,target\n",
"\n"
],
"id": "ca16ef22cbe2c02a",
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 128
2026-04-01 15:21:14 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:30.821851145Z",
"start_time": "2026-04-22T07:03:30.165447701Z"
2026-04-01 15:21:14 +00:00
}
},
"cell_type": "code",
"source": [
"source, target = tokenize_nmt(text)\n",
2026-04-22 07:23:35 +00:00
"source[:5], target[:5]"
2026-04-01 15:21:14 +00:00
],
"id": "5ece5cb4b78168d0",
2026-03-25 15:07:28 +00:00
"outputs": [
{
"data": {
"text/plain": [
2026-04-22 07:23:35 +00:00
"([['go', '.'], ['hi', '.'], ['run', '!'], ['run', '!'], ['who', '?']],\n",
" [['va', '!'], ['salut', '!'], ['cours', '!'], ['courez', '!'], ['qui', '?']])"
2026-03-25 15:07:28 +00:00
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 129,
2026-03-25 15:07:28 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 129
2026-03-25 15:07:28 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:31.001500702Z",
"start_time": "2026-04-22T07:03:30.869254698Z"
2026-03-25 15:07:28 +00:00
}
},
"cell_type": "code",
"source": [
2026-04-01 15:21:14 +00:00
"def show_list_len_pair_hist(legend,xlabel,ylabel,xlist,ylist):\n",
" d2l.set_figsize()\n",
" _,_,patches = d2l.plt.hist([[len(l) for l in xlist],[len(l) for l in ylist]])\n",
" d2l.plt.xlabel(xlabel)\n",
" d2l.plt.ylabel(ylabel)\n",
" for patch in patches[1].patches:\n",
" patch.set_hatch('/')\n",
" d2l.plt.legend(legend)\n",
"\n",
"show_list_len_pair_hist(['source','target'],'# tokens per sequence','count',source,target)"
2026-03-25 15:07:28 +00:00
],
2026-04-01 15:21:14 +00:00
"id": "518249f852ec54c4",
2026-03-25 15:07:28 +00:00
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 350x250 with 1 Axes>"
],
2026-04-22 07:23:35 +00:00
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"274.320356pt\" height=\"183.35625pt\" viewBox=\"0 0 274.320356 183.35625\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2026-04-22T15:03:30.973116</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.7.2, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 183.35625 \nL 274.320356 183.35625 \nL 274.320356 0 \nL 0 0 \nz\n\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 66.053125 145.8 \nL 261.353125 145.8 \nL 261.353125 7.2 \nL 66.053125 7.2 \nz\n\"/>\n </g>\n <g id=\"patch_3\">\n <path d=\"M 74.930398 145.8 \nL 82.177151 145.8 \nL 82.177151 13.8 \nL 74.930398 13.8 \nz\n\" clip-path=\"url(#p080f3d0d15)\" style=\"fill: #8dd3c7\"/>\n </g>\n <g id=\"patch_4\">\n <path d=\"M 93.047281 145.8 \nL 100.294034 145.8 \nL 100.294034 70.342894 \nL 93.047281 70.342894 \nz\n\" clip-path=\"url(#p080f3d0d15)\" style=\"fill: #8dd3c7\"/>\n </g>\n <g id=\"patch_5\">\n <path d=\"M 111.164164 145.8 \nL 118.410917 145.8 \nL 118.410917 141.170363 \nL 111.164164 141.170363 \nz\n\" clip-path=\"url(#p080f3d0d15)\" style=\"fill: #8dd3c7\"/>\n </g>\n <g id=\"patch_6\">\n <path d=\"M 129.281047 145.8 \nL 136.5278 145.8 \nL 136.5278 145.34327 \nL 129.281047 145.34327 \nz\n\" clip-path=\"url(#p080f3d0d15)\" style=\"fill: #8dd3c7\"/>\n </g>\n <g id=\"patch_7\">\n <path d=\"M 147.39793 145.8 \nL 154.644683 145.8 \nL 154.644683 145.744022 \nL 147.39793 145.744022 \nz\n\" clip-path=\"url(#p080f3d0d15)\" style=\"fill: #8dd3c7\"/>\n </g>\n <g id=\"patch_8\">\n <path d=\"M 165.514813 145.8 \nL 172.761567 145.8 \nL 172.761567 145.783461 \nL 165.514813 145.783461 \nz\n\" clip-path=\"url(#p080f3d0d15)\" style=\"fill: #8dd3c7\"/>\n </g>\n <g id=\"patch_9\">\n <path d=\"M 183.631696 145.8 \nL 190.87845 145.8 \nL 190.87845 145.792367 \nL 183.631696 145.792367 \nz\n\" clip-path=\"url(#p080f3d0d15)\" style=\"fill: #8dd3c7\"/>\n </g>\n <g id=\"patch_10\">\n <path d=\"M 201.74858 145.8 \nL 208.995333 145.8 \nL 208.995333 145.798728 \nL 201.74858 145.798728 \nz\n\" clip-path=\"url(#p080f3d0d15)\" style=\"fill: #8dd3c7\"/>\n </g>\n <g id=\"patch_11\">\n <path d=\"M 219.865463 145.8 \nL 227.112216 145.8 \nL 227.112216 145.797456 \nL 219.865463 145.797456 \nz\n\" clip-path=\"url(#p080f3d0d15)\" style=\"fill: #8dd3c7\"/>\n </g>\n <g id=\"patch_12\">\n <path d=\"M 237.982346 145.8 \nL 245.229099 145.8 \nL 245.229099 145.8 \nL 237.982346 145.8 \nz\n\" clip-path=\"url(#p080f3d0d15)\" style=\"fill: #8dd3c7\"/>\n </g>\n <g id=\"patch_13\">\n <path d=\"M 82.177151 145.8 \nL 89.423904 145.8 \nL 89.423904 26.779268 \nL 82.177151 26.779268 \nz\n\" clip-path=\"url(#p080f3d0d15)\" style=\"fill: url(#h6a8879fef6)\"/>\n </g>\n <g id=\"patch_14\">\n <path d=\"M 100.294034 145.8 \nL 107.540787 145.8 \nL 107.540787 60.492034 \nL 100.294034 60.492034 \nz\n\" clip-path=\"url(#p080f3d0d15)\" style=\"fill: url(#h6a8879fef6)\"/>\n </g>\n <g id=\"patch_15\">\n <path d=\"M 118.410917 145.8 \nL 125.65767 145.8 \nL 125.65767 138.604279 \nL 118.410917 138.604279 \nz\n\" clip-path=\"url(#p080f3d0d15)\" style=\"fill: url(#h6a8879fef6)\"/>\n </g>\n <g id=\"patch_16\">\n <path d=\"M 136.5278 145.8 \nL 143.774554 145.8 \nL 143.774554 144.858551 \nL 136.5278 144.858551 \nz
2026-03-25 15:07:28 +00:00
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 130
2026-04-01 15:21:14 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:31.126122101Z",
"start_time": "2026-04-22T07:03:31.016562308Z"
2026-04-01 15:21:14 +00:00
}
},
"cell_type": "code",
"source": [
"src_vocab=Vocab(source,min_freq=2,reserved_tokens=['<pad>','<bos>','<eos>'])\n",
"len(src_vocab)"
],
"id": "c2dc82617d5a41a4",
"outputs": [
{
"data": {
"text/plain": [
"10012"
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 131,
2026-04-01 15:21:14 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 131
2026-04-01 15:21:14 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:31.181454267Z",
"start_time": "2026-04-22T07:03:31.127995134Z"
2026-04-01 15:21:14 +00:00
}
},
"cell_type": "code",
"source": [
"def truncate_pad(line,num_steps,padding_token):\n",
" if len(line) > num_steps:\n",
" return line[:num_steps]\n",
" return line + [padding_token] * (num_steps - len(line))\n",
"truncate_pad(src_vocab[source[0]], 10, src_vocab['<pad>'])"
],
"id": "93ae326a3258ecc",
"outputs": [
{
"data": {
"text/plain": [
"[47, 4, 1, 1, 1, 1, 1, 1, 1, 1]"
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 132,
2026-04-01 15:21:14 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 132
2026-04-01 15:21:14 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:31.244013357Z",
"start_time": "2026-04-22T07:03:31.194631134Z"
2026-04-01 15:21:14 +00:00
}
},
"cell_type": "code",
"source": [
"def build_array_nmt(lines, vocab, num_steps):\n",
" \"\"\"将机器翻译的文本序列转换成小批量\"\"\"\n",
" lines = [vocab[l] for l in lines]\n",
" lines = [l + [vocab['<eos>']] for l in lines]\n",
" array = torch.tensor([truncate_pad(\n",
" l, num_steps, vocab['<pad>']) for l in lines])\n",
" valid_len = (array != vocab['<pad>']).type(torch.int32).sum(1)\n",
" return array, valid_len"
],
"id": "acd4344e678cf487",
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 133
2026-04-01 15:21:14 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:31.295907069Z",
"start_time": "2026-04-22T07:03:31.246433021Z"
2026-04-01 15:21:14 +00:00
}
},
"cell_type": "code",
"source": [
"def load_data_nmt(batch_size, num_steps, num_examples=600):\n",
" \"\"\"返回翻译数据集的迭代器和词表\"\"\"\n",
" text = preprocess_nmt(read_data_nmt())\n",
" source, target = tokenize_nmt(text, num_examples)\n",
" src_vocab = d2l.Vocab(source, min_freq=2,\n",
" reserved_tokens=['<pad>', '<bos>', '<eos>'])\n",
" tgt_vocab = d2l.Vocab(target, min_freq=2,\n",
" reserved_tokens=['<pad>', '<bos>', '<eos>'])\n",
" src_array, src_valid_len = build_array_nmt(source, src_vocab, num_steps)\n",
" tgt_array, tgt_valid_len = build_array_nmt(target, tgt_vocab, num_steps)\n",
" data_arrays = (src_array, src_valid_len, tgt_array, tgt_valid_len)\n",
" data_iter = d2l.load_array(data_arrays, batch_size)\n",
" return data_iter, src_vocab, tgt_vocab"
],
"id": "62586b0175993a4f",
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 134
2026-04-01 15:21:14 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:32.568811153Z",
"start_time": "2026-04-22T07:03:31.298284240Z"
2026-04-01 15:21:14 +00:00
}
},
"cell_type": "code",
"source": [
"train_iter, src_vocab, tgt_vocab = load_data_nmt(batch_size=2, num_steps=8)\n",
"for X, X_valid_len, Y, Y_valid_len in train_iter:\n",
" print('X:', X.type(torch.int32))\n",
" print('X的有效长度:', X_valid_len)\n",
" print('Y:', Y.type(torch.int32))\n",
" print('Y的有效长度:', Y_valid_len)\n",
" break"
],
"id": "87a2f147db41a91d",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2026-04-22 07:23:35 +00:00
"../data/fra-eng\n",
"X: tensor([[ 83, 163, 2, 4, 5, 5, 5, 5],\n",
" [ 29, 69, 2, 4, 5, 5, 5, 5]], dtype=torch.int32)\n",
"X的有效长度: tensor([4, 4])\n",
"Y: tensor([[100, 171, 6, 2, 4, 5, 5, 5],\n",
" [191, 6, 2, 4, 5, 5, 5, 5]], dtype=torch.int32)\n",
"Y的有效长度: tensor([5, 4])\n"
2026-04-01 15:21:14 +00:00
]
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 135
2026-04-01 15:21:14 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:32.679822180Z",
"start_time": "2026-04-22T07:03:32.616779320Z"
2026-04-01 15:21:14 +00:00
}
},
"cell_type": "code",
"source": [
"class Encoder(nn.Module):\n",
" def __init__(self,**kargs):\n",
" super(Encoder,self).__init__(**kargs)\n",
" def forward(self,X,*args):\n",
" raise NotImplementedError(\"必须实现这个方法\")\n",
"class Decoder(nn.Module):\n",
" \"\"\"编码器-解码器架构的基本解码器接口\"\"\"\n",
" def __init__(self, **kwargs):\n",
" super(Decoder, self).__init__(**kwargs)\n",
" def init_state(self, enc_outputs, *args):\n",
" raise NotImplementedError\n",
" def forward(self, X, state):\n",
" raise NotImplementedError\n",
"class EncoderDecoder(nn.Module):\n",
" \"\"\"编码器-解码器架构的基类\"\"\"\n",
" def __init__(self, encoder, decoder, **kwargs):\n",
" super(EncoderDecoder, self).__init__(**kwargs)\n",
" self.encoder = encoder\n",
" self.decoder = decoder\n",
" def forward(self, enc_X, dec_X, *args):\n",
" enc_outputs = self.encoder(enc_X, *args)\n",
" dec_state = self.decoder.init_state(enc_outputs, *args)\n",
" return self.decoder(dec_X, dec_state)\n",
"class Seq2SeqEncoder(Encoder):\n",
" def __init__(self,vocab_size,embed_size,num_hiddens,num_layers,dropout=0,**kwargs):\n",
" super(Seq2SeqEncoder, self).__init__(**kwargs)\n",
" self.embedding = nn.Embedding(vocab_size, embed_size)\n",
" self.rnn = nn.GRU(embed_size,num_hiddens,num_layers,dropout=dropout)\n",
" def forward(self,X,*args):\n",
" X = self.embedding(X)\n",
" X = X.permute(1,0,2) #(batch,steps,embed_size) -> (steps,batch,embed_size)\n",
" output,state = self.rnn(X)\n",
" return output,state\n",
" # shape of output (steps,batch,num_hiddens)\n",
" # shape of state (num_layers,batch,num_hiddens)\n",
"encoder = Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16,\n",
" num_layers=2)\n",
"X = torch.zeros((4,7),dtype=torch.long) #one-hot is integer\n",
"output,state = encoder(X)\n",
"output.shape\n"
],
"id": "d0d01aef4857ee9c",
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([7, 4, 16])"
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 136,
2026-04-01 15:21:14 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 136
2026-04-01 15:21:14 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:32.745403173Z",
"start_time": "2026-04-22T07:03:32.681269832Z"
2026-04-01 15:21:14 +00:00
}
},
"cell_type": "code",
"source": "state.shape",
"id": "bba15a040c10cb01",
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([2, 4, 16])"
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 137,
2026-04-01 15:21:14 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 137
2026-04-01 15:21:14 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:32.800607701Z",
"start_time": "2026-04-22T07:03:32.749059569Z"
2026-04-01 15:21:14 +00:00
}
},
"cell_type": "code",
"source": [
"class Seq2SeqDecoder(Decoder):\n",
" def __init__(self,vocab_size,embed_size,num_hiddens,num_layers,dropout=0,**kwargs):\n",
" super(Seq2SeqDecoder, self).__init__(**kwargs)\n",
" self.embedding = nn.Embedding(vocab_size, embed_size)\n",
" self.rnn = nn.GRU(embed_size+num_hiddens,num_hiddens,num_layers,dropout=dropout)\n",
" self.dense = nn.Linear(num_hiddens,vocab_size)\n",
" def init_state(self,enc_outputs,*args):\n",
" return enc_outputs[1]\n",
" def forward(self,X,state):\n",
" X = self.embedding(X).permute(1,0,2)\n",
" context = state[-1].repeat(X.shape[0],1,1)\n",
" X_and_context = torch.cat((X,context),2)\n",
" output,state = self.rnn(X_and_context,state)\n",
" output = self.dense(output).permute(1,0,2)\n",
" return output,state"
],
"id": "b659bfd2fdcabebe",
"outputs": [],
2026-04-22 07:23:35 +00:00
"execution_count": 138
2026-04-01 15:21:14 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:32.863556712Z",
"start_time": "2026-04-22T07:03:32.801449588Z"
2026-04-01 15:21:14 +00:00
}
},
"cell_type": "code",
"source": [
"decoder = Seq2SeqDecoder(vocab_size=10, embed_size=8, num_hiddens=16,\n",
"num_layers=2)\n",
"decoder.eval()\n",
"state = decoder.init_state(encoder(X))\n",
"output, state = decoder(X, state)\n",
"output.shape, state.shape"
],
"id": "e9c451e560ce3769",
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([4, 7, 10]), torch.Size([2, 4, 16]))"
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 139,
2026-04-01 15:21:14 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 139
2026-04-01 15:21:14 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:32.926789294Z",
"start_time": "2026-04-22T07:03:32.864974421Z"
2026-04-01 15:21:14 +00:00
}
},
"cell_type": "code",
"source": [
"def sequence_mask(X, valid_len, value=0):\n",
" \"\"\"在序列中屏蔽不相关的项\"\"\"\n",
" maxlen = X.size(1)\n",
" mask = torch.arange((maxlen), dtype=torch.float32,\n",
" device=X.device)[None, :] < valid_len[:, None]\n",
" X[~mask] = value\n",
" return X\n",
"X = torch.tensor([[1, 2, 3], [4, 5, 6]])\n",
"sequence_mask(X, torch.tensor([1, 2]))"
],
"id": "9ee2db877c089d48",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[1, 0, 0],\n",
" [4, 5, 0]])"
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 140,
2026-04-01 15:21:14 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 140
2026-04-01 15:21:14 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:33.055040504Z",
"start_time": "2026-04-22T07:03:32.985613772Z"
2026-04-01 15:21:14 +00:00
}
},
"cell_type": "code",
"source": [
"X = torch.ones(2, 3, 4)\n",
"sequence_mask(X, torch.tensor([1, 2]), value=-1)"
],
"id": "caca450bdfe7f650",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[ 1., 1., 1., 1.],\n",
" [-1., -1., -1., -1.],\n",
" [-1., -1., -1., -1.]],\n",
"\n",
" [[ 1., 1., 1., 1.],\n",
" [ 1., 1., 1., 1.],\n",
" [-1., -1., -1., -1.]]])"
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 141,
2026-04-01 15:21:14 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 141
2026-04-01 15:21:14 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-04-22 07:23:35 +00:00
"end_time": "2026-04-22T07:03:33.183816304Z",
"start_time": "2026-04-22T07:03:33.114715644Z"
2026-04-01 15:21:14 +00:00
}
},
"cell_type": "code",
"source": [
"class MaskedSoftmaxCELoss(nn.CrossEntropyLoss):\n",
" \"\"\"带遮蔽的softmax交叉熵损失函数\"\"\"\n",
" # pred的形状: (batch_size,num_steps,vocab_size)\n",
" # label的形状: (batch_size,num_steps)\n",
" # valid_len的形状: (batch_size,)\n",
" def forward(self, pred, label, valid_len):\n",
" weights = torch.ones_like(label)\n",
" weights = sequence_mask(weights, valid_len)\n",
" self.reduction='none'\n",
" unweighted_loss = super(MaskedSoftmaxCELoss, self).forward(\n",
" pred.permute(0, 2, 1), label)\n",
" weighted_loss = (unweighted_loss * weights).mean(dim=1)\n",
" return weighted_loss\n",
"loss = MaskedSoftmaxCELoss()\n",
"loss(torch.ones(3, 4, 10), torch.ones((3, 4), dtype=torch.long),\n",
"torch.tensor([4, 2, 0]))"
],
"id": "46fc96f0246f32b7",
"outputs": [
{
"data": {
"text/plain": [
"tensor([2.3026, 1.1513, 0.0000])"
]
},
2026-04-22 07:23:35 +00:00
"execution_count": 142,
2026-04-01 15:21:14 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-04-22 07:23:35 +00:00
"execution_count": 142
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-04-22T07:11:21.961865073Z",
"start_time": "2026-04-22T07:11:21.895381436Z"
}
},
"cell_type": "code",
"source": [
"def train_seq2seq(net, data_iter, lr, num_epochs, tgt_vocab, device):\n",
" \"\"\"训练序列到序列模型\"\"\"\n",
" def xavier_init_weights(m):\n",
" if type(m) == nn.Linear:\n",
" nn.init.xavier_uniform_(m.weight)\n",
" if type(m) == nn.GRU:\n",
" for param in m._flat_weights_names:\n",
" if \"weight\" in param:\n",
" nn.init.xavier_uniform_(m._parameters[param])\n",
" net.apply(xavier_init_weights)\n",
" net.to(device)\n",
" optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n",
" loss = MaskedSoftmaxCELoss()\n",
" net.train()\n",
" animator = d2l.Animator(xlabel='epoch', ylabel='loss',\n",
" xlim=[10, num_epochs])\n",
" for epoch in range(num_epochs):\n",
" timer = d2l.Timer()\n",
" metric = d2l.Accumulator(2) # 训练损失总和,词元数量\n",
" for batch in data_iter:\n",
" optimizer.zero_grad()\n",
" X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in batch]\n",
" bos = torch.tensor([tgt_vocab['<bos>']] * Y.shape[0],\n",
" device=device).reshape(-1, 1)\n",
" dec_input = torch.cat([bos, Y[:, :-1]], 1) # 强制教学\n",
" Y_hat, _ = net(X, dec_input, X_valid_len)\n",
" l = loss(Y_hat, Y, Y_valid_len)\n",
" l.sum().backward() # 损失函数的标量进行“反向传播”\n",
" d2l.grad_clipping(net, 1)\n",
" num_tokens = Y_valid_len.sum()\n",
" optimizer.step()\n",
" with torch.no_grad():\n",
" metric.add(l.sum(), num_tokens)\n",
" if (epoch + 1) % 10 == 0:\n",
" animator.add(epoch + 1, (metric[0] / metric[1],))"
],
"id": "69c315b5875fc288",
"outputs": [],
"execution_count": 153
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-04-22T07:13:26.866073447Z",
"start_time": "2026-04-22T07:11:51.132170814Z"
}
},
"cell_type": "code",
"source": [
"embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1\n",
"batch_size, num_steps = 64, 10\n",
"lr, num_epochs, device = 0.005, 300, d2l.try_gpu()\n",
"train_iter, src_vocab, tgt_vocab = load_data_nmt(batch_size, num_steps)\n",
"encoder = Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers,\n",
"dropout)\n",
"decoder = Seq2SeqDecoder(len(tgt_vocab), embed_size, num_hiddens, num_layers,\n",
"dropout)\n",
"net = EncoderDecoder(encoder, decoder)\n",
"train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)"
],
"id": "58e54d7b6b77205d",
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 350x250 with 1 Axes>"
],
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"262.1875pt\" height=\"183.35625pt\" viewBox=\"0 0 262.1875 183.35625\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2026-04-22T15:13:26.821924</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.7.2, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 183.35625 \nL 262.1875 183.35625 \nL 262.1875 0 \nL 0 0 \nz\n\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 50.14375 145.8 \nL 245.44375 145.8 \nL 245.44375 7.2 \nL 50.14375 7.2 \nz\n\"/>\n </g>\n <g id=\"matplotlib.axis_1\">\n <g id=\"xtick_1\">\n <g id=\"line2d_1\">\n <path d=\"M 77.081681 145.8 \nL 77.081681 7.2 \n\" clip-path=\"url(#p033e6c5bf4)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_2\">\n <defs>\n <path id=\"mcb3ceca69b\" d=\"M 0 0 \nL 0 3.5 \n\" style=\"stroke: #ffffff; stroke-width: 0.8\"/>\n </defs>\n <g>\n <use xlink:href=\"#mcb3ceca69b\" x=\"77.081681\" y=\"145.8\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_1\">\n <!-- 50 -->\n <g style=\"fill: #ffffff\" transform=\"translate(70.719181 160.398438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-35\" d=\"M 691 4666 \nL 3169 4666 \nL 3169 4134 \nL 1269 4134 \nL 1269 2991 \nQ 1406 3038 1543 3061 \nQ 1681 3084 1819 3084 \nQ 2600 3084 3056 2656 \nQ 3513 2228 3513 1497 \nQ 3513 744 3044 326 \nQ 2575 -91 1722 -91 \nQ 1428 -91 1123 -41 \nQ 819 9 494 109 \nL 494 744 \nQ 775 591 1075 516 \nQ 1375 441 1709 441 \nQ 2250 441 2565 725 \nQ 2881 1009 2881 1497 \nQ 2881 1984 2565 2268 \nQ 2250 2553 1709 2553 \nQ 1456 2553 1204 2497 \nQ 953 2441 691 2322 \nL 691 4666 \nz\n\" transform=\"scale(0.015625)\"/>\n <path id=\"DejaVuSans-30\" d=\"M 2034 4250 \nQ 1547 4250 1301 3770 \nQ 1056 3291 1056 2328 \nQ 1056 1369 1301 889 \nQ 1547 409 2034 409 \nQ 2525 409 2770 889 \nQ 3016 1369 3016 2328 \nQ 3016 3291 2770 3770 \nQ 2525 4250 2034 4250 \nz\nM 2034 4750 \nQ 2819 4750 3233 4129 \nQ 3647 3509 3647 2328 \nQ 3647 1150 3233 529 \nQ 2819 -91 2034 -91 \nQ 1250 -91 836 529 \nQ 422 1150 422 2328 \nQ 422 3509 836 4129 \nQ 1250 4750 2034 4750 \nz\n\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-35\"/>\n <use xlink:href=\"#DejaVuSans-30\" x=\"63.623047\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_2\">\n <g id=\"line2d_3\">\n <path d=\"M 110.754095 145.8 \nL 110.754095 7.2 \n\" clip-path=\"url(#p033e6c5bf4)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_4\">\n <g>\n <use xlink:href=\"#mcb3ceca69b\" x=\"110.754095\" y=\"145.8\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_2\">\n <!-- 100 -->\n <g style=\"fill: #ffffff\" transform=\"translate(101.210345 160.398438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-31\" d=\"M 794 531 \nL 1825 531 \nL 1825 4091 \nL 703 3866 \nL 703 4441 \nL 1819 4666 \nL 2450 4666 \nL 2450 531 \nL 3481 531 \nL 3481 0 \nL 794 0 \nL 794 531 \nz\n\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-31\"/>\
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
}
],
"execution_count": 155
2026-03-25 15:07:28 +00:00
},
2026-03-22 08:28:55 +00:00
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": "",
2026-04-22 07:23:35 +00:00
"id": "15f5b277bf8d51ed"
2026-03-22 08:28:55 +00:00
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}