{ "cells": [ { "cell_type": "code", "id": "initial_id", "metadata": { "collapsed": true, "ExecuteTime": { "end_time": "2026-03-22T08:08:29.467878519Z", "start_time": "2026-03-22T08:08:28.169276694Z" } }, "source": [ "import d2l\n", "import torch\n", "import d2l\n", "import numpy\n", "import torch.nn as nn\n", "import torch.nn.functional as F" ], "outputs": [], "execution_count": 1 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-22T08:08:29.562521238Z", "start_time": "2026-03-22T08:08:29.470022979Z" } }, "cell_type": "code", "source": [ "net = nn.Sequential(nn.Linear(20, 256), nn.ReLU(), nn.Linear(256, 10))\n", "X = torch.rand(2, 20)\n", "net(X)" ], "id": "dcd5590e7795eec1", "outputs": [ { "data": { "text/plain": [ "tensor([[ 0.0362, 0.0737, -0.0211, 0.0666, -0.1115, 0.0158, -0.1162, 0.0884,\n", " 0.1486, -0.1063],\n", " [ 0.1796, -0.0009, 0.1236, -0.0783, -0.0937, -0.0560, 0.0441, 0.0812,\n", " 0.2236, -0.0597]], grad_fn=)" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 2 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-22T08:08:29.719824050Z", "start_time": "2026-03-22T08:08:29.644546772Z" } }, "cell_type": "code", "source": [ "class MLP(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.hidden=nn.Linear(20,256)\n", " self.out=nn.Linear(256,10)\n", " def forward(self,X):\n", " return self.out(F.relu(self.hidden(X)))\n" ], "id": "4ae330604b643cb4", "outputs": [], "execution_count": 3 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-22T08:08:29.957480273Z", "start_time": "2026-03-22T08:08:29.789396471Z" } }, "cell_type": "code", "source": [ "net=MLP()\n", "net(X)" ], "id": "cca55c6c0c7da12f", "outputs": [ { "data": { "text/plain": [ "tensor([[ 0.0376, -0.2522, -0.0243, -0.0838, 0.1215, 0.0258, -0.2358, 0.0799,\n", " 0.0756, 0.0520],\n", " [ 0.0098, -0.2070, 0.0638, 0.1173, 0.0275, 0.0116, -0.0448, -0.0448,\n", " -0.0309, -0.0976]], grad_fn=)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 4 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-22T08:08:30.177898340Z", "start_time": "2026-03-22T08:08:30.069505281Z" } }, "cell_type": "code", "source": [ "class FixedHiddenMLP(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " # 不计算梯度的随机权重参数。因此其在训练期间保持不变\n", " self.rand_weight = torch.rand((20, 20), requires_grad=False)\n", " self.linear = nn.Linear(20, 20)\n", " def forward(self, X):\n", " X = self.linear(X)\n", " # 使用创建的常量参数以及relu和mm函数\n", " X = F.relu(torch.mm(X, self.rand_weight) + 1)\n", " # 复用全连接层。这相当于两个全连接层共享参数\n", " X = self.linear(X)\n", " # 控制流\n", " while X.abs().sum() > 1:\n", " X /= 2\n", " return X.sum()" ], "id": "4518d62611d5e749", "outputs": [], "execution_count": 5 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-22T08:08:30.414648562Z", "start_time": "2026-03-22T08:08:30.251182946Z" } }, "cell_type": "code", "source": [ "net = FixedHiddenMLP()\n", "net(X)" ], "id": "fae0187ece4ed5c6", "outputs": [ { "data": { "text/plain": [ "tensor(0.1704, grad_fn=)" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 6 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-22T08:08:30.514145240Z", "start_time": "2026-03-22T08:08:30.426612891Z" } }, "cell_type": "code", "source": [ "class NestMLP(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.net = nn.Sequential(nn.Linear(20, 64), nn.ReLU(),\n", " nn.Linear(64, 32), nn.ReLU())\n", " self.linear = nn.Linear(32, 16)\n", " def forward(self, X):\n", " return self.linear(self.net(X))\n", " chimera = nn.Sequential(NestMLP(), nn.Linear(16, 20), FixedHiddenMLP())\n", " chimera(X)" ], "id": "407ef13a86453aae", "outputs": [], "execution_count": 7 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-22T08:08:30.617747011Z", "start_time": "2026-03-22T08:08:30.517586238Z" } }, "cell_type": "code", "source": [ "net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 1))\n", "X = torch.rand(size=(2, 4))\n", "net(X)" ], "id": "9f3526f263c7a249", "outputs": [ { "data": { "text/plain": [ "tensor([[-0.2445],\n", " [-0.2901]], grad_fn=)" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 8 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-22T08:08:30.825671299Z", "start_time": "2026-03-22T08:08:30.691388202Z" } }, "cell_type": "code", "source": "print(net[2].state_dict())", "id": "8c73f8daa02ba28b", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "OrderedDict([('weight', tensor([[-0.2116, 0.3448, 0.0726, -0.0626, -0.2922, 0.3172, 0.3025, -0.3025]])), ('bias', tensor([-0.3315]))])\n" ] } ], "execution_count": 9 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-22T08:08:31.020445691Z", "start_time": "2026-03-22T08:08:30.902305394Z" } }, "cell_type": "code", "source": "net[2].state_dict()", "id": "b6fee6b64fb96e3c", "outputs": [ { "data": { "text/plain": [ "OrderedDict([('weight',\n", " tensor([[-0.2116, 0.3448, 0.0726, -0.0626, -0.2922, 0.3172, 0.3025, -0.3025]])),\n", " ('bias', tensor([-0.3315]))])" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 10 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-22T08:08:31.127350738Z", "start_time": "2026-03-22T08:08:31.055037575Z" } }, "cell_type": "code", "source": "print(type(net[2].bias))", "id": "b38e8dc384e038c5", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "execution_count": 11 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-22T08:08:31.206008645Z", "start_time": "2026-03-22T08:08:31.139239226Z" } }, "cell_type": "code", "source": [ "print(net[2].bias)\n", "print(net[2].bias.data)\n" ], "id": "73f12ca3669d9ede", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Parameter containing:\n", "tensor([-0.3315], requires_grad=True)\n", "tensor([-0.3315])\n" ] } ], "execution_count": 12 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-22T08:08:31.286245677Z", "start_time": "2026-03-22T08:08:31.230728228Z" } }, "cell_type": "code", "source": "net[2].weight.grad==None", "id": "db0fe33018c16fac", "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 13 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-22T08:08:31.385905750Z", "start_time": "2026-03-22T08:08:31.307506282Z" } }, "cell_type": "code", "source": [ "print(*[(name, param.shape) for name, param in net[0].named_parameters()])\n", "print(*[(name, param.shape) for name, param in net.named_parameters()])" ], "id": "75847a1c608ee5c7", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "('weight', torch.Size([8, 4])) ('bias', torch.Size([8]))\n", "('0.weight', torch.Size([8, 4])) ('0.bias', torch.Size([8])) ('2.weight', torch.Size([1, 8])) ('2.bias', torch.Size([1]))\n" ] } ], "execution_count": 14 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-22T08:08:31.470489659Z", "start_time": "2026-03-22T08:08:31.391689979Z" } }, "cell_type": "code", "source": "net.state_dict()['2.bias'].data", "id": "cc74913e8742da7d", "outputs": [ { "data": { "text/plain": [ "tensor([-0.3315])" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 15 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-22T08:08:31.522581812Z", "start_time": "2026-03-22T08:08:31.482437170Z" } }, "cell_type": "code", "source": [ "def block1():\n", " return nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 4),nn.ReLU())\n", "def block2():\n", " net = nn.Sequential()\n", " for i in range(4):\n", " net.add_module(f'block{i}', block1())\n", " return net" ], "id": "53c39c5e61fa7bf5", "outputs": [], "execution_count": 16 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-22T08:08:31.643076590Z", "start_time": "2026-03-22T08:08:31.558403449Z" } }, "cell_type": "code", "source": [ "rgnet = nn.Sequential(block2(),nn.Linear(4,1))\n", "rgnet(X)" ], "id": "d3ac7759b619aca", "outputs": [ { "data": { "text/plain": [ "tensor([[-0.3640],\n", " [-0.3640]], grad_fn=)" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 17 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-22T08:08:31.860114658Z", "start_time": "2026-03-22T08:08:31.722546330Z" } }, "cell_type": "code", "source": "print(rgnet)", "id": "8fc60f64b07781e6", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sequential(\n", " (0): Sequential(\n", " (block0): Sequential(\n", " (0): Linear(in_features=4, out_features=8, bias=True)\n", " (1): ReLU()\n", " (2): Linear(in_features=8, out_features=4, bias=True)\n", " (3): ReLU()\n", " )\n", " (block1): Sequential(\n", " (0): Linear(in_features=4, out_features=8, bias=True)\n", " (1): ReLU()\n", " (2): Linear(in_features=8, out_features=4, bias=True)\n", " (3): ReLU()\n", " )\n", " (block2): Sequential(\n", " (0): Linear(in_features=4, out_features=8, bias=True)\n", " (1): ReLU()\n", " (2): Linear(in_features=8, out_features=4, bias=True)\n", " (3): ReLU()\n", " )\n", " (block3): Sequential(\n", " (0): Linear(in_features=4, out_features=8, bias=True)\n", " (1): ReLU()\n", " (2): Linear(in_features=8, out_features=4, bias=True)\n", " (3): ReLU()\n", " )\n", " )\n", " (1): Linear(in_features=4, out_features=1, bias=True)\n", ")\n" ] } ], "execution_count": 18 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-22T08:08:32.103301882Z", "start_time": "2026-03-22T08:08:31.980555778Z" } }, "cell_type": "code", "source": "rgnet[0][1][0].bias.data", "id": "e590aaafca787b50", "outputs": [ { "data": { "text/plain": [ "tensor([ 0.3672, -0.3124, -0.3113, -0.3251, -0.4771, -0.3622, 0.1464, -0.4632])" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 19 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-22T08:08:32.231455211Z", "start_time": "2026-03-22T08:08:32.137730392Z" } }, "cell_type": "code", "source": [ "def init_normal(m):\n", " if type(m) == nn.Linear:\n", " nn.init.normal_(m.weight, mean=0, std=0.01)\n", " nn.init.zeros_(m.bias)\n", "net.apply(init_normal)\n", "net[0].weight.data[0], net[0].bias.data[0]" ], "id": "925ca33221d0a87e", "outputs": [ { "data": { "text/plain": [ "(tensor([-0.0004, 0.0166, -0.0085, -0.0099]), tensor(0.))" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 20 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-22T08:08:32.322842445Z", "start_time": "2026-03-22T08:08:32.234982576Z" } }, "cell_type": "code", "source": [ "def init_xavier(m):\n", " if type(m) == nn.Linear:\n", " nn.init.xavier_uniform_(m.weight)\n", "def init_42(m):\n", " if type(m) == nn.Linear:\n", " nn.init.constant_(m.weight, 42)\n", "\n", "net[0].apply(init_xavier)\n", "net[2].apply(init_42)\n", "print(net[0].weight.data[0])\n", "print(net[2].weight.data)" ], "id": "81e2de84a8c4ef32", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([-0.3265, -0.5057, -0.5062, -0.2116])\n", "tensor([[42., 42., 42., 42., 42., 42., 42., 42.]])\n" ] } ], "execution_count": 21 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-22T08:08:32.377993392Z", "start_time": "2026-03-22T08:08:32.324885649Z" } }, "cell_type": "code", "source": [ "x = torch.arange(4)\n", "torch.save(x, 'x-file')" ], "id": "f05bb378bb60ab9e", "outputs": [], "execution_count": 22 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-22T08:08:32.474581135Z", "start_time": "2026-03-22T08:08:32.386815096Z" } }, "cell_type": "code", "source": [ "x2 = torch.load('x-file')\n", "x2" ], "id": "a74ecaaac0d826c6", "outputs": [ { "data": { "text/plain": [ "tensor([0, 1, 2, 3])" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 23 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-22T08:08:32.492632663Z", "start_time": "2026-03-22T08:08:32.476644136Z" } }, "cell_type": "code", "source": [ "class MLP(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.hidden = nn.Linear(20, 256)\n", " self.output = nn.Linear(256, 10)\n", " def forward(self, x):\n", " return self.output(F.relu(self.hidden(x)))\n", "\n", "net = MLP()\n", "X = torch.randn(size=(2, 20))\n", "Y = net(X)" ], "id": "b42598f0c4a8e801", "outputs": [], "execution_count": 24 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-22T08:08:32.600312816Z", "start_time": "2026-03-22T08:08:32.528496997Z" } }, "cell_type": "code", "source": "torch.save(net.state_dict(), 'mlp.params')", "id": "aaa22eef549caa6f", "outputs": [], "execution_count": 25 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-22T08:08:32.697771732Z", "start_time": "2026-03-22T08:08:32.616231856Z" } }, "cell_type": "code", "source": [ "clone = MLP()\n", "clone.load_state_dict(torch.load('mlp.params'))\n", "clone.eval()" ], "id": "b92f920229abeeae", "outputs": [ { "data": { "text/plain": [ "MLP(\n", " (hidden): Linear(in_features=20, out_features=256, bias=True)\n", " (output): Linear(in_features=256, out_features=10, bias=True)\n", ")" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 26 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-22T08:08:32.782119612Z", "start_time": "2026-03-22T08:08:32.715175504Z" } }, "cell_type": "code", "source": [ "Y_clone = clone(X)\n", "Y_clone == Y" ], "id": "646c9eb6d7cc81c2", "outputs": [ { "data": { "text/plain": [ "tensor([[True, True, True, True, True, True, True, True, True, True],\n", " [True, True, True, True, True, True, True, True, True, True]])" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 27 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-22T08:08:32.908809253Z", "start_time": "2026-03-22T08:08:32.843784246Z" } }, "cell_type": "code", "source": [ "def corr2d(X,K):\n", " h,w=K.shape\n", " Y=torch.ones((X.shape[0]-h+1,X.shape[1]-w+1))\n", " for i in range(Y.shape[0]):\n", " for j in range(Y.shape[1]):\n", " Y[i,j]=(X[i:i+h,j:j+w]*K).sum()\n", " return Y\n" ], "id": "d45f9adfe47fce20", "outputs": [], "execution_count": 28 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-22T08:08:33.174646764Z", "start_time": "2026-03-22T08:08:33.092115317Z" } }, "cell_type": "code", "source": [ "X = torch.tensor([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]])\n", "K = torch.tensor([[0.0, 1.0], [2.0, 3.0]])\n", "corr2d(X,K)" ], "id": "db7279e13647c315", "outputs": [ { "data": { "text/plain": [ "tensor([[19., 25.],\n", " [37., 43.]])" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 29 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-22T08:08:33.292523218Z", "start_time": "2026-03-22T08:08:33.234892901Z" } }, "cell_type": "code", "source": [ "class Conv2D(nn.Module):\n", " def __init__(self, kernel_size):\n", " super().__init__()\n", " self.weight = nn.Parameter(torch.rand(kernel_size))\n", " self.bias = nn.Parameter(torch.zeros(1))\n", " def forward(self, x):\n", " return corr2d(x, self.weight) + self.bias\n" ], "id": "d60be1bd12a1f37e", "outputs": [], "execution_count": 30 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-22T08:08:33.497922435Z", "start_time": "2026-03-22T08:08:33.387838879Z" } }, "cell_type": "code", "source": [ "X = torch.ones((6, 8))\n", "X[:, 2:6] = 0\n", "X" ], "id": "5083789b7a728442", "outputs": [ { "data": { "text/plain": [ "tensor([[1., 1., 0., 0., 0., 0., 1., 1.],\n", " [1., 1., 0., 0., 0., 0., 1., 1.],\n", " [1., 1., 0., 0., 0., 0., 1., 1.],\n", " [1., 1., 0., 0., 0., 0., 1., 1.],\n", " [1., 1., 0., 0., 0., 0., 1., 1.],\n", " [1., 1., 0., 0., 0., 0., 1., 1.]])" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 31 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-22T08:08:33.751910360Z", "start_time": "2026-03-22T08:08:33.561038017Z" } }, "cell_type": "code", "source": [ "K = torch.tensor([[1.0, -1.0]])\n", "Y = corr2d(X, K)\n", "Y" ], "id": "ee8d6bedbde886ad", "outputs": [ { "data": { "text/plain": [ "tensor([[ 0., 1., 0., 0., 0., -1., 0.],\n", " [ 0., 1., 0., 0., 0., -1., 0.],\n", " [ 0., 1., 0., 0., 0., -1., 0.],\n", " [ 0., 1., 0., 0., 0., -1., 0.],\n", " [ 0., 1., 0., 0., 0., -1., 0.],\n", " [ 0., 1., 0., 0., 0., -1., 0.]])" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 32 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-22T08:08:34.156019672Z", "start_time": "2026-03-22T08:08:33.891686033Z" } }, "cell_type": "code", "source": "corr2d(X.t(), K)", "id": "a8278c3837fa9a1c", "outputs": [ { "data": { "text/plain": [ "tensor([[0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.]])" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 33 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-22T08:08:34.438245033Z", "start_time": "2026-03-22T08:08:34.313636464Z" } }, "cell_type": "code", "source": "conv2d = nn.Conv2d(1,1, kernel_size=(1, 2), bias=False)", "id": "ec61cdb61a8cabff", "outputs": [], "execution_count": 34 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-22T08:08:34.619825611Z", "start_time": "2026-03-22T08:08:34.507329239Z" } }, "cell_type": "code", "source": [ "X = X.reshape((1, 1, 6, 8))\n", "Y = Y.reshape((1, 1, 6, 7))\n", "lr = 3e-2" ], "id": "d2fc19d84c79a10", "outputs": [], "execution_count": 35 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-22T08:08:35.980584379Z", "start_time": "2026-03-22T08:08:34.640606389Z" } }, "cell_type": "code", "source": [ "for i in range(100):\n", " Y_hat = conv2d(X)\n", " l = (Y_hat - Y) ** 2\n", " conv2d.zero_grad()\n", " l.sum().backward()\n", " # 迭代卷积核\n", " conv2d.weight.data[:] -= lr * conv2d.weight.grad\n", " if (i + 1) % 20 == 0:\n", " print(f'epoch {i+1}, loss {l.sum():.3f}')" ], "id": "51fbb2e6398a9bd5", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch 20, loss 0.000\n", "epoch 40, loss 0.000\n", "epoch 60, loss 0.000\n", "epoch 80, loss 0.000\n", "epoch 100, loss 0.000\n" ] } ], "execution_count": 36 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-22T08:08:36.147166039Z", "start_time": "2026-03-22T08:08:36.070619821Z" } }, "cell_type": "code", "source": "conv2d.weight.data.reshape((1, 2))\n", "id": "bf53a423f429dfe4", "outputs": [ { "data": { "text/plain": [ "tensor([[ 1.0000, -1.0000]])" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 37 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-22T08:08:36.349805416Z", "start_time": "2026-03-22T08:08:36.243749022Z" } }, "cell_type": "code", "source": [ "\n", "# 为了方便起见,我们定义了一个计算卷积层的函数。\n", "# 此函数初始化卷积层权重,并对输入和输出提高和缩减相应的维数\n", "def comp_conv2d(conv2d, X):\n", "# 这里的(1,1)表示批量大小和通道数都是1\n", " X = X.reshape((1, 1) + X.shape)\n", " Y = conv2d(X)\n", " # 省略前两个维度:批量大小和通道\n", " return Y.reshape(Y.shape[2:])\n", "# 请注意,这里每边都填充了1行或1列,因此总共添加了2行或2列\n", "conv2d = nn.Conv2d(1, 1, kernel_size=3, padding=1)" ], "id": "77b61d8c9a2363cc", "outputs": [], "execution_count": 38 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-22T08:08:36.565573176Z", "start_time": "2026-03-22T08:08:36.473905736Z" } }, "cell_type": "code", "source": [ "X = torch.rand(size=(8, 8))\n", "comp_conv2d(conv2d, X).shape" ], "id": "beda6ffa67ec2677", "outputs": [ { "data": { "text/plain": [ "torch.Size([8, 8])" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 39 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-22T08:08:36.629064391Z", "start_time": "2026-03-22T08:08:36.577410135Z" } }, "cell_type": "code", "source": [ "conv2d = nn.Conv2d(1, 1, kernel_size=(5, 3), padding=(2, 1))\n", "comp_conv2d(conv2d, X).shape" ], "id": "8c51095daea1432d", "outputs": [ { "data": { "text/plain": [ "torch.Size([8, 8])" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 40 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-22T08:08:36.778747354Z", "start_time": "2026-03-22T08:08:36.631642133Z" } }, "cell_type": "code", "source": [ "conv2d = nn.Conv2d(1, 1, kernel_size=3, padding=1, stride=2)\n", "comp_conv2d(conv2d, X).shape" ], "id": "581bf1b15162cbf6", "outputs": [ { "data": { "text/plain": [ "torch.Size([4, 4])" ] }, "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 41 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-22T08:08:36.912486341Z", "start_time": "2026-03-22T08:08:36.816037554Z" } }, "cell_type": "code", "source": [ "conv2d = nn.Conv2d(1, 1, kernel_size=(3, 5), padding=(0, 1), stride=(3, 4))\n", "comp_conv2d(conv2d, X).shape" ], "id": "6f7a2411247baff0", "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 2])" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 42 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-22T08:09:37.230541174Z", "start_time": "2026-03-22T08:09:37.139260256Z" } }, "cell_type": "code", "source": [ "def corr2d_multi_in(X,K):\n", " return sum(corr2d(x,k) for x,k in zip(X,K))\n", "X = torch.tensor([[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]],\n", "[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]])\n", "K = torch.tensor([[[0.0, 1.0], [2.0, 3.0]], [[1.0, 2.0], [3.0, 4.0]]])\n", "corr2d_multi_in(X, K)" ], "id": "7ac0f17f97b2daa8", "outputs": [ { "data": { "text/plain": [ "tensor([[ 56., 72.],\n", " [104., 120.]])" ] }, "execution_count": 50, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 50 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-22T08:09:01.390798830Z", "start_time": "2026-03-22T08:09:01.334900206Z" } }, "cell_type": "code", "source": [ "def corr2d_multi_in_out(X,K) ->torch.Tensor :\n", " return torch.stack([corr2d_multi_in(X,k) for k in K],0)\n" ], "id": "d409110d0d6b4b49", "outputs": [], "execution_count": 47 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-22T08:09:39.731392821Z", "start_time": "2026-03-22T08:09:39.608604541Z" } }, "cell_type": "code", "source": [ "K = torch.stack((K, K + 1, K + 2), 0)\n", "K.shape" ], "id": "4114cd871a627075", "outputs": [ { "data": { "text/plain": [ "torch.Size([3, 2, 2, 2])" ] }, "execution_count": 51, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 51 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-22T08:09:43.297502870Z", "start_time": "2026-03-22T08:09:43.186648920Z" } }, "cell_type": "code", "source": "corr2d_multi_in_out(X, K)", "id": "ce52f41dc9585f8c", "outputs": [ { "data": { "text/plain": [ "tensor([[[ 56., 72.],\n", " [104., 120.]],\n", "\n", " [[ 76., 100.],\n", " [148., 172.]],\n", "\n", " [[ 96., 128.],\n", " [192., 224.]]])" ] }, "execution_count": 52, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 52 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-22T08:23:20.460754568Z", "start_time": "2026-03-22T08:23:20.424813665Z" } }, "cell_type": "code", "source": [ "def corr2d_multi_in_out_1x1(X, K):\n", " h_i,h,w=X.shape\n", " h_o=K.shape[0]\n", " X=X.reshape((h_i,h*w))\n", " print(X.shape)\n", " K=K.reshape((h_o,h_i))\n", " print(K.shape)\n", " Y=torch.matmul(K,X)\n", " return Y.reshape((h_o,h,w))" ], "id": "362d8c692b3c1d75", "outputs": [], "execution_count": 56 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-22T08:23:21.690978973Z", "start_time": "2026-03-22T08:23:21.638506037Z" } }, "cell_type": "code", "source": [ "X = torch.normal(0, 1, (3, 3, 3))\n", "K = torch.normal(0, 1, (2, 3, 1, 1))" ], "id": "28e761f677df8b16", "outputs": [], "execution_count": 57 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-22T08:23:22.844449890Z", "start_time": "2026-03-22T08:23:22.694694019Z" } }, "cell_type": "code", "source": "Y1 = corr2d_multi_in_out_1x1(X, K)", "id": "8eb276fed751a6b9", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([3, 9])\n", "torch.Size([2, 3])\n" ] } ], "execution_count": 58 }, { "metadata": { "ExecuteTime": { "end_time": "2026-03-22T08:24:53.417955565Z", "start_time": "2026-03-22T08:24:53.297421833Z" } }, "cell_type": "code", "source": [ "Y2 = corr2d_multi_in_out(X, K)\n", "assert float(torch.abs(Y1 - Y2).sum()) < 1e-6" ], "id": "be28e27d30f36e2c", "outputs": [], "execution_count": 59 }, { "metadata": {}, "cell_type": "code", "outputs": [], "execution_count": null, "source": "", "id": "3c3f71349a2e54c0" } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 5 }