{ "cells": [ { "cell_type": "code", "id": "initial_id", "metadata": { "collapsed": true, "ExecuteTime": { "end_time": "2026-03-29T09:05:27.116778011Z", "start_time": "2026-03-29T09:05:25.689631294Z" } }, "source": [ "\n", "import torch\n", "import d2l\n", "import numpy\n", "import torch.nn as nn\n", "import torch.nn.functional as F" ], "outputs": [], "execution_count": 1 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:27.211229222Z", "start_time": "2026-03-29T09:05:27.119198350Z" } }, "cell_type": "code", "source": [ "net = nn.Sequential(nn.Linear(20, 256), nn.ReLU(), nn.Linear(256, 10))\n", "X = torch.rand(2, 20)\n", "net(X)" ], "id": "dcd5590e7795eec1", "outputs": [ { "data": { "text/plain": [ "tensor([[-0.0620, 0.3494, -0.2706, 0.0934, 0.1624, 0.0507, 0.0276, 0.1200,\n", " 0.1418, 0.1422],\n", " [-0.2129, 0.3243, -0.3727, -0.0147, 0.1620, -0.0534, 0.0918, 0.0569,\n", " 0.0515, -0.0515]], grad_fn=)" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 2 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:28.023879514Z", "start_time": "2026-03-29T09:05:27.503008526Z" } }, "cell_type": "code", "source": [ "class MLP(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.hidden=nn.Linear(20,256)\n", " self.out=nn.Linear(256,10)\n", " def forward(self,X):\n", " return self.out(F.relu(self.hidden(X)))\n" ], "id": "4ae330604b643cb4", "outputs": [], "execution_count": 3 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:28.598152494Z", "start_time": "2026-03-29T09:05:28.312949801Z" } }, "cell_type": "code", "source": [ "net=MLP()\n", "net(X)" ], "id": "cca55c6c0c7da12f", "outputs": [ { "data": { "text/plain": [ "tensor([[-0.0426, -0.0041, 0.0686, 0.0151, -0.0754, 0.0269, 0.2757, 0.0227,\n", " 0.2260, 0.0424],\n", " [ 0.0319, 0.0394, 0.0179, 0.0704, -0.1369, -0.0294, 0.2276, 0.0702,\n", " 0.1313, 0.2124]], grad_fn=)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 4 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:28.833430055Z", "start_time": "2026-03-29T09:05:28.633643058Z" } }, "cell_type": "code", "source": [ "class FixedHiddenMLP(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " # 不计算梯度的随机权重参数。因此其在训练期间保持不变\n", " self.rand_weight = torch.rand((20, 20), requires_grad=False)\n", " self.linear = nn.Linear(20, 20)\n", " def forward(self, X):\n", " X = self.linear(X)\n", " # 使用创建的常量参数以及relu和mm函数\n", " X = F.relu(torch.mm(X, self.rand_weight) + 1)\n", " # 复用全连接层。这相当于两个全连接层共享参数\n", " X = self.linear(X)\n", " # 控制流\n", " while X.abs().sum() > 1:\n", " X /= 2\n", " return X.sum()" ], "id": "4518d62611d5e749", "outputs": [], "execution_count": 5 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:28.918094363Z", "start_time": "2026-03-29T09:05:28.865853300Z" } }, "cell_type": "code", "source": [ "net = FixedHiddenMLP()\n", "net(X)" ], "id": "fae0187ece4ed5c6", "outputs": [ { "data": { "text/plain": [ "tensor(-0.0847, grad_fn=)" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 6 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:29.031399095Z", "start_time": "2026-03-29T09:05:28.937762813Z" } }, "cell_type": "code", "source": [ "class NestMLP(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.net = nn.Sequential(nn.Linear(20, 64), nn.ReLU(),\n", " nn.Linear(64, 32), nn.ReLU())\n", " self.linear = nn.Linear(32, 16)\n", " def forward(self, X):\n", " return self.linear(self.net(X))\n", " chimera = nn.Sequential(NestMLP(), nn.Linear(16, 20), FixedHiddenMLP())\n", " chimera(X)" ], "id": "407ef13a86453aae", "outputs": [], "execution_count": 7 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:29.115640738Z", "start_time": "2026-03-29T09:05:29.033684765Z" } }, "cell_type": "code", "source": [ "net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 1))\n", "X = torch.rand(size=(2, 4))\n", "net(X)" ], "id": "9f3526f263c7a249", "outputs": [ { "data": { "text/plain": [ "tensor([[-0.5641],\n", " [-0.5857]], grad_fn=)" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 8 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:29.228591089Z", "start_time": "2026-03-29T09:05:29.144124890Z" } }, "cell_type": "code", "source": "print(net[2].state_dict())", "id": "8c73f8daa02ba28b", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "OrderedDict([('weight', tensor([[-0.3334, 0.0416, -0.3501, -0.1285, 0.0512, 0.2549, -0.2154, -0.2633]])), ('bias', tensor([-0.0772]))])\n" ] } ], "execution_count": 9 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:29.323152541Z", "start_time": "2026-03-29T09:05:29.255547310Z" } }, "cell_type": "code", "source": "net[2].state_dict()", "id": "b6fee6b64fb96e3c", "outputs": [ { "data": { "text/plain": [ "OrderedDict([('weight',\n", " tensor([[-0.3334, 0.0416, -0.3501, -0.1285, 0.0512, 0.2549, -0.2154, -0.2633]])),\n", " ('bias', tensor([-0.0772]))])" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 10 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:29.379782432Z", "start_time": "2026-03-29T09:05:29.328071536Z" } }, "cell_type": "code", "source": "print(type(net[2].bias))", "id": "b38e8dc384e038c5", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "execution_count": 11 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:29.451394790Z", "start_time": "2026-03-29T09:05:29.397103231Z" } }, "cell_type": "code", "source": [ "print(net[2].bias)\n", "print(net[2].bias.data)\n" ], "id": "73f12ca3669d9ede", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Parameter containing:\n", "tensor([-0.0772], requires_grad=True)\n", "tensor([-0.0772])\n" ] } ], "execution_count": 12 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:29.518076804Z", "start_time": "2026-03-29T09:05:29.463897541Z" } }, "cell_type": "code", "source": "net[2].weight.grad==None", "id": "db0fe33018c16fac", "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 13 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:29.588672960Z", "start_time": "2026-03-29T09:05:29.534910431Z" } }, "cell_type": "code", "source": [ "print(*[(name, param.shape) for name, param in net[0].named_parameters()])\n", "print(*[(name, param.shape) for name, param in net.named_parameters()])" ], "id": "75847a1c608ee5c7", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "('weight', torch.Size([8, 4])) ('bias', torch.Size([8]))\n", "('0.weight', torch.Size([8, 4])) ('0.bias', torch.Size([8])) ('2.weight', torch.Size([1, 8])) ('2.bias', torch.Size([1]))\n" ] } ], "execution_count": 14 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:29.642447216Z", "start_time": "2026-03-29T09:05:29.589891260Z" } }, "cell_type": "code", "source": "net.state_dict()['2.bias'].data", "id": "cc74913e8742da7d", "outputs": [ { "data": { "text/plain": [ "tensor([-0.0772])" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 15 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:29.708774862Z", "start_time": "2026-03-29T09:05:29.659219518Z" } }, "cell_type": "code", "source": [ "def block1():\n", " return nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 4),nn.ReLU())\n", "def block2():\n", " net = nn.Sequential()\n", " for i in range(4):\n", " net.add_module(f'block{i}', block1())\n", " return net" ], "id": "53c39c5e61fa7bf5", "outputs": [], "execution_count": 16 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:29.765983699Z", "start_time": "2026-03-29T09:05:29.710778796Z" } }, "cell_type": "code", "source": [ "rgnet = nn.Sequential(block2(),nn.Linear(4,1))\n", "rgnet(X)" ], "id": "d3ac7759b619aca", "outputs": [ { "data": { "text/plain": [ "tensor([[0.2190],\n", " [0.2190]], grad_fn=)" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 17 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:30.160316585Z", "start_time": "2026-03-29T09:05:29.826193253Z" } }, "cell_type": "code", "source": "print(rgnet)", "id": "8fc60f64b07781e6", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sequential(\n", " (0): Sequential(\n", " (block0): Sequential(\n", " (0): Linear(in_features=4, out_features=8, bias=True)\n", " (1): ReLU()\n", " (2): Linear(in_features=8, out_features=4, bias=True)\n", " (3): ReLU()\n", " )\n", " (block1): Sequential(\n", " (0): Linear(in_features=4, out_features=8, bias=True)\n", " (1): ReLU()\n", " (2): Linear(in_features=8, out_features=4, bias=True)\n", " (3): ReLU()\n", " )\n", " (block2): Sequential(\n", " (0): Linear(in_features=4, out_features=8, bias=True)\n", " (1): ReLU()\n", " (2): Linear(in_features=8, out_features=4, bias=True)\n", " (3): ReLU()\n", " )\n", " (block3): Sequential(\n", " (0): Linear(in_features=4, out_features=8, bias=True)\n", " (1): ReLU()\n", " (2): Linear(in_features=8, out_features=4, bias=True)\n", " (3): ReLU()\n", " )\n", " )\n", " (1): Linear(in_features=4, out_features=1, bias=True)\n", ")\n" ] } ], "execution_count": 18 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:30.301122409Z", "start_time": "2026-03-29T09:05:30.198111011Z" } }, "cell_type": "code", "source": "rgnet[0][1][0].bias.data", "id": "e590aaafca787b50", "outputs": [ { "data": { "text/plain": [ "tensor([-0.3955, 0.0030, -0.0100, 0.3198, -0.4639, -0.4023, -0.3653, 0.0766])" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 19 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:30.342375999Z", "start_time": "2026-03-29T09:05:30.305643495Z" } }, "cell_type": "code", "source": [ "def init_normal(m):\n", " if type(m) == nn.Linear:\n", " nn.init.normal_(m.weight, mean=0, std=0.01)\n", " nn.init.zeros_(m.bias)\n", "net.apply(init_normal)\n", "net[0].weight.data[0], net[0].bias.data[0]" ], "id": "925ca33221d0a87e", "outputs": [ { "data": { "text/plain": [ "(tensor([-0.0059, -0.0004, -0.0091, 0.0014]), tensor(0.))" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 20 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:30.413353915Z", "start_time": "2026-03-29T09:05:30.350001341Z" } }, "cell_type": "code", "source": [ "def init_xavier(m):\n", " if type(m) == nn.Linear:\n", " nn.init.xavier_uniform_(m.weight)\n", "def init_42(m):\n", " if type(m) == nn.Linear:\n", " nn.init.constant_(m.weight, 42)\n", "\n", "net[0].apply(init_xavier)\n", "net[2].apply(init_42)\n", "print(net[0].weight.data[0])\n", "print(net[2].weight.data)" ], "id": "81e2de84a8c4ef32", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([ 0.1297, -0.3070, -0.2955, 0.3630])\n", "tensor([[42., 42., 42., 42., 42., 42., 42., 42.]])\n" ] } ], "execution_count": 21 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:30.481562932Z", "start_time": "2026-03-29T09:05:30.431139492Z" } }, "cell_type": "code", "source": [ "x = torch.arange(4)\n", "torch.save(x, 'x-file')" ], "id": "f05bb378bb60ab9e", "outputs": [], "execution_count": 22 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:30.535399376Z", "start_time": "2026-03-29T09:05:30.483767993Z" } }, "cell_type": "code", "source": [ "x2 = torch.load('x-file')\n", "x2" ], "id": "a74ecaaac0d826c6", "outputs": [ { "data": { "text/plain": [ "tensor([0, 1, 2, 3])" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 23 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:30.599578181Z", "start_time": "2026-03-29T09:05:30.548047255Z" } }, "cell_type": "code", "source": [ "class MLP(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.hidden = nn.Linear(20, 256)\n", " self.output = nn.Linear(256, 10)\n", " def forward(self, x):\n", " return self.output(F.relu(self.hidden(x)))\n", "\n", "net = MLP()\n", "X = torch.randn(size=(2, 20))\n", "Y = net(X)" ], "id": "b42598f0c4a8e801", "outputs": [], "execution_count": 24 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:30.650496800Z", "start_time": "2026-03-29T09:05:30.601298115Z" } }, "cell_type": "code", "source": "torch.save(net.state_dict(), 'mlp.params')", "id": "aaa22eef549caa6f", "outputs": [], "execution_count": 25 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:30.706861993Z", "start_time": "2026-03-29T09:05:30.652573915Z" } }, "cell_type": "code", "source": [ "clone = MLP()\n", "clone.load_state_dict(torch.load('mlp.params'))\n", "clone.eval()" ], "id": "b92f920229abeeae", "outputs": [ { "data": { "text/plain": [ "MLP(\n", " (hidden): Linear(in_features=20, out_features=256, bias=True)\n", " (output): Linear(in_features=256, out_features=10, bias=True)\n", ")" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 26 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:30.786547453Z", "start_time": "2026-03-29T09:05:30.723425498Z" } }, "cell_type": "code", "source": [ "Y_clone = clone(X)\n", "Y_clone == Y" ], "id": "646c9eb6d7cc81c2", "outputs": [ { "data": { "text/plain": [ "tensor([[True, True, True, True, True, True, True, True, True, True],\n", " [True, True, True, True, True, True, True, True, True, True]])" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 27 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:30.901224528Z", "start_time": "2026-03-29T09:05:30.851425725Z" } }, "cell_type": "code", "source": [ "def corr2d(X,K):\n", " h,w=K.shape\n", " Y=torch.ones((X.shape[0]-h+1,X.shape[1]-w+1))\n", " for i in range(Y.shape[0]):\n", " for j in range(Y.shape[1]):\n", " Y[i,j]=(X[i:i+h,j:j+w]*K).sum()\n", " return Y\n" ], "id": "d45f9adfe47fce20", "outputs": [], "execution_count": 28 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:31.060500958Z", "start_time": "2026-03-29T09:05:30.961112847Z" } }, "cell_type": "code", "source": [ "X = torch.tensor([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]])\n", "K = torch.tensor([[0.0, 1.0], [2.0, 3.0]])\n", "corr2d(X,K)" ], "id": "db7279e13647c315", "outputs": [ { "data": { "text/plain": [ "tensor([[19., 25.],\n", " [37., 43.]])" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 29 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:31.226121987Z", "start_time": "2026-03-29T09:05:31.161589982Z" } }, "cell_type": "code", "source": [ "class Conv2D(nn.Module):\n", " def __init__(self, kernel_size):\n", " super().__init__()\n", " self.weight = nn.Parameter(torch.rand(kernel_size))\n", " self.bias = nn.Parameter(torch.zeros(1))\n", " def forward(self, x):\n", " return corr2d(x, self.weight) + self.bias\n" ], "id": "d60be1bd12a1f37e", "outputs": [], "execution_count": 30 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:31.353964548Z", "start_time": "2026-03-29T09:05:31.284845417Z" } }, "cell_type": "code", "source": [ "X = torch.ones((6, 8))\n", "X[:, 2:6] = 0\n", "X" ], "id": "5083789b7a728442", "outputs": [ { "data": { "text/plain": [ "tensor([[1., 1., 0., 0., 0., 0., 1., 1.],\n", " [1., 1., 0., 0., 0., 0., 1., 1.],\n", " [1., 1., 0., 0., 0., 0., 1., 1.],\n", " [1., 1., 0., 0., 0., 0., 1., 1.],\n", " [1., 1., 0., 0., 0., 0., 1., 1.],\n", " [1., 1., 0., 0., 0., 0., 1., 1.]])" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 31 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:31.497133186Z", "start_time": "2026-03-29T09:05:31.413547301Z" } }, "cell_type": "code", "source": [ "K = torch.tensor([[1.0, -1.0]])\n", "Y = corr2d(X, K)\n", "Y" ], "id": "ee8d6bedbde886ad", "outputs": [ { "data": { "text/plain": [ "tensor([[ 0., 1., 0., 0., 0., -1., 0.],\n", " [ 0., 1., 0., 0., 0., -1., 0.],\n", " [ 0., 1., 0., 0., 0., -1., 0.],\n", " [ 0., 1., 0., 0., 0., -1., 0.],\n", " [ 0., 1., 0., 0., 0., -1., 0.],\n", " [ 0., 1., 0., 0., 0., -1., 0.]])" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 32 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:31.722341828Z", "start_time": "2026-03-29T09:05:31.614326731Z" } }, "cell_type": "code", "source": "corr2d(X.t(), K)", "id": "a8278c3837fa9a1c", "outputs": [ { "data": { "text/plain": [ "tensor([[0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.]])" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 33 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:31.872242267Z", "start_time": "2026-03-29T09:05:31.806910138Z" } }, "cell_type": "code", "source": "conv2d = nn.Conv2d(1,1, kernel_size=(1, 2), bias=False)", "id": "ec61cdb61a8cabff", "outputs": [], "execution_count": 34 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:31.992426160Z", "start_time": "2026-03-29T09:05:31.931105219Z" } }, "cell_type": "code", "source": [ "X = X.reshape((1, 1, 6, 8))\n", "Y = Y.reshape((1, 1, 6, 7))\n", "lr = 3e-2" ], "id": "d2fc19d84c79a10", "outputs": [], "execution_count": 35 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:32.264274507Z", "start_time": "2026-03-29T09:05:31.995396351Z" } }, "cell_type": "code", "source": [ "for i in range(100):\n", " Y_hat = conv2d(X)\n", " l = (Y_hat - Y) ** 2\n", " conv2d.zero_grad()\n", " l.sum().backward()\n", " # 迭代卷积核\n", " conv2d.weight.data[:] -= lr * conv2d.weight.grad\n", " if (i + 1) % 20 == 0:\n", " print(f'epoch {i+1}, loss {l.sum():.3f}')" ], "id": "51fbb2e6398a9bd5", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch 20, loss 0.000\n", "epoch 40, loss 0.000\n", "epoch 60, loss 0.000\n", "epoch 80, loss 0.000\n", "epoch 100, loss 0.000\n" ] } ], "execution_count": 36 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:32.346542597Z", "start_time": "2026-03-29T09:05:32.281931301Z" } }, "cell_type": "code", "source": "conv2d.weight.data.reshape((1, 2))\n", "id": "bf53a423f429dfe4", "outputs": [ { "data": { "text/plain": [ "tensor([[ 1.0000, -1.0000]])" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 37 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:32.474945924Z", "start_time": "2026-03-29T09:05:32.410988598Z" } }, "cell_type": "code", "source": [ "\n", "# 为了方便起见,我们定义了一个计算卷积层的函数。\n", "# 此函数初始化卷积层权重,并对输入和输出提高和缩减相应的维数\n", "def comp_conv2d(conv2d, X):\n", "# 这里的(1,1)表示批量大小和通道数都是1\n", " X = X.reshape((1, 1) + X.shape)\n", " Y = conv2d(X)\n", " # 省略前两个维度:批量大小和通道\n", " return Y.reshape(Y.shape[2:])\n", "# 请注意,这里每边都填充了1行或1列,因此总共添加了2行或2列\n", "conv2d = nn.Conv2d(1, 1, kernel_size=3, padding=1)" ], "id": "77b61d8c9a2363cc", "outputs": [], "execution_count": 38 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:32.610536979Z", "start_time": "2026-03-29T09:05:32.537011140Z" } }, "cell_type": "code", "source": [ "X = torch.rand(size=(8, 8))\n", "comp_conv2d(conv2d, X).shape" ], "id": "beda6ffa67ec2677", "outputs": [ { "data": { "text/plain": [ "torch.Size([8, 8])" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 39 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:32.680637151Z", "start_time": "2026-03-29T09:05:32.618558400Z" } }, "cell_type": "code", "source": [ "conv2d = nn.Conv2d(1, 1, kernel_size=(5, 3), padding=(2, 1))\n", "comp_conv2d(conv2d, X).shape" ], "id": "8c51095daea1432d", "outputs": [ { "data": { "text/plain": [ "torch.Size([8, 8])" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 40 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:32.771287457Z", "start_time": "2026-03-29T09:05:32.697379946Z" } }, "cell_type": "code", "source": [ "conv2d = nn.Conv2d(1, 1, kernel_size=3, padding=1, stride=2)\n", "comp_conv2d(conv2d, X).shape" ], "id": "581bf1b15162cbf6", "outputs": [ { "data": { "text/plain": [ "torch.Size([4, 4])" ] }, "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 41 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:32.831539044Z", "start_time": "2026-03-29T09:05:32.773186750Z" } }, "cell_type": "code", "source": [ "conv2d = nn.Conv2d(1, 1, kernel_size=(3, 5), padding=(0, 1), stride=(3, 4))\n", "comp_conv2d(conv2d, X).shape" ], "id": "6f7a2411247baff0", "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 2])" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 42 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:32.896336773Z", "start_time": "2026-03-29T09:05:32.842010971Z" } }, "cell_type": "code", "source": [ "def corr2d_multi_in(X,K):\n", " return sum(corr2d(x,k) for x,k in zip(X,K))\n", "X = torch.tensor([[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]],\n", "[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]])\n", "K = torch.tensor([[[0.0, 1.0], [2.0, 3.0]], [[1.0, 2.0], [3.0, 4.0]]])\n", "corr2d_multi_in(X, K)" ], "id": "7ac0f17f97b2daa8", "outputs": [ { "data": { "text/plain": [ "tensor([[ 56., 72.],\n", " [104., 120.]])" ] }, "execution_count": 43, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 43 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:33.000406757Z", "start_time": "2026-03-29T09:05:32.949274100Z" } }, "cell_type": "code", "source": [ "def corr2d_multi_in_out(X,K) ->torch.Tensor :\n", " return torch.stack([corr2d_multi_in(X,k) for k in K],0)\n" ], "id": "d409110d0d6b4b49", "outputs": [], "execution_count": 44 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:33.125525859Z", "start_time": "2026-03-29T09:05:33.060398301Z" } }, "cell_type": "code", "source": [ "K = torch.stack((K, K + 1, K + 2), 0)\n", "K.shape" ], "id": "4114cd871a627075", "outputs": [ { "data": { "text/plain": [ "torch.Size([3, 2, 2, 2])" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 45 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:33.198169422Z", "start_time": "2026-03-29T09:05:33.131159996Z" } }, "cell_type": "code", "source": "corr2d_multi_in_out(X, K)", "id": "ce52f41dc9585f8c", "outputs": [ { "data": { "text/plain": [ "tensor([[[ 56., 72.],\n", " [104., 120.]],\n", "\n", " [[ 76., 100.],\n", " [148., 172.]],\n", "\n", " [[ 96., 128.],\n", " [192., 224.]]])" ] }, "execution_count": 46, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 46 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:33.249073Z", "start_time": "2026-03-29T09:05:33.199594667Z" } }, "cell_type": "code", "source": [ "def corr2d_multi_in_out_1x1(X, K):\n", " h_i,h,w=X.shape\n", " h_o=K.shape[0]\n", " X=X.reshape((h_i,h*w))\n", " print(X.shape)\n", " K=K.reshape((h_o,h_i))\n", " print(K.shape)\n", " Y=torch.matmul(K,X)\n", " return Y.reshape((h_o,h,w))" ], "id": "362d8c692b3c1d75", "outputs": [], "execution_count": 47 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:33.298538943Z", "start_time": "2026-03-29T09:05:33.250967377Z" } }, "cell_type": "code", "source": [ "X = torch.normal(0, 1, (3, 3, 3))\n", "K = torch.normal(0, 1, (2, 3, 1, 1))" ], "id": "28e761f677df8b16", "outputs": [], "execution_count": 48 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:33.354138870Z", "start_time": "2026-03-29T09:05:33.300557739Z" } }, "cell_type": "code", "source": "Y1 = corr2d_multi_in_out_1x1(X, K)", "id": "8eb276fed751a6b9", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([3, 9])\n", "torch.Size([2, 3])\n" ] } ], "execution_count": 49 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:33.417520891Z", "start_time": "2026-03-29T09:05:33.369461725Z" } }, "cell_type": "code", "source": [ "Y2 = corr2d_multi_in_out(X, K)\n", "assert float(torch.abs(Y1 - Y2).sum()) < 1e-6" ], "id": "be28e27d30f36e2c", "outputs": [], "execution_count": 50 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:33.468327774Z", "start_time": "2026-03-29T09:05:33.419539671Z" } }, "cell_type": "code", "source": [ "def pool2d(X,pool_size,mode='max'):\n", " p_h,p_w =pool_size\n", " Y = torch.zeros((X.shape[0]-p_h+1,X.shape[1]-p_w+1))\n", " for i in range(Y.shape[0]):\n", " for j in range(Y.shape[1]):\n", " match mode:\n", " case 'max':\n", " Y[i,j]=X[i:i+p_h,j:j+p_w].max()\n", " case 'avg':\n", " Y[i,j]=X[i:i+p_h,j:j+p_w].mean()\n", "\n", " return Y" ], "id": "3c3f71349a2e54c0", "outputs": [], "execution_count": 51 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:33.526566195Z", "start_time": "2026-03-29T09:05:33.470337819Z" } }, "cell_type": "code", "source": [ "X = torch.tensor([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]])\n", "pool2d(X, (2, 2))" ], "id": "a67207c861cf0cfd", "outputs": [ { "data": { "text/plain": [ "tensor([[4., 5.],\n", " [7., 8.]])" ] }, "execution_count": 52, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 52 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:33.680166201Z", "start_time": "2026-03-29T09:05:33.578598740Z" } }, "cell_type": "code", "source": "pool2d(X, (2, 2), 'avg')", "id": "e387b48df3831b85", "outputs": [ { "data": { "text/plain": [ "tensor([[2., 3.],\n", " [5., 6.]])" ] }, "execution_count": 53, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 53 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:33.855460170Z", "start_time": "2026-03-29T09:05:33.765585069Z" } }, "cell_type": "code", "source": [ "X = torch.arange(16, dtype=torch.float32).reshape((1, 1, 4, 4))\n", "X" ], "id": "41b618b3a48522b4", "outputs": [ { "data": { "text/plain": [ "tensor([[[[ 0., 1., 2., 3.],\n", " [ 4., 5., 6., 7.],\n", " [ 8., 9., 10., 11.],\n", " [12., 13., 14., 15.]]]])" ] }, "execution_count": 54, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 54 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:34.023104664Z", "start_time": "2026-03-29T09:05:33.949683264Z" } }, "cell_type": "code", "source": [ "pool2d=nn.MaxPool2d(3)\n", "pool2d(X)" ], "id": "c77484a8d1267259", "outputs": [ { "data": { "text/plain": [ "tensor([[[[10.]]]])" ] }, "execution_count": 55, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 55 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:34.093370019Z", "start_time": "2026-03-29T09:05:34.026511981Z" } }, "cell_type": "code", "source": [ "pool2d = nn.MaxPool2d(3, padding=1, stride=2)\n", "pool2d(X)" ], "id": "847a2bacfb6f2bd7", "outputs": [ { "data": { "text/plain": [ "tensor([[[[ 5., 7.],\n", " [13., 15.]]]])" ] }, "execution_count": 56, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 56 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:34.146680689Z", "start_time": "2026-03-29T09:05:34.094779518Z" } }, "cell_type": "code", "source": [ "pool2d = nn.MaxPool2d((2, 3), stride=(2, 3), padding=(0, 1))\n", "pool2d(X)" ], "id": "5efad1e0b616fff7", "outputs": [ { "data": { "text/plain": [ "tensor([[[[ 5., 7.],\n", " [13., 15.]]]])" ] }, "execution_count": 57, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 57 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:34.239200323Z", "start_time": "2026-03-29T09:05:34.159117900Z" } }, "cell_type": "code", "source": [ "X = torch.cat((X, X + 1), 1)\n", "X" ], "id": "386d4b3eb8069328", "outputs": [ { "data": { "text/plain": [ "tensor([[[[ 0., 1., 2., 3.],\n", " [ 4., 5., 6., 7.],\n", " [ 8., 9., 10., 11.],\n", " [12., 13., 14., 15.]],\n", "\n", " [[ 1., 2., 3., 4.],\n", " [ 5., 6., 7., 8.],\n", " [ 9., 10., 11., 12.],\n", " [13., 14., 15., 16.]]]])" ] }, "execution_count": 58, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 58 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:34.331040051Z", "start_time": "2026-03-29T09:05:34.248225619Z" } }, "cell_type": "code", "source": [ "pool2d = nn.MaxPool2d(3, padding=1, stride=2)\n", "pool2d(X)" ], "id": "ba5f57a8ca2a3b06", "outputs": [ { "data": { "text/plain": [ "tensor([[[[ 5., 7.],\n", " [13., 15.]],\n", "\n", " [[ 6., 8.],\n", " [14., 16.]]]])" ] }, "execution_count": 59, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 59 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:34.425299189Z", "start_time": "2026-03-29T09:05:34.352449123Z" } }, "cell_type": "code", "source": [ "net = nn.Sequential(\n", " nn.Conv2d(1,6,kernel_size=5,padding=2), #1*1*28*28 -> 1*6*28*28\n", " nn.Sigmoid(),\n", " nn.AvgPool2d(kernel_size=2, stride=2), #1*6*28*28 -> 1*6*14*14\n", " nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(), #1*6*14*14 -> 1*16*10*10\n", " nn.AvgPool2d(kernel_size=2, stride=2), #1*16*10*10 -> 1*16*5*5\n", " nn.Flatten(),\n", " nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),\n", " nn.Linear(120, 84), nn.Sigmoid(),\n", " nn.Linear(84, 10)\n", ")\n", "X = torch.rand(size=(1,1,28,28),dtype=torch.float32)\n", "for layer in net:\n", " X=layer(X)\n", " print(layer.__class__.__name__,'output shape: \\t',X.shape)" ], "id": "1eabc29f9c838842", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Conv2d output shape: \t torch.Size([1, 6, 28, 28])\n", "Sigmoid output shape: \t torch.Size([1, 6, 28, 28])\n", "AvgPool2d output shape: \t torch.Size([1, 6, 14, 14])\n", "Conv2d output shape: \t torch.Size([1, 16, 10, 10])\n", "Sigmoid output shape: \t torch.Size([1, 16, 10, 10])\n", "AvgPool2d output shape: \t torch.Size([1, 16, 5, 5])\n", "Flatten output shape: \t torch.Size([1, 400])\n", "Linear output shape: \t torch.Size([1, 120])\n", "Sigmoid output shape: \t torch.Size([1, 120])\n", "Linear output shape: \t torch.Size([1, 84])\n", "Sigmoid output shape: \t torch.Size([1, 84])\n", "Linear output shape: \t torch.Size([1, 10])\n" ] } ], "execution_count": 60 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:35.484128366Z", "start_time": "2026-03-29T09:05:34.438360644Z" } }, "cell_type": "code", "source": [ "import d2l.torch as d2l\n", "batch_size = 256\n", "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)" ], "id": "e372f75817ad4a0f", "outputs": [], "execution_count": 61 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:35.582060752Z", "start_time": "2026-03-29T09:05:35.533014324Z" } }, "cell_type": "code", "source": [ "lr, num_epochs = 0.9, 10\n", "#d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())" ], "id": "9aaeb948f3353955", "outputs": [], "execution_count": 62 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:35.632954849Z", "start_time": "2026-03-29T09:05:35.583695074Z" } }, "cell_type": "code", "source": [ "class Inception(nn.Module):\n", " def __init__(self,in_channels,c1,c2,c3,c4,**kwargs):\n", " super(Inception,self).__init__(**kwargs)\n", " self.p1_1 = nn.Conv2d(in_channels,c1,kernel_size=1)\n", " self.p2_1 = nn.Conv2d(in_channels,c2[0],kernel_size=1)\n", " self.p2_2 = nn.Conv2d(c2[0],c2[1],kernel_size=3,padding=1)\n", " self.p3_1 = nn.Conv2d(in_channels,c3[0],kernel_size=1)\n", " self.p3_2 = nn.Conv2d(c3[0],c3[1],kernel_size=5,padding=2)\n", " self.p4_1 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)\n", " self.p4_2 = nn.Conv2d(in_channels, c4, kernel_size=1)\n", " def forward(self,x):\n", " p1 = F.relu(self.p1_1(x))\n", " p2 = F.relu(self.p2_2(F.relu(self.p2_1(x))))\n", " p3 = F.relu(self.p3_2(F.relu(self.p3_1(x))))\n", " p4 = F.relu(self.p4_2(self.p4_1(x)))\n", " return torch.cat((p1,p2,p3,p4),dim=1)" ], "id": "6d3bb3f70f297dba", "outputs": [], "execution_count": 63 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:35.690181474Z", "start_time": "2026-03-29T09:05:35.634277862Z" } }, "cell_type": "code", "source": [ "b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),\n", " nn.ReLU(),\n", " nn.MaxPool2d(kernel_size=3, stride=2, padding=1))\n", "b2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1),\n", " nn.ReLU(),\n", " nn.Conv2d(64, 192, kernel_size=3, padding=1),\n", " nn.ReLU(),\n", " nn.MaxPool2d(kernel_size=3, stride=2, padding=1))\n", "b3 = nn.Sequential(Inception(192, 64, (96, 128), (16, 32), 32),\n", " Inception(256, 128, (128, 192), (32, 96), 64),\n", " nn.MaxPool2d(kernel_size=3, stride=2, padding=1))\n", "b4 = nn.Sequential(Inception(480, 192, (96, 208), (16, 48), 64),\n", " Inception(512, 160, (112, 224), (24, 64), 64),\n", " Inception(512, 128, (128, 256), (24, 64), 64),\n", " Inception(512, 112, (144, 288), (32, 64), 64),\n", " Inception(528, 256, (160, 320), (32, 128), 128),\n", " nn.MaxPool2d(kernel_size=3, stride=2, padding=1))\n", "b5 = nn.Sequential(Inception(832, 256, (160, 320), (32, 128), 128),\n", " Inception(832, 384, (192, 384), (48, 128), 128),\n", " nn.AdaptiveAvgPool2d((1,1)),\n", " nn.Flatten())\n", "net = nn.Sequential(b1, b2, b3, b4, b5, nn.Linear(1024, 10))\n", "X = torch.rand(size=(1, 1, 96, 96))\n", "for layer in net:\n", " X = layer(X)\n", " print(layer.__class__.__name__,'output shape:\\t', X.shape)" ], "id": "6ef7022bcb288d65", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sequential output shape:\t torch.Size([1, 64, 24, 24])\n", "Sequential output shape:\t torch.Size([1, 192, 12, 12])\n", "Sequential output shape:\t torch.Size([1, 480, 6, 6])\n", "Sequential output shape:\t torch.Size([1, 832, 3, 3])\n", "Sequential output shape:\t torch.Size([1, 1024])\n", "Linear output shape:\t torch.Size([1, 10])\n" ] } ], "execution_count": 64 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:36.618596478Z", "start_time": "2026-03-29T09:05:35.704418296Z" } }, "cell_type": "code", "source": [ "import torchinfo\n", "torchinfo.summary(net,(1,1,96,96))" ], "id": "acc019ce7afa4470", "outputs": [ { "data": { "text/plain": [ "==========================================================================================\n", "Layer (type:depth-idx) Output Shape Param #\n", "==========================================================================================\n", "Sequential [1, 10] --\n", "├─Sequential: 1-1 [1, 64, 24, 24] --\n", "│ └─Conv2d: 2-1 [1, 64, 48, 48] 3,200\n", "│ └─ReLU: 2-2 [1, 64, 48, 48] --\n", "│ └─MaxPool2d: 2-3 [1, 64, 24, 24] --\n", "├─Sequential: 1-2 [1, 192, 12, 12] --\n", "│ └─Conv2d: 2-4 [1, 64, 24, 24] 4,160\n", "│ └─ReLU: 2-5 [1, 64, 24, 24] --\n", "│ └─Conv2d: 2-6 [1, 192, 24, 24] 110,784\n", "│ └─ReLU: 2-7 [1, 192, 24, 24] --\n", "│ └─MaxPool2d: 2-8 [1, 192, 12, 12] --\n", "├─Sequential: 1-3 [1, 480, 6, 6] --\n", "│ └─Inception: 2-9 [1, 256, 12, 12] --\n", "│ │ └─Conv2d: 3-1 [1, 64, 12, 12] 12,352\n", "│ │ └─Conv2d: 3-2 [1, 96, 12, 12] 18,528\n", "│ │ └─Conv2d: 3-3 [1, 128, 12, 12] 110,720\n", "│ │ └─Conv2d: 3-4 [1, 16, 12, 12] 3,088\n", "│ │ └─Conv2d: 3-5 [1, 32, 12, 12] 12,832\n", "│ │ └─MaxPool2d: 3-6 [1, 192, 12, 12] --\n", "│ │ └─Conv2d: 3-7 [1, 32, 12, 12] 6,176\n", "│ └─Inception: 2-10 [1, 480, 12, 12] --\n", "│ │ └─Conv2d: 3-8 [1, 128, 12, 12] 32,896\n", "│ │ └─Conv2d: 3-9 [1, 128, 12, 12] 32,896\n", "│ │ └─Conv2d: 3-10 [1, 192, 12, 12] 221,376\n", "│ │ └─Conv2d: 3-11 [1, 32, 12, 12] 8,224\n", "│ │ └─Conv2d: 3-12 [1, 96, 12, 12] 76,896\n", "│ │ └─MaxPool2d: 3-13 [1, 256, 12, 12] --\n", "│ │ └─Conv2d: 3-14 [1, 64, 12, 12] 16,448\n", "│ └─MaxPool2d: 2-11 [1, 480, 6, 6] --\n", "├─Sequential: 1-4 [1, 832, 3, 3] --\n", "│ └─Inception: 2-12 [1, 512, 6, 6] --\n", "│ │ └─Conv2d: 3-15 [1, 192, 6, 6] 92,352\n", "│ │ └─Conv2d: 3-16 [1, 96, 6, 6] 46,176\n", "│ │ └─Conv2d: 3-17 [1, 208, 6, 6] 179,920\n", "│ │ └─Conv2d: 3-18 [1, 16, 6, 6] 7,696\n", "│ │ └─Conv2d: 3-19 [1, 48, 6, 6] 19,248\n", "│ │ └─MaxPool2d: 3-20 [1, 480, 6, 6] --\n", "│ │ └─Conv2d: 3-21 [1, 64, 6, 6] 30,784\n", "│ └─Inception: 2-13 [1, 512, 6, 6] --\n", "│ │ └─Conv2d: 3-22 [1, 160, 6, 6] 82,080\n", "│ │ └─Conv2d: 3-23 [1, 112, 6, 6] 57,456\n", "│ │ └─Conv2d: 3-24 [1, 224, 6, 6] 226,016\n", "│ │ └─Conv2d: 3-25 [1, 24, 6, 6] 12,312\n", "│ │ └─Conv2d: 3-26 [1, 64, 6, 6] 38,464\n", "│ │ └─MaxPool2d: 3-27 [1, 512, 6, 6] --\n", "│ │ └─Conv2d: 3-28 [1, 64, 6, 6] 32,832\n", "│ └─Inception: 2-14 [1, 512, 6, 6] --\n", "│ │ └─Conv2d: 3-29 [1, 128, 6, 6] 65,664\n", "│ │ └─Conv2d: 3-30 [1, 128, 6, 6] 65,664\n", "│ │ └─Conv2d: 3-31 [1, 256, 6, 6] 295,168\n", "│ │ └─Conv2d: 3-32 [1, 24, 6, 6] 12,312\n", "│ │ └─Conv2d: 3-33 [1, 64, 6, 6] 38,464\n", "│ │ └─MaxPool2d: 3-34 [1, 512, 6, 6] --\n", "│ │ └─Conv2d: 3-35 [1, 64, 6, 6] 32,832\n", "│ └─Inception: 2-15 [1, 528, 6, 6] --\n", "│ │ └─Conv2d: 3-36 [1, 112, 6, 6] 57,456\n", "│ │ └─Conv2d: 3-37 [1, 144, 6, 6] 73,872\n", "│ │ └─Conv2d: 3-38 [1, 288, 6, 6] 373,536\n", "│ │ └─Conv2d: 3-39 [1, 32, 6, 6] 16,416\n", "│ │ └─Conv2d: 3-40 [1, 64, 6, 6] 51,264\n", "│ │ └─MaxPool2d: 3-41 [1, 512, 6, 6] --\n", "│ │ └─Conv2d: 3-42 [1, 64, 6, 6] 32,832\n", "│ └─Inception: 2-16 [1, 832, 6, 6] --\n", "│ │ └─Conv2d: 3-43 [1, 256, 6, 6] 135,424\n", "│ │ └─Conv2d: 3-44 [1, 160, 6, 6] 84,640\n", "│ │ └─Conv2d: 3-45 [1, 320, 6, 6] 461,120\n", "│ │ └─Conv2d: 3-46 [1, 32, 6, 6] 16,928\n", "│ │ └─Conv2d: 3-47 [1, 128, 6, 6] 102,528\n", "│ │ └─MaxPool2d: 3-48 [1, 528, 6, 6] --\n", "│ │ └─Conv2d: 3-49 [1, 128, 6, 6] 67,712\n", "│ └─MaxPool2d: 2-17 [1, 832, 3, 3] --\n", "├─Sequential: 1-5 [1, 1024] --\n", "│ └─Inception: 2-18 [1, 832, 3, 3] --\n", "│ │ └─Conv2d: 3-50 [1, 256, 3, 3] 213,248\n", "│ │ └─Conv2d: 3-51 [1, 160, 3, 3] 133,280\n", "│ │ └─Conv2d: 3-52 [1, 320, 3, 3] 461,120\n", "│ │ └─Conv2d: 3-53 [1, 32, 3, 3] 26,656\n", "│ │ └─Conv2d: 3-54 [1, 128, 3, 3] 102,528\n", "│ │ └─MaxPool2d: 3-55 [1, 832, 3, 3] --\n", "│ │ └─Conv2d: 3-56 [1, 128, 3, 3] 106,624\n", "│ └─Inception: 2-19 [1, 1024, 3, 3] --\n", "│ │ └─Conv2d: 3-57 [1, 384, 3, 3] 319,872\n", "│ │ └─Conv2d: 3-58 [1, 192, 3, 3] 159,936\n", "│ │ └─Conv2d: 3-59 [1, 384, 3, 3] 663,936\n", "│ │ └─Conv2d: 3-60 [1, 48, 3, 3] 39,984\n", "│ │ └─Conv2d: 3-61 [1, 128, 3, 3] 153,728\n", "│ │ └─MaxPool2d: 3-62 [1, 832, 3, 3] --\n", "│ │ └─Conv2d: 3-63 [1, 128, 3, 3] 106,624\n", "│ └─AdaptiveAvgPool2d: 2-20 [1, 1024, 1, 1] --\n", "│ └─Flatten: 2-21 [1, 1024] --\n", "├─Linear: 1-6 [1, 10] 10,250\n", "==========================================================================================\n", "Total params: 5,977,530\n", "Trainable params: 5,977,530\n", "Non-trainable params: 0\n", "Total mult-adds (Units.MEGABYTES): 276.66\n", "==========================================================================================\n", "Input size (MB): 0.04\n", "Forward/backward pass size (MB): 4.74\n", "Params size (MB): 23.91\n", "Estimated Total Size (MB): 28.69\n", "==========================================================================================" ] }, "execution_count": 65, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 65 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:36.729050116Z", "start_time": "2026-03-29T09:05:36.672212588Z" } }, "cell_type": "code", "source": [ "lr, num_epochs, batch_size = 0.1, 10, 128\n", "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)\n", "#d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())" ], "id": "3760a5e5813405f7", "outputs": [], "execution_count": 66 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:36.793463633Z", "start_time": "2026-03-29T09:05:36.742329034Z" } }, "cell_type": "code", "source": [ "class Residual(nn.Module):\n", " def __init__(self,input_channels,num_channels,use_1x1conv=False,strides=1):\n", " super().__init__()\n", " self.conv1 = nn.Conv2d(input_channels,num_channels,kernel_size=3,padding=1,stride=strides)\n", " self.conv2 = nn.Conv2d(num_channels,num_channels,kernel_size=3,padding=1)\n", " if use_1x1conv:\n", " self.conv3 = nn.Conv2d(input_channels,num_channels,kernel_size=1,stride=strides)\n", " else:\n", " self.conv3= None\n", " self.bn1=nn.BatchNorm2d(num_channels)\n", " self.bn2=nn.BatchNorm2d(num_channels)\n", " def forward(self,X):\n", " Y=F.relu(self.bn1(self.conv1(X)))\n", " Y=self.bn2(self.conv2(Y))\n", " if self.conv3:\n", " X = self.conv3(X)\n", " Y+=X\n", " return F.relu(Y)\n" ], "id": "9300979845ba6916", "outputs": [], "execution_count": 67 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:36.854913593Z", "start_time": "2026-03-29T09:05:36.806531736Z" } }, "cell_type": "code", "source": [ "blk = Residual(3,3)\n", "X = torch.rand(4, 3, 6, 6)" ], "id": "1248323517ff3228", "outputs": [], "execution_count": 68 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:36.923726941Z", "start_time": "2026-03-29T09:05:36.856121115Z" } }, "cell_type": "code", "source": [ "blk = Residual(3,6, use_1x1conv=True, strides=2)\n", "blk(X).shape" ], "id": "82cdbd71a157b51c", "outputs": [ { "data": { "text/plain": [ "torch.Size([4, 6, 3, 3])" ] }, "execution_count": 69, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 69 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:36.973017126Z", "start_time": "2026-03-29T09:05:36.924981665Z" } }, "cell_type": "code", "source": [ "b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),\n", "nn.BatchNorm2d(64), nn.ReLU(),\n", "nn.MaxPool2d(kernel_size=3, stride=2, padding=1))" ], "id": "727da1d2d363ac62", "outputs": [], "execution_count": 70 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:37.023199935Z", "start_time": "2026-03-29T09:05:36.974467726Z" } }, "cell_type": "code", "source": [ "def resnet_block(input_channels, num_channels, num_residuals,\n", " first_block=False):\n", " blk = []\n", " for i in range(num_residuals):\n", " if i == 0 and not first_block:\n", " blk.append(Residual(input_channels, num_channels,\n", " use_1x1conv=True, strides=2))\n", " else:\n", " blk.append(Residual(num_channels, num_channels))\n", " return blk" ], "id": "124134971f8441c0", "outputs": [], "execution_count": 71 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:37.072041927Z", "start_time": "2026-03-29T09:05:37.024683317Z" } }, "cell_type": "code", "source": [ "b2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))\n", "b3 = nn.Sequential(*resnet_block(64, 128, 2))\n", "b4 = nn.Sequential(*resnet_block(128, 256, 2))\n", "b5 = nn.Sequential(*resnet_block(256, 512, 2))" ], "id": "ca1f1c69fba3e913", "outputs": [], "execution_count": 72 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:37.121767680Z", "start_time": "2026-03-29T09:05:37.073541025Z" } }, "cell_type": "code", "source": [ "net = nn.Sequential(b1, b2, b3, b4, b5,\n", "nn.AdaptiveAvgPool2d((1,1)),\n", "nn.Flatten(), nn.Linear(512, 10))" ], "id": "f21db27de5dbdec1", "outputs": [], "execution_count": 73 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:37.202821371Z", "start_time": "2026-03-29T09:05:37.123064637Z" } }, "cell_type": "code", "source": [ "X = torch.rand(size=(1, 1, 224, 224))\n", "for layer in net:\n", " X = layer(X)\n", " print(layer.__class__.__name__,'output shape:\\t', X.shape)" ], "id": "6f8851a2bfd18c4e", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sequential output shape:\t torch.Size([1, 64, 56, 56])\n", "Sequential output shape:\t torch.Size([1, 64, 56, 56])\n", "Sequential output shape:\t torch.Size([1, 128, 28, 28])\n", "Sequential output shape:\t torch.Size([1, 256, 14, 14])\n", "Sequential output shape:\t torch.Size([1, 512, 7, 7])\n", "AdaptiveAvgPool2d output shape:\t torch.Size([1, 512, 1, 1])\n", "Flatten output shape:\t torch.Size([1, 512])\n", "Linear output shape:\t torch.Size([1, 10])\n" ] } ], "execution_count": 74 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:37.251179948Z", "start_time": "2026-03-29T09:05:37.204125947Z" } }, "cell_type": "code", "source": [ "lr, num_epochs, batch_size = 0.05, 10, 256\n", "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)\n", "#d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())" ], "id": "e095d74b29dffef6", "outputs": [], "execution_count": 75 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:37.324948458Z", "start_time": "2026-03-29T09:05:37.256990577Z" } }, "cell_type": "code", "source": [ "import torch\n", "import d2l.torch as d2l\n", "import numpy\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "print(torch.version.__version__)" ], "id": "3fd6d22221f87bea", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.10.0+cu128\n" ] } ], "execution_count": 76 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:37.375981476Z", "start_time": "2026-03-29T09:05:37.328240979Z" } }, "cell_type": "code", "source": [ "A=torch.Tensor([[1,2,0,0],[0,2,0,0],[0,0,2,1],[0,0,0,3]])\n", "C=torch.Tensor([[1,0,0,0],[0,1,0,0],[0,0,-2,3],[0,0,0,-3]])" ], "id": "254f5d3d659dbe0f", "outputs": [], "execution_count": 77 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:37.425553458Z", "start_time": "2026-03-29T09:05:37.377516460Z" } }, "cell_type": "code", "source": "B=torch.Tensor([[2,0,0,0],[-2,1,0,0],[0,0,-3,0],[0,0,0,-3]])", "id": "a13d9c27c2fdbfad", "outputs": [], "execution_count": 78 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:37.480263855Z", "start_time": "2026-03-29T09:05:37.427439519Z" } }, "cell_type": "code", "source": "torch.mm(A,C)", "id": "e513a37beaa85f8f", "outputs": [ { "data": { "text/plain": [ "tensor([[ 1., 2., 0., 0.],\n", " [ 0., 2., 0., 0.],\n", " [ 0., 0., -4., 3.],\n", " [ 0., 0., 0., -9.]])" ] }, "execution_count": 79, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 79 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:37.626602811Z", "start_time": "2026-03-29T09:05:37.543653646Z" } }, "cell_type": "code", "source": "torch.det(torch.mm(torch.mm(A,C),B))", "id": "9a85eceac652875f", "outputs": [ { "data": { "text/plain": [ "tensor(1296.)" ] }, "execution_count": 80, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 80 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:37.767593515Z", "start_time": "2026-03-29T09:05:37.692991032Z" } }, "cell_type": "code", "source": "1296**5\n", "id": "6dc27d79722da58f", "outputs": [ { "data": { "text/plain": [ "3656158440062976" ] }, "execution_count": 81, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 81 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:37.817415051Z", "start_time": "2026-03-29T09:05:37.779730510Z" } }, "cell_type": "code", "source": "torch.mm(C,B)", "id": "ec5a170d775f4705", "outputs": [ { "data": { "text/plain": [ "tensor([[ 2., 0., 0., 0.],\n", " [-2., 1., 0., 0.],\n", " [ 0., 0., 6., -9.],\n", " [ 0., 0., 0., 9.]])" ] }, "execution_count": 82, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 82 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:38.035076815Z", "start_time": "2026-03-29T09:05:37.870776022Z" } }, "cell_type": "code", "source": [ "T = 1000 # 总共产生1000个点\n", "time = torch.arange(1, T + 1, dtype=torch.float32)\n", "x = torch.sin(0.01 * time) + torch.normal(0, 0.2, (T,))\n", "d2l.plot(time, [x], 'time', 'x', xlim=[1, 1000], figsize=(6, 3))" ], "id": "c3884b10464c6baa", "outputs": [ { "data": { "text/plain": [ "
" ], "image/svg+xml": "\n\n\n \n \n \n \n 2026-03-29T17:05:37.982813\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" }, "metadata": {}, "output_type": "display_data", "jetTransient": { "display_id": null } } ], "execution_count": 83 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:38.254208009Z", "start_time": "2026-03-29T09:05:38.094754063Z" } }, "cell_type": "code", "source": [ "tau = 4\n", "features = torch.zeros((T - tau, tau))\n", "for i in range(tau):\n", " features[:, i] = x[i: T - tau + i]\n", "labels = x[tau:].reshape((-1, 1))\n", "x,features,labels\n" ], "id": "5d450c6f6b724a14", "outputs": [ { "data": { "text/plain": [ "(tensor([ 6.4438e-02, -2.8849e-02, 2.2463e-01, -1.7219e-01, 6.6353e-02,\n", " -1.0363e-01, 8.8868e-02, 1.5603e-01, 2.0316e-01, 2.2209e-01,\n", " 4.0233e-01, 1.8191e-01, 9.4271e-02, 1.9833e-01, 2.7132e-01,\n", " -2.6344e-02, 1.3314e-01, -9.5498e-02, 4.2949e-01, 2.9735e-01,\n", " 2.6208e-01, 2.5798e-01, 2.6224e-01, 4.0028e-01, 1.6453e-01,\n", " -3.6497e-03, 4.5941e-02, 2.9152e-01, 2.8247e-01, 3.1230e-01,\n", " 4.0011e-01, 1.4096e-01, 4.1744e-01, 3.4225e-01, -4.5256e-02,\n", " 2.8870e-01, 3.8852e-01, 3.4837e-01, 6.1889e-01, 6.2549e-01,\n", " 2.3834e-01, 4.8642e-01, 3.5614e-01, 1.1784e-01, 3.8346e-01,\n", " 4.5669e-01, 3.6588e-01, 2.6488e-01, 6.0995e-01, 6.9697e-01,\n", " 7.5780e-01, 5.8101e-01, 3.5400e-01, 2.4635e-01, 4.7288e-01,\n", " 6.6484e-01, 6.3196e-01, 5.6758e-01, 3.1575e-01, 7.3676e-01,\n", " 8.1908e-01, 8.8408e-01, 7.6086e-01, 4.3549e-01, 7.9157e-01,\n", " 4.1029e-01, 3.4122e-01, 1.0624e+00, 9.8399e-01, 7.3473e-01,\n", " 6.9833e-01, 3.4119e-01, 7.3251e-01, 6.7880e-01, 6.3626e-01,\n", " 1.0105e+00, 7.0007e-01, 1.1702e+00, 6.0600e-01, 8.9456e-01,\n", " 5.1218e-01, 7.3733e-01, 6.1851e-01, 7.6468e-01, 6.5189e-01,\n", " 1.0688e+00, 1.0419e+00, 9.7937e-01, 1.0725e+00, 5.4258e-01,\n", " 9.2915e-01, 4.3405e-01, 4.7934e-01, 1.1528e+00, 7.5340e-01,\n", " 5.4904e-01, 5.4025e-01, 5.1751e-01, 2.9075e-01, 5.5143e-01,\n", " 8.6160e-01, 9.6728e-01, 6.0795e-01, 7.0219e-01, 1.0551e+00,\n", " 7.9270e-01, 9.2103e-01, 8.7458e-01, 9.9153e-01, 6.0989e-01,\n", " 7.4993e-01, 6.9077e-01, 5.6804e-01, 7.0561e-01, 7.8830e-01,\n", " 9.7916e-01, 9.9039e-01, 7.9061e-01, 9.6164e-01, 8.5340e-01,\n", " 8.2899e-01, 7.3213e-01, 6.8678e-01, 1.2765e+00, 1.2545e+00,\n", " 1.1249e+00, 1.3865e+00, 8.9114e-01, 8.0419e-01, 1.2773e+00,\n", " 8.9564e-01, 6.2510e-01, 1.1143e+00, 9.5270e-01, 9.5466e-01,\n", " 7.9755e-01, 1.0294e+00, 5.8184e-01, 1.2175e+00, 1.0392e+00,\n", " 9.4017e-01, 1.1067e+00, 9.4888e-01, 8.5048e-01, 9.3845e-01,\n", " 1.2021e+00, 9.6893e-01, 1.0378e+00, 1.1524e+00, 1.0356e+00,\n", " 1.2582e+00, 9.6289e-01, 1.1062e+00, 6.3397e-01, 8.6299e-01,\n", " 8.4336e-01, 6.3310e-01, 1.3410e+00, 9.9408e-01, 7.2785e-01,\n", " 1.2686e+00, 8.6769e-01, 1.2330e+00, 1.0719e+00, 7.1942e-01,\n", " 1.1396e+00, 9.7066e-01, 1.1835e+00, 9.9540e-01, 7.0843e-01,\n", " 1.0240e+00, 9.7269e-01, 8.0341e-01, 1.0116e+00, 8.1968e-01,\n", " 7.5934e-01, 9.6155e-01, 1.1158e+00, 9.6646e-01, 7.0824e-01,\n", " 1.2226e+00, 7.9395e-01, 9.5238e-01, 1.1413e+00, 8.7932e-01,\n", " 9.0194e-01, 8.6707e-01, 1.2885e+00, 1.2093e+00, 1.2944e+00,\n", " 6.3165e-01, 5.5781e-01, 1.0286e+00, 1.3556e+00, 8.9051e-01,\n", " 1.1046e+00, 9.9294e-01, 1.2333e+00, 8.1726e-01, 3.1740e-01,\n", " 7.6761e-01, 6.5092e-01, 8.9621e-01, 6.8024e-01, 1.0196e+00,\n", " 7.5904e-01, 9.4622e-01, 8.6102e-01, 8.3147e-01, 7.4786e-01,\n", " 9.6059e-01, 8.6270e-01, 9.4766e-01, 5.2817e-01, 1.0704e+00,\n", " 9.7133e-01, 8.4503e-01, 8.8452e-01, 5.0776e-01, 9.2430e-01,\n", " 5.9001e-01, 8.0198e-01, 7.8422e-01, 5.8749e-01, 7.0924e-01,\n", " 7.5607e-01, 4.9436e-01, 7.7539e-01, 6.4349e-01, 7.5043e-01,\n", " 7.3827e-01, 8.6847e-01, 5.4753e-01, 6.5105e-01, 1.0554e+00,\n", " 7.8901e-01, 8.9882e-01, 7.0064e-01, 5.4479e-01, 8.2511e-01,\n", " 5.1943e-01, 1.5267e-01, 7.0765e-01, 5.7810e-01, 6.0173e-01,\n", " 4.8342e-01, 6.4010e-01, 9.0313e-01, 3.0786e-01, 1.0283e+00,\n", " 2.3870e-01, 5.7824e-01, -3.3643e-02, 4.2503e-01, 5.7349e-01,\n", " 4.4148e-01, 6.8640e-01, 6.1931e-01, 2.5912e-01, 2.2371e-01,\n", " 6.0460e-01, 3.7744e-01, 5.9038e-01, 3.7926e-01, 5.2749e-01,\n", " 6.2748e-01, 6.6149e-01, 4.3518e-01, 4.0026e-01, 2.9409e-01,\n", " 4.5821e-01, 3.7015e-01, 3.4187e-01, 1.8859e-01, 6.9215e-01,\n", " 3.2195e-01, 2.4332e-02, 5.2798e-01, 9.0723e-02, 1.6245e-01,\n", " 2.7128e-01, 2.0240e-01, -8.3513e-02, 3.9523e-01, 5.8745e-01,\n", " 3.2908e-01, 2.4919e-01, 2.8691e-01, 1.1735e-01, 4.3031e-01,\n", " 2.6840e-01, 2.5892e-01, 1.7928e-01, 4.2978e-01, 5.2001e-02,\n", " 2.7463e-01, -1.3417e-01, -1.6025e-02, 2.9625e-01, 2.3450e-01,\n", " 3.7070e-01, 7.5755e-02, 1.6683e-01, 1.1036e-01, -6.8264e-03,\n", " 7.2137e-03, 2.7841e-01, 9.1316e-02, 9.1231e-03, -3.5385e-01,\n", " 1.1431e-01, -2.0163e-01, 2.0756e-01, -9.4054e-02, -1.2446e-01,\n", " 2.4384e-01, -2.3242e-01, -2.0931e-01, -1.8707e-01, -1.3957e-01,\n", " -1.8903e-01, -8.1507e-02, -3.4759e-01, -1.1257e-02, -1.9703e-01,\n", " 2.7359e-02, 7.6564e-04, -6.1846e-01, -7.9818e-02, -3.6661e-01,\n", " 6.5931e-02, -5.8843e-01, -1.7423e-01, -2.1431e-01, 6.6695e-02,\n", " -2.9555e-01, -1.3760e-01, -5.6146e-02, -1.7448e-02, -3.2177e-01,\n", " -1.8931e-01, -3.3209e-01, -2.6944e-01, -3.6146e-01, -3.5334e-01,\n", " -3.3019e-01, -2.8488e-02, -3.6981e-01, -3.1455e-01, -4.2320e-01,\n", " -5.8333e-01, -4.5083e-01, -2.6372e-01, -6.6177e-01, -5.4376e-01,\n", " -2.1988e-01, -7.4067e-02, -3.6120e-01, -7.7958e-01, -2.7244e-01,\n", " -3.3669e-01, -6.1547e-01, -7.1691e-01, -3.2713e-01, -3.2994e-01,\n", " -4.0011e-01, -4.5194e-01, -5.5936e-01, -4.8557e-01, -7.0421e-01,\n", " -1.8149e-01, -6.7299e-01, -4.4816e-01, -5.8107e-01, -4.7465e-01,\n", " -4.2599e-01, -8.3749e-01, -6.6348e-01, -6.9997e-01, -5.8357e-01,\n", " -3.5789e-01, -7.1656e-01, -8.7729e-01, -6.3883e-01, -6.9591e-01,\n", " -9.3954e-01, -4.9190e-01, -7.3111e-01, -4.3942e-01, -6.2770e-01,\n", " -9.2674e-01, -8.6653e-01, -9.6315e-01, -6.9102e-01, -5.9326e-01,\n", " -8.5505e-01, -8.5113e-01, -6.2499e-01, -9.9391e-01, -8.4853e-01,\n", " -7.8337e-01, -6.2505e-01, -8.0748e-01, -8.2683e-01, -6.9701e-01,\n", " -7.8696e-01, -6.9023e-01, -6.2324e-01, -9.8813e-01, -8.8023e-01,\n", " -7.4747e-01, -9.1390e-01, -1.1208e+00, -1.3740e+00, -1.0556e+00,\n", " -9.5917e-01, -7.6300e-01, -1.0235e+00, -1.0120e+00, -7.2330e-01,\n", " -1.0387e+00, -5.4913e-01, -5.8775e-01, -9.5260e-01, -7.9546e-01,\n", " -6.3363e-01, -7.6522e-01, -1.0495e+00, -1.1376e+00, -9.7715e-01,\n", " -1.0136e+00, -1.0233e+00, -4.2618e-01, -9.1957e-01, -1.1356e+00,\n", " -9.9493e-01, -6.0729e-01, -6.6857e-01, -8.6124e-01, -6.8043e-01,\n", " -7.9954e-01, -9.4514e-01, -9.6364e-01, -7.6007e-01, -7.9190e-01,\n", " -9.0256e-01, -1.0012e+00, -8.1351e-01, -1.0263e+00, -7.2729e-01,\n", " -8.8496e-01, -1.2832e+00, -8.1455e-01, -7.6072e-01, -9.7935e-01,\n", " -1.0354e+00, -1.1628e+00, -8.2264e-01, -7.7549e-01, -1.6201e+00,\n", " -6.8463e-01, -1.0060e+00, -8.6768e-01, -7.4937e-01, -1.0767e+00,\n", " -5.6351e-01, -9.3733e-01, -1.2270e+00, -9.3480e-01, -1.1192e+00,\n", " -1.2890e+00, -1.3016e+00, -9.8122e-01, -1.2527e+00, -8.9113e-01,\n", " -9.9467e-01, -7.3336e-01, -1.2790e+00, -1.2139e+00, -7.6957e-01,\n", " -8.9945e-01, -1.2749e+00, -7.1642e-01, -1.0271e+00, -1.3689e+00,\n", " -8.8271e-01, -9.3780e-01, -1.0709e+00, -1.0079e+00, -1.2095e+00,\n", " -8.3435e-01, -1.1892e+00, -5.8446e-01, -1.0576e+00, -7.8082e-01,\n", " -9.9774e-01, -1.0047e+00, -9.4661e-01, -7.9260e-01, -7.8298e-01,\n", " -8.1630e-01, -1.1429e+00, -9.0614e-01, -1.2286e+00, -1.0185e+00,\n", " -9.2398e-01, -9.3490e-01, -1.1074e+00, -7.6938e-01, -7.7835e-01,\n", " -7.9201e-01, -8.3866e-01, -5.0138e-01, -1.0518e+00, -1.1464e+00,\n", " -8.3545e-01, -6.3239e-01, -8.6411e-01, -1.0649e+00, -8.3904e-01,\n", " -9.3103e-01, -9.5688e-01, -1.3042e+00, -7.8724e-01, -8.9785e-01,\n", " -5.8319e-01, -9.7922e-01, -9.8292e-01, -1.0255e+00, -6.4694e-01,\n", " -8.0609e-01, -6.8586e-01, -1.0256e+00, -6.2613e-01, -5.2035e-01,\n", " -8.2406e-01, -6.1214e-01, -6.3858e-01, -7.9211e-01, -8.4110e-01,\n", " -8.7759e-01, -1.0926e+00, -4.8413e-01, -8.8961e-01, -8.6125e-01,\n", " -8.4024e-01, -7.4395e-01, -7.6605e-01, -7.7586e-01, -6.6531e-01,\n", " -8.7798e-01, -5.3314e-01, -3.8761e-01, -5.2371e-01, -4.9831e-01,\n", " -4.7124e-01, -4.0311e-01, -3.9151e-01, -6.0217e-01, -3.5831e-01,\n", " -8.1952e-01, -3.7521e-01, -4.1182e-01, -6.9520e-01, -1.3176e-01,\n", " -3.5725e-01, -6.3746e-01, -7.1734e-01, -6.0116e-01, -4.7620e-01,\n", " -1.3156e-01, -6.7144e-01, -4.4765e-01, -7.9655e-01, -4.6641e-01,\n", " -5.1395e-01, -5.9736e-01, -7.0441e-02, -2.9234e-01, -6.3963e-01,\n", " -5.2695e-01, -8.9920e-01, -3.9060e-01, -3.3070e-01, -5.8975e-02,\n", " -3.4768e-01, -3.9728e-01, -5.0462e-01, -6.9052e-01, -2.6284e-01,\n", " -5.3189e-01, -2.8471e-01, -3.8808e-01, -2.5389e-01, -1.0635e-01,\n", " -4.4742e-01, -2.8809e-01, -4.6124e-01, -1.8804e-01, -6.5422e-01,\n", " -2.9021e-01, -1.6320e-01, -2.7098e-01, -1.8750e-01, 2.4683e-01,\n", " -1.2878e-01, -2.1855e-01, -5.1811e-01, 5.4305e-02, -1.7425e-01,\n", " -2.6757e-01, 1.3357e-01, -3.1198e-01, 8.1655e-02, -2.8527e-01,\n", " -1.3569e-01, -1.4587e-01, 1.7095e-01, -2.3103e-02, 2.1838e-01,\n", " -1.6752e-01, 3.1579e-01, 2.7031e-01, -2.4856e-01, 7.6009e-03,\n", " -1.1322e-03, -2.0730e-01, -1.2818e-01, 1.4944e-01, 9.0087e-02,\n", " 4.0082e-01, 2.9144e-01, -1.4654e-01, 8.8202e-02, -1.7362e-01,\n", " -9.1909e-03, -8.0324e-02, -5.5343e-02, 5.8975e-01, 1.6623e-01,\n", " 3.3504e-01, 2.4773e-02, 8.7046e-02, -1.6163e-01, 5.1961e-01,\n", " 1.7512e-01, 1.0362e-01, 2.0362e-01, 1.9839e-01, 4.1845e-01,\n", " 4.6793e-01, -1.1853e-01, 1.2487e-01, 1.9344e-01, 3.0220e-01,\n", " 8.3792e-02, 3.1021e-01, 3.2869e-01, 3.0734e-01, 5.7626e-01,\n", " 4.3928e-01, 3.1897e-01, 2.8437e-01, 5.5234e-01, 6.2213e-01,\n", " 6.1580e-01, 3.9967e-01, 4.5823e-01, 4.3247e-01, 5.0114e-01,\n", " 8.3448e-01, 4.9888e-01, 5.0631e-01, 2.0848e-01, 3.6072e-01,\n", " 2.7618e-01, 4.0099e-01, 5.4027e-01, 2.4210e-01, 1.2701e-01,\n", " 4.4325e-01, 3.0193e-01, 3.6690e-01, 5.7623e-01, 5.2195e-01,\n", " 6.5280e-01, 5.7883e-01, 2.9837e-01, 2.5124e-01, 3.4579e-01,\n", " 2.2099e-01, 3.4217e-01, 8.5317e-01, 7.0991e-01, 3.0030e-01,\n", " 7.5253e-01, 7.0718e-01, 7.5546e-01, 8.3272e-01, 8.2167e-01,\n", " 6.8525e-01, 8.5421e-01, 3.8577e-01, 6.1654e-01, 6.7905e-01,\n", " 9.9523e-01, 7.6051e-01, 8.6416e-01, 6.0249e-01, 1.2840e+00,\n", " 6.4849e-01, 5.6504e-01, 6.7845e-01, 3.4798e-01, 6.4645e-01,\n", " 7.8018e-01, 9.8716e-01, 7.3607e-01, 7.5667e-01, 9.4265e-01,\n", " 8.0938e-01, 2.6675e-01, 6.1355e-01, 9.0162e-01, 4.2799e-01,\n", " 8.3804e-01, 4.6611e-01, 8.9841e-01, 9.1118e-01, 6.0615e-01,\n", " 8.8064e-01, 1.1570e+00, 9.5360e-01, 6.1671e-01, 7.3730e-01,\n", " 1.1604e+00, 7.5139e-01, 1.0676e+00, 8.5293e-01, 1.2444e+00,\n", " 7.3487e-01, 7.2214e-01, 1.3246e+00, 9.9059e-01, 1.1876e+00,\n", " 8.9527e-01, 1.1861e+00, 1.1246e+00, 1.4331e+00, 1.0912e+00,\n", " 1.0825e+00, 1.1025e+00, 9.9692e-01, 7.3992e-01, 6.6441e-01,\n", " 9.1873e-01, 9.2798e-01, 1.3192e+00, 1.0102e+00, 1.3542e+00,\n", " 1.0066e+00, 1.0593e+00, 1.0544e+00, 1.0728e+00, 8.1399e-01,\n", " 1.2641e+00, 1.0079e+00, 1.0601e+00, 1.0796e+00, 1.0223e+00,\n", " 1.1092e+00, 7.7240e-01, 1.3857e+00, 9.1455e-01, 8.6129e-01,\n", " 8.0728e-01, 6.1754e-01, 1.0502e+00, 9.3687e-01, 9.5378e-01,\n", " 9.7054e-01, 1.2540e+00, 1.1463e+00, 1.1405e+00, 1.2536e+00,\n", " 8.8310e-01, 1.3811e+00, 1.1267e+00, 7.9463e-01, 1.2574e+00,\n", " 1.0988e+00, 1.3334e+00, 1.2709e+00, 1.0338e+00, 8.9485e-01,\n", " 8.5191e-01, 6.2941e-01, 8.1570e-01, 1.1244e+00, 1.0805e+00,\n", " 9.9755e-01, 1.0758e+00, 1.1607e+00, 1.0960e+00, 9.6049e-01,\n", " 1.0852e+00, 9.1462e-01, 9.4122e-01, 9.8505e-01, 7.3513e-01,\n", " 1.0134e+00, 8.3373e-01, 7.4578e-01, 1.1270e+00, 1.0679e+00,\n", " 8.9848e-01, 9.9106e-01, 9.4795e-01, 1.0659e+00, 8.2919e-01,\n", " 8.6020e-01, 1.3219e+00, 1.0991e+00, 1.0899e+00, 1.1484e+00,\n", " 1.0549e+00, 8.9757e-01, 1.2341e+00, 7.1129e-01, 7.8177e-01,\n", " 7.1453e-01, 9.2287e-01, 5.1673e-01, 7.2670e-01, 5.6472e-01,\n", " 1.0603e+00, 5.5677e-01, 7.6662e-01, 5.9738e-01, 8.7946e-01,\n", " 7.2365e-01, 1.1941e+00, 9.4780e-01, 5.6618e-01, 5.3710e-01,\n", " 6.8202e-01, 1.0785e+00, 7.5097e-01, 7.3525e-01, 7.4950e-01,\n", " 7.1948e-01, 8.9217e-01, 3.9244e-01, 7.3835e-01, 4.3247e-01,\n", " 7.5097e-01, 7.1474e-01, 8.1818e-01, 6.3685e-01, 1.0300e+00,\n", " 5.9656e-01, 1.0586e+00, 8.1963e-01, 4.9452e-01, 1.0996e+00,\n", " 5.0523e-01, 9.3571e-01, 5.5205e-01, 6.1644e-01, 6.4985e-01,\n", " 6.3577e-01, 8.5211e-01, 9.2536e-01, 4.3236e-01, 5.6647e-01,\n", " 4.7429e-01, 8.5065e-01, 5.0285e-01, 1.0053e+00, 4.8989e-01,\n", " 1.1755e-01, 8.2002e-01, 7.0019e-01, 6.1519e-02, 6.4777e-01,\n", " 3.1640e-01, 2.8206e-01, 7.1172e-01, 7.0953e-01, 6.4411e-01,\n", " 4.9723e-01, 4.9160e-01, 7.2991e-01, 4.4568e-01, 4.7622e-01,\n", " 2.6474e-01, 6.0209e-01, 5.5910e-01, 4.3042e-01, 5.2249e-01,\n", " 2.1004e-01, 5.4428e-01, 1.2475e-01, 4.2799e-01, 7.4566e-02,\n", " 5.3251e-01, 6.1238e-01, 3.2354e-01, 1.3797e-02, 2.1109e-01,\n", " 5.6343e-01, 3.2116e-01, 5.0386e-01, 9.1126e-02, 4.6912e-01,\n", " 3.4669e-02, 4.0979e-01, 1.4810e-02, 3.8405e-01, 2.2161e-01,\n", " 1.9445e-01, -3.5447e-01, 1.5456e-01, 1.6863e-01, 2.0110e-01,\n", " 1.5556e-01, 2.2514e-02, 1.6489e-01, 1.6907e-01, -9.4499e-02,\n", " 1.3021e-01, 2.4134e-01, 9.6924e-02, 1.5037e-01, 3.9969e-02,\n", " -2.2726e-01, 2.8770e-01, -1.7184e-01, -1.3635e-01, -8.5396e-02,\n", " -9.3818e-02, -4.1428e-02, -4.6396e-01, -1.7805e-01, 3.6114e-01,\n", " 1.5889e-01, -2.7120e-01, 2.0932e-01, -4.9246e-01, -1.9852e-02,\n", " -9.9432e-02, -3.6289e-01, 2.1602e-01, -1.5902e-01, 2.5226e-01,\n", " -4.1119e-01, 7.3532e-03, -2.6737e-01, -9.9375e-02, -5.8365e-01,\n", " -3.8112e-01, 1.0808e-02, -6.2558e-01, -4.5019e-01, -3.2798e-01,\n", " -7.1162e-02, -2.6805e-01, -2.4978e-01, -3.4975e-01, -2.8487e-01,\n", " -2.4127e-01, -5.3032e-01, -5.4788e-01, -6.5170e-01, -3.3645e-01,\n", " -3.3031e-01, -2.5862e-01, -4.1498e-01, -3.1122e-01, -4.6045e-01,\n", " -5.4418e-01, -2.6614e-01, -4.7850e-01, -3.8730e-01, -3.8611e-01,\n", " -4.1716e-01, -4.9462e-01, -6.8122e-01, -4.3859e-01, -4.6447e-01,\n", " -2.6121e-01, -6.4777e-01, -2.9884e-01, -2.7754e-01, -3.8261e-01,\n", " -5.6598e-01, -1.7966e-01, -8.3324e-01, -5.7268e-01, -5.2891e-01]),\n", " tensor([[ 0.0644, -0.0288, 0.2246, -0.1722],\n", " [-0.0288, 0.2246, -0.1722, 0.0664],\n", " [ 0.2246, -0.1722, 0.0664, -0.1036],\n", " ...,\n", " [-0.2775, -0.3826, -0.5660, -0.1797],\n", " [-0.3826, -0.5660, -0.1797, -0.8332],\n", " [-0.5660, -0.1797, -0.8332, -0.5727]]),\n", " tensor([[ 6.6353e-02],\n", " [-1.0363e-01],\n", " [ 8.8868e-02],\n", " [ 1.5603e-01],\n", " [ 2.0316e-01],\n", " [ 2.2209e-01],\n", " [ 4.0233e-01],\n", " [ 1.8191e-01],\n", " [ 9.4271e-02],\n", " [ 1.9833e-01],\n", " [ 2.7132e-01],\n", " [-2.6344e-02],\n", " [ 1.3314e-01],\n", " [-9.5498e-02],\n", " [ 4.2949e-01],\n", " [ 2.9735e-01],\n", " [ 2.6208e-01],\n", " [ 2.5798e-01],\n", " [ 2.6224e-01],\n", " [ 4.0028e-01],\n", " [ 1.6453e-01],\n", " [-3.6497e-03],\n", " [ 4.5941e-02],\n", " [ 2.9152e-01],\n", " [ 2.8247e-01],\n", " [ 3.1230e-01],\n", " [ 4.0011e-01],\n", " [ 1.4096e-01],\n", " [ 4.1744e-01],\n", " [ 3.4225e-01],\n", " [-4.5256e-02],\n", " [ 2.8870e-01],\n", " [ 3.8852e-01],\n", " [ 3.4837e-01],\n", " [ 6.1889e-01],\n", " [ 6.2549e-01],\n", " [ 2.3834e-01],\n", " [ 4.8642e-01],\n", " [ 3.5614e-01],\n", " [ 1.1784e-01],\n", " [ 3.8346e-01],\n", " [ 4.5669e-01],\n", " [ 3.6588e-01],\n", " [ 2.6488e-01],\n", " [ 6.0995e-01],\n", " [ 6.9697e-01],\n", " [ 7.5780e-01],\n", " [ 5.8101e-01],\n", " [ 3.5400e-01],\n", " [ 2.4635e-01],\n", " [ 4.7288e-01],\n", " [ 6.6484e-01],\n", " [ 6.3196e-01],\n", " [ 5.6758e-01],\n", " [ 3.1575e-01],\n", " [ 7.3676e-01],\n", " [ 8.1908e-01],\n", " [ 8.8408e-01],\n", " [ 7.6086e-01],\n", " [ 4.3549e-01],\n", " [ 7.9157e-01],\n", " [ 4.1029e-01],\n", " [ 3.4122e-01],\n", " [ 1.0624e+00],\n", " [ 9.8399e-01],\n", " [ 7.3473e-01],\n", " [ 6.9833e-01],\n", " [ 3.4119e-01],\n", " [ 7.3251e-01],\n", " [ 6.7880e-01],\n", " [ 6.3626e-01],\n", " [ 1.0105e+00],\n", " [ 7.0007e-01],\n", " [ 1.1702e+00],\n", " [ 6.0600e-01],\n", " [ 8.9456e-01],\n", " [ 5.1218e-01],\n", " [ 7.3733e-01],\n", " [ 6.1851e-01],\n", " [ 7.6468e-01],\n", " [ 6.5189e-01],\n", " [ 1.0688e+00],\n", " [ 1.0419e+00],\n", " [ 9.7937e-01],\n", " [ 1.0725e+00],\n", " [ 5.4258e-01],\n", " [ 9.2915e-01],\n", " [ 4.3405e-01],\n", " [ 4.7934e-01],\n", " [ 1.1528e+00],\n", " [ 7.5340e-01],\n", " [ 5.4904e-01],\n", " [ 5.4025e-01],\n", " [ 5.1751e-01],\n", " [ 2.9075e-01],\n", " [ 5.5143e-01],\n", " [ 8.6160e-01],\n", " [ 9.6728e-01],\n", " [ 6.0795e-01],\n", " [ 7.0219e-01],\n", " [ 1.0551e+00],\n", " [ 7.9270e-01],\n", " [ 9.2103e-01],\n", " [ 8.7458e-01],\n", " [ 9.9153e-01],\n", " [ 6.0989e-01],\n", " [ 7.4993e-01],\n", " [ 6.9077e-01],\n", " [ 5.6804e-01],\n", " [ 7.0561e-01],\n", " [ 7.8830e-01],\n", " [ 9.7916e-01],\n", " [ 9.9039e-01],\n", " [ 7.9061e-01],\n", " [ 9.6164e-01],\n", " [ 8.5340e-01],\n", " [ 8.2899e-01],\n", " [ 7.3213e-01],\n", " [ 6.8678e-01],\n", " [ 1.2765e+00],\n", " [ 1.2545e+00],\n", " [ 1.1249e+00],\n", " [ 1.3865e+00],\n", " [ 8.9114e-01],\n", " [ 8.0419e-01],\n", " [ 1.2773e+00],\n", " [ 8.9564e-01],\n", " [ 6.2510e-01],\n", " [ 1.1143e+00],\n", " [ 9.5270e-01],\n", " [ 9.5466e-01],\n", " [ 7.9755e-01],\n", " [ 1.0294e+00],\n", " [ 5.8184e-01],\n", " [ 1.2175e+00],\n", " [ 1.0392e+00],\n", " [ 9.4017e-01],\n", " [ 1.1067e+00],\n", " [ 9.4888e-01],\n", " [ 8.5048e-01],\n", " [ 9.3845e-01],\n", " [ 1.2021e+00],\n", " [ 9.6893e-01],\n", " [ 1.0378e+00],\n", " [ 1.1524e+00],\n", " [ 1.0356e+00],\n", " [ 1.2582e+00],\n", " [ 9.6289e-01],\n", " [ 1.1062e+00],\n", " [ 6.3397e-01],\n", " [ 8.6299e-01],\n", " [ 8.4336e-01],\n", " [ 6.3310e-01],\n", " [ 1.3410e+00],\n", " [ 9.9408e-01],\n", " [ 7.2785e-01],\n", " [ 1.2686e+00],\n", " [ 8.6769e-01],\n", " [ 1.2330e+00],\n", " [ 1.0719e+00],\n", " [ 7.1942e-01],\n", " [ 1.1396e+00],\n", " [ 9.7066e-01],\n", " [ 1.1835e+00],\n", " [ 9.9540e-01],\n", " [ 7.0843e-01],\n", " [ 1.0240e+00],\n", " [ 9.7269e-01],\n", " [ 8.0341e-01],\n", " [ 1.0116e+00],\n", " [ 8.1968e-01],\n", " [ 7.5934e-01],\n", " [ 9.6155e-01],\n", " [ 1.1158e+00],\n", " [ 9.6646e-01],\n", " [ 7.0824e-01],\n", " [ 1.2226e+00],\n", " [ 7.9395e-01],\n", " [ 9.5238e-01],\n", " [ 1.1413e+00],\n", " [ 8.7932e-01],\n", " [ 9.0194e-01],\n", " [ 8.6707e-01],\n", " [ 1.2885e+00],\n", " [ 1.2093e+00],\n", " [ 1.2944e+00],\n", " [ 6.3165e-01],\n", " [ 5.5781e-01],\n", " [ 1.0286e+00],\n", " [ 1.3556e+00],\n", " [ 8.9051e-01],\n", " [ 1.1046e+00],\n", " [ 9.9294e-01],\n", " [ 1.2333e+00],\n", " [ 8.1726e-01],\n", " [ 3.1740e-01],\n", " [ 7.6761e-01],\n", " [ 6.5092e-01],\n", " [ 8.9621e-01],\n", " [ 6.8024e-01],\n", " [ 1.0196e+00],\n", " [ 7.5904e-01],\n", " [ 9.4622e-01],\n", " [ 8.6102e-01],\n", " [ 8.3147e-01],\n", " [ 7.4786e-01],\n", " [ 9.6059e-01],\n", " [ 8.6270e-01],\n", " [ 9.4766e-01],\n", " [ 5.2817e-01],\n", " [ 1.0704e+00],\n", " [ 9.7133e-01],\n", " [ 8.4503e-01],\n", " [ 8.8452e-01],\n", " [ 5.0776e-01],\n", " [ 9.2430e-01],\n", " [ 5.9001e-01],\n", " [ 8.0198e-01],\n", " [ 7.8422e-01],\n", " [ 5.8749e-01],\n", " [ 7.0924e-01],\n", " [ 7.5607e-01],\n", " [ 4.9436e-01],\n", " [ 7.7539e-01],\n", " [ 6.4349e-01],\n", " [ 7.5043e-01],\n", " [ 7.3827e-01],\n", " [ 8.6847e-01],\n", " [ 5.4753e-01],\n", " [ 6.5105e-01],\n", " [ 1.0554e+00],\n", " [ 7.8901e-01],\n", " [ 8.9882e-01],\n", " [ 7.0064e-01],\n", " [ 5.4479e-01],\n", " [ 8.2511e-01],\n", " [ 5.1943e-01],\n", " [ 1.5267e-01],\n", " [ 7.0765e-01],\n", " [ 5.7810e-01],\n", " [ 6.0173e-01],\n", " [ 4.8342e-01],\n", " [ 6.4010e-01],\n", " [ 9.0313e-01],\n", " [ 3.0786e-01],\n", " [ 1.0283e+00],\n", " [ 2.3870e-01],\n", " [ 5.7824e-01],\n", " [-3.3643e-02],\n", " [ 4.2503e-01],\n", " [ 5.7349e-01],\n", " [ 4.4148e-01],\n", " [ 6.8640e-01],\n", " [ 6.1931e-01],\n", " [ 2.5912e-01],\n", " [ 2.2371e-01],\n", " [ 6.0460e-01],\n", " [ 3.7744e-01],\n", " [ 5.9038e-01],\n", " [ 3.7926e-01],\n", " [ 5.2749e-01],\n", " [ 6.2748e-01],\n", " [ 6.6149e-01],\n", " [ 4.3518e-01],\n", " [ 4.0026e-01],\n", " [ 2.9409e-01],\n", " [ 4.5821e-01],\n", " [ 3.7015e-01],\n", " [ 3.4187e-01],\n", " [ 1.8859e-01],\n", " [ 6.9215e-01],\n", " [ 3.2195e-01],\n", " [ 2.4332e-02],\n", " [ 5.2798e-01],\n", " [ 9.0723e-02],\n", " [ 1.6245e-01],\n", " [ 2.7128e-01],\n", " [ 2.0240e-01],\n", " [-8.3513e-02],\n", " [ 3.9523e-01],\n", " [ 5.8745e-01],\n", " [ 3.2908e-01],\n", " [ 2.4919e-01],\n", " [ 2.8691e-01],\n", " [ 1.1735e-01],\n", " [ 4.3031e-01],\n", " [ 2.6840e-01],\n", " [ 2.5892e-01],\n", " [ 1.7928e-01],\n", " [ 4.2978e-01],\n", " [ 5.2001e-02],\n", " [ 2.7463e-01],\n", " [-1.3417e-01],\n", " [-1.6025e-02],\n", " [ 2.9625e-01],\n", " [ 2.3450e-01],\n", " [ 3.7070e-01],\n", " [ 7.5755e-02],\n", " [ 1.6683e-01],\n", " [ 1.1036e-01],\n", " [-6.8264e-03],\n", " [ 7.2137e-03],\n", " [ 2.7841e-01],\n", " [ 9.1316e-02],\n", " [ 9.1231e-03],\n", " [-3.5385e-01],\n", " [ 1.1431e-01],\n", " [-2.0163e-01],\n", " [ 2.0756e-01],\n", " [-9.4054e-02],\n", " [-1.2446e-01],\n", " [ 2.4384e-01],\n", " [-2.3242e-01],\n", " [-2.0931e-01],\n", " [-1.8707e-01],\n", " [-1.3957e-01],\n", " [-1.8903e-01],\n", " [-8.1507e-02],\n", " [-3.4759e-01],\n", " [-1.1257e-02],\n", " [-1.9703e-01],\n", " [ 2.7359e-02],\n", " [ 7.6564e-04],\n", " [-6.1846e-01],\n", " [-7.9818e-02],\n", " [-3.6661e-01],\n", " [ 6.5931e-02],\n", " [-5.8843e-01],\n", " [-1.7423e-01],\n", " [-2.1431e-01],\n", " [ 6.6695e-02],\n", " [-2.9555e-01],\n", " [-1.3760e-01],\n", " [-5.6146e-02],\n", " [-1.7448e-02],\n", " [-3.2177e-01],\n", " [-1.8931e-01],\n", " [-3.3209e-01],\n", " [-2.6944e-01],\n", " [-3.6146e-01],\n", " [-3.5334e-01],\n", " [-3.3019e-01],\n", " [-2.8488e-02],\n", " [-3.6981e-01],\n", " [-3.1455e-01],\n", " [-4.2320e-01],\n", " [-5.8333e-01],\n", " [-4.5083e-01],\n", " [-2.6372e-01],\n", " [-6.6177e-01],\n", " [-5.4376e-01],\n", " [-2.1988e-01],\n", " [-7.4067e-02],\n", " [-3.6120e-01],\n", " [-7.7958e-01],\n", " [-2.7244e-01],\n", " [-3.3669e-01],\n", " [-6.1547e-01],\n", " [-7.1691e-01],\n", " [-3.2713e-01],\n", " [-3.2994e-01],\n", " [-4.0011e-01],\n", " [-4.5194e-01],\n", " [-5.5936e-01],\n", " [-4.8557e-01],\n", " [-7.0421e-01],\n", " [-1.8149e-01],\n", " [-6.7299e-01],\n", " [-4.4816e-01],\n", " [-5.8107e-01],\n", " [-4.7465e-01],\n", " [-4.2599e-01],\n", " [-8.3749e-01],\n", " [-6.6348e-01],\n", " [-6.9997e-01],\n", " [-5.8357e-01],\n", " [-3.5789e-01],\n", " [-7.1656e-01],\n", " [-8.7729e-01],\n", " [-6.3883e-01],\n", " [-6.9591e-01],\n", " [-9.3954e-01],\n", " [-4.9190e-01],\n", " [-7.3111e-01],\n", " [-4.3942e-01],\n", " [-6.2770e-01],\n", " [-9.2674e-01],\n", " [-8.6653e-01],\n", " [-9.6315e-01],\n", " [-6.9102e-01],\n", " [-5.9326e-01],\n", " [-8.5505e-01],\n", " [-8.5113e-01],\n", " [-6.2499e-01],\n", " [-9.9391e-01],\n", " [-8.4853e-01],\n", " [-7.8337e-01],\n", " [-6.2505e-01],\n", " [-8.0748e-01],\n", " [-8.2683e-01],\n", " [-6.9701e-01],\n", " [-7.8696e-01],\n", " [-6.9023e-01],\n", " [-6.2324e-01],\n", " [-9.8813e-01],\n", " [-8.8023e-01],\n", " [-7.4747e-01],\n", " [-9.1390e-01],\n", " [-1.1208e+00],\n", " [-1.3740e+00],\n", " [-1.0556e+00],\n", " [-9.5917e-01],\n", " [-7.6300e-01],\n", " [-1.0235e+00],\n", " [-1.0120e+00],\n", " [-7.2330e-01],\n", " [-1.0387e+00],\n", " [-5.4913e-01],\n", " [-5.8775e-01],\n", " [-9.5260e-01],\n", " [-7.9546e-01],\n", " [-6.3363e-01],\n", " [-7.6522e-01],\n", " [-1.0495e+00],\n", " [-1.1376e+00],\n", " [-9.7715e-01],\n", " [-1.0136e+00],\n", " [-1.0233e+00],\n", " [-4.2618e-01],\n", " [-9.1957e-01],\n", " [-1.1356e+00],\n", " [-9.9493e-01],\n", " [-6.0729e-01],\n", " [-6.6857e-01],\n", " [-8.6124e-01],\n", " [-6.8043e-01],\n", " [-7.9954e-01],\n", " [-9.4514e-01],\n", " [-9.6364e-01],\n", " [-7.6007e-01],\n", " [-7.9190e-01],\n", " [-9.0256e-01],\n", " [-1.0012e+00],\n", " [-8.1351e-01],\n", " [-1.0263e+00],\n", " [-7.2729e-01],\n", " [-8.8496e-01],\n", " [-1.2832e+00],\n", " [-8.1455e-01],\n", " [-7.6072e-01],\n", " [-9.7935e-01],\n", " [-1.0354e+00],\n", " [-1.1628e+00],\n", " [-8.2264e-01],\n", " [-7.7549e-01],\n", " [-1.6201e+00],\n", " [-6.8463e-01],\n", " [-1.0060e+00],\n", " [-8.6768e-01],\n", " [-7.4937e-01],\n", " [-1.0767e+00],\n", " [-5.6351e-01],\n", " [-9.3733e-01],\n", " [-1.2270e+00],\n", " [-9.3480e-01],\n", " [-1.1192e+00],\n", " [-1.2890e+00],\n", " [-1.3016e+00],\n", " [-9.8122e-01],\n", " [-1.2527e+00],\n", " [-8.9113e-01],\n", " [-9.9467e-01],\n", " [-7.3336e-01],\n", " [-1.2790e+00],\n", " [-1.2139e+00],\n", " [-7.6957e-01],\n", " [-8.9945e-01],\n", " [-1.2749e+00],\n", " [-7.1642e-01],\n", " [-1.0271e+00],\n", " [-1.3689e+00],\n", " [-8.8271e-01],\n", " [-9.3780e-01],\n", " [-1.0709e+00],\n", " [-1.0079e+00],\n", " [-1.2095e+00],\n", " [-8.3435e-01],\n", " [-1.1892e+00],\n", " [-5.8446e-01],\n", " [-1.0576e+00],\n", " [-7.8082e-01],\n", " [-9.9774e-01],\n", " [-1.0047e+00],\n", " [-9.4661e-01],\n", " [-7.9260e-01],\n", " [-7.8298e-01],\n", " [-8.1630e-01],\n", " [-1.1429e+00],\n", " [-9.0614e-01],\n", " [-1.2286e+00],\n", " [-1.0185e+00],\n", " [-9.2398e-01],\n", " [-9.3490e-01],\n", " [-1.1074e+00],\n", " [-7.6938e-01],\n", " [-7.7835e-01],\n", " [-7.9201e-01],\n", " [-8.3866e-01],\n", " [-5.0138e-01],\n", " [-1.0518e+00],\n", " [-1.1464e+00],\n", " [-8.3545e-01],\n", " [-6.3239e-01],\n", " [-8.6411e-01],\n", " [-1.0649e+00],\n", " [-8.3904e-01],\n", " [-9.3103e-01],\n", " [-9.5688e-01],\n", " [-1.3042e+00],\n", " [-7.8724e-01],\n", " [-8.9785e-01],\n", " [-5.8319e-01],\n", " [-9.7922e-01],\n", " [-9.8292e-01],\n", " [-1.0255e+00],\n", " [-6.4694e-01],\n", " [-8.0609e-01],\n", " [-6.8586e-01],\n", " [-1.0256e+00],\n", " [-6.2613e-01],\n", " [-5.2035e-01],\n", " [-8.2406e-01],\n", " [-6.1214e-01],\n", " [-6.3858e-01],\n", " [-7.9211e-01],\n", " [-8.4110e-01],\n", " [-8.7759e-01],\n", " [-1.0926e+00],\n", " [-4.8413e-01],\n", " [-8.8961e-01],\n", " [-8.6125e-01],\n", " [-8.4024e-01],\n", " [-7.4395e-01],\n", " [-7.6605e-01],\n", " [-7.7586e-01],\n", " [-6.6531e-01],\n", " [-8.7798e-01],\n", " [-5.3314e-01],\n", " [-3.8761e-01],\n", " [-5.2371e-01],\n", " [-4.9831e-01],\n", " [-4.7124e-01],\n", " [-4.0311e-01],\n", " [-3.9151e-01],\n", " [-6.0217e-01],\n", " [-3.5831e-01],\n", " [-8.1952e-01],\n", " [-3.7521e-01],\n", " [-4.1182e-01],\n", " [-6.9520e-01],\n", " [-1.3176e-01],\n", " [-3.5725e-01],\n", " [-6.3746e-01],\n", " [-7.1734e-01],\n", " [-6.0116e-01],\n", " [-4.7620e-01],\n", " [-1.3156e-01],\n", " [-6.7144e-01],\n", " [-4.4765e-01],\n", " [-7.9655e-01],\n", " [-4.6641e-01],\n", " [-5.1395e-01],\n", " [-5.9736e-01],\n", " [-7.0441e-02],\n", " [-2.9234e-01],\n", " [-6.3963e-01],\n", " [-5.2695e-01],\n", " [-8.9920e-01],\n", " [-3.9060e-01],\n", " [-3.3070e-01],\n", " [-5.8975e-02],\n", " [-3.4768e-01],\n", " [-3.9728e-01],\n", " [-5.0462e-01],\n", " [-6.9052e-01],\n", " [-2.6284e-01],\n", " [-5.3189e-01],\n", " [-2.8471e-01],\n", " [-3.8808e-01],\n", " [-2.5389e-01],\n", " [-1.0635e-01],\n", " [-4.4742e-01],\n", " [-2.8809e-01],\n", " [-4.6124e-01],\n", " [-1.8804e-01],\n", " [-6.5422e-01],\n", " [-2.9021e-01],\n", " [-1.6320e-01],\n", " [-2.7098e-01],\n", " [-1.8750e-01],\n", " [ 2.4683e-01],\n", " [-1.2878e-01],\n", " [-2.1855e-01],\n", " [-5.1811e-01],\n", " [ 5.4305e-02],\n", " [-1.7425e-01],\n", " [-2.6757e-01],\n", " [ 1.3357e-01],\n", " [-3.1198e-01],\n", " [ 8.1655e-02],\n", " [-2.8527e-01],\n", " [-1.3569e-01],\n", " [-1.4587e-01],\n", " [ 1.7095e-01],\n", " [-2.3103e-02],\n", " [ 2.1838e-01],\n", " [-1.6752e-01],\n", " [ 3.1579e-01],\n", " [ 2.7031e-01],\n", " [-2.4856e-01],\n", " [ 7.6009e-03],\n", " [-1.1322e-03],\n", " [-2.0730e-01],\n", " [-1.2818e-01],\n", " [ 1.4944e-01],\n", " [ 9.0087e-02],\n", " [ 4.0082e-01],\n", " [ 2.9144e-01],\n", " [-1.4654e-01],\n", " [ 8.8202e-02],\n", " [-1.7362e-01],\n", " [-9.1909e-03],\n", " [-8.0324e-02],\n", " [-5.5343e-02],\n", " [ 5.8975e-01],\n", " [ 1.6623e-01],\n", " [ 3.3504e-01],\n", " [ 2.4773e-02],\n", " [ 8.7046e-02],\n", " [-1.6163e-01],\n", " [ 5.1961e-01],\n", " [ 1.7512e-01],\n", " [ 1.0362e-01],\n", " [ 2.0362e-01],\n", " [ 1.9839e-01],\n", " [ 4.1845e-01],\n", " [ 4.6793e-01],\n", " [-1.1853e-01],\n", " [ 1.2487e-01],\n", " [ 1.9344e-01],\n", " [ 3.0220e-01],\n", " [ 8.3792e-02],\n", " [ 3.1021e-01],\n", " [ 3.2869e-01],\n", " [ 3.0734e-01],\n", " [ 5.7626e-01],\n", " [ 4.3928e-01],\n", " [ 3.1897e-01],\n", " [ 2.8437e-01],\n", " [ 5.5234e-01],\n", " [ 6.2213e-01],\n", " [ 6.1580e-01],\n", " [ 3.9967e-01],\n", " [ 4.5823e-01],\n", " [ 4.3247e-01],\n", " [ 5.0114e-01],\n", " [ 8.3448e-01],\n", " [ 4.9888e-01],\n", " [ 5.0631e-01],\n", " [ 2.0848e-01],\n", " [ 3.6072e-01],\n", " [ 2.7618e-01],\n", " [ 4.0099e-01],\n", " [ 5.4027e-01],\n", " [ 2.4210e-01],\n", " [ 1.2701e-01],\n", " [ 4.4325e-01],\n", " [ 3.0193e-01],\n", " [ 3.6690e-01],\n", " [ 5.7623e-01],\n", " [ 5.2195e-01],\n", " [ 6.5280e-01],\n", " [ 5.7883e-01],\n", " [ 2.9837e-01],\n", " [ 2.5124e-01],\n", " [ 3.4579e-01],\n", " [ 2.2099e-01],\n", " [ 3.4217e-01],\n", " [ 8.5317e-01],\n", " [ 7.0991e-01],\n", " [ 3.0030e-01],\n", " [ 7.5253e-01],\n", " [ 7.0718e-01],\n", " [ 7.5546e-01],\n", " [ 8.3272e-01],\n", " [ 8.2167e-01],\n", " [ 6.8525e-01],\n", " [ 8.5421e-01],\n", " [ 3.8577e-01],\n", " [ 6.1654e-01],\n", " [ 6.7905e-01],\n", " [ 9.9523e-01],\n", " [ 7.6051e-01],\n", " [ 8.6416e-01],\n", " [ 6.0249e-01],\n", " [ 1.2840e+00],\n", " [ 6.4849e-01],\n", " [ 5.6504e-01],\n", " [ 6.7845e-01],\n", " [ 3.4798e-01],\n", " [ 6.4645e-01],\n", " [ 7.8018e-01],\n", " [ 9.8716e-01],\n", " [ 7.3607e-01],\n", " [ 7.5667e-01],\n", " [ 9.4265e-01],\n", " [ 8.0938e-01],\n", " [ 2.6675e-01],\n", " [ 6.1355e-01],\n", " [ 9.0162e-01],\n", " [ 4.2799e-01],\n", " [ 8.3804e-01],\n", " [ 4.6611e-01],\n", " [ 8.9841e-01],\n", " [ 9.1118e-01],\n", " [ 6.0615e-01],\n", " [ 8.8064e-01],\n", " [ 1.1570e+00],\n", " [ 9.5360e-01],\n", " [ 6.1671e-01],\n", " [ 7.3730e-01],\n", " [ 1.1604e+00],\n", " [ 7.5139e-01],\n", " [ 1.0676e+00],\n", " [ 8.5293e-01],\n", " [ 1.2444e+00],\n", " [ 7.3487e-01],\n", " [ 7.2214e-01],\n", " [ 1.3246e+00],\n", " [ 9.9059e-01],\n", " [ 1.1876e+00],\n", " [ 8.9527e-01],\n", " [ 1.1861e+00],\n", " [ 1.1246e+00],\n", " [ 1.4331e+00],\n", " [ 1.0912e+00],\n", " [ 1.0825e+00],\n", " [ 1.1025e+00],\n", " [ 9.9692e-01],\n", " [ 7.3992e-01],\n", " [ 6.6441e-01],\n", " [ 9.1873e-01],\n", " [ 9.2798e-01],\n", " [ 1.3192e+00],\n", " [ 1.0102e+00],\n", " [ 1.3542e+00],\n", " [ 1.0066e+00],\n", " [ 1.0593e+00],\n", " [ 1.0544e+00],\n", " [ 1.0728e+00],\n", " [ 8.1399e-01],\n", " [ 1.2641e+00],\n", " [ 1.0079e+00],\n", " [ 1.0601e+00],\n", " [ 1.0796e+00],\n", " [ 1.0223e+00],\n", " [ 1.1092e+00],\n", " [ 7.7240e-01],\n", " [ 1.3857e+00],\n", " [ 9.1455e-01],\n", " [ 8.6129e-01],\n", " [ 8.0728e-01],\n", " [ 6.1754e-01],\n", " [ 1.0502e+00],\n", " [ 9.3687e-01],\n", " [ 9.5378e-01],\n", " [ 9.7054e-01],\n", " [ 1.2540e+00],\n", " [ 1.1463e+00],\n", " [ 1.1405e+00],\n", " [ 1.2536e+00],\n", " [ 8.8310e-01],\n", " [ 1.3811e+00],\n", " [ 1.1267e+00],\n", " [ 7.9463e-01],\n", " [ 1.2574e+00],\n", " [ 1.0988e+00],\n", " [ 1.3334e+00],\n", " [ 1.2709e+00],\n", " [ 1.0338e+00],\n", " [ 8.9485e-01],\n", " [ 8.5191e-01],\n", " [ 6.2941e-01],\n", " [ 8.1570e-01],\n", " [ 1.1244e+00],\n", " [ 1.0805e+00],\n", " [ 9.9755e-01],\n", " [ 1.0758e+00],\n", " [ 1.1607e+00],\n", " [ 1.0960e+00],\n", " [ 9.6049e-01],\n", " [ 1.0852e+00],\n", " [ 9.1462e-01],\n", " [ 9.4122e-01],\n", " [ 9.8505e-01],\n", " [ 7.3513e-01],\n", " [ 1.0134e+00],\n", " [ 8.3373e-01],\n", " [ 7.4578e-01],\n", " [ 1.1270e+00],\n", " [ 1.0679e+00],\n", " [ 8.9848e-01],\n", " [ 9.9106e-01],\n", " [ 9.4795e-01],\n", " [ 1.0659e+00],\n", " [ 8.2919e-01],\n", " [ 8.6020e-01],\n", " [ 1.3219e+00],\n", " [ 1.0991e+00],\n", " [ 1.0899e+00],\n", " [ 1.1484e+00],\n", " [ 1.0549e+00],\n", " [ 8.9757e-01],\n", " [ 1.2341e+00],\n", " [ 7.1129e-01],\n", " [ 7.8177e-01],\n", " [ 7.1453e-01],\n", " [ 9.2287e-01],\n", " [ 5.1673e-01],\n", " [ 7.2670e-01],\n", " [ 5.6472e-01],\n", " [ 1.0603e+00],\n", " [ 5.5677e-01],\n", " [ 7.6662e-01],\n", " [ 5.9738e-01],\n", " [ 8.7946e-01],\n", " [ 7.2365e-01],\n", " [ 1.1941e+00],\n", " [ 9.4780e-01],\n", " [ 5.6618e-01],\n", " [ 5.3710e-01],\n", " [ 6.8202e-01],\n", " [ 1.0785e+00],\n", " [ 7.5097e-01],\n", " [ 7.3525e-01],\n", " [ 7.4950e-01],\n", " [ 7.1948e-01],\n", " [ 8.9217e-01],\n", " [ 3.9244e-01],\n", " [ 7.3835e-01],\n", " [ 4.3247e-01],\n", " [ 7.5097e-01],\n", " [ 7.1474e-01],\n", " [ 8.1818e-01],\n", " [ 6.3685e-01],\n", " [ 1.0300e+00],\n", " [ 5.9656e-01],\n", " [ 1.0586e+00],\n", " [ 8.1963e-01],\n", " [ 4.9452e-01],\n", " [ 1.0996e+00],\n", " [ 5.0523e-01],\n", " [ 9.3571e-01],\n", " [ 5.5205e-01],\n", " [ 6.1644e-01],\n", " [ 6.4985e-01],\n", " [ 6.3577e-01],\n", " [ 8.5211e-01],\n", " [ 9.2536e-01],\n", " [ 4.3236e-01],\n", " [ 5.6647e-01],\n", " [ 4.7429e-01],\n", " [ 8.5065e-01],\n", " [ 5.0285e-01],\n", " [ 1.0053e+00],\n", " [ 4.8989e-01],\n", " [ 1.1755e-01],\n", " [ 8.2002e-01],\n", " [ 7.0019e-01],\n", " [ 6.1519e-02],\n", " [ 6.4777e-01],\n", " [ 3.1640e-01],\n", " [ 2.8206e-01],\n", " [ 7.1172e-01],\n", " [ 7.0953e-01],\n", " [ 6.4411e-01],\n", " [ 4.9723e-01],\n", " [ 4.9160e-01],\n", " [ 7.2991e-01],\n", " [ 4.4568e-01],\n", " [ 4.7622e-01],\n", " [ 2.6474e-01],\n", " [ 6.0209e-01],\n", " [ 5.5910e-01],\n", " [ 4.3042e-01],\n", " [ 5.2249e-01],\n", " [ 2.1004e-01],\n", " [ 5.4428e-01],\n", " [ 1.2475e-01],\n", " [ 4.2799e-01],\n", " [ 7.4566e-02],\n", " [ 5.3251e-01],\n", " [ 6.1238e-01],\n", " [ 3.2354e-01],\n", " [ 1.3797e-02],\n", " [ 2.1109e-01],\n", " [ 5.6343e-01],\n", " [ 3.2116e-01],\n", " [ 5.0386e-01],\n", " [ 9.1126e-02],\n", " [ 4.6912e-01],\n", " [ 3.4669e-02],\n", " [ 4.0979e-01],\n", " [ 1.4810e-02],\n", " [ 3.8405e-01],\n", " [ 2.2161e-01],\n", " [ 1.9445e-01],\n", " [-3.5447e-01],\n", " [ 1.5456e-01],\n", " [ 1.6863e-01],\n", " [ 2.0110e-01],\n", " [ 1.5556e-01],\n", " [ 2.2514e-02],\n", " [ 1.6489e-01],\n", " [ 1.6907e-01],\n", " [-9.4499e-02],\n", " [ 1.3021e-01],\n", " [ 2.4134e-01],\n", " [ 9.6924e-02],\n", " [ 1.5037e-01],\n", " [ 3.9969e-02],\n", " [-2.2726e-01],\n", " [ 2.8770e-01],\n", " [-1.7184e-01],\n", " [-1.3635e-01],\n", " [-8.5396e-02],\n", " [-9.3818e-02],\n", " [-4.1428e-02],\n", " [-4.6396e-01],\n", " [-1.7805e-01],\n", " [ 3.6114e-01],\n", " [ 1.5889e-01],\n", " [-2.7120e-01],\n", " [ 2.0932e-01],\n", " [-4.9246e-01],\n", " [-1.9852e-02],\n", " [-9.9432e-02],\n", " [-3.6289e-01],\n", " [ 2.1602e-01],\n", " [-1.5902e-01],\n", " [ 2.5226e-01],\n", " [-4.1119e-01],\n", " [ 7.3532e-03],\n", " [-2.6737e-01],\n", " [-9.9375e-02],\n", " [-5.8365e-01],\n", " [-3.8112e-01],\n", " [ 1.0808e-02],\n", " [-6.2558e-01],\n", " [-4.5019e-01],\n", " [-3.2798e-01],\n", " [-7.1162e-02],\n", " [-2.6805e-01],\n", " [-2.4978e-01],\n", " [-3.4975e-01],\n", " [-2.8487e-01],\n", " [-2.4127e-01],\n", " [-5.3032e-01],\n", " [-5.4788e-01],\n", " [-6.5170e-01],\n", " [-3.3645e-01],\n", " [-3.3031e-01],\n", " [-2.5862e-01],\n", " [-4.1498e-01],\n", " [-3.1122e-01],\n", " [-4.6045e-01],\n", " [-5.4418e-01],\n", " [-2.6614e-01],\n", " [-4.7850e-01],\n", " [-3.8730e-01],\n", " [-3.8611e-01],\n", " [-4.1716e-01],\n", " [-4.9462e-01],\n", " [-6.8122e-01],\n", " [-4.3859e-01],\n", " [-4.6447e-01],\n", " [-2.6121e-01],\n", " [-6.4777e-01],\n", " [-2.9884e-01],\n", " [-2.7754e-01],\n", " [-3.8261e-01],\n", " [-5.6598e-01],\n", " [-1.7966e-01],\n", " [-8.3324e-01],\n", " [-5.7268e-01],\n", " [-5.2891e-01]]))" ] }, "execution_count": 84, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 84 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:38.270738873Z", "start_time": "2026-03-29T09:05:38.256891510Z" } }, "cell_type": "code", "source": [ "batch_size, n_train = 16, 600\n", "# 只有前n_train个样本用于训练\n", "train_iter = d2l.load_array((features[:n_train], labels[:n_train]),\n", "batch_size, is_train=True)\n" ], "id": "239a596b20d40dec", "outputs": [], "execution_count": 85 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:38.324191264Z", "start_time": "2026-03-29T09:05:38.275367502Z" } }, "cell_type": "code", "source": [ "def init_weights(m):\n", " if type(m) == nn.Linear:\n", " nn.init.xavier_uniform_(m.weight)" ], "id": "54d30bd0ee41cb8", "outputs": [], "execution_count": 86 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:38.371737960Z", "start_time": "2026-03-29T09:05:38.325724185Z" } }, "cell_type": "code", "source": [ "def get_net():\n", " net = nn.Sequential(nn.Linear(4, 10),\n", " nn.ReLU(),\n", " nn.Linear(10, 1))\n", " net.apply(init_weights)\n", " return net\n", "loss = nn.MSELoss(reduction='none')" ], "id": "5d095792e3b3681", "outputs": [], "execution_count": 87 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:38.471484710Z", "start_time": "2026-03-29T09:05:38.373039400Z" } }, "cell_type": "code", "source": [ "def train(net, train_iter, loss, epochs, lr):\n", " trainer = torch.optim.Adam(net.parameters(), lr)\n", " for epoch in range(epochs):\n", " for X, y in train_iter:\n", " trainer.zero_grad()\n", " l = loss(net(X), y)\n", " l.sum().backward()\n", " trainer.step()\n", " print(f'epoch {epoch + 1}, '\n", " f'loss: {d2l.evaluate_loss(net, train_iter, loss):f}')\n", "net = get_net()\n", "train(net, train_iter, loss, 5, 0.01)" ], "id": "5c1dba484e805335", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch 1, loss: 0.057278\n", "epoch 2, loss: 0.051284\n", "epoch 3, loss: 0.047806\n", "epoch 4, loss: 0.048928\n", "epoch 5, loss: 0.049009\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/home/yukun/.conda/envs/nn/lib/python3.11/site-packages/d2l/torch.py:3179: UserWarning: Converting a tensor with requires_grad=True to a scalar may lead to unexpected behavior.\n", "Consider using tensor.detach() first. (Triggered internally at /pytorch/torch/csrc/autograd/generated/python_variable_methods.cpp:836.)\n", " self.data = [a + float(b) for a, b in zip(self.data, args)]\n" ] } ], "execution_count": 88 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:38.651527416Z", "start_time": "2026-03-29T09:05:38.486176418Z" } }, "cell_type": "code", "source": [ "onestep_preds = net(features)\n", "d2l.plot([time, time[tau:]],\n", "[x.detach().numpy(), onestep_preds.detach().numpy()], 'time',\n", "'x', legend=['data', '1-step preds'], xlim=[1, 1000],\n", "figsize=(6, 3))" ], "id": "a6efb4a978cd9375", "outputs": [ { "data": { "text/plain": [ "
" ], "image/svg+xml": "\n\n\n \n \n \n \n 2026-03-29T17:05:38.555907\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" }, "metadata": {}, "output_type": "display_data", "jetTransient": { "display_id": null } } ], "execution_count": 89 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:38.758036538Z", "start_time": "2026-03-29T09:05:38.654924198Z" } }, "cell_type": "code", "source": [ "multistep_preds = torch.zeros(T)\n", "multistep_preds[: n_train + tau] = x[: n_train + tau]\n", "for i in range(n_train + tau, T):\n", " multistep_preds[i] = net(\n", " multistep_preds[i - tau:i].reshape((1, -1)))\n", "d2l.plot([time, time[tau:], time[n_train + tau:]],\n", " [x.detach().numpy(), onestep_preds.detach().numpy(),\n", " multistep_preds[n_train + tau:].detach().numpy()], 'time',\n", " 'x', legend=['data', '1-step preds', 'multistep preds'],\n", " xlim=[1, 1000], figsize=(6, 3))" ], "id": "12c3a1c3912da4dd", "outputs": [ { "data": { "text/plain": [ "
" ], "image/svg+xml": "\n\n\n \n \n \n \n 2026-03-29T17:05:38.729306\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" }, "metadata": {}, "output_type": "display_data", "jetTransient": { "display_id": null } } ], "execution_count": 90 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:38.831552628Z", "start_time": "2026-03-29T09:05:38.780075298Z" } }, "cell_type": "code", "source": [ "import collections\n", "import re" ], "id": "aab66c10a4c143d2", "outputs": [], "execution_count": 91 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:38.889096136Z", "start_time": "2026-03-29T09:05:38.833703204Z" } }, "cell_type": "code", "source": [ "d2l.DATA_HUB['time_machine'] = (d2l.DATA_URL + 'timemachine.txt',\n", "'090b5e7e70c295757f55df93cb0a180b9691891a')\n", "def read_time_machine(): #@save\n", " \"\"\"将时间机器数据集加载到文本行的列表中\"\"\"\n", " with open(d2l.download('time_machine'), 'r') as f:\n", " lines = f.readlines()\n", " return [re.sub('[^A-Za-z]+', ' ', line).strip().lower() for line in lines]\n", "lines = read_time_machine()\n", "print(f'# 文本总行数: {len(lines)}')\n", "print(lines[0])\n", "print(lines[10])" ], "id": "1aff117af810525e", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "# 文本总行数: 3221\n", "the time machine by h g wells\n", "twinkled and his usually pale face was flushed and animated the\n" ] } ], "execution_count": 92 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:38.974340058Z", "start_time": "2026-03-29T09:05:38.905164643Z" } }, "cell_type": "code", "source": [ "def tokenize(lines, token='word'): #@save\n", " \"\"\"将文本行拆分为单词或字符词元\"\"\"\n", " if token == 'word':\n", " return [line.split() for line in lines]\n", " elif token == 'char':\n", " return [list(line) for line in lines]\n", " else:\n", " print('错误:未知词元类型:' + token)\n", "tokens = tokenize(lines)\n", "for i in range(11):\n", " print(tokens[i])" ], "id": "eb4fe9745fbaa5e2", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['the', 'time', 'machine', 'by', 'h', 'g', 'wells']\n", "[]\n", "[]\n", "[]\n", "[]\n", "['i']\n", "[]\n", "[]\n", "['the', 'time', 'traveller', 'for', 'so', 'it', 'will', 'be', 'convenient', 'to', 'speak', 'of', 'him']\n", "['was', 'expounding', 'a', 'recondite', 'matter', 'to', 'us', 'his', 'grey', 'eyes', 'shone', 'and']\n", "['twinkled', 'and', 'his', 'usually', 'pale', 'face', 'was', 'flushed', 'and', 'animated', 'the']\n" ] } ], "execution_count": 93 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:39.026306689Z", "start_time": "2026-03-29T09:05:38.975870867Z" } }, "cell_type": "code", "source": [ "def count_corpus(tokens): #@save\n", " \"\"\"统计词元的频率\"\"\"\n", " # 这里的tokens是1D列表或2D列表\n", " if len(tokens) == 0 or isinstance(tokens[0], list):\n", " # 将词元列表展平成一个列表\n", " tokens = [token for line in tokens for token in line]\n", " return collections.Counter(tokens)\n", "class Vocab: #@save\n", " \"\"\"文本词表\"\"\"\n", " def __init__(self, tokens=None, min_freq=0, reserved_tokens=None):\n", " if tokens is None:\n", " tokens = []\n", " if reserved_tokens is None:\n", " reserved_tokens = []\n", " # 按出现频率排序\n", " counter = count_corpus(tokens)\n", " self._token_freqs = sorted(counter.items(), key=lambda x: x[1],\n", " reverse=True)\n", " # 未知词元的索引为0\n", " self.idx_to_token = [''] + reserved_tokens\n", " self.token_to_idx = {token: idx\n", " for idx, token in enumerate(self.idx_to_token)}\n", " for token, freq in self._token_freqs:\n", " if freq < min_freq:\n", " break\n", " if token not in self.token_to_idx:\n", " self.idx_to_token.append(token)\n", " self.token_to_idx[token] = len(self.idx_to_token) - 1\n", " def __len__(self):\n", " return len(self.idx_to_token)\n", " def __getitem__(self, tokens):\n", " if not isinstance(tokens, (list, tuple)):\n", " return self.token_to_idx.get(tokens, self.unk)\n", " return [self.__getitem__(token) for token in tokens]\n", " def to_tokens(self, indices):\n", " if not isinstance(indices, (list, tuple)):\n", " return self.idx_to_token[indices]\n", " return [self.idx_to_token[index] for index in indices]\n", " @property\n", " def unk(self): # 未知词元的索引为0\n", " return 0\n", " @property\n", " def token_freqs(self):\n", " return self._token_freqs" ], "id": "bee8e5d7b798c6c", "outputs": [], "execution_count": 94 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:39.084414078Z", "start_time": "2026-03-29T09:05:39.028224941Z" } }, "cell_type": "code", "source": [ "vocab = Vocab(tokens)\n", "print(list(vocab.token_to_idx.items())[:10])" ], "id": "ff4e8ac2044850b7", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[('', 0), ('the', 1), ('i', 2), ('and', 3), ('of', 4), ('a', 5), ('to', 6), ('was', 7), ('in', 8), ('that', 9)]\n" ] } ], "execution_count": 95 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:39.172122701Z", "start_time": "2026-03-29T09:05:39.098729247Z" } }, "cell_type": "code", "source": [ "for i in [0, 100]:\n", " print('文本:', tokens[i])\n", " print('索引:', vocab[tokens[i]])" ], "id": "a4e569dfbd251608", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "文本: ['the', 'time', 'machine', 'by', 'h', 'g', 'wells']\n", "索引: [1, 19, 50, 40, 2183, 2184, 400]\n", "文本: ['were', 'three', 'dimensional', 'representations', 'of', 'his', 'four', 'dimensioned']\n", "索引: [20, 175, 1452, 2250, 4, 25, 262, 2251]\n" ] } ], "execution_count": 96 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:39.242984562Z", "start_time": "2026-03-29T09:05:39.173224665Z" } }, "cell_type": "code", "source": [ "def load_corpus_time_machine(max_tokens=-1): #@save\n", " \"\"\"返回时光机器数据集的词元索引列表和词表\"\"\"\n", " lines = read_time_machine()\n", " tokens = tokenize(lines, 'char')\n", " vocab = Vocab(tokens)\n", " # 因为时光机器数据集中的每个文本行不一定是一个句子或一个段落,\n", " # 所以将所有文本行展平到一个列表中\n", " corpus = [vocab[token] for line in tokens for token in line]\n", " if max_tokens > 0:\n", " corpus = corpus[:max_tokens]\n", " return corpus, vocab\n", "corpus, vocab = load_corpus_time_machine()\n", "\n", "len(corpus), len(vocab)\n" ], "id": "1b5c1776ae47af5c", "outputs": [ { "data": { "text/plain": [ "(170580, 28)" ] }, "execution_count": 97, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 97 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:39.315199771Z", "start_time": "2026-03-29T09:05:39.258108438Z" } }, "cell_type": "code", "source": [ "tokens = d2l.tokenize(read_time_machine())\n", "# 因为每个文本行不一定是一个句子或一个段落,因此我们把所有文本行拼接到一起\n", "corpus = [token for line in tokens for token in line]\n", "vocab = d2l.Vocab(corpus)\n", "vocab.token_freqs[:10]" ], "id": "99deb85c025e5cdd", "outputs": [ { "data": { "text/plain": [ "[('the', 2261),\n", " ('i', 1267),\n", " ('and', 1245),\n", " ('of', 1155),\n", " ('a', 816),\n", " ('to', 695),\n", " ('was', 552),\n", " ('in', 541),\n", " ('that', 443),\n", " ('my', 440)]" ] }, "execution_count": 98, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 98 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:39.598676534Z", "start_time": "2026-03-29T09:05:39.332385572Z" } }, "cell_type": "code", "source": [ "freqs = [freq for token, freq in vocab.token_freqs]\n", "d2l.plot(freqs, xlabel='token: x', ylabel='frequency: n(x)',\n", "xscale='log', yscale='log')" ], "id": "5c0d846673e16c33", "outputs": [ { "data": { "text/plain": [ "
" ], "image/svg+xml": "\n\n\n \n \n \n \n 2026-03-29T17:05:39.535495\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" }, "metadata": {}, "output_type": "display_data", "jetTransient": { "display_id": null } } ], "execution_count": 99 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:39.691613898Z", "start_time": "2026-03-29T09:05:39.602184215Z" } }, "cell_type": "code", "source": [ "bigram_tokens = [pair for pair in zip(corpus[:-1], corpus[1:])]\n", "bigram_vocab = Vocab(bigram_tokens)\n", "bigram_vocab.token_freqs[:10]" ], "id": "2826e3ab0863ee64", "outputs": [ { "data": { "text/plain": [ "[(('of', 'the'), 309),\n", " (('in', 'the'), 169),\n", " (('i', 'had'), 130),\n", " (('i', 'was'), 112),\n", " (('and', 'the'), 109),\n", " (('the', 'time'), 102),\n", " (('it', 'was'), 99),\n", " (('to', 'the'), 85),\n", " (('as', 'i'), 78),\n", " (('of', 'a'), 73)]" ] }, "execution_count": 100, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 100 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:39.773596513Z", "start_time": "2026-03-29T09:05:39.693158268Z" } }, "cell_type": "code", "source": [ "trigram_tokens = [triple for triple in zip(\n", "corpus[:-2], corpus[1:-1], corpus[2:])]\n", "trigram_vocab = Vocab(trigram_tokens)\n", "trigram_vocab.token_freqs[:10]" ], "id": "7c8cd0544bf872bb", "outputs": [ { "data": { "text/plain": [ "[(('the', 'time', 'traveller'), 59),\n", " (('the', 'time', 'machine'), 30),\n", " (('the', 'medical', 'man'), 24),\n", " (('it', 'seemed', 'to'), 16),\n", " (('it', 'was', 'a'), 15),\n", " (('here', 'and', 'there'), 15),\n", " (('seemed', 'to', 'me'), 14),\n", " (('i', 'did', 'not'), 14),\n", " (('i', 'saw', 'the'), 13),\n", " (('i', 'began', 'to'), 13)]" ] }, "execution_count": 101, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 101 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:40.025057797Z", "start_time": "2026-03-29T09:05:39.775024733Z" } }, "cell_type": "code", "source": [ "bigram_freqs = [freq for token, freq in bigram_vocab.token_freqs]\n", "trigram_freqs = [freq for token, freq in trigram_vocab.token_freqs]\n", "d2l.plot([freqs, bigram_freqs, trigram_freqs], xlabel='token: x',\n", "ylabel='frequency: n(x)', xscale='log', yscale='log',\n", "legend=['unigram', 'bigram', 'trigram'])" ], "id": "dc3d97dda738613d", "outputs": [ { "data": { "text/plain": [ "
" ], "image/svg+xml": "\n\n\n \n \n \n \n 2026-03-29T17:05:39.966276\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" }, "metadata": {}, "output_type": "display_data", "jetTransient": { "display_id": null } } ], "execution_count": 102 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:40.090429874Z", "start_time": "2026-03-29T09:05:40.040896926Z" } }, "cell_type": "code", "source": [ "import random\n", "def seq_data_iter_random(corpus, batch_size, num_steps): #@save\n", " \"\"\"使用随机抽样生成一个小批量子序列\"\"\"\n", " # 从随机偏移量开始对序列进行分区,随机范围包括num_steps-1\n", " corpus = corpus[random.randint(0, num_steps - 1):]\n", " # 减去1,是因为我们需要考虑标签\n", " num_subseqs = (len(corpus) - 1) // num_steps\n", " # 长度为num_steps的子序列的起始索引\n", " initial_indices = list(range(0, num_subseqs * num_steps, num_steps))\n", " # 在随机抽样的迭代过程中,\n", " # 来自两个相邻的、随机的、小批量中的子序列不一定在原始序列上相邻\n", " random.shuffle(initial_indices)\n", " def data(pos):\n", " # 返回从pos位置开始的长度为num_steps的序列\n", " return corpus[pos: pos + num_steps]\n", " num_batches = num_subseqs // batch_size\n", " for i in range(0, batch_size * num_batches, batch_size):\n", " # 在这里,initial_indices包含子序列的随机起始索引\n", " initial_indices_per_batch = initial_indices[i: i + batch_size]\n", " X = [data(j) for j in initial_indices_per_batch]\n", " Y = [data(j + 1) for j in initial_indices_per_batch]\n", " yield torch.tensor(X), torch.tensor(Y)" ], "id": "fd015793938b83ab", "outputs": [], "execution_count": 103 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:40.149638908Z", "start_time": "2026-03-29T09:05:40.092265248Z" } }, "cell_type": "code", "source": [ "my_seq = list(range(35))\n", "for X, Y in seq_data_iter_random(my_seq, batch_size=2, num_steps=5):\n", " print('X: ', X, '\\nY:', Y)" ], "id": "8961f4934b8c7cc", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "X: tensor([[26, 27, 28, 29, 30],\n", " [21, 22, 23, 24, 25]]) \n", "Y: tensor([[27, 28, 29, 30, 31],\n", " [22, 23, 24, 25, 26]])\n", "X: tensor([[11, 12, 13, 14, 15],\n", " [ 1, 2, 3, 4, 5]]) \n", "Y: tensor([[12, 13, 14, 15, 16],\n", " [ 2, 3, 4, 5, 6]])\n", "X: tensor([[ 6, 7, 8, 9, 10],\n", " [16, 17, 18, 19, 20]]) \n", "Y: tensor([[ 7, 8, 9, 10, 11],\n", " [17, 18, 19, 20, 21]])\n" ] } ], "execution_count": 104 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:40.215887628Z", "start_time": "2026-03-29T09:05:40.163501969Z" } }, "cell_type": "code", "source": [ "def seq_data_iter_sequential(corpus, batch_size, num_steps): #@save\n", " \"\"\"使用顺序分区生成一个小批量子序列\"\"\"\n", " # 从随机偏移量开始划分序列\n", " offset = random.randint(0, num_steps)\n", " num_tokens = ((len(corpus) - offset - 1) // batch_size) * batch_size\n", " Xs = torch.tensor(corpus[offset: offset + num_tokens])\n", " Ys = torch.tensor(corpus[offset + 1: offset + 1 + num_tokens])\n", " Xs, Ys = Xs.reshape(batch_size, -1), Ys.reshape(batch_size, -1)\n", " num_batches = Xs.shape[1] // num_steps\n", " for i in range(0, num_steps * num_batches, num_steps):\n", " X = Xs[:, i: i + num_steps]\n", " Y = Ys[:, i: i + num_steps]\n", " yield X, Y" ], "id": "621b66c0614b22da", "outputs": [], "execution_count": 105 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:40.267211418Z", "start_time": "2026-03-29T09:05:40.217997368Z" } }, "cell_type": "code", "source": [ "class SeqDataLoader:\n", " \"\"\"加载序列数据的迭代器\"\"\"\n", " def __init__(self, batch_size, num_steps, use_random_iter, max_tokens):\n", " if use_random_iter:\n", " self.data_iter_fn = seq_data_iter_random\n", " else:\n", " self.data_iter_fn = seq_data_iter_sequential\n", " self.corpus, self.vocab = load_corpus_time_machine(max_tokens)\n", " self.batch_size, self.num_steps = batch_size, num_steps\n", " def __iter__(self):\n", " return self.data_iter_fn(self.corpus, self.batch_size, self.num_steps)\n", "def load_data_time_machine(batch_size, num_steps, #@save\n", " use_random_iter=False, max_tokens=10000):\n", " \"\"\"返回时光机器数据集的迭代器和词表\"\"\"\n", " data_iter = SeqDataLoader(\n", " batch_size, num_steps, use_random_iter, max_tokens)\n", " return data_iter, data_iter.vocab" ], "id": "f09fe2507a925fe9", "outputs": [], "execution_count": 106 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:40.327565785Z", "start_time": "2026-03-29T09:05:40.269141766Z" } }, "cell_type": "code", "source": [ "batch_size, num_steps = 32, 35\n", "train_iter, vocab = load_data_time_machine(batch_size, num_steps)" ], "id": "69272a664d3b9ae1", "outputs": [], "execution_count": 107 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:40.390353470Z", "start_time": "2026-03-29T09:05:40.335766683Z" } }, "cell_type": "code", "source": "F.one_hot(torch.tensor([0,2]),len(vocab))", "id": "35806d36e5ec3ca7", "outputs": [ { "data": { "text/plain": [ "tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0],\n", " [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0]])" ] }, "execution_count": 108, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 108 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:40.545044245Z", "start_time": "2026-03-29T09:05:40.446941053Z" } }, "cell_type": "code", "source": [ "X = torch.arange(10).reshape((2, 5))\n", "F.one_hot(X.T, 28).shape" ], "id": "6a4695284b898013", "outputs": [ { "data": { "text/plain": [ "torch.Size([5, 2, 28])" ] }, "execution_count": 109, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 109 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:40.636340487Z", "start_time": "2026-03-29T09:05:40.574500540Z" } }, "cell_type": "code", "source": [ "def get_params(vocab_size, num_hiddens, device):\n", " num_inputs = num_outputs = vocab_size\n", " def normal(shape):\n", " return torch.randn(size=shape, device=device) * 0.01\n", " # 隐藏层参数\n", " W_xh = normal((num_inputs, num_hiddens))\n", " W_hh = normal((num_hiddens, num_hiddens))\n", " b_h = torch.zeros(num_hiddens, device=device)\n", " # 输出层参数\n", " W_hq = normal((num_hiddens, num_outputs))\n", " b_q = torch.zeros(num_outputs, device=device)\n", " # 附加梯度\n", " params = [W_xh, W_hh, b_h, W_hq, b_q]\n", " '''\n", " W_xh x->H_t\n", " W_hh H_t-1->H_t\n", " W_hq H_t->opt\n", " '''\n", " for param in params:\n", " param.requires_grad_(True)\n", " return params\n", "def init_rnn_state(batch_size, num_hiddens, device):\n", " return (torch.zeros((batch_size, num_hiddens), device=device), )\n", "def rnn(inputs,state,params):\n", " W_xh,W_hh,b_h,W_hq,b_q = params\n", " H, = state\n", " outputs = []\n", " for X in inputs: # X (batchs,vocab)\n", " H = torch.tanh(torch.mm(X,W_xh)+torch.mm(H,W_hh)+b_h)\n", " Y = torch.mm(H,W_hq)+b_q\n", " outputs.append(Y)\n", " return torch.cat(outputs,dim=0),(H,)\n", "class RNNModelScratch: #@save\n", " \"\"\"从零开始实现的循环神经网络模型\"\"\"\n", " def __init__(self, vocab_size, num_hiddens, device,\n", " get_params, init_state, forward_fn):\n", " self.vocab_size, self.num_hiddens = vocab_size, num_hiddens\n", " self.params = get_params(vocab_size, num_hiddens, device)\n", " self.init_state, self.forward_fn = init_state, forward_fn\n", " def __call__(self, X, state):\n", " X = F.one_hot(X.T, self.vocab_size).type(torch.float32)\n", " return self.forward_fn(X, state, self.params)\n", " def begin_state(self, batch_size, device):\n", " return self.init_state(batch_size, self.num_hiddens, device)" ], "id": "405f7c6af8bdd939", "outputs": [], "execution_count": 110 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:40.714678739Z", "start_time": "2026-03-29T09:05:40.640615581Z" } }, "cell_type": "code", "source": [ "num_hiddens = 512\n", "net = RNNModelScratch(len(vocab), num_hiddens, d2l.try_gpu(), get_params,\n", "init_rnn_state, rnn)\n", "state = net.begin_state(X.shape[0], d2l.try_gpu())\n", "Y, new_state = net(X.to(d2l.try_gpu()), state)\n", "Y.shape, len(new_state), new_state[0].shape" ], "id": "7bdd0b37b1458ec2", "outputs": [ { "data": { "text/plain": [ "(torch.Size([10, 28]), 1, torch.Size([2, 512]))" ] }, "execution_count": 111, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 111 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:40.933485442Z", "start_time": "2026-03-29T09:05:40.715946914Z" } }, "cell_type": "code", "source": [ "def predict_ch8(prefix, num_preds, net, vocab, device): #@save\n", " \"\"\"在prefix后面生成新字符\"\"\"\n", " state = net.begin_state(batch_size=1, device=device)\n", " outputs = [vocab[prefix[0]]]\n", " get_input = lambda: torch.tensor([outputs[-1]], device=device).reshape((1, 1))\n", " for y in prefix[1:]: # 预热期\n", " _, state = net(get_input(), state)\n", " outputs.append(vocab[y])\n", " for _ in range(num_preds): # 预测num_preds步\n", " y, state = net(get_input(), state)\n", " outputs.append(int(y.argmax(dim=1).reshape(1)))\n", " return ''.join([vocab.idx_to_token[i] for i in outputs])\n", "predict_ch8('time traveller ', 10, net, vocab, d2l.try_gpu())" ], "id": "f38dbe2b11e02bc2", "outputs": [ { "data": { "text/plain": [ "'time traveller qxyumumumu'" ] }, "execution_count": 112, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 112 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:40.987574570Z", "start_time": "2026-03-29T09:05:40.936566018Z" } }, "cell_type": "code", "source": [ "def grad_clipping(net, theta): #@save\n", " \"\"\"裁剪梯度\"\"\"\n", " if isinstance(net, nn.Module):\n", " params = [p for p in net.parameters() if p.requires_grad]\n", " else:\n", " params = net.params\n", " norm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))\n", " if norm > theta:\n", " for param in params:\n", " param.grad[:] *= theta / norm" ], "id": "6c19717736ffbc68", "outputs": [], "execution_count": 113 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:41.038078698Z", "start_time": "2026-03-29T09:05:40.989835785Z" } }, "cell_type": "code", "source": [ "import math\n", "def train_epoch_ch8(net, train_iter, loss, updater, device, use_random_iter):\n", " \"\"\"训练网络一个迭代周期(定义见第8章)\"\"\"\n", " state, timer = None, d2l.Timer()\n", " metric = d2l.Accumulator(2) # 训练损失之和,词元数量\n", " for X, Y in train_iter:\n", " if state is None or use_random_iter:\n", " # 在第一次迭代或使用随机抽样时初始化state\n", " state = net.begin_state(batch_size=X.shape[0], device=device)\n", " else:\n", " if isinstance(net, nn.Module) and not isinstance(state, tuple):\n", " # state对于nn.GRU是个张量\n", " state.detach_()\n", " else:\n", " # state对于nn.LSTM或对于我们从零开始实现的模型是个张量\n", " for s in state:\n", " s.detach_()\n", " y = Y.T.reshape(-1)\n", " X, y = X.to(device), y.to(device)\n", " y_hat, state = net(X, state)\n", " l = loss(y_hat, y.long()).mean()\n", " if isinstance(updater, torch.optim.Optimizer):\n", " updater.zero_grad()\n", " l.backward()\n", " grad_clipping(net, 1)\n", " updater.step()\n", " else:\n", " l.backward()\n", " grad_clipping(net, 1)\n", " # 因为已经调用了mean函数\n", " updater(batch_size=1)\n", " metric.add(l * y.numel(), y.numel())\n", " return math.exp(metric[0]/metric[1]),metric[1]/timer.stop()\n" ], "id": "37ef611dedcb714b", "outputs": [], "execution_count": 114 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:41.089234437Z", "start_time": "2026-03-29T09:05:41.040415821Z" } }, "cell_type": "code", "source": [ "def train_ch8(net, train_iter, vocab, lr, num_epochs, device,\n", "use_random_iter=False):\n", " \"\"\"训练模型(定义见第8章)\"\"\"\n", " loss = nn.CrossEntropyLoss()\n", " animator = d2l.Animator(xlabel='epoch', ylabel='perplexity',\n", " legend=['train'], xlim=[10, num_epochs])\n", " # 初始化\n", " if isinstance(net, nn.Module):\n", " updater = torch.optim.SGD(net.parameters(), lr)\n", " else:\n", " updater = lambda batch_size: d2l.sgd(net.params, lr, batch_size)\n", " predict = lambda prefix: predict_ch8(prefix, 50, net, vocab, device)\n", " # 训练和预测\n", " for epoch in range(num_epochs):\n", " ppl, speed = train_epoch_ch8(\n", " net, train_iter, loss, updater, device, use_random_iter)\n", " if (epoch + 1) % 10 == 0:\n", " print(predict('time traveller'))\n", " animator.add(epoch + 1, [ppl])\n", " print(f'困惑度 {ppl:.1f}, {speed:.1f} 词元/秒 {str(device)}')\n", " print(predict('time traveller'))\n", " print(predict('traveller'))" ], "id": "c96b60b55664378a", "outputs": [], "execution_count": 115 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:41.147914921Z", "start_time": "2026-03-29T09:05:41.099546332Z" } }, "cell_type": "code", "source": [ "num_epochs, lr = 500, 1\n", "#train_ch8(net, train_iter, vocab, lr, num_epochs, d2l.try_gpu())" ], "id": "ab4a2fbf4dfd21ef", "outputs": [], "execution_count": 116 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:41.196189157Z", "start_time": "2026-03-29T09:05:41.149680280Z" } }, "cell_type": "code", "source": [ "batch_size, num_steps = 32, 35\n", "train_iter, vocab = load_data_time_machine(batch_size, num_steps)" ], "id": "74d672745751714", "outputs": [], "execution_count": 117 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:41.583876478Z", "start_time": "2026-03-29T09:05:41.197696611Z" } }, "cell_type": "code", "source": [ "num_hiddens = 256\n", "rnn_layer = nn.RNN(len(vocab), num_hiddens)\n", "state = torch.zeros((1,batch_size,num_hiddens))\n", "X = torch.rand(size=(num_steps, batch_size, len(vocab)))\n", "Y, state_new = rnn_layer(X, state)\n", "Y.shape, state_new.shape" ], "id": "9694c029c9e657e8", "outputs": [ { "data": { "text/plain": [ "(torch.Size([35, 32, 256]), torch.Size([1, 32, 256]))" ] }, "execution_count": 118, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 118 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:41.638041580Z", "start_time": "2026-03-29T09:05:41.586864085Z" } }, "cell_type": "code", "source": [ "class RNNModel(nn.Module):\n", " \"\"\"循环神经网络模型\"\"\"\n", " def __init__(self, rnn_layer, vocab_size, **kwargs):\n", " super(RNNModel, self).__init__(**kwargs)\n", " self.rnn = rnn_layer\n", " self.vocab_size = vocab_size\n", " self.num_hiddens = self.rnn.hidden_size\n", " # 如果RNN是双向的(之后将介绍),num_directions应该是2,否则应该是1\n", " if not self.rnn.bidirectional:\n", " self.num_directions = 1\n", " self.linear = nn.Linear(self.num_hiddens, self.vocab_size)\n", " else:\n", " self.num_directions = 2\n", " self.linear = nn.Linear(self.num_hiddens * 2, self.vocab_size)\n", " def forward(self, inputs, state):\n", " X = F.one_hot(inputs.T.long(), self.vocab_size)\n", " X = X.to(torch.float32)\n", " Y, state = self.rnn(X, state)\n", " # 全连接层首先将Y的形状改为(时间步数*批量大小,隐藏单元数)\n", " # 它的输出形状是(时间步数*批量大小,词表大小)。\n", " output = self.linear(Y.reshape((-1, Y.shape[-1])))\n", " return output, state\n", " def begin_state(self, device, batch_size=1):\n", " if not isinstance(self.rnn, nn.LSTM):\n", " # nn.GRU以张量作为隐状态\n", " return torch.zeros((self.num_directions * self.rnn.num_layers,\n", " batch_size, self.num_hiddens),\n", " device=device)\n", " else:\n", " # nn.LSTM以元组作为隐状态\n", " return (torch.zeros((\n", " self.num_directions * self.rnn.num_layers,\n", " batch_size, self.num_hiddens), device=device),\n", " torch.zeros((\n", " self.num_directions * self.rnn.num_layers,\n", " batch_size, self.num_hiddens), device=device))" ], "id": "858873034be01538", "outputs": [], "execution_count": 119 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:41.686881805Z", "start_time": "2026-03-29T09:05:41.640393942Z" } }, "cell_type": "code", "source": [ "device = d2l.try_gpu()\n", "net = RNNModel(rnn_layer, vocab_size=len(vocab))\n", "net = net.to(device)\n", "#predict_ch8('time traveller', 10, net, vocab, device)" ], "id": "d59c1599998c8fd4", "outputs": [], "execution_count": 120 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:41.737181902Z", "start_time": "2026-03-29T09:05:41.688775143Z" } }, "cell_type": "code", "source": [ "vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()\n", "num_inputs = vocab_size\n", "gru_layer = nn.GRU(num_inputs, num_hiddens)\n", "model = RNNModel(gru_layer, len(vocab))\n", "model = model.to(device)\n", "#train_ch8(model, train_iter, vocab, lr, num_epochs, device)" ], "id": "adda23bc3664ec6b", "outputs": [], "execution_count": 121 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:41.786646674Z", "start_time": "2026-03-29T09:05:41.738416914Z" } }, "cell_type": "code", "source": [ "num_inputs = vocab_size\n", "lstm_layer = nn.LSTM(num_inputs, num_hiddens)\n", "model = RNNModel(lstm_layer, len(vocab))\n", "model = model.to(device)\n", "#train_ch8(model, train_iter, vocab, lr, num_epochs, device)" ], "id": "b4e30d643d6f755d", "outputs": [], "execution_count": 122 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:41.841175343Z", "start_time": "2026-03-29T09:05:41.788640688Z" } }, "cell_type": "code", "source": [ "d2l.DATA_HUB['fra-eng'] = (d2l.DATA_URL + 'fra-eng.zip',\n", " '94646ad1522d915e7b0f9296181140edcf86a4f5')" ], "id": "50554e839be36011", "outputs": [], "execution_count": 123 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:41.893584442Z", "start_time": "2026-03-29T09:05:41.843443764Z" } }, "cell_type": "code", "source": [ "import os\n", "def read_data_nmt():\n", " \"\"\"载入“英语-法语”数据集\"\"\"\n", " data_dir = d2l.download_extract('fra-eng')\n", " with open(os.path.join(data_dir, 'fra.txt'), 'r',\n", " encoding='utf-8') as f:\n", " return f.read()" ], "id": "9cd4287ed84db220", "outputs": [], "execution_count": 124 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:41.981913006Z", "start_time": "2026-03-29T09:05:41.895418207Z" } }, "cell_type": "code", "source": [ "raw_text = read_data_nmt()\n", "print(raw_text[:75])" ], "id": "7c4452b3b6a32f91", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Go.\tVa !\n", "Hi.\tSalut !\n", "Run!\tCours !\n", "Run!\tCourez !\n", "Who?\tQui ?\n", "Wow!\tÇa alors !\n", "\n" ] } ], "execution_count": 125 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:43.184417099Z", "start_time": "2026-03-29T09:05:41.984779926Z" } }, "cell_type": "code", "source": [ "def preprocess_nmt(text):\n", " def no_space(char,prev_char):\n", " return char in set(',.!?') and prev_char != ' '\n", " text = text.replace('\\u202f',' ').replace('\\xa0',' ').lower()\n", " out = [' ' + char if i >0 and no_space(char,text[i-1]) else char for i,char in enumerate(text)]\n", " return ''.join(out)\n", "text = preprocess_nmt(raw_text)\n", "print(text[:80])" ], "id": "1c729da265572287", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "go .\tva !\n", "hi .\tsalut !\n", "run !\tcours !\n", "run !\tcourez !\n", "who ?\tqui ?\n", "wow !\tça alors !\n" ] } ], "execution_count": 126 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:43.281720452Z", "start_time": "2026-03-29T09:05:43.232801710Z" } }, "cell_type": "code", "source": [ "def tokenize_nmt(text,num_examples=None):\n", " source,target = [],[]\n", " for i,line in enumerate(text.split('\\n')):\n", " if num_examples and i > num_examples:\n", " break\n", " parts = line.split('\\t')\n", " if len(parts) == 2:\n", " source.append(parts[0].split(' '))\n", " target.append(parts[1].split(' '))\n", " return source,target\n", "\n" ], "id": "ca16ef22cbe2c02a", "outputs": [], "execution_count": 127 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:43.925371949Z", "start_time": "2026-03-29T09:05:43.283929461Z" } }, "cell_type": "code", "source": [ "source, target = tokenize_nmt(text)\n", "source[:6], target[:6]" ], "id": "5ece5cb4b78168d0", "outputs": [ { "data": { "text/plain": [ "([['go', '.'],\n", " ['hi', '.'],\n", " ['run', '!'],\n", " ['run', '!'],\n", " ['who', '?'],\n", " ['wow', '!']],\n", " [['va', '!'],\n", " ['salut', '!'],\n", " ['cours', '!'],\n", " ['courez', '!'],\n", " ['qui', '?'],\n", " ['ça', 'alors', '!']])" ] }, "execution_count": 128, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 128 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:44.088961043Z", "start_time": "2026-03-29T09:05:43.974341311Z" } }, "cell_type": "code", "source": [ "def show_list_len_pair_hist(legend,xlabel,ylabel,xlist,ylist):\n", " d2l.set_figsize()\n", " _,_,patches = d2l.plt.hist([[len(l) for l in xlist],[len(l) for l in ylist]])\n", " d2l.plt.xlabel(xlabel)\n", " d2l.plt.ylabel(ylabel)\n", " for patch in patches[1].patches:\n", " patch.set_hatch('/')\n", " d2l.plt.legend(legend)\n", "\n", "show_list_len_pair_hist(['source','target'],'# tokens per sequence','count',source,target)" ], "id": "518249f852ec54c4", "outputs": [ { "data": { "text/plain": [ "
" ], "image/svg+xml": "\n\n\n \n \n \n \n 2026-03-29T17:05:44.064312\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" }, "metadata": {}, "output_type": "display_data", "jetTransient": { "display_id": null } } ], "execution_count": 129 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:44.211764847Z", "start_time": "2026-03-29T09:05:44.108134314Z" } }, "cell_type": "code", "source": [ "src_vocab=Vocab(source,min_freq=2,reserved_tokens=['','',''])\n", "len(src_vocab)" ], "id": "c2dc82617d5a41a4", "outputs": [ { "data": { "text/plain": [ "10012" ] }, "execution_count": 130, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 130 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:44.266693213Z", "start_time": "2026-03-29T09:05:44.214356393Z" } }, "cell_type": "code", "source": [ "def truncate_pad(line,num_steps,padding_token):\n", " if len(line) > num_steps:\n", " return line[:num_steps]\n", " return line + [padding_token] * (num_steps - len(line))\n", "truncate_pad(src_vocab[source[0]], 10, src_vocab[''])" ], "id": "93ae326a3258ecc", "outputs": [ { "data": { "text/plain": [ "[47, 4, 1, 1, 1, 1, 1, 1, 1, 1]" ] }, "execution_count": 131, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 131 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:44.326945861Z", "start_time": "2026-03-29T09:05:44.279624651Z" } }, "cell_type": "code", "source": [ "def build_array_nmt(lines, vocab, num_steps):\n", " \"\"\"将机器翻译的文本序列转换成小批量\"\"\"\n", " lines = [vocab[l] for l in lines]\n", " lines = [l + [vocab['']] for l in lines]\n", " array = torch.tensor([truncate_pad(\n", " l, num_steps, vocab['']) for l in lines])\n", " valid_len = (array != vocab['']).type(torch.int32).sum(1)\n", " return array, valid_len" ], "id": "acd4344e678cf487", "outputs": [], "execution_count": 132 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:44.376299140Z", "start_time": "2026-03-29T09:05:44.328437032Z" } }, "cell_type": "code", "source": [ "def load_data_nmt(batch_size, num_steps, num_examples=600):\n", " \"\"\"返回翻译数据集的迭代器和词表\"\"\"\n", " text = preprocess_nmt(read_data_nmt())\n", " source, target = tokenize_nmt(text, num_examples)\n", " src_vocab = d2l.Vocab(source, min_freq=2,\n", " reserved_tokens=['', '', ''])\n", " tgt_vocab = d2l.Vocab(target, min_freq=2,\n", " reserved_tokens=['', '', ''])\n", " src_array, src_valid_len = build_array_nmt(source, src_vocab, num_steps)\n", " tgt_array, tgt_valid_len = build_array_nmt(target, tgt_vocab, num_steps)\n", " data_arrays = (src_array, src_valid_len, tgt_array, tgt_valid_len)\n", " data_iter = d2l.load_array(data_arrays, batch_size)\n", " return data_iter, src_vocab, tgt_vocab" ], "id": "62586b0175993a4f", "outputs": [], "execution_count": 133 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T09:05:45.610515926Z", "start_time": "2026-03-29T09:05:44.378149052Z" } }, "cell_type": "code", "source": [ "train_iter, src_vocab, tgt_vocab = load_data_nmt(batch_size=2, num_steps=8)\n", "for X, X_valid_len, Y, Y_valid_len in train_iter:\n", " print('X:', X.type(torch.int32))\n", " print('X的有效长度:', X_valid_len)\n", " print('Y:', Y.type(torch.int32))\n", " print('Y的有效长度:', Y_valid_len)\n", " break" ], "id": "87a2f147db41a91d", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "X: tensor([[81, 6, 2, 4, 5, 5, 5, 5],\n", " [81, 11, 96, 2, 4, 5, 5, 5]], dtype=torch.int32)\n", "X的有效长度: tensor([4, 5])\n", "Y: tensor([[103, 79, 166, 55, 105, 6, 2, 4],\n", " [100, 171, 75, 2, 4, 5, 5, 5]], dtype=torch.int32)\n", "Y的有效长度: tensor([8, 5])\n" ] } ], "execution_count": 134 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T12:26:28.968567662Z", "start_time": "2026-03-29T12:26:28.880277260Z" } }, "cell_type": "code", "source": [ "class Encoder(nn.Module):\n", " def __init__(self,**kargs):\n", " super(Encoder,self).__init__(**kargs)\n", " def forward(self,X,*args):\n", " raise NotImplementedError(\"必须实现这个方法\")\n", "class Decoder(nn.Module):\n", " \"\"\"编码器-解码器架构的基本解码器接口\"\"\"\n", " def __init__(self, **kwargs):\n", " super(Decoder, self).__init__(**kwargs)\n", " def init_state(self, enc_outputs, *args):\n", " raise NotImplementedError\n", " def forward(self, X, state):\n", " raise NotImplementedError\n", "class EncoderDecoder(nn.Module):\n", " \"\"\"编码器-解码器架构的基类\"\"\"\n", " def __init__(self, encoder, decoder, **kwargs):\n", " super(EncoderDecoder, self).__init__(**kwargs)\n", " self.encoder = encoder\n", " self.decoder = decoder\n", " def forward(self, enc_X, dec_X, *args):\n", " enc_outputs = self.encoder(enc_X, *args)\n", " dec_state = self.decoder.init_state(enc_outputs, *args)\n", " return self.decoder(dec_X, dec_state)\n", "class Seq2SeqEncoder(Encoder):\n", " def __init__(self,vocab_size,embed_size,num_hiddens,num_layers,dropout=0,**kwargs):\n", " super(Seq2SeqEncoder, self).__init__(**kwargs)\n", " self.embedding = nn.Embedding(vocab_size, embed_size)\n", " self.rnn = nn.GRU(embed_size,num_hiddens,num_layers,dropout=dropout)\n", " def forward(self,X,*args):\n", " X = self.embedding(X)\n", " X = X.permute(1,0,2) #(batch,steps,embed_size) -> (steps,batch,embed_size)\n", " output,state = self.rnn(X)\n", " return output,state\n", " # shape of output (steps,batch,num_hiddens)\n", " # shape of state (num_layers,batch,num_hiddens)\n", "encoder = Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16,\n", " num_layers=2)\n", "X = torch.zeros((4,7),dtype=torch.long) #one-hot is integer\n", "output,state = encoder(X)\n", "output.shape\n" ], "id": "d0d01aef4857ee9c", "outputs": [ { "data": { "text/plain": [ "torch.Size([7, 4, 16])" ] }, "execution_count": 139, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 139 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T12:26:43.259123495Z", "start_time": "2026-03-29T12:26:43.204225780Z" } }, "cell_type": "code", "source": "state.shape", "id": "bba15a040c10cb01", "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 4, 16])" ] }, "execution_count": 140, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 140 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T13:47:51.954052296Z", "start_time": "2026-03-29T13:47:51.853781706Z" } }, "cell_type": "code", "source": [ "class Seq2SeqDecoder(Decoder):\n", " def __init__(self,vocab_size,embed_size,num_hiddens,num_layers,dropout=0,**kwargs):\n", " super(Seq2SeqDecoder, self).__init__(**kwargs)\n", " self.embedding = nn.Embedding(vocab_size, embed_size)\n", " self.rnn = nn.GRU(embed_size+num_hiddens,num_hiddens,num_layers,dropout=dropout)\n", " self.dense = nn.Linear(num_hiddens,vocab_size)\n", " def init_state(self,enc_outputs,*args):\n", " return enc_outputs[1]\n", " def forward(self,X,state):\n", " X = self.embedding(X).permute(1,0,2)\n", " context = state[-1].repeat(X.shape[0],1,1)\n", " X_and_context = torch.cat((X,context),2)\n", " output,state = self.rnn(X_and_context,state)\n", " output = self.dense(output).permute(1,0,2)\n", " return output,state" ], "id": "b659bfd2fdcabebe", "outputs": [], "execution_count": 141 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T13:48:04.405972579Z", "start_time": "2026-03-29T13:48:04.307991931Z" } }, "cell_type": "code", "source": [ "decoder = Seq2SeqDecoder(vocab_size=10, embed_size=8, num_hiddens=16,\n", "num_layers=2)\n", "decoder.eval()\n", "state = decoder.init_state(encoder(X))\n", "output, state = decoder(X, state)\n", "output.shape, state.shape" ], "id": "e9c451e560ce3769", "outputs": [ { "data": { "text/plain": [ "(torch.Size([4, 7, 10]), torch.Size([2, 4, 16]))" ] }, "execution_count": 142, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 142 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T13:56:58.324121647Z", "start_time": "2026-03-29T13:56:58.160904863Z" } }, "cell_type": "code", "source": [ "def sequence_mask(X, valid_len, value=0):\n", " \"\"\"在序列中屏蔽不相关的项\"\"\"\n", " maxlen = X.size(1)\n", " mask = torch.arange((maxlen), dtype=torch.float32,\n", " device=X.device)[None, :] < valid_len[:, None]\n", " X[~mask] = value\n", " return X\n", "X = torch.tensor([[1, 2, 3], [4, 5, 6]])\n", "sequence_mask(X, torch.tensor([1, 2]))" ], "id": "9ee2db877c089d48", "outputs": [ { "data": { "text/plain": [ "tensor([[1, 0, 0],\n", " [4, 5, 0]])" ] }, "execution_count": 143, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 143 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T13:57:55.473311100Z", "start_time": "2026-03-29T13:57:55.287150264Z" } }, "cell_type": "code", "source": [ "X = torch.ones(2, 3, 4)\n", "sequence_mask(X, torch.tensor([1, 2]), value=-1)" ], "id": "caca450bdfe7f650", "outputs": [ { "data": { "text/plain": [ "tensor([[[ 1., 1., 1., 1.],\n", " [-1., -1., -1., -1.],\n", " [-1., -1., -1., -1.]],\n", "\n", " [[ 1., 1., 1., 1.],\n", " [ 1., 1., 1., 1.],\n", " [-1., -1., -1., -1.]]])" ] }, "execution_count": 144, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 144 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-29T14:08:10.919373362Z", "start_time": "2026-03-29T14:08:10.636085289Z" } }, "cell_type": "code", "source": [ "class MaskedSoftmaxCELoss(nn.CrossEntropyLoss):\n", " \"\"\"带遮蔽的softmax交叉熵损失函数\"\"\"\n", " # pred的形状:(batch_size,num_steps,vocab_size)\n", " # label的形状:(batch_size,num_steps)\n", " # valid_len的形状:(batch_size,)\n", " def forward(self, pred, label, valid_len):\n", " weights = torch.ones_like(label)\n", " weights = sequence_mask(weights, valid_len)\n", " self.reduction='none'\n", " unweighted_loss = super(MaskedSoftmaxCELoss, self).forward(\n", " pred.permute(0, 2, 1), label)\n", " weighted_loss = (unweighted_loss * weights).mean(dim=1)\n", " return weighted_loss\n", "loss = MaskedSoftmaxCELoss()\n", "loss(torch.ones(3, 4, 10), torch.ones((3, 4), dtype=torch.long),\n", "torch.tensor([4, 2, 0]))" ], "id": "46fc96f0246f32b7", "outputs": [ { "data": { "text/plain": [ "tensor([2.3026, 1.1513, 0.0000])" ] }, "execution_count": 145, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 145 }, { "metadata": {}, "cell_type": "code", "outputs": [], "execution_count": null, "source": "", "id": "69c315b5875fc288" } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 5 }