nn/chapter5.ipynb
yukun-hh 350ecc39e9 chapter5 to 6 over
change environment manager and package mannager from virtualenv to miniconda
2026-03-22 16:28:55 +08:00

1313 lines
31 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "code",
"id": "initial_id",
"metadata": {
"collapsed": true,
"ExecuteTime": {
"end_time": "2026-03-22T08:08:29.467878519Z",
"start_time": "2026-03-22T08:08:28.169276694Z"
}
},
"source": [
"import d2l\n",
"import torch\n",
"import d2l\n",
"import numpy\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F"
],
"outputs": [],
"execution_count": 1
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-22T08:08:29.562521238Z",
"start_time": "2026-03-22T08:08:29.470022979Z"
}
},
"cell_type": "code",
"source": [
"net = nn.Sequential(nn.Linear(20, 256), nn.ReLU(), nn.Linear(256, 10))\n",
"X = torch.rand(2, 20)\n",
"net(X)"
],
"id": "dcd5590e7795eec1",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 0.0362, 0.0737, -0.0211, 0.0666, -0.1115, 0.0158, -0.1162, 0.0884,\n",
" 0.1486, -0.1063],\n",
" [ 0.1796, -0.0009, 0.1236, -0.0783, -0.0937, -0.0560, 0.0441, 0.0812,\n",
" 0.2236, -0.0597]], grad_fn=<AddmmBackward0>)"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 2
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-22T08:08:29.719824050Z",
"start_time": "2026-03-22T08:08:29.644546772Z"
}
},
"cell_type": "code",
"source": [
"class MLP(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.hidden=nn.Linear(20,256)\n",
" self.out=nn.Linear(256,10)\n",
" def forward(self,X):\n",
" return self.out(F.relu(self.hidden(X)))\n"
],
"id": "4ae330604b643cb4",
"outputs": [],
"execution_count": 3
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-22T08:08:29.957480273Z",
"start_time": "2026-03-22T08:08:29.789396471Z"
}
},
"cell_type": "code",
"source": [
"net=MLP()\n",
"net(X)"
],
"id": "cca55c6c0c7da12f",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 0.0376, -0.2522, -0.0243, -0.0838, 0.1215, 0.0258, -0.2358, 0.0799,\n",
" 0.0756, 0.0520],\n",
" [ 0.0098, -0.2070, 0.0638, 0.1173, 0.0275, 0.0116, -0.0448, -0.0448,\n",
" -0.0309, -0.0976]], grad_fn=<AddmmBackward0>)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 4
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-22T08:08:30.177898340Z",
"start_time": "2026-03-22T08:08:30.069505281Z"
}
},
"cell_type": "code",
"source": [
"class FixedHiddenMLP(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" # 不计算梯度的随机权重参数。因此其在训练期间保持不变\n",
" self.rand_weight = torch.rand((20, 20), requires_grad=False)\n",
" self.linear = nn.Linear(20, 20)\n",
" def forward(self, X):\n",
" X = self.linear(X)\n",
" # 使用创建的常量参数以及relu和mm函数\n",
" X = F.relu(torch.mm(X, self.rand_weight) + 1)\n",
" # 复用全连接层。这相当于两个全连接层共享参数\n",
" X = self.linear(X)\n",
" # 控制流\n",
" while X.abs().sum() > 1:\n",
" X /= 2\n",
" return X.sum()"
],
"id": "4518d62611d5e749",
"outputs": [],
"execution_count": 5
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-22T08:08:30.414648562Z",
"start_time": "2026-03-22T08:08:30.251182946Z"
}
},
"cell_type": "code",
"source": [
"net = FixedHiddenMLP()\n",
"net(X)"
],
"id": "fae0187ece4ed5c6",
"outputs": [
{
"data": {
"text/plain": [
"tensor(0.1704, grad_fn=<SumBackward0>)"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 6
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-22T08:08:30.514145240Z",
"start_time": "2026-03-22T08:08:30.426612891Z"
}
},
"cell_type": "code",
"source": [
"class NestMLP(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.net = nn.Sequential(nn.Linear(20, 64), nn.ReLU(),\n",
" nn.Linear(64, 32), nn.ReLU())\n",
" self.linear = nn.Linear(32, 16)\n",
" def forward(self, X):\n",
" return self.linear(self.net(X))\n",
" chimera = nn.Sequential(NestMLP(), nn.Linear(16, 20), FixedHiddenMLP())\n",
" chimera(X)"
],
"id": "407ef13a86453aae",
"outputs": [],
"execution_count": 7
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-22T08:08:30.617747011Z",
"start_time": "2026-03-22T08:08:30.517586238Z"
}
},
"cell_type": "code",
"source": [
"net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 1))\n",
"X = torch.rand(size=(2, 4))\n",
"net(X)"
],
"id": "9f3526f263c7a249",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[-0.2445],\n",
" [-0.2901]], grad_fn=<AddmmBackward0>)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 8
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-22T08:08:30.825671299Z",
"start_time": "2026-03-22T08:08:30.691388202Z"
}
},
"cell_type": "code",
"source": "print(net[2].state_dict())",
"id": "8c73f8daa02ba28b",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"OrderedDict([('weight', tensor([[-0.2116, 0.3448, 0.0726, -0.0626, -0.2922, 0.3172, 0.3025, -0.3025]])), ('bias', tensor([-0.3315]))])\n"
]
}
],
"execution_count": 9
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-22T08:08:31.020445691Z",
"start_time": "2026-03-22T08:08:30.902305394Z"
}
},
"cell_type": "code",
"source": "net[2].state_dict()",
"id": "b6fee6b64fb96e3c",
"outputs": [
{
"data": {
"text/plain": [
"OrderedDict([('weight',\n",
" tensor([[-0.2116, 0.3448, 0.0726, -0.0626, -0.2922, 0.3172, 0.3025, -0.3025]])),\n",
" ('bias', tensor([-0.3315]))])"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 10
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-22T08:08:31.127350738Z",
"start_time": "2026-03-22T08:08:31.055037575Z"
}
},
"cell_type": "code",
"source": "print(type(net[2].bias))",
"id": "b38e8dc384e038c5",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'torch.nn.parameter.Parameter'>\n"
]
}
],
"execution_count": 11
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-22T08:08:31.206008645Z",
"start_time": "2026-03-22T08:08:31.139239226Z"
}
},
"cell_type": "code",
"source": [
"print(net[2].bias)\n",
"print(net[2].bias.data)\n"
],
"id": "73f12ca3669d9ede",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Parameter containing:\n",
"tensor([-0.3315], requires_grad=True)\n",
"tensor([-0.3315])\n"
]
}
],
"execution_count": 12
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-22T08:08:31.286245677Z",
"start_time": "2026-03-22T08:08:31.230728228Z"
}
},
"cell_type": "code",
"source": "net[2].weight.grad==None",
"id": "db0fe33018c16fac",
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 13
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-22T08:08:31.385905750Z",
"start_time": "2026-03-22T08:08:31.307506282Z"
}
},
"cell_type": "code",
"source": [
"print(*[(name, param.shape) for name, param in net[0].named_parameters()])\n",
"print(*[(name, param.shape) for name, param in net.named_parameters()])"
],
"id": "75847a1c608ee5c7",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"('weight', torch.Size([8, 4])) ('bias', torch.Size([8]))\n",
"('0.weight', torch.Size([8, 4])) ('0.bias', torch.Size([8])) ('2.weight', torch.Size([1, 8])) ('2.bias', torch.Size([1]))\n"
]
}
],
"execution_count": 14
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-22T08:08:31.470489659Z",
"start_time": "2026-03-22T08:08:31.391689979Z"
}
},
"cell_type": "code",
"source": "net.state_dict()['2.bias'].data",
"id": "cc74913e8742da7d",
"outputs": [
{
"data": {
"text/plain": [
"tensor([-0.3315])"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 15
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-22T08:08:31.522581812Z",
"start_time": "2026-03-22T08:08:31.482437170Z"
}
},
"cell_type": "code",
"source": [
"def block1():\n",
" return nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 4),nn.ReLU())\n",
"def block2():\n",
" net = nn.Sequential()\n",
" for i in range(4):\n",
" net.add_module(f'block{i}', block1())\n",
" return net"
],
"id": "53c39c5e61fa7bf5",
"outputs": [],
"execution_count": 16
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-22T08:08:31.643076590Z",
"start_time": "2026-03-22T08:08:31.558403449Z"
}
},
"cell_type": "code",
"source": [
"rgnet = nn.Sequential(block2(),nn.Linear(4,1))\n",
"rgnet(X)"
],
"id": "d3ac7759b619aca",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[-0.3640],\n",
" [-0.3640]], grad_fn=<AddmmBackward0>)"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 17
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-22T08:08:31.860114658Z",
"start_time": "2026-03-22T08:08:31.722546330Z"
}
},
"cell_type": "code",
"source": "print(rgnet)",
"id": "8fc60f64b07781e6",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sequential(\n",
" (0): Sequential(\n",
" (block0): Sequential(\n",
" (0): Linear(in_features=4, out_features=8, bias=True)\n",
" (1): ReLU()\n",
" (2): Linear(in_features=8, out_features=4, bias=True)\n",
" (3): ReLU()\n",
" )\n",
" (block1): Sequential(\n",
" (0): Linear(in_features=4, out_features=8, bias=True)\n",
" (1): ReLU()\n",
" (2): Linear(in_features=8, out_features=4, bias=True)\n",
" (3): ReLU()\n",
" )\n",
" (block2): Sequential(\n",
" (0): Linear(in_features=4, out_features=8, bias=True)\n",
" (1): ReLU()\n",
" (2): Linear(in_features=8, out_features=4, bias=True)\n",
" (3): ReLU()\n",
" )\n",
" (block3): Sequential(\n",
" (0): Linear(in_features=4, out_features=8, bias=True)\n",
" (1): ReLU()\n",
" (2): Linear(in_features=8, out_features=4, bias=True)\n",
" (3): ReLU()\n",
" )\n",
" )\n",
" (1): Linear(in_features=4, out_features=1, bias=True)\n",
")\n"
]
}
],
"execution_count": 18
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-22T08:08:32.103301882Z",
"start_time": "2026-03-22T08:08:31.980555778Z"
}
},
"cell_type": "code",
"source": "rgnet[0][1][0].bias.data",
"id": "e590aaafca787b50",
"outputs": [
{
"data": {
"text/plain": [
"tensor([ 0.3672, -0.3124, -0.3113, -0.3251, -0.4771, -0.3622, 0.1464, -0.4632])"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 19
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-22T08:08:32.231455211Z",
"start_time": "2026-03-22T08:08:32.137730392Z"
}
},
"cell_type": "code",
"source": [
"def init_normal(m):\n",
" if type(m) == nn.Linear:\n",
" nn.init.normal_(m.weight, mean=0, std=0.01)\n",
" nn.init.zeros_(m.bias)\n",
"net.apply(init_normal)\n",
"net[0].weight.data[0], net[0].bias.data[0]"
],
"id": "925ca33221d0a87e",
"outputs": [
{
"data": {
"text/plain": [
"(tensor([-0.0004, 0.0166, -0.0085, -0.0099]), tensor(0.))"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 20
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-22T08:08:32.322842445Z",
"start_time": "2026-03-22T08:08:32.234982576Z"
}
},
"cell_type": "code",
"source": [
"def init_xavier(m):\n",
" if type(m) == nn.Linear:\n",
" nn.init.xavier_uniform_(m.weight)\n",
"def init_42(m):\n",
" if type(m) == nn.Linear:\n",
" nn.init.constant_(m.weight, 42)\n",
"\n",
"net[0].apply(init_xavier)\n",
"net[2].apply(init_42)\n",
"print(net[0].weight.data[0])\n",
"print(net[2].weight.data)"
],
"id": "81e2de84a8c4ef32",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([-0.3265, -0.5057, -0.5062, -0.2116])\n",
"tensor([[42., 42., 42., 42., 42., 42., 42., 42.]])\n"
]
}
],
"execution_count": 21
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-22T08:08:32.377993392Z",
"start_time": "2026-03-22T08:08:32.324885649Z"
}
},
"cell_type": "code",
"source": [
"x = torch.arange(4)\n",
"torch.save(x, 'x-file')"
],
"id": "f05bb378bb60ab9e",
"outputs": [],
"execution_count": 22
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-22T08:08:32.474581135Z",
"start_time": "2026-03-22T08:08:32.386815096Z"
}
},
"cell_type": "code",
"source": [
"x2 = torch.load('x-file')\n",
"x2"
],
"id": "a74ecaaac0d826c6",
"outputs": [
{
"data": {
"text/plain": [
"tensor([0, 1, 2, 3])"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 23
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-22T08:08:32.492632663Z",
"start_time": "2026-03-22T08:08:32.476644136Z"
}
},
"cell_type": "code",
"source": [
"class MLP(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.hidden = nn.Linear(20, 256)\n",
" self.output = nn.Linear(256, 10)\n",
" def forward(self, x):\n",
" return self.output(F.relu(self.hidden(x)))\n",
"\n",
"net = MLP()\n",
"X = torch.randn(size=(2, 20))\n",
"Y = net(X)"
],
"id": "b42598f0c4a8e801",
"outputs": [],
"execution_count": 24
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-22T08:08:32.600312816Z",
"start_time": "2026-03-22T08:08:32.528496997Z"
}
},
"cell_type": "code",
"source": "torch.save(net.state_dict(), 'mlp.params')",
"id": "aaa22eef549caa6f",
"outputs": [],
"execution_count": 25
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-22T08:08:32.697771732Z",
"start_time": "2026-03-22T08:08:32.616231856Z"
}
},
"cell_type": "code",
"source": [
"clone = MLP()\n",
"clone.load_state_dict(torch.load('mlp.params'))\n",
"clone.eval()"
],
"id": "b92f920229abeeae",
"outputs": [
{
"data": {
"text/plain": [
"MLP(\n",
" (hidden): Linear(in_features=20, out_features=256, bias=True)\n",
" (output): Linear(in_features=256, out_features=10, bias=True)\n",
")"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 26
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-22T08:08:32.782119612Z",
"start_time": "2026-03-22T08:08:32.715175504Z"
}
},
"cell_type": "code",
"source": [
"Y_clone = clone(X)\n",
"Y_clone == Y"
],
"id": "646c9eb6d7cc81c2",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[True, True, True, True, True, True, True, True, True, True],\n",
" [True, True, True, True, True, True, True, True, True, True]])"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 27
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-22T08:08:32.908809253Z",
"start_time": "2026-03-22T08:08:32.843784246Z"
}
},
"cell_type": "code",
"source": [
"def corr2d(X,K):\n",
" h,w=K.shape\n",
" Y=torch.ones((X.shape[0]-h+1,X.shape[1]-w+1))\n",
" for i in range(Y.shape[0]):\n",
" for j in range(Y.shape[1]):\n",
" Y[i,j]=(X[i:i+h,j:j+w]*K).sum()\n",
" return Y\n"
],
"id": "d45f9adfe47fce20",
"outputs": [],
"execution_count": 28
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-22T08:08:33.174646764Z",
"start_time": "2026-03-22T08:08:33.092115317Z"
}
},
"cell_type": "code",
"source": [
"X = torch.tensor([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]])\n",
"K = torch.tensor([[0.0, 1.0], [2.0, 3.0]])\n",
"corr2d(X,K)"
],
"id": "db7279e13647c315",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[19., 25.],\n",
" [37., 43.]])"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 29
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-22T08:08:33.292523218Z",
"start_time": "2026-03-22T08:08:33.234892901Z"
}
},
"cell_type": "code",
"source": [
"class Conv2D(nn.Module):\n",
" def __init__(self, kernel_size):\n",
" super().__init__()\n",
" self.weight = nn.Parameter(torch.rand(kernel_size))\n",
" self.bias = nn.Parameter(torch.zeros(1))\n",
" def forward(self, x):\n",
" return corr2d(x, self.weight) + self.bias\n"
],
"id": "d60be1bd12a1f37e",
"outputs": [],
"execution_count": 30
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-22T08:08:33.497922435Z",
"start_time": "2026-03-22T08:08:33.387838879Z"
}
},
"cell_type": "code",
"source": [
"X = torch.ones((6, 8))\n",
"X[:, 2:6] = 0\n",
"X"
],
"id": "5083789b7a728442",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[1., 1., 0., 0., 0., 0., 1., 1.],\n",
" [1., 1., 0., 0., 0., 0., 1., 1.],\n",
" [1., 1., 0., 0., 0., 0., 1., 1.],\n",
" [1., 1., 0., 0., 0., 0., 1., 1.],\n",
" [1., 1., 0., 0., 0., 0., 1., 1.],\n",
" [1., 1., 0., 0., 0., 0., 1., 1.]])"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 31
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-22T08:08:33.751910360Z",
"start_time": "2026-03-22T08:08:33.561038017Z"
}
},
"cell_type": "code",
"source": [
"K = torch.tensor([[1.0, -1.0]])\n",
"Y = corr2d(X, K)\n",
"Y"
],
"id": "ee8d6bedbde886ad",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 0., 1., 0., 0., 0., -1., 0.],\n",
" [ 0., 1., 0., 0., 0., -1., 0.],\n",
" [ 0., 1., 0., 0., 0., -1., 0.],\n",
" [ 0., 1., 0., 0., 0., -1., 0.],\n",
" [ 0., 1., 0., 0., 0., -1., 0.],\n",
" [ 0., 1., 0., 0., 0., -1., 0.]])"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 32
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-22T08:08:34.156019672Z",
"start_time": "2026-03-22T08:08:33.891686033Z"
}
},
"cell_type": "code",
"source": "corr2d(X.t(), K)",
"id": "a8278c3837fa9a1c",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0.]])"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 33
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-22T08:08:34.438245033Z",
"start_time": "2026-03-22T08:08:34.313636464Z"
}
},
"cell_type": "code",
"source": "conv2d = nn.Conv2d(1,1, kernel_size=(1, 2), bias=False)",
"id": "ec61cdb61a8cabff",
"outputs": [],
"execution_count": 34
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-22T08:08:34.619825611Z",
"start_time": "2026-03-22T08:08:34.507329239Z"
}
},
"cell_type": "code",
"source": [
"X = X.reshape((1, 1, 6, 8))\n",
"Y = Y.reshape((1, 1, 6, 7))\n",
"lr = 3e-2"
],
"id": "d2fc19d84c79a10",
"outputs": [],
"execution_count": 35
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-22T08:08:35.980584379Z",
"start_time": "2026-03-22T08:08:34.640606389Z"
}
},
"cell_type": "code",
"source": [
"for i in range(100):\n",
" Y_hat = conv2d(X)\n",
" l = (Y_hat - Y) ** 2\n",
" conv2d.zero_grad()\n",
" l.sum().backward()\n",
" # 迭代卷积核\n",
" conv2d.weight.data[:] -= lr * conv2d.weight.grad\n",
" if (i + 1) % 20 == 0:\n",
" print(f'epoch {i+1}, loss {l.sum():.3f}')"
],
"id": "51fbb2e6398a9bd5",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 20, loss 0.000\n",
"epoch 40, loss 0.000\n",
"epoch 60, loss 0.000\n",
"epoch 80, loss 0.000\n",
"epoch 100, loss 0.000\n"
]
}
],
"execution_count": 36
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-22T08:08:36.147166039Z",
"start_time": "2026-03-22T08:08:36.070619821Z"
}
},
"cell_type": "code",
"source": "conv2d.weight.data.reshape((1, 2))\n",
"id": "bf53a423f429dfe4",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 1.0000, -1.0000]])"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 37
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-22T08:08:36.349805416Z",
"start_time": "2026-03-22T08:08:36.243749022Z"
}
},
"cell_type": "code",
"source": [
"\n",
"# 为了方便起见,我们定义了一个计算卷积层的函数。\n",
"# 此函数初始化卷积层权重,并对输入和输出提高和缩减相应的维数\n",
"def comp_conv2d(conv2d, X):\n",
"# 这里的11表示批量大小和通道数都是1\n",
" X = X.reshape((1, 1) + X.shape)\n",
" Y = conv2d(X)\n",
" # 省略前两个维度:批量大小和通道\n",
" return Y.reshape(Y.shape[2:])\n",
"# 请注意这里每边都填充了1行或1列因此总共添加了2行或2列\n",
"conv2d = nn.Conv2d(1, 1, kernel_size=3, padding=1)"
],
"id": "77b61d8c9a2363cc",
"outputs": [],
"execution_count": 38
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-22T08:08:36.565573176Z",
"start_time": "2026-03-22T08:08:36.473905736Z"
}
},
"cell_type": "code",
"source": [
"X = torch.rand(size=(8, 8))\n",
"comp_conv2d(conv2d, X).shape"
],
"id": "beda6ffa67ec2677",
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([8, 8])"
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 39
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-22T08:08:36.629064391Z",
"start_time": "2026-03-22T08:08:36.577410135Z"
}
},
"cell_type": "code",
"source": [
"conv2d = nn.Conv2d(1, 1, kernel_size=(5, 3), padding=(2, 1))\n",
"comp_conv2d(conv2d, X).shape"
],
"id": "8c51095daea1432d",
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([8, 8])"
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 40
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-22T08:08:36.778747354Z",
"start_time": "2026-03-22T08:08:36.631642133Z"
}
},
"cell_type": "code",
"source": [
"conv2d = nn.Conv2d(1, 1, kernel_size=3, padding=1, stride=2)\n",
"comp_conv2d(conv2d, X).shape"
],
"id": "581bf1b15162cbf6",
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([4, 4])"
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 41
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-22T08:08:36.912486341Z",
"start_time": "2026-03-22T08:08:36.816037554Z"
}
},
"cell_type": "code",
"source": [
"conv2d = nn.Conv2d(1, 1, kernel_size=(3, 5), padding=(0, 1), stride=(3, 4))\n",
"comp_conv2d(conv2d, X).shape"
],
"id": "6f7a2411247baff0",
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([2, 2])"
]
},
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 42
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-22T08:09:37.230541174Z",
"start_time": "2026-03-22T08:09:37.139260256Z"
}
},
"cell_type": "code",
"source": [
"def corr2d_multi_in(X,K):\n",
" return sum(corr2d(x,k) for x,k in zip(X,K))\n",
"X = torch.tensor([[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]],\n",
"[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]])\n",
"K = torch.tensor([[[0.0, 1.0], [2.0, 3.0]], [[1.0, 2.0], [3.0, 4.0]]])\n",
"corr2d_multi_in(X, K)"
],
"id": "7ac0f17f97b2daa8",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 56., 72.],\n",
" [104., 120.]])"
]
},
"execution_count": 50,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 50
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-22T08:09:01.390798830Z",
"start_time": "2026-03-22T08:09:01.334900206Z"
}
},
"cell_type": "code",
"source": [
"def corr2d_multi_in_out(X,K) ->torch.Tensor :\n",
" return torch.stack([corr2d_multi_in(X,k) for k in K],0)\n"
],
"id": "d409110d0d6b4b49",
"outputs": [],
"execution_count": 47
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-22T08:09:39.731392821Z",
"start_time": "2026-03-22T08:09:39.608604541Z"
}
},
"cell_type": "code",
"source": [
"K = torch.stack((K, K + 1, K + 2), 0)\n",
"K.shape"
],
"id": "4114cd871a627075",
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([3, 2, 2, 2])"
]
},
"execution_count": 51,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 51
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-22T08:09:43.297502870Z",
"start_time": "2026-03-22T08:09:43.186648920Z"
}
},
"cell_type": "code",
"source": "corr2d_multi_in_out(X, K)",
"id": "ce52f41dc9585f8c",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[ 56., 72.],\n",
" [104., 120.]],\n",
"\n",
" [[ 76., 100.],\n",
" [148., 172.]],\n",
"\n",
" [[ 96., 128.],\n",
" [192., 224.]]])"
]
},
"execution_count": 52,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 52
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-22T08:23:20.460754568Z",
"start_time": "2026-03-22T08:23:20.424813665Z"
}
},
"cell_type": "code",
"source": [
"def corr2d_multi_in_out_1x1(X, K):\n",
" h_i,h,w=X.shape\n",
" h_o=K.shape[0]\n",
" X=X.reshape((h_i,h*w))\n",
" print(X.shape)\n",
" K=K.reshape((h_o,h_i))\n",
" print(K.shape)\n",
" Y=torch.matmul(K,X)\n",
" return Y.reshape((h_o,h,w))"
],
"id": "362d8c692b3c1d75",
"outputs": [],
"execution_count": 56
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-22T08:23:21.690978973Z",
"start_time": "2026-03-22T08:23:21.638506037Z"
}
},
"cell_type": "code",
"source": [
"X = torch.normal(0, 1, (3, 3, 3))\n",
"K = torch.normal(0, 1, (2, 3, 1, 1))"
],
"id": "28e761f677df8b16",
"outputs": [],
"execution_count": 57
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-22T08:23:22.844449890Z",
"start_time": "2026-03-22T08:23:22.694694019Z"
}
},
"cell_type": "code",
"source": "Y1 = corr2d_multi_in_out_1x1(X, K)",
"id": "8eb276fed751a6b9",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([3, 9])\n",
"torch.Size([2, 3])\n"
]
}
],
"execution_count": 58
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-22T08:24:53.417955565Z",
"start_time": "2026-03-22T08:24:53.297421833Z"
}
},
"cell_type": "code",
"source": [
"Y2 = corr2d_multi_in_out(X, K)\n",
"assert float(torch.abs(Y1 - Y2).sum()) < 1e-6"
],
"id": "be28e27d30f36e2c",
"outputs": [],
"execution_count": 59
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": "",
"id": "3c3f71349a2e54c0"
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}