1313 lines
31 KiB
Text
1313 lines
31 KiB
Text
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "code",
|
||
"id": "initial_id",
|
||
"metadata": {
|
||
"collapsed": true,
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-22T08:08:29.467878519Z",
|
||
"start_time": "2026-03-22T08:08:28.169276694Z"
|
||
}
|
||
},
|
||
"source": [
|
||
"import d2l\n",
|
||
"import torch\n",
|
||
"import d2l\n",
|
||
"import numpy\n",
|
||
"import torch.nn as nn\n",
|
||
"import torch.nn.functional as F"
|
||
],
|
||
"outputs": [],
|
||
"execution_count": 1
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-22T08:08:29.562521238Z",
|
||
"start_time": "2026-03-22T08:08:29.470022979Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"net = nn.Sequential(nn.Linear(20, 256), nn.ReLU(), nn.Linear(256, 10))\n",
|
||
"X = torch.rand(2, 20)\n",
|
||
"net(X)"
|
||
],
|
||
"id": "dcd5590e7795eec1",
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"tensor([[ 0.0362, 0.0737, -0.0211, 0.0666, -0.1115, 0.0158, -0.1162, 0.0884,\n",
|
||
" 0.1486, -0.1063],\n",
|
||
" [ 0.1796, -0.0009, 0.1236, -0.0783, -0.0937, -0.0560, 0.0441, 0.0812,\n",
|
||
" 0.2236, -0.0597]], grad_fn=<AddmmBackward0>)"
|
||
]
|
||
},
|
||
"execution_count": 2,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"execution_count": 2
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-22T08:08:29.719824050Z",
|
||
"start_time": "2026-03-22T08:08:29.644546772Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"class MLP(nn.Module):\n",
|
||
" def __init__(self):\n",
|
||
" super().__init__()\n",
|
||
" self.hidden=nn.Linear(20,256)\n",
|
||
" self.out=nn.Linear(256,10)\n",
|
||
" def forward(self,X):\n",
|
||
" return self.out(F.relu(self.hidden(X)))\n"
|
||
],
|
||
"id": "4ae330604b643cb4",
|
||
"outputs": [],
|
||
"execution_count": 3
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-22T08:08:29.957480273Z",
|
||
"start_time": "2026-03-22T08:08:29.789396471Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"net=MLP()\n",
|
||
"net(X)"
|
||
],
|
||
"id": "cca55c6c0c7da12f",
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"tensor([[ 0.0376, -0.2522, -0.0243, -0.0838, 0.1215, 0.0258, -0.2358, 0.0799,\n",
|
||
" 0.0756, 0.0520],\n",
|
||
" [ 0.0098, -0.2070, 0.0638, 0.1173, 0.0275, 0.0116, -0.0448, -0.0448,\n",
|
||
" -0.0309, -0.0976]], grad_fn=<AddmmBackward0>)"
|
||
]
|
||
},
|
||
"execution_count": 4,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"execution_count": 4
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-22T08:08:30.177898340Z",
|
||
"start_time": "2026-03-22T08:08:30.069505281Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"class FixedHiddenMLP(nn.Module):\n",
|
||
" def __init__(self):\n",
|
||
" super().__init__()\n",
|
||
" # 不计算梯度的随机权重参数。因此其在训练期间保持不变\n",
|
||
" self.rand_weight = torch.rand((20, 20), requires_grad=False)\n",
|
||
" self.linear = nn.Linear(20, 20)\n",
|
||
" def forward(self, X):\n",
|
||
" X = self.linear(X)\n",
|
||
" # 使用创建的常量参数以及relu和mm函数\n",
|
||
" X = F.relu(torch.mm(X, self.rand_weight) + 1)\n",
|
||
" # 复用全连接层。这相当于两个全连接层共享参数\n",
|
||
" X = self.linear(X)\n",
|
||
" # 控制流\n",
|
||
" while X.abs().sum() > 1:\n",
|
||
" X /= 2\n",
|
||
" return X.sum()"
|
||
],
|
||
"id": "4518d62611d5e749",
|
||
"outputs": [],
|
||
"execution_count": 5
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-22T08:08:30.414648562Z",
|
||
"start_time": "2026-03-22T08:08:30.251182946Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"net = FixedHiddenMLP()\n",
|
||
"net(X)"
|
||
],
|
||
"id": "fae0187ece4ed5c6",
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"tensor(0.1704, grad_fn=<SumBackward0>)"
|
||
]
|
||
},
|
||
"execution_count": 6,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"execution_count": 6
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-22T08:08:30.514145240Z",
|
||
"start_time": "2026-03-22T08:08:30.426612891Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"class NestMLP(nn.Module):\n",
|
||
" def __init__(self):\n",
|
||
" super().__init__()\n",
|
||
" self.net = nn.Sequential(nn.Linear(20, 64), nn.ReLU(),\n",
|
||
" nn.Linear(64, 32), nn.ReLU())\n",
|
||
" self.linear = nn.Linear(32, 16)\n",
|
||
" def forward(self, X):\n",
|
||
" return self.linear(self.net(X))\n",
|
||
" chimera = nn.Sequential(NestMLP(), nn.Linear(16, 20), FixedHiddenMLP())\n",
|
||
" chimera(X)"
|
||
],
|
||
"id": "407ef13a86453aae",
|
||
"outputs": [],
|
||
"execution_count": 7
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-22T08:08:30.617747011Z",
|
||
"start_time": "2026-03-22T08:08:30.517586238Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 1))\n",
|
||
"X = torch.rand(size=(2, 4))\n",
|
||
"net(X)"
|
||
],
|
||
"id": "9f3526f263c7a249",
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"tensor([[-0.2445],\n",
|
||
" [-0.2901]], grad_fn=<AddmmBackward0>)"
|
||
]
|
||
},
|
||
"execution_count": 8,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"execution_count": 8
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-22T08:08:30.825671299Z",
|
||
"start_time": "2026-03-22T08:08:30.691388202Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": "print(net[2].state_dict())",
|
||
"id": "8c73f8daa02ba28b",
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"OrderedDict([('weight', tensor([[-0.2116, 0.3448, 0.0726, -0.0626, -0.2922, 0.3172, 0.3025, -0.3025]])), ('bias', tensor([-0.3315]))])\n"
|
||
]
|
||
}
|
||
],
|
||
"execution_count": 9
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-22T08:08:31.020445691Z",
|
||
"start_time": "2026-03-22T08:08:30.902305394Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": "net[2].state_dict()",
|
||
"id": "b6fee6b64fb96e3c",
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"OrderedDict([('weight',\n",
|
||
" tensor([[-0.2116, 0.3448, 0.0726, -0.0626, -0.2922, 0.3172, 0.3025, -0.3025]])),\n",
|
||
" ('bias', tensor([-0.3315]))])"
|
||
]
|
||
},
|
||
"execution_count": 10,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"execution_count": 10
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-22T08:08:31.127350738Z",
|
||
"start_time": "2026-03-22T08:08:31.055037575Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": "print(type(net[2].bias))",
|
||
"id": "b38e8dc384e038c5",
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"<class 'torch.nn.parameter.Parameter'>\n"
|
||
]
|
||
}
|
||
],
|
||
"execution_count": 11
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-22T08:08:31.206008645Z",
|
||
"start_time": "2026-03-22T08:08:31.139239226Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"print(net[2].bias)\n",
|
||
"print(net[2].bias.data)\n"
|
||
],
|
||
"id": "73f12ca3669d9ede",
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Parameter containing:\n",
|
||
"tensor([-0.3315], requires_grad=True)\n",
|
||
"tensor([-0.3315])\n"
|
||
]
|
||
}
|
||
],
|
||
"execution_count": 12
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-22T08:08:31.286245677Z",
|
||
"start_time": "2026-03-22T08:08:31.230728228Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": "net[2].weight.grad==None",
|
||
"id": "db0fe33018c16fac",
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"True"
|
||
]
|
||
},
|
||
"execution_count": 13,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"execution_count": 13
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-22T08:08:31.385905750Z",
|
||
"start_time": "2026-03-22T08:08:31.307506282Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"print(*[(name, param.shape) for name, param in net[0].named_parameters()])\n",
|
||
"print(*[(name, param.shape) for name, param in net.named_parameters()])"
|
||
],
|
||
"id": "75847a1c608ee5c7",
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"('weight', torch.Size([8, 4])) ('bias', torch.Size([8]))\n",
|
||
"('0.weight', torch.Size([8, 4])) ('0.bias', torch.Size([8])) ('2.weight', torch.Size([1, 8])) ('2.bias', torch.Size([1]))\n"
|
||
]
|
||
}
|
||
],
|
||
"execution_count": 14
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-22T08:08:31.470489659Z",
|
||
"start_time": "2026-03-22T08:08:31.391689979Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": "net.state_dict()['2.bias'].data",
|
||
"id": "cc74913e8742da7d",
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"tensor([-0.3315])"
|
||
]
|
||
},
|
||
"execution_count": 15,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"execution_count": 15
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-22T08:08:31.522581812Z",
|
||
"start_time": "2026-03-22T08:08:31.482437170Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"def block1():\n",
|
||
" return nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 4),nn.ReLU())\n",
|
||
"def block2():\n",
|
||
" net = nn.Sequential()\n",
|
||
" for i in range(4):\n",
|
||
" net.add_module(f'block{i}', block1())\n",
|
||
" return net"
|
||
],
|
||
"id": "53c39c5e61fa7bf5",
|
||
"outputs": [],
|
||
"execution_count": 16
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-22T08:08:31.643076590Z",
|
||
"start_time": "2026-03-22T08:08:31.558403449Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"rgnet = nn.Sequential(block2(),nn.Linear(4,1))\n",
|
||
"rgnet(X)"
|
||
],
|
||
"id": "d3ac7759b619aca",
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"tensor([[-0.3640],\n",
|
||
" [-0.3640]], grad_fn=<AddmmBackward0>)"
|
||
]
|
||
},
|
||
"execution_count": 17,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"execution_count": 17
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-22T08:08:31.860114658Z",
|
||
"start_time": "2026-03-22T08:08:31.722546330Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": "print(rgnet)",
|
||
"id": "8fc60f64b07781e6",
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Sequential(\n",
|
||
" (0): Sequential(\n",
|
||
" (block0): Sequential(\n",
|
||
" (0): Linear(in_features=4, out_features=8, bias=True)\n",
|
||
" (1): ReLU()\n",
|
||
" (2): Linear(in_features=8, out_features=4, bias=True)\n",
|
||
" (3): ReLU()\n",
|
||
" )\n",
|
||
" (block1): Sequential(\n",
|
||
" (0): Linear(in_features=4, out_features=8, bias=True)\n",
|
||
" (1): ReLU()\n",
|
||
" (2): Linear(in_features=8, out_features=4, bias=True)\n",
|
||
" (3): ReLU()\n",
|
||
" )\n",
|
||
" (block2): Sequential(\n",
|
||
" (0): Linear(in_features=4, out_features=8, bias=True)\n",
|
||
" (1): ReLU()\n",
|
||
" (2): Linear(in_features=8, out_features=4, bias=True)\n",
|
||
" (3): ReLU()\n",
|
||
" )\n",
|
||
" (block3): Sequential(\n",
|
||
" (0): Linear(in_features=4, out_features=8, bias=True)\n",
|
||
" (1): ReLU()\n",
|
||
" (2): Linear(in_features=8, out_features=4, bias=True)\n",
|
||
" (3): ReLU()\n",
|
||
" )\n",
|
||
" )\n",
|
||
" (1): Linear(in_features=4, out_features=1, bias=True)\n",
|
||
")\n"
|
||
]
|
||
}
|
||
],
|
||
"execution_count": 18
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-22T08:08:32.103301882Z",
|
||
"start_time": "2026-03-22T08:08:31.980555778Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": "rgnet[0][1][0].bias.data",
|
||
"id": "e590aaafca787b50",
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"tensor([ 0.3672, -0.3124, -0.3113, -0.3251, -0.4771, -0.3622, 0.1464, -0.4632])"
|
||
]
|
||
},
|
||
"execution_count": 19,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"execution_count": 19
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-22T08:08:32.231455211Z",
|
||
"start_time": "2026-03-22T08:08:32.137730392Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"def init_normal(m):\n",
|
||
" if type(m) == nn.Linear:\n",
|
||
" nn.init.normal_(m.weight, mean=0, std=0.01)\n",
|
||
" nn.init.zeros_(m.bias)\n",
|
||
"net.apply(init_normal)\n",
|
||
"net[0].weight.data[0], net[0].bias.data[0]"
|
||
],
|
||
"id": "925ca33221d0a87e",
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"(tensor([-0.0004, 0.0166, -0.0085, -0.0099]), tensor(0.))"
|
||
]
|
||
},
|
||
"execution_count": 20,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"execution_count": 20
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-22T08:08:32.322842445Z",
|
||
"start_time": "2026-03-22T08:08:32.234982576Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"def init_xavier(m):\n",
|
||
" if type(m) == nn.Linear:\n",
|
||
" nn.init.xavier_uniform_(m.weight)\n",
|
||
"def init_42(m):\n",
|
||
" if type(m) == nn.Linear:\n",
|
||
" nn.init.constant_(m.weight, 42)\n",
|
||
"\n",
|
||
"net[0].apply(init_xavier)\n",
|
||
"net[2].apply(init_42)\n",
|
||
"print(net[0].weight.data[0])\n",
|
||
"print(net[2].weight.data)"
|
||
],
|
||
"id": "81e2de84a8c4ef32",
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"tensor([-0.3265, -0.5057, -0.5062, -0.2116])\n",
|
||
"tensor([[42., 42., 42., 42., 42., 42., 42., 42.]])\n"
|
||
]
|
||
}
|
||
],
|
||
"execution_count": 21
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-22T08:08:32.377993392Z",
|
||
"start_time": "2026-03-22T08:08:32.324885649Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"x = torch.arange(4)\n",
|
||
"torch.save(x, 'x-file')"
|
||
],
|
||
"id": "f05bb378bb60ab9e",
|
||
"outputs": [],
|
||
"execution_count": 22
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-22T08:08:32.474581135Z",
|
||
"start_time": "2026-03-22T08:08:32.386815096Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"x2 = torch.load('x-file')\n",
|
||
"x2"
|
||
],
|
||
"id": "a74ecaaac0d826c6",
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"tensor([0, 1, 2, 3])"
|
||
]
|
||
},
|
||
"execution_count": 23,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"execution_count": 23
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-22T08:08:32.492632663Z",
|
||
"start_time": "2026-03-22T08:08:32.476644136Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"class MLP(nn.Module):\n",
|
||
" def __init__(self):\n",
|
||
" super().__init__()\n",
|
||
" self.hidden = nn.Linear(20, 256)\n",
|
||
" self.output = nn.Linear(256, 10)\n",
|
||
" def forward(self, x):\n",
|
||
" return self.output(F.relu(self.hidden(x)))\n",
|
||
"\n",
|
||
"net = MLP()\n",
|
||
"X = torch.randn(size=(2, 20))\n",
|
||
"Y = net(X)"
|
||
],
|
||
"id": "b42598f0c4a8e801",
|
||
"outputs": [],
|
||
"execution_count": 24
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-22T08:08:32.600312816Z",
|
||
"start_time": "2026-03-22T08:08:32.528496997Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": "torch.save(net.state_dict(), 'mlp.params')",
|
||
"id": "aaa22eef549caa6f",
|
||
"outputs": [],
|
||
"execution_count": 25
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-22T08:08:32.697771732Z",
|
||
"start_time": "2026-03-22T08:08:32.616231856Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"clone = MLP()\n",
|
||
"clone.load_state_dict(torch.load('mlp.params'))\n",
|
||
"clone.eval()"
|
||
],
|
||
"id": "b92f920229abeeae",
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"MLP(\n",
|
||
" (hidden): Linear(in_features=20, out_features=256, bias=True)\n",
|
||
" (output): Linear(in_features=256, out_features=10, bias=True)\n",
|
||
")"
|
||
]
|
||
},
|
||
"execution_count": 26,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"execution_count": 26
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-22T08:08:32.782119612Z",
|
||
"start_time": "2026-03-22T08:08:32.715175504Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"Y_clone = clone(X)\n",
|
||
"Y_clone == Y"
|
||
],
|
||
"id": "646c9eb6d7cc81c2",
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"tensor([[True, True, True, True, True, True, True, True, True, True],\n",
|
||
" [True, True, True, True, True, True, True, True, True, True]])"
|
||
]
|
||
},
|
||
"execution_count": 27,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"execution_count": 27
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-22T08:08:32.908809253Z",
|
||
"start_time": "2026-03-22T08:08:32.843784246Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"def corr2d(X,K):\n",
|
||
" h,w=K.shape\n",
|
||
" Y=torch.ones((X.shape[0]-h+1,X.shape[1]-w+1))\n",
|
||
" for i in range(Y.shape[0]):\n",
|
||
" for j in range(Y.shape[1]):\n",
|
||
" Y[i,j]=(X[i:i+h,j:j+w]*K).sum()\n",
|
||
" return Y\n"
|
||
],
|
||
"id": "d45f9adfe47fce20",
|
||
"outputs": [],
|
||
"execution_count": 28
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-22T08:08:33.174646764Z",
|
||
"start_time": "2026-03-22T08:08:33.092115317Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"X = torch.tensor([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]])\n",
|
||
"K = torch.tensor([[0.0, 1.0], [2.0, 3.0]])\n",
|
||
"corr2d(X,K)"
|
||
],
|
||
"id": "db7279e13647c315",
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"tensor([[19., 25.],\n",
|
||
" [37., 43.]])"
|
||
]
|
||
},
|
||
"execution_count": 29,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"execution_count": 29
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-22T08:08:33.292523218Z",
|
||
"start_time": "2026-03-22T08:08:33.234892901Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"class Conv2D(nn.Module):\n",
|
||
" def __init__(self, kernel_size):\n",
|
||
" super().__init__()\n",
|
||
" self.weight = nn.Parameter(torch.rand(kernel_size))\n",
|
||
" self.bias = nn.Parameter(torch.zeros(1))\n",
|
||
" def forward(self, x):\n",
|
||
" return corr2d(x, self.weight) + self.bias\n"
|
||
],
|
||
"id": "d60be1bd12a1f37e",
|
||
"outputs": [],
|
||
"execution_count": 30
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-22T08:08:33.497922435Z",
|
||
"start_time": "2026-03-22T08:08:33.387838879Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"X = torch.ones((6, 8))\n",
|
||
"X[:, 2:6] = 0\n",
|
||
"X"
|
||
],
|
||
"id": "5083789b7a728442",
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"tensor([[1., 1., 0., 0., 0., 0., 1., 1.],\n",
|
||
" [1., 1., 0., 0., 0., 0., 1., 1.],\n",
|
||
" [1., 1., 0., 0., 0., 0., 1., 1.],\n",
|
||
" [1., 1., 0., 0., 0., 0., 1., 1.],\n",
|
||
" [1., 1., 0., 0., 0., 0., 1., 1.],\n",
|
||
" [1., 1., 0., 0., 0., 0., 1., 1.]])"
|
||
]
|
||
},
|
||
"execution_count": 31,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"execution_count": 31
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-22T08:08:33.751910360Z",
|
||
"start_time": "2026-03-22T08:08:33.561038017Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"K = torch.tensor([[1.0, -1.0]])\n",
|
||
"Y = corr2d(X, K)\n",
|
||
"Y"
|
||
],
|
||
"id": "ee8d6bedbde886ad",
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"tensor([[ 0., 1., 0., 0., 0., -1., 0.],\n",
|
||
" [ 0., 1., 0., 0., 0., -1., 0.],\n",
|
||
" [ 0., 1., 0., 0., 0., -1., 0.],\n",
|
||
" [ 0., 1., 0., 0., 0., -1., 0.],\n",
|
||
" [ 0., 1., 0., 0., 0., -1., 0.],\n",
|
||
" [ 0., 1., 0., 0., 0., -1., 0.]])"
|
||
]
|
||
},
|
||
"execution_count": 32,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"execution_count": 32
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-22T08:08:34.156019672Z",
|
||
"start_time": "2026-03-22T08:08:33.891686033Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": "corr2d(X.t(), K)",
|
||
"id": "a8278c3837fa9a1c",
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"tensor([[0., 0., 0., 0., 0.],\n",
|
||
" [0., 0., 0., 0., 0.],\n",
|
||
" [0., 0., 0., 0., 0.],\n",
|
||
" [0., 0., 0., 0., 0.],\n",
|
||
" [0., 0., 0., 0., 0.],\n",
|
||
" [0., 0., 0., 0., 0.],\n",
|
||
" [0., 0., 0., 0., 0.],\n",
|
||
" [0., 0., 0., 0., 0.]])"
|
||
]
|
||
},
|
||
"execution_count": 33,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"execution_count": 33
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-22T08:08:34.438245033Z",
|
||
"start_time": "2026-03-22T08:08:34.313636464Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": "conv2d = nn.Conv2d(1,1, kernel_size=(1, 2), bias=False)",
|
||
"id": "ec61cdb61a8cabff",
|
||
"outputs": [],
|
||
"execution_count": 34
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-22T08:08:34.619825611Z",
|
||
"start_time": "2026-03-22T08:08:34.507329239Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"X = X.reshape((1, 1, 6, 8))\n",
|
||
"Y = Y.reshape((1, 1, 6, 7))\n",
|
||
"lr = 3e-2"
|
||
],
|
||
"id": "d2fc19d84c79a10",
|
||
"outputs": [],
|
||
"execution_count": 35
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-22T08:08:35.980584379Z",
|
||
"start_time": "2026-03-22T08:08:34.640606389Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"for i in range(100):\n",
|
||
" Y_hat = conv2d(X)\n",
|
||
" l = (Y_hat - Y) ** 2\n",
|
||
" conv2d.zero_grad()\n",
|
||
" l.sum().backward()\n",
|
||
" # 迭代卷积核\n",
|
||
" conv2d.weight.data[:] -= lr * conv2d.weight.grad\n",
|
||
" if (i + 1) % 20 == 0:\n",
|
||
" print(f'epoch {i+1}, loss {l.sum():.3f}')"
|
||
],
|
||
"id": "51fbb2e6398a9bd5",
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"epoch 20, loss 0.000\n",
|
||
"epoch 40, loss 0.000\n",
|
||
"epoch 60, loss 0.000\n",
|
||
"epoch 80, loss 0.000\n",
|
||
"epoch 100, loss 0.000\n"
|
||
]
|
||
}
|
||
],
|
||
"execution_count": 36
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-22T08:08:36.147166039Z",
|
||
"start_time": "2026-03-22T08:08:36.070619821Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": "conv2d.weight.data.reshape((1, 2))\n",
|
||
"id": "bf53a423f429dfe4",
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"tensor([[ 1.0000, -1.0000]])"
|
||
]
|
||
},
|
||
"execution_count": 37,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"execution_count": 37
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-22T08:08:36.349805416Z",
|
||
"start_time": "2026-03-22T08:08:36.243749022Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"\n",
|
||
"# 为了方便起见,我们定义了一个计算卷积层的函数。\n",
|
||
"# 此函数初始化卷积层权重,并对输入和输出提高和缩减相应的维数\n",
|
||
"def comp_conv2d(conv2d, X):\n",
|
||
"# 这里的(1,1)表示批量大小和通道数都是1\n",
|
||
" X = X.reshape((1, 1) + X.shape)\n",
|
||
" Y = conv2d(X)\n",
|
||
" # 省略前两个维度:批量大小和通道\n",
|
||
" return Y.reshape(Y.shape[2:])\n",
|
||
"# 请注意,这里每边都填充了1行或1列,因此总共添加了2行或2列\n",
|
||
"conv2d = nn.Conv2d(1, 1, kernel_size=3, padding=1)"
|
||
],
|
||
"id": "77b61d8c9a2363cc",
|
||
"outputs": [],
|
||
"execution_count": 38
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-22T08:08:36.565573176Z",
|
||
"start_time": "2026-03-22T08:08:36.473905736Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"X = torch.rand(size=(8, 8))\n",
|
||
"comp_conv2d(conv2d, X).shape"
|
||
],
|
||
"id": "beda6ffa67ec2677",
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"torch.Size([8, 8])"
|
||
]
|
||
},
|
||
"execution_count": 39,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"execution_count": 39
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-22T08:08:36.629064391Z",
|
||
"start_time": "2026-03-22T08:08:36.577410135Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"conv2d = nn.Conv2d(1, 1, kernel_size=(5, 3), padding=(2, 1))\n",
|
||
"comp_conv2d(conv2d, X).shape"
|
||
],
|
||
"id": "8c51095daea1432d",
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"torch.Size([8, 8])"
|
||
]
|
||
},
|
||
"execution_count": 40,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"execution_count": 40
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-22T08:08:36.778747354Z",
|
||
"start_time": "2026-03-22T08:08:36.631642133Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"conv2d = nn.Conv2d(1, 1, kernel_size=3, padding=1, stride=2)\n",
|
||
"comp_conv2d(conv2d, X).shape"
|
||
],
|
||
"id": "581bf1b15162cbf6",
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"torch.Size([4, 4])"
|
||
]
|
||
},
|
||
"execution_count": 41,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"execution_count": 41
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-22T08:08:36.912486341Z",
|
||
"start_time": "2026-03-22T08:08:36.816037554Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"conv2d = nn.Conv2d(1, 1, kernel_size=(3, 5), padding=(0, 1), stride=(3, 4))\n",
|
||
"comp_conv2d(conv2d, X).shape"
|
||
],
|
||
"id": "6f7a2411247baff0",
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"torch.Size([2, 2])"
|
||
]
|
||
},
|
||
"execution_count": 42,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"execution_count": 42
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-22T08:09:37.230541174Z",
|
||
"start_time": "2026-03-22T08:09:37.139260256Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"def corr2d_multi_in(X,K):\n",
|
||
" return sum(corr2d(x,k) for x,k in zip(X,K))\n",
|
||
"X = torch.tensor([[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]],\n",
|
||
"[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]])\n",
|
||
"K = torch.tensor([[[0.0, 1.0], [2.0, 3.0]], [[1.0, 2.0], [3.0, 4.0]]])\n",
|
||
"corr2d_multi_in(X, K)"
|
||
],
|
||
"id": "7ac0f17f97b2daa8",
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"tensor([[ 56., 72.],\n",
|
||
" [104., 120.]])"
|
||
]
|
||
},
|
||
"execution_count": 50,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"execution_count": 50
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-22T08:09:01.390798830Z",
|
||
"start_time": "2026-03-22T08:09:01.334900206Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"def corr2d_multi_in_out(X,K) ->torch.Tensor :\n",
|
||
" return torch.stack([corr2d_multi_in(X,k) for k in K],0)\n"
|
||
],
|
||
"id": "d409110d0d6b4b49",
|
||
"outputs": [],
|
||
"execution_count": 47
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-22T08:09:39.731392821Z",
|
||
"start_time": "2026-03-22T08:09:39.608604541Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"K = torch.stack((K, K + 1, K + 2), 0)\n",
|
||
"K.shape"
|
||
],
|
||
"id": "4114cd871a627075",
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"torch.Size([3, 2, 2, 2])"
|
||
]
|
||
},
|
||
"execution_count": 51,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"execution_count": 51
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-22T08:09:43.297502870Z",
|
||
"start_time": "2026-03-22T08:09:43.186648920Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": "corr2d_multi_in_out(X, K)",
|
||
"id": "ce52f41dc9585f8c",
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"tensor([[[ 56., 72.],\n",
|
||
" [104., 120.]],\n",
|
||
"\n",
|
||
" [[ 76., 100.],\n",
|
||
" [148., 172.]],\n",
|
||
"\n",
|
||
" [[ 96., 128.],\n",
|
||
" [192., 224.]]])"
|
||
]
|
||
},
|
||
"execution_count": 52,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"execution_count": 52
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-22T08:23:20.460754568Z",
|
||
"start_time": "2026-03-22T08:23:20.424813665Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"def corr2d_multi_in_out_1x1(X, K):\n",
|
||
" h_i,h,w=X.shape\n",
|
||
" h_o=K.shape[0]\n",
|
||
" X=X.reshape((h_i,h*w))\n",
|
||
" print(X.shape)\n",
|
||
" K=K.reshape((h_o,h_i))\n",
|
||
" print(K.shape)\n",
|
||
" Y=torch.matmul(K,X)\n",
|
||
" return Y.reshape((h_o,h,w))"
|
||
],
|
||
"id": "362d8c692b3c1d75",
|
||
"outputs": [],
|
||
"execution_count": 56
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-22T08:23:21.690978973Z",
|
||
"start_time": "2026-03-22T08:23:21.638506037Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"X = torch.normal(0, 1, (3, 3, 3))\n",
|
||
"K = torch.normal(0, 1, (2, 3, 1, 1))"
|
||
],
|
||
"id": "28e761f677df8b16",
|
||
"outputs": [],
|
||
"execution_count": 57
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-22T08:23:22.844449890Z",
|
||
"start_time": "2026-03-22T08:23:22.694694019Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": "Y1 = corr2d_multi_in_out_1x1(X, K)",
|
||
"id": "8eb276fed751a6b9",
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"torch.Size([3, 9])\n",
|
||
"torch.Size([2, 3])\n"
|
||
]
|
||
}
|
||
],
|
||
"execution_count": 58
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-22T08:24:53.417955565Z",
|
||
"start_time": "2026-03-22T08:24:53.297421833Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"Y2 = corr2d_multi_in_out(X, K)\n",
|
||
"assert float(torch.abs(Y1 - Y2).sum()) < 1e-6"
|
||
],
|
||
"id": "be28e27d30f36e2c",
|
||
"outputs": [],
|
||
"execution_count": 59
|
||
},
|
||
{
|
||
"metadata": {},
|
||
"cell_type": "code",
|
||
"outputs": [],
|
||
"execution_count": null,
|
||
"source": "",
|
||
"id": "3c3f71349a2e54c0"
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 2
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython2",
|
||
"version": "2.7.6"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
}
|