nn/test.ipynb

2300 lines
533 KiB
Text
Raw Normal View History

2026-03-12 07:50:33 +00:00
{
"cells": [
{
"cell_type": "code",
"id": "initial_id",
"metadata": {
"collapsed": true,
"ExecuteTime": {
2026-03-15 06:42:56 +00:00
"end_time": "2026-03-15T06:40:53.937066295Z",
"start_time": "2026-03-15T06:40:52.962818496Z"
2026-03-12 07:50:33 +00:00
}
},
"source": [
"import torch\n",
"import numpy\n",
"import pandas\n",
2026-03-14 11:51:56 +00:00
"from sympy.physics.control.control_plots import matplotlib\n",
"from torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook import batched_powerSGD_hook\n",
2026-03-12 07:50:33 +00:00
"\n"
],
"outputs": [],
2026-03-14 11:51:56 +00:00
"execution_count": 1
2026-03-12 07:50:33 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-03-15 06:42:56 +00:00
"end_time": "2026-03-15T06:40:53.993520411Z",
"start_time": "2026-03-15T06:40:53.939990190Z"
2026-03-12 07:50:33 +00:00
}
},
"cell_type": "code",
"source": "torch.randn(3,4,2)",
"id": "3e141a42d342fa96",
"outputs": [
{
"data": {
"text/plain": [
2026-03-15 06:42:56 +00:00
"tensor([[[ 0.5509, -1.6216],\n",
" [ 0.1083, 0.4464],\n",
" [ 1.8819, 0.4029],\n",
" [-0.0733, 2.6961]],\n",
2026-03-12 07:50:33 +00:00
"\n",
2026-03-15 06:42:56 +00:00
" [[-2.0316, 0.7172],\n",
" [-0.3774, 0.5248],\n",
" [-0.0134, 0.3256],\n",
" [ 0.3433, 0.1697]],\n",
2026-03-12 07:50:33 +00:00
"\n",
2026-03-15 06:42:56 +00:00
" [[ 1.1434, 0.6595],\n",
" [ 0.2386, -0.6560],\n",
" [ 1.3177, -0.6876],\n",
" [-1.0916, -0.6199]]])"
2026-03-14 11:51:56 +00:00
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 2
},
{
"metadata": {
"ExecuteTime": {
2026-03-15 06:42:56 +00:00
"end_time": "2026-03-15T06:40:54.023331390Z",
"start_time": "2026-03-15T06:40:53.996610811Z"
2026-03-14 11:51:56 +00:00
}
},
"cell_type": "code",
"source": [
"X = torch.arange(12, dtype=torch.float32).reshape((3,4))\n",
"Y = torch.tensor([[2.0, 1, 4, 3], [1, 2, 3, 4], [4, 3, 2, 1]])\n",
"torch.cat((X, Y), dim=0), torch.cat((X, Y), dim=1)"
],
"id": "8ae20ae68abbf32f",
"outputs": [
{
"data": {
"text/plain": [
"(tensor([[ 0., 1., 2., 3.],\n",
" [ 4., 5., 6., 7.],\n",
" [ 8., 9., 10., 11.],\n",
" [ 2., 1., 4., 3.],\n",
" [ 1., 2., 3., 4.],\n",
" [ 4., 3., 2., 1.]]),\n",
" tensor([[ 0., 1., 2., 3., 2., 1., 4., 3.],\n",
" [ 4., 5., 6., 7., 1., 2., 3., 4.],\n",
" [ 8., 9., 10., 11., 4., 3., 2., 1.]]))"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 3
},
{
"metadata": {
"ExecuteTime": {
2026-03-15 06:42:56 +00:00
"end_time": "2026-03-15T06:40:54.081323347Z",
"start_time": "2026-03-15T06:40:54.028230344Z"
2026-03-14 11:51:56 +00:00
}
},
"cell_type": "code",
"source": [
"a = torch.arange(3).reshape((3, 1))\n",
"b = torch.arange(2).reshape((1, 2))\n",
"a, b\n",
"a+b"
],
"id": "2960a1ded2cdd5a4",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0, 1],\n",
" [1, 2],\n",
" [2, 3]])"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 4
},
{
"metadata": {
"ExecuteTime": {
2026-03-15 06:42:56 +00:00
"end_time": "2026-03-15T06:40:54.538103014Z",
"start_time": "2026-03-15T06:40:54.136512769Z"
2026-03-14 11:51:56 +00:00
}
},
"cell_type": "code",
"source": "X[-1], X[1:3]\n",
"id": "69c2ec23ab6ae97c",
"outputs": [
{
"data": {
"text/plain": [
"(tensor([ 8., 9., 10., 11.]),\n",
" tensor([[ 4., 5., 6., 7.],\n",
" [ 8., 9., 10., 11.]]))"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 5
},
{
"metadata": {
"ExecuteTime": {
2026-03-15 06:42:56 +00:00
"end_time": "2026-03-15T06:40:54.691375092Z",
"start_time": "2026-03-15T06:40:54.563788479Z"
2026-03-14 11:51:56 +00:00
}
},
"cell_type": "code",
"source": [
"A = X.numpy()\n",
"B = torch.tensor(A)\n",
"type(A), type(B)"
],
"id": "b8d779a1bc7e4b1a",
"outputs": [
{
"data": {
"text/plain": [
"(numpy.ndarray, torch.Tensor)"
2026-03-12 07:50:33 +00:00
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 6
},
{
"metadata": {
"ExecuteTime": {
2026-03-15 06:42:56 +00:00
"end_time": "2026-03-15T06:40:54.732824469Z",
"start_time": "2026-03-15T06:40:54.714530988Z"
2026-03-12 07:50:33 +00:00
}
},
"cell_type": "code",
2026-03-14 11:51:56 +00:00
"source": [
"import os\n",
"os.makedirs(os.path.join(\"..\",\"data\"),exist_ok=True)\n",
"data_file = os.path.join(os.path.join(\"..\",\"data\",\"data.csv\"))\n",
"with open(data_file, \"w\") as f:\n",
" f.write('NumRooms,Alley,Price\\n') # 列名\n",
" f.write('NA,Pave,127500\\n') # 每行表示一个数据样本\n",
" f.write('2,NA,106000\\n')\n",
" f.write('4,NA,178100\\n')\n",
" f.write('NA,NA,140000\\n')\n",
"\n"
],
"id": "82be028b0f1dd1e3",
2026-03-12 07:50:33 +00:00
"outputs": [],
2026-03-14 11:51:56 +00:00
"execution_count": 7
},
{
"metadata": {
"ExecuteTime": {
2026-03-15 06:42:56 +00:00
"end_time": "2026-03-15T06:40:54.811604177Z",
"start_time": "2026-03-15T06:40:54.734220773Z"
2026-03-14 11:51:56 +00:00
}
},
"cell_type": "code",
"source": [
"import pandas as pd\n",
"data = pd.read_csv(data_file)\n",
"print(data)\n"
],
"id": "ddd789a2656899d1",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" NumRooms Alley Price\n",
"0 NaN Pave 127500\n",
"1 2.0 NaN 106000\n",
"2 4.0 NaN 178100\n",
"3 NaN NaN 140000\n"
]
}
],
"execution_count": 8
},
{
"metadata": {
"ExecuteTime": {
2026-03-15 06:42:56 +00:00
"end_time": "2026-03-15T06:40:54.853161280Z",
"start_time": "2026-03-15T06:40:54.813874506Z"
2026-03-14 11:51:56 +00:00
}
},
"cell_type": "code",
"source": [
"inputs, outputs = data.iloc[:, 0:2], data.iloc[:, 2]\n",
"\n",
"\n",
"inputs = pd.get_dummies(inputs, dummy_na=True)\n",
"print(inputs)\n",
"inputs = inputs.fillna(inputs.mean())\n",
"print(inputs)\n"
],
"id": "e98fcc3bd4f067cf",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" NumRooms Alley_Pave Alley_nan\n",
"0 NaN True False\n",
"1 2.0 False True\n",
"2 4.0 False True\n",
"3 NaN False True\n",
" NumRooms Alley_Pave Alley_nan\n",
"0 3.0 True False\n",
"1 2.0 False True\n",
"2 4.0 False True\n",
"3 3.0 False True\n"
]
}
],
"execution_count": 9
},
{
"metadata": {
"ExecuteTime": {
2026-03-15 06:42:56 +00:00
"end_time": "2026-03-15T06:40:54.885351255Z",
"start_time": "2026-03-15T06:40:54.855249881Z"
2026-03-14 11:51:56 +00:00
}
},
"cell_type": "code",
"source": [
"X = torch.tensor(inputs.to_numpy(dtype=float))\n",
"y = torch.tensor(outputs.to_numpy(dtype=float))\n",
"X, y\n"
],
"id": "8ff0f7b40f0e4996",
"outputs": [
{
"data": {
"text/plain": [
"(tensor([[3., 1., 0.],\n",
" [2., 0., 1.],\n",
" [4., 0., 1.],\n",
" [3., 0., 1.]], dtype=torch.float64),\n",
" tensor([127500., 106000., 178100., 140000.], dtype=torch.float64))"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 10
},
{
"metadata": {
"ExecuteTime": {
2026-03-15 06:42:56 +00:00
"end_time": "2026-03-15T06:40:54.918539225Z",
"start_time": "2026-03-15T06:40:54.887053932Z"
2026-03-14 11:51:56 +00:00
}
},
"cell_type": "code",
"source": [
"B=torch.tensor([[1,2,3],[2,0,4],[3,4,5]])\n",
"B"
],
"id": "91a6e0da442b95a0",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[1, 2, 3],\n",
" [2, 0, 4],\n",
" [3, 4, 5]])"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 11
},
{
"metadata": {
"ExecuteTime": {
2026-03-15 06:42:56 +00:00
"end_time": "2026-03-15T06:40:55.091544877Z",
"start_time": "2026-03-15T06:40:54.972764187Z"
2026-03-14 11:51:56 +00:00
}
},
"cell_type": "code",
"source": "B==B.T",
"id": "297e6a678fb19be7",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[True, True, True],\n",
" [True, True, True],\n",
" [True, True, True]])"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 12
},
{
"metadata": {
"ExecuteTime": {
2026-03-15 06:42:56 +00:00
"end_time": "2026-03-15T06:40:55.318112438Z",
"start_time": "2026-03-15T06:40:55.208719181Z"
2026-03-14 11:51:56 +00:00
}
},
"cell_type": "code",
"source": [
"X=torch.arange(24).reshape(2,3,4)\n",
"X"
],
"id": "24e864b336beb58b",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[ 0, 1, 2, 3],\n",
" [ 4, 5, 6, 7],\n",
" [ 8, 9, 10, 11]],\n",
"\n",
" [[12, 13, 14, 15],\n",
" [16, 17, 18, 19],\n",
" [20, 21, 22, 23]]])"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 13
},
{
"metadata": {
"ExecuteTime": {
2026-03-15 06:42:56 +00:00
"end_time": "2026-03-15T06:40:55.396075402Z",
"start_time": "2026-03-15T06:40:55.345582838Z"
2026-03-14 11:51:56 +00:00
}
},
"cell_type": "code",
"source": [
"A = torch.arange(20, dtype=torch.float32).reshape(5, 4)\n",
"B = A.clone() # 通过分配新内存将A的一个副本分配给B\n",
"A, A + B\n",
"#A = torch.arange(20, dtype=torch.float32).reshape(5, 4)\n",
"#B = A # 通过分配新内存将A的一个副本分配给B\n",
"id(A),id(B)"
],
"id": "ee0905479b1dbc2b",
"outputs": [
{
"data": {
"text/plain": [
2026-03-15 06:42:56 +00:00
"(140556050244048, 140556050244432)"
2026-03-14 11:51:56 +00:00
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 14
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Hadamard乘积",
"id": "136459f5efe765cf"
},
{
"metadata": {
"ExecuteTime": {
2026-03-15 06:42:56 +00:00
"end_time": "2026-03-15T06:40:55.435745645Z",
"start_time": "2026-03-15T06:40:55.410554891Z"
2026-03-14 11:51:56 +00:00
}
},
"cell_type": "code",
"source": "A*B",
"id": "f576b0df17cc0e98",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 0., 1., 4., 9.],\n",
" [ 16., 25., 36., 49.],\n",
" [ 64., 81., 100., 121.],\n",
" [144., 169., 196., 225.],\n",
" [256., 289., 324., 361.]])"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 15
},
{
"metadata": {
"ExecuteTime": {
2026-03-15 06:42:56 +00:00
"end_time": "2026-03-15T06:40:55.547870490Z",
"start_time": "2026-03-15T06:40:55.455797448Z"
2026-03-14 11:51:56 +00:00
}
},
"cell_type": "code",
"source": [
"a=2\n",
"X=torch.arange(24).reshape(2,3,4)\n",
"a+X,(a*X).shape"
],
"id": "b2373af1d7f2a45",
"outputs": [
{
"data": {
"text/plain": [
"(tensor([[[ 2, 3, 4, 5],\n",
" [ 6, 7, 8, 9],\n",
" [10, 11, 12, 13]],\n",
" \n",
" [[14, 15, 16, 17],\n",
" [18, 19, 20, 21],\n",
" [22, 23, 24, 25]]]),\n",
" torch.Size([2, 3, 4]))"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 16
},
{
"metadata": {
"ExecuteTime": {
2026-03-15 06:42:56 +00:00
"end_time": "2026-03-15T06:40:55.646393885Z",
"start_time": "2026-03-15T06:40:55.571643582Z"
2026-03-14 11:51:56 +00:00
}
},
"cell_type": "code",
"source": [
"print(A)\n",
"A_sum_axis0=A.sum(axis=0)\n",
"A_sum_axis1=A.sum(axis=1)\n",
"A_sum_axis0,A_sum_axis1"
],
"id": "2b50246e1ca8a3bc",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[ 0., 1., 2., 3.],\n",
" [ 4., 5., 6., 7.],\n",
" [ 8., 9., 10., 11.],\n",
" [12., 13., 14., 15.],\n",
" [16., 17., 18., 19.]])\n"
]
},
{
"data": {
"text/plain": [
"(tensor([40., 45., 50., 55.]), tensor([ 6., 22., 38., 54., 70.]))"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 17
},
{
"metadata": {
"ExecuteTime": {
2026-03-15 06:42:56 +00:00
"end_time": "2026-03-15T06:40:55.715281021Z",
"start_time": "2026-03-15T06:40:55.664043259Z"
2026-03-14 11:51:56 +00:00
}
},
"cell_type": "code",
"source": [
"x=torch.arange(4,dtype=torch.float32)\n",
"torch.mv(A,x)"
],
"id": "3195464dfeb554ed",
"outputs": [
{
"data": {
"text/plain": [
"tensor([ 14., 38., 62., 86., 110.])"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 18
},
{
"metadata": {
"ExecuteTime": {
2026-03-15 06:42:56 +00:00
"end_time": "2026-03-15T06:40:55.752354225Z",
"start_time": "2026-03-15T06:40:55.717824570Z"
2026-03-14 11:51:56 +00:00
}
},
"cell_type": "code",
"source": [
"import time\n",
"\n",
"def showtime(func):\n",
" def wrapper():\n",
" start = time.time()\n",
" result = func() # 执行原始函数\n",
" end = time.time()\n",
" print(f\"执行时间: {end - start:.6f}秒\")\n",
" return result\n",
" return wrapper # 返回包装函数\n",
"\n",
"@showtime\n",
"def fun():\n",
" print(\"I am silly\")\n",
"\n",
"fun()\n"
],
"id": "ebda8c74ead3e42b",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"I am silly\n",
2026-03-15 06:42:56 +00:00
"执行时间: 0.000183秒\n"
2026-03-14 11:51:56 +00:00
]
}
],
"execution_count": 19
},
{
"metadata": {
"ExecuteTime": {
2026-03-15 06:42:56 +00:00
"end_time": "2026-03-15T06:40:55.822293071Z",
"start_time": "2026-03-15T06:40:55.771678228Z"
2026-03-14 11:51:56 +00:00
}
},
"cell_type": "code",
"source": "torch.norm(torch.ones((4, 9)))",
"id": "3343cc0c01d0161c",
"outputs": [
{
"data": {
"text/plain": [
"tensor(6.)"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 20
},
{
"metadata": {
"ExecuteTime": {
2026-03-15 06:42:56 +00:00
"end_time": "2026-03-15T06:40:55.831732982Z",
"start_time": "2026-03-15T06:40:55.824867487Z"
2026-03-14 11:51:56 +00:00
}
},
"cell_type": "code",
"source": [
"x =torch.arange(4.0,requires_grad=True)\n",
"x.grad"
],
"id": "674e2416e9417cfe",
"outputs": [],
"execution_count": 21
},
{
"metadata": {
"ExecuteTime": {
2026-03-15 06:42:56 +00:00
"end_time": "2026-03-15T06:40:55.870751082Z",
"start_time": "2026-03-15T06:40:55.832780649Z"
2026-03-14 11:51:56 +00:00
}
},
"cell_type": "code",
"source": [
"y=2*torch.dot(x,x)\n",
"y"
],
"id": "66c0febebcf98cde",
"outputs": [
{
"data": {
"text/plain": [
"tensor(28., grad_fn=<MulBackward0>)"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 22
},
{
"metadata": {
"ExecuteTime": {
2026-03-15 06:42:56 +00:00
"end_time": "2026-03-15T06:40:55.897976833Z",
"start_time": "2026-03-15T06:40:55.872431759Z"
2026-03-14 11:51:56 +00:00
}
},
"cell_type": "code",
"source": [
"y.backward()\n",
"x.grad"
],
"id": "825f2ce6c46ca4a8",
"outputs": [
{
"data": {
"text/plain": [
"tensor([ 0., 4., 8., 12.])"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 23
},
{
"metadata": {
"ExecuteTime": {
2026-03-15 06:42:56 +00:00
"end_time": "2026-03-15T06:40:55.913347196Z",
"start_time": "2026-03-15T06:40:55.901720461Z"
2026-03-14 11:51:56 +00:00
}
},
"cell_type": "code",
"source": [
"x.grad.zero_()\n",
"y = x.sum()\n",
"y.backward()\n",
"x.grad\n"
],
"id": "df399463515e9d3c",
"outputs": [
{
"data": {
"text/plain": [
"tensor([1., 1., 1., 1.])"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 24
},
{
"metadata": {
"ExecuteTime": {
2026-03-15 06:42:56 +00:00
"end_time": "2026-03-15T06:40:55.985375574Z",
"start_time": "2026-03-15T06:40:55.929676346Z"
2026-03-14 11:51:56 +00:00
}
},
"cell_type": "code",
"source": [
"# 对非标量调用backward需要传入一个gradient参数该参数指定微分函数关于self的梯度。\n",
"# 本例只想求偏导数的和所以传递一个1的梯度是合适的\n",
"x.grad.zero_()\n",
"y = x * x\n",
"# 等价于y.backward(torch.ones(len(x)))\n",
"print(y)\n",
"y.sum().backward()\n",
"x.grad"
],
"id": "f9207619bd4b3de8",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([0., 1., 4., 9.], grad_fn=<MulBackward0>)\n"
]
},
{
"data": {
"text/plain": [
"tensor([0., 2., 4., 6.])"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 25
},
{
"metadata": {
"ExecuteTime": {
2026-03-15 06:42:56 +00:00
"end_time": "2026-03-15T06:40:56.018025749Z",
"start_time": "2026-03-15T06:40:55.987257926Z"
2026-03-14 11:51:56 +00:00
}
},
"cell_type": "code",
"source": "torch.ones(len(x))",
"id": "409c14c230570859",
"outputs": [
{
"data": {
"text/plain": [
"tensor([1., 1., 1., 1.])"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 26
},
{
"metadata": {
"ExecuteTime": {
2026-03-15 06:42:56 +00:00
"end_time": "2026-03-15T06:40:56.046029736Z",
"start_time": "2026-03-15T06:40:56.019720336Z"
2026-03-14 11:51:56 +00:00
}
},
"cell_type": "code",
"source": [
"x.grad.zero_()\n",
"y=x*x\n",
"u=y.detach()\n",
"z=u*x\n",
"z.sum().backward()\n",
"x.grad==u"
],
"id": "521b948fe0683b12",
"outputs": [
{
"data": {
"text/plain": [
"tensor([True, True, True, True])"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 27
},
{
"metadata": {
"ExecuteTime": {
2026-03-15 06:42:56 +00:00
"end_time": "2026-03-15T06:40:56.086132664Z",
"start_time": "2026-03-15T06:40:56.049185612Z"
2026-03-14 11:51:56 +00:00
}
},
"cell_type": "code",
"source": [
"x.grad.zero_()\n",
"y.sum().backward()\n",
"x.grad==2*x"
],
"id": "b040beecf0632315",
"outputs": [
{
"data": {
"text/plain": [
"tensor([True, True, True, True])"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 28
},
{
"metadata": {
"ExecuteTime": {
2026-03-15 06:42:56 +00:00
"end_time": "2026-03-15T06:40:56.121364505Z",
"start_time": "2026-03-15T06:40:56.091095688Z"
2026-03-14 11:51:56 +00:00
}
},
"cell_type": "code",
"source": [
"from torch.distributions import multinomial\n",
"fair_probs=torch.ones([6])\n",
"fair_probs"
],
"id": "4e6ec763dbea5aa3",
"outputs": [
{
"data": {
"text/plain": [
"tensor([1., 1., 1., 1., 1., 1.])"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 29
},
{
"metadata": {
"ExecuteTime": {
2026-03-15 06:42:56 +00:00
"end_time": "2026-03-15T06:40:56.150868516Z",
"start_time": "2026-03-15T06:40:56.126763544Z"
2026-03-14 11:51:56 +00:00
}
},
"cell_type": "code",
"source": "multinomial.Multinomial(1, fair_probs).sample()",
"id": "f12d5e85bc6ab595",
"outputs": [
{
"data": {
"text/plain": [
2026-03-15 06:42:56 +00:00
"tensor([0., 1., 0., 0., 0., 0.])"
2026-03-14 11:51:56 +00:00
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 30
},
{
"metadata": {
"ExecuteTime": {
2026-03-15 06:42:56 +00:00
"end_time": "2026-03-15T06:40:56.197988031Z",
"start_time": "2026-03-15T06:40:56.152490623Z"
2026-03-14 11:51:56 +00:00
}
},
"cell_type": "code",
"source": [
"counts = multinomial.Multinomial(10, fair_probs).sample((500,))\n",
"\n",
"cum_counts = counts.cumsum(dim=0)\n",
"cum_counts.size()"
],
"id": "b02f43376fd6f1fe",
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([500, 6])"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 31
},
{
"metadata": {
"ExecuteTime": {
2026-03-15 06:42:56 +00:00
"end_time": "2026-03-15T06:40:56.313824234Z",
"start_time": "2026-03-15T06:40:56.200143209Z"
2026-03-14 11:51:56 +00:00
}
},
"cell_type": "code",
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"# 假设 estimates 是你的数据张量\n",
"estimates = cum_counts / cum_counts.sum(dim=1, keepdims=True)\n",
"\n",
"# 设置图形大小 (等效于 d2l.set_figsize)\n",
"plt.figure(figsize=(6, 4.5))\n",
"\n",
"# 绘制每条概率曲线\n",
"for i in range(6):\n",
" plt.plot(estimates[:, i].numpy(),\n",
" label=f\"P(die={i + 1})\") # 使用 f-string 更简洁\n",
"\n",
"# 添加理论概率水平线\n",
"plt.axhline(y=0.167, color='black', linestyle='dashed', label='Theoretical probability')\n",
"\n",
"# 设置坐标轴标签\n",
"plt.xlabel('Groups of experiments')\n",
"plt.ylabel('Estimated probability')\n",
"\n",
"# 添加图例\n",
"plt.legend()\n",
"\n",
"# 显示图形\n",
"plt.show()\n",
"#plt.savefig('dice_probability.png', bbox_inches='tight')"
],
"id": "8b80daa4edd0b066",
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 600x450 with 1 Axes>"
],
2026-03-15 06:42:56 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAikAAAGZCAYAAABMsaH2AAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAAPYQAAD2EBqD+naQAA/CVJREFUeJzs3Xdc1dX/wPHXvXDZl40gKIiKggPce+PONLXcWaYNK0szU8td3/xpubNyj9I00xzlxJm5Udw4EEVB9t6Xy+f3x5UPXgHlKgjmeT4e5yH38zmf8zn3Ive+75kKQEIQBEEQBKGcUZZ1BQRBEARBEAojghRBEARBEMolEaQIgiAIglAuiSBFEARBEIRySQQpgiAIgiCUSyJIEQRBEAShXBJBiiAIgiAI5ZJxWVfgRebq6kpKSkpZV0MQBEEQXjhqtZqIiIjH5hFBylNydXUlPDy8rKshCIIgCC8sNze3xwYqIkh5SnktKG5ubqI1RRAEQRAMoFarCQ8Pf+LnpwhSnlFKSooIUgRBEAShFIiBs4IgCIIglEsiSBEEQRAEoVwSQYogCIIgCOWSGJMiCIJQTBYWFjg6OqJQKMq6KoJQbkmSRGxsLOnp6c9clghSBEEQnkChUDBs2DDatWtX1lURhBfGoUOHWLVqFZIkPXUZIkgRBEF4gmHDhtG2bVs2btxIcHAwOTk5ZV0lQSi3jI2N8fb2pl+/fgCsXLny6csqqUoJgiD8F1laWtKuXTs2btzI33//XdbVEYQXQkhICAD9+/dnw4YNT931IwbOCoIgPIaDgwMAwcHBZVwTQXix5P3NODo6PnUZIkgRBEF4jLxBsqKLRxAMk/c38ywDzctFkPLhhx8SGhpKRkYGJ06coHHjxkXm7d27N6dPnyYhIYHU1FTOnTvHkCFDCuSbPn06ERERpKens2/fPqpXr6533s7Ojl9//ZWkpCQSEhJYvnw5lpaWJf7cBEEQBEF4OmUepPTr14+5c+cyffp0GjRowPnz59mzZw9OTk6F5o+Pj+d///sfzZs3x9fXl1WrVrFq1So6d+4s5/niiy/45JNP+OCDD2jatClpaWns2bMHU1NTOc+6deuoXbs2nTp1okePHrRp04alS5eW+vMVBEEQBKH4pLJMJ06ckBYtWiQ/VigU0r1796Tx48cXu4zAwEBpxowZ8uOIiAhp7Nix8mNra2spIyND6t+/vwRI3t7ekiRJUsOGDeU8Xbp0kbRarVSxYsVi3VOtVkuSJElqtbpEXgeFUimZW1lL5lbWZfr7EEkkkfSTh4eHtHbtWsnDw6PM61Iaae3atdLEiRMfmyc0NFT69NNP5ceSJEm9evUqszo7ODhIUVFRkpubW5m/fiIVnR73t1Pcz9AybUlRqVQ0bNiQgIAA+ZgkSQQEBNC8efNildGhQwdq1qzJkSNHAPD09KRixYp6ZSYnJ3Py5Em5zObNm5OQkEBgYKCcJyAggNzcXJo2bVrofUxMTFCr1XqpJNk6VWTCuqOMWb63RMsVBOHllbdGhSRJZGVlcePGDSZPnoyRkREAvr6+dO/enYULFxpUrouLC7t27SqNKgPw7rvvcvDgQZKSkpAkCRsbG73zcXFxrF27lunTp5daHYTyoUyDFEdHR4yNjYmKitI7HhUVhYuLS5HXWVtbk5KSQnZ2Nn///TejRo2Sg5K86x5XpouLC9HR0XrntVot8fHxRd534sSJJCcnyyk8PNywJ/sEeWvdiJUsBUEoSbt27cLFxQUvLy/mzJnDtGnTGDduHACjRo1i06ZNpKWlGVRmVFQU2dnZpVFdQLey7+7du/n222+LzLNq1SoGDx6MnZ1dqdVDKHtlPiblaaSkpFCvXj0aN27MV199xdy5c2nbtm2p3nPmzJlYW1vLyc3NrUTLl6Rc3Q8iRhGEF4KJudlzT08jKyuLqKgowsLC+PnnnwkICKBnz54olUpef/11duzYoZffycmJ7du3k56ezq1btxg0aFCBMiVJolevXvLjSpUqsXHjRhISEoiLi2Pr1q14eHg8VX0BFixYwKxZszhx4kSRea5cuUJERAS9e/d+6vsI5V+ZLuYWGxtLTk4Ozs7OesednZ2JjIws8jpJkuSFYs6fP4+Pjw8TJ07k8OHD8nWPluHs7ExQUBAAkZGRVKhQQa9MIyMj7O3ti7xvdnZ2qX5zyGtKUSheyLhREF4qJuZmzDx18Lnfd2KT9mRnZD5TGRkZGTg4OODr64utrS1nzpzRO7969WpcXV1p3749Go2GhQsXFni/fJixsTF79uzh+PHjtG7dmpycHCZNmsTu3bvx9fVFo9EwaNAglixZ8th6devWjaNHjxr0XE6dOkXr1q2faUVToXwr0yBFo9EQGBiIv78/27ZtA3TdHf7+/vzwww/FLkepVMozd0JDQ7l//z7+/v6cP38eALVaTdOmTfnpp58AOH78OHZ2djRo0ICzZ88CurEtSqWSkydPluRTLDaJvCBFNKUIglA6/P396dKlC4sWLcLDw4OcnBy9rm8vLy+6d+9O48aN5eBl+PDhj13Irn///iiVSkaMGCEfGzZsGImJibRr1459+/axffv2J763Pk0XekREBPXr1zf4OuHFUebL4s+dO5c1a9Zw5swZTp06xejRo7G0tGTVqlUArFmzhvDwcL788ksAJkyYwJkzZwgJCcHU1JTu3bvz5ptvMnLkSLnM+fPnM2nSJG7cuEFoaChff/01ERERbN26FdCtgrdr1y6WLVvGBx98gEql4ocffmDDhg3cv3//ub8GAFKuCFIE4UWRnZHJxCbty+S+hurRowcpKSmoVCqUSiXr169n2rRp9OzZk6ysLL28Pj4+8pfHPNeuXSMhIaHI8v38/KhevTopKSl6x83MzKhWrRr79u0jNTWV1NRUg+v+JBkZGVhYWJR4uUL5UeZByu+//46TkxMzZszAxcWFoKAgunbtKkf37u7u5ObmyvktLS358ccfqVSpEhkZGQQHBzNkyBB+//13Oc/s2bOxtLRk6dKl2NracvToUbp27ar3Bzl48GB++OEH9u/fT25uLps3b+aTTz55fk+8AHnkbBnWQRCE4nrWbpfn5eDBg4wcOZLs7GwiIiLQarWArrvd0tISlUqFRqN56vKtrKwIDAxk8ODBBc7FxMQAlFp3j729vXwP4b+pzIMUgMWLF7N48eJCz7Vvr/9tZfLkyUyePPmJZU6dOpWpU6cWeT4hIaHQP6qyIj0IxBRi5KwgCCUoLS1NHsP3sLwxerVq1ZK7xoODg+WlIfK6e2rUqPHYGTRnz56lf//+REdHF2hNyVNa3T116tTh0KFDBl8nvDjEKM1yQnowcFb5YP0CQRCE0hQbG0tgYCCtWrWSj12/fp1du3axZMkSmjRpQoMGDVi+fPljd7Bdt24dsbGxbNu2jVatWlGlShXatm3LggUL5FmQqamphISEPDZlZua3TDk7O8vdSAB169bFz89PL1gyNzenYcOG7N0r1pb6LxNBiiAIwktq+fLlBVqUhw0bRkREBIcPH2bLli0sXbq0wLpSD8vIyKBNmzaEhYWxZcsWrl69yooVKzAzMyM5Ofmp6vXBBx8QFBTE8uXLAfjnn38ICgqiZ8+ecp5evXoRFhZmcBeR8OIp86VzX8RU0sviW6htpenbLkjTt10o8+cmkkgi5af/8rL4ZmZm0p07d6RmzZqVeV0MTcePH5cGDhxY5vUQqej0wi+LL+TLm4IMoFCKX4sgCKUvMzOToUOH4ujoWNZVMYiDgwNbtmzht99+K+uqCKWsXAycFUBeFx/d4FnpMVkFQRBKyuHDh8u6CgaLi4vju+++K+tqCM+B+MpeTkgPBylKMcNHEARBEESQUk48HKSIDXwEQRAEQQQp5Ya8wSBi1VlBEARBABGklB8Pd/eITQYFQRAEQQQp5YV+b49oSREEQRAEEaSUE/rdPWVYEUEQBEEoJ0SQUl6I7h5BEARB0CM+DcsJvdk9oilFEITnZO3atUycOPGxeUJDQ/n000/lx5Ik0atXr9KuWpF8fHy4e/cuFhYWZVYH4fkQQUo5obdOighSBEEoAatWrUKSJCR
2026-03-14 11:51:56 +00:00
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
}
],
"execution_count": 32
},
{
"metadata": {
"ExecuteTime": {
2026-03-15 06:42:56 +00:00
"end_time": "2026-03-15T06:40:56.335068964Z",
"start_time": "2026-03-15T06:40:56.317850591Z"
2026-03-14 11:51:56 +00:00
}
},
"cell_type": "code",
"source": [
"import numpy as np\n",
"class Timer:\n",
" \"\"\"记录多次运行时间\"\"\"\n",
" def __init__(self):\n",
" self.times = []\n",
" self.start()\n",
" def start(self):\n",
" \"\"\"启动计时器\"\"\"\n",
" self.tik = time.time()\n",
" def stop(self):\n",
" \"\"\"停止计时器并将时间记录在列表中\"\"\"\n",
" self.times.append(time.time() - self.tik)\n",
" return self.times[-1]\n",
" def avg(self):\n",
" \"\"\"返回平均时间\"\"\"\n",
" return sum(self.times) / len(self.times)\n",
" def sum(self):\n",
" \"\"\"返回时间总和\"\"\"\n",
" return sum(self.times)\n",
" def cumsum(self):\n",
" \"\"\"返回累计时间\"\"\"\n",
" return np.array(self.times).cumsum().tolist()"
],
"id": "4bdbb4999907154a",
"outputs": [],
"execution_count": 33
},
{
"metadata": {
"ExecuteTime": {
2026-03-15 06:42:56 +00:00
"end_time": "2026-03-15T06:40:56.450342214Z",
"start_time": "2026-03-15T06:40:56.336961011Z"
2026-03-14 11:51:56 +00:00
}
},
"cell_type": "code",
"source": [
"n = 10000\n",
"a = torch.ones([n])\n",
"b = torch.ones([n])\n",
"c=torch.zeros(n)\n",
"timer = Timer()\n",
"for i in range(n):\n",
" c[i]=a[i]+b[i]\n",
"f'{timer.stop():.5f} sec'"
],
"id": "c6f71622e2cc578a",
"outputs": [
{
"data": {
"text/plain": [
2026-03-15 06:42:56 +00:00
"'0.03117 sec'"
2026-03-14 11:51:56 +00:00
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 34
},
{
"metadata": {
"ExecuteTime": {
2026-03-15 06:42:56 +00:00
"end_time": "2026-03-15T06:40:56.498587786Z",
"start_time": "2026-03-15T06:40:56.454031052Z"
2026-03-14 11:51:56 +00:00
}
},
"cell_type": "code",
"source": [
"timer.start()\n",
"d=a+b\n",
"f'{timer.stop():.5f} sec'"
],
"id": "2578c79b1214a79f",
"outputs": [
{
"data": {
"text/plain": [
2026-03-15 06:42:56 +00:00
"'0.00041 sec'"
2026-03-14 11:51:56 +00:00
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 35
},
{
"metadata": {
"ExecuteTime": {
2026-03-15 06:42:56 +00:00
"end_time": "2026-03-15T06:40:56.507308186Z",
"start_time": "2026-03-15T06:40:56.500119417Z"
2026-03-14 11:51:56 +00:00
}
},
"cell_type": "code",
"source": [
"import math\n",
"def normal(x, mu, sigma):\n",
" p = 1 / math.sqrt(2 * math.pi * sigma**2)\n",
" return p * np.exp(-0.5 / sigma**2 * (x - mu)**2)"
],
"id": "fd17fdbe38a5f79",
"outputs": [],
"execution_count": 36
},
{
"metadata": {
"ExecuteTime": {
2026-03-15 06:42:56 +00:00
"end_time": "2026-03-15T06:40:56.533559761Z",
"start_time": "2026-03-15T06:40:56.508247056Z"
2026-03-14 11:51:56 +00:00
}
},
"cell_type": "code",
"source": [
"from matplotlib_inline import backend_inline\n",
"def use_svg_display(): #@save\n",
" \"\"\"使用svg格式在Jupyter中显示绘图\"\"\"\n",
" backend_inline.set_matplotlib_formats('svg')\n",
"def set_figsize(figsize=(3.5, 2.5)): #@save\n",
" \"\"\"设置matplotlib的图表大小\"\"\"\n",
" use_svg_display()\n",
" plt.rcParams['figure.figsize'] = figsize\n",
"def set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend):\n",
" \"\"\"设置matplotlib的轴\"\"\"\n",
" axes.set_xlabel(xlabel)\n",
" axes.set_ylabel(ylabel)\n",
" axes.set_xscale(xscale)\n",
" axes.set_yscale(yscale)\n",
" axes.set_xlim(xlim)\n",
" axes.set_ylim(ylim)\n",
" if legend:\n",
" axes.legend(legend)\n",
" axes.grid()\n",
"def plot(X, Y=None, xlabel=None, ylabel=None, legend=None, xlim=None,\n",
"ylim=None, xscale='linear', yscale='linear',\n",
"fmts=('-', 'm--', 'g-.', 'r:'), figsize=(3.5, 2.5), axes=None):\n",
" \"\"\"绘制数据点\"\"\"\n",
" if legend is None:\n",
" legend = []\n",
" set_figsize(figsize)\n",
" axes = axes if axes else plt.gca()\n",
" # 如果X有一个轴输出True\n",
" def has_one_axis(X):\n",
" return (hasattr(X, \"ndim\") and X.ndim == 1 or isinstance(X, list)\n",
"and not hasattr(X[0], \"__len__\"))\n",
" if has_one_axis(X):\n",
" X = [X]\n",
" if Y is None:\n",
" X, Y = [[]] * len(X), X\n",
" elif has_one_axis(Y):\n",
" Y = [Y]\n",
" if len(X) != len(Y):\n",
" X = X * len(Y)\n",
" axes.cla()\n",
" for x, y, fmt in zip(X, Y, fmts):\n",
" if len(x):\n",
" axes.plot(x, y, fmt)\n",
" else:\n",
" axes.plot(y, fmt)\n",
" set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend)"
],
"id": "82158a69cba14da0",
"outputs": [],
"execution_count": 37
},
{
"metadata": {
"ExecuteTime": {
2026-03-15 06:42:56 +00:00
"end_time": "2026-03-15T06:40:56.625694875Z",
"start_time": "2026-03-15T06:40:56.535032024Z"
2026-03-14 11:51:56 +00:00
}
},
"cell_type": "code",
"source": [
"# 再次使用numpy进行可视化\n",
"x = np.arange(-7, 7, 0.01)\n",
"# 均值和标准差对\n",
"params = [(0, 1), (0, 2), (3, 1)]\n",
"plot(x, [normal(x, mu, sigma) for mu, sigma in params], xlabel='x',\n",
"ylabel='p(x)', figsize=(4.5, 2.5),\n",
"legend=[f'mean {mu}, std {sigma}' for mu, sigma in params])"
],
"id": "f69ac10ebc3d13d8",
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 450x250 with 1 Axes>"
],
2026-03-15 06:42:56 +00:00
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"302.08125pt\" height=\"183.35625pt\" viewBox=\"0 0 302.08125 183.35625\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2026-03-15T14:40:56.601846</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.10.8, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 183.35625 \nL 302.08125 183.35625 \nL 302.08125 0 \nL 0 0 \nz\n\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 43.78125 145.8 \nL 294.88125 145.8 \nL 294.88125 7.2 \nL 43.78125 7.2 \nz\n\"/>\n </g>\n <g id=\"matplotlib.axis_1\">\n <g id=\"xtick_1\">\n <g id=\"line2d_1\">\n <path d=\"M 71.511736 145.8 \nL 71.511736 7.2 \n\" clip-path=\"url(#p41fc57fa9b)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_2\">\n <defs>\n <path id=\"m3a442d9e56\" d=\"M 0 0 \nL 0 3.5 \n\" style=\"stroke: #ffffff; stroke-width: 0.8\"/>\n </defs>\n <g>\n <use xlink:href=\"#m3a442d9e56\" x=\"71.511736\" y=\"145.8\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_1\">\n <!-- 6 -->\n <g style=\"fill: #ffffff\" transform=\"translate(64.140642 160.398438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-2212\" d=\"M 678 2272 \nL 4684 2272 \nL 4684 1741 \nL 678 1741 \nL 678 2272 \nz\n\" transform=\"scale(0.015625)\"/>\n <path id=\"DejaVuSans-36\" d=\"M 2113 2584 \nQ 1688 2584 1439 2293 \nQ 1191 2003 1191 1497 \nQ 1191 994 1439 701 \nQ 1688 409 2113 409 \nQ 2538 409 2786 701 \nQ 3034 994 3034 1497 \nQ 3034 2003 2786 2293 \nQ 2538 2584 2113 2584 \nz\nM 3366 4563 \nL 3366 3988 \nQ 3128 4100 2886 4159 \nQ 2644 4219 2406 4219 \nQ 1781 4219 1451 3797 \nQ 1122 3375 1075 2522 \nQ 1259 2794 1537 2939 \nQ 1816 3084 2150 3084 \nQ 2853 3084 3261 2657 \nQ 3669 2231 3669 1497 \nQ 3669 778 3244 343 \nQ 2819 -91 2113 -91 \nQ 1303 -91 875 529 \nQ 447 1150 447 2328 \nQ 447 3434 972 4092 \nQ 1497 4750 2381 4750 \nQ 2619 4750 2861 4703 \nQ 3103 4656 3366 4563 \nz\n\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-2212\"/>\n <use xlink:href=\"#DejaVuSans-36\" transform=\"translate(83.789062 0)\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_2\">\n <g id=\"line2d_3\">\n <path d=\"M 104.145435 145.8 \nL 104.145435 7.2 \n\" clip-path=\"url(#p41fc57fa9b)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_4\">\n <g>\n <use xlink:href=\"#m3a442d9e56\" x=\"104.145435\" y=\"145.8\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_2\">\n <!-- 4 -->\n <g style=\"fill: #ffffff\" transform=\"translate(96.774342 160.398438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-34\" d=\"M 2419 4116 \nL 825 1625 \nL 2419 1625 \nL 2419 4116 \nz\nM 2253 4666 \nL 3047 4666 \nL 3047 1625 \nL 3713 1625 \nL 3713 1100 \nL 3047 1100 \nL 3047 0 \nL 2419 0 \nL 2419 1100 \nL 313 1100 \nL 313 1709 \nL 2253 4666 \nz\n\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-2212\"/>\n <use xlink:href=\"#DejaVuSans-34\" transform=\"translate(83.789062 0)\"/>\n
2026-03-14 11:51:56 +00:00
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
}
],
"execution_count": 38
},
{
"metadata": {
"ExecuteTime": {
2026-03-15 06:42:56 +00:00
"end_time": "2026-03-15T06:40:56.659870613Z",
"start_time": "2026-03-15T06:40:56.640905156Z"
2026-03-14 11:51:56 +00:00
}
},
"cell_type": "code",
"source": [
"#注意一下matmul做向量乘上矩阵的时候不用考虑转置的情况\n",
"def synthetic_data(w, b, num_examples): #@save\n",
" \"\"\"生成y=Xw+b+噪声\"\"\"\n",
" X = torch.normal(0, 1, (num_examples, len(w)))\n",
" y = torch.matmul(X, w) + b\n",
" y += torch.normal(0, 0.01, y.shape)\n",
" return X, y.reshape((-1, 1))\n"
],
"id": "7ed837bdd2b3a26d",
"outputs": [],
"execution_count": 39
},
{
"metadata": {
"ExecuteTime": {
2026-03-15 06:42:56 +00:00
"end_time": "2026-03-15T06:40:56.700642098Z",
"start_time": "2026-03-15T06:40:56.665034974Z"
2026-03-14 11:51:56 +00:00
}
},
"cell_type": "code",
"source": [
"true_w = torch.tensor([2, -3.4])\n",
"true_b = 4.2\n",
"features, labels = synthetic_data(true_w, true_b, 1000)"
],
"id": "5ec2e204a6fd5cb2",
"outputs": [],
"execution_count": 40
},
{
"metadata": {
"ExecuteTime": {
2026-03-15 06:42:56 +00:00
"end_time": "2026-03-15T06:40:56.766436205Z",
"start_time": "2026-03-15T06:40:56.704907868Z"
2026-03-14 11:51:56 +00:00
}
},
"cell_type": "code",
"source": [
"set_figsize()\n",
"plt.scatter(features[:, (1)].detach().numpy(), labels.detach().numpy(), 1)"
],
"id": "38213d46b3d9900d",
"outputs": [
{
"data": {
"text/plain": [
2026-03-15 06:42:56 +00:00
"<matplotlib.collections.PathCollection at 0x7fd5c02b8830>"
2026-03-14 11:51:56 +00:00
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"text/plain": [
"<Figure size 350x250 with 1 Axes>"
],
2026-03-15 06:42:56 +00:00
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"231.442187pt\" height=\"169.678125pt\" viewBox=\"0 0 231.442187 169.678125\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2026-03-15T14:40:56.736997</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.10.8, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 169.678125 \nL 231.442187 169.678125 \nL 231.442187 0 \nL 0 0 \nz\n\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 28.942188 145.8 \nL 224.242188 145.8 \nL 224.242188 7.2 \nL 28.942188 7.2 \nz\n\"/>\n </g>\n <g id=\"PathCollection_1\">\n <defs>\n <path id=\"m31e52db880\" d=\"M 0 0.5 \nC 0.132602 0.5 0.25979 0.447317 0.353553 0.353553 \nC 0.447317 0.25979 0.5 0.132602 0.5 0 \nC 0.5 -0.132602 0.447317 -0.25979 0.353553 -0.353553 \nC 0.25979 -0.447317 0.132602 -0.5 0 -0.5 \nC -0.132602 -0.5 -0.25979 -0.447317 -0.353553 -0.353553 \nC -0.447317 -0.25979 -0.5 -0.132602 -0.5 0 \nC -0.5 0.132602 -0.447317 0.25979 -0.353553 0.353553 \nC -0.25979 0.447317 -0.132602 0.5 0 0.5 \nz\n\" style=\"stroke: #8dd3c7\"/>\n </defs>\n <g clip-path=\"url(#p4989f295ce)\">\n <use xlink:href=\"#m31e52db880\" x=\"107.425452\" y=\"58.204664\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m31e52db880\" x=\"168.09734\" y=\"97.834837\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m31e52db880\" x=\"142.779082\" y=\"101.067423\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m31e52db880\" x=\"132.744701\" y=\"88.534686\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m31e52db880\" x=\"136.882127\" y=\"82.604629\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m31e52db880\" x=\"127.396832\" y=\"72.027719\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m31e52db880\" x=\"109.854126\" y=\"98.996271\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m31e52db880\" x=\"134.061634\" y=\"92.088623\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m31e52db880\" x=\"90.03612\" y=\"50.804814\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m31e52db880\" x=\"121.38814\" y=\"76.82656\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m31e52db880\" x=\"130.512817\" y=\"78.428948\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m31e52db880\" x=\"85.656041\" y=\"37.59452\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m31e52db880\" x=\"122.524868\" y=\"81.482687\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m31e52db880\" x=\"144.431826\" y=\"82.705835\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m31e52db880\" x=\"82.913166\" y=\"56.144404\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m31e52db880\" x=\"77.023104\" y=\"35.081326\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m31e52db880\" x=\"69.558241\" y=\"41.717562\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m31e52db880\" x=\"150.233438\" y=\"108.863484\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m31e52db880\" x=\"147.303054\" y=\"88.662968\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m31e5
2026-03-14 11:51:56 +00:00
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
}
],
"execution_count": 41
},
{
"metadata": {
"ExecuteTime": {
2026-03-15 06:42:56 +00:00
"end_time": "2026-03-15T06:40:56.788076510Z",
"start_time": "2026-03-15T06:40:56.778166680Z"
2026-03-14 11:51:56 +00:00
}
},
"cell_type": "code",
"source": [
"w=torch.normal(0,0.01,size=(2,1),requires_grad=True)\n",
"b=torch.zeros(1,requires_grad=True)\n",
"def linreg(X, w, b):\n",
" return torch.matmul(X,w)+b\n",
"def squared_loss(y_hat,y):\n",
" return (y_hat-y.reshape(y_hat.shape))**2/2\n",
"def sgd(params,lr,batch_size):\n",
" with torch.no_grad():\n",
" for param in params:\n",
" param-=lr*param.grad/batch_size\n",
" param.grad.zero_()\n",
"lr = 0.03\n",
"num_epochs =20\n",
"net = linreg\n",
"loss = squared_loss"
],
"id": "12166e1bc3ddd695",
"outputs": [],
"execution_count": 42
},
{
"metadata": {
"ExecuteTime": {
2026-03-15 06:42:56 +00:00
"end_time": "2026-03-15T06:40:56.811093546Z",
"start_time": "2026-03-15T06:40:56.788825920Z"
2026-03-14 11:51:56 +00:00
}
},
"cell_type": "code",
"source": [
"import random\n",
"def data_iter(batch_size, features, labels):\n",
" num_examples = len(features)\n",
" indices = list(range(num_examples))\n",
" # 这些样本是随机读取的,没有特定的顺序\n",
" random.shuffle(indices)\n",
" for i in range(0, num_examples, batch_size):\n",
" batch_indices = torch.tensor(\n",
" indices[i: min(i + batch_size, num_examples)])\n",
" yield features[batch_indices], labels[batch_indices]"
],
"id": "f3b7ee9f326bc687",
"outputs": [],
"execution_count": 43
},
{
"metadata": {
"ExecuteTime": {
2026-03-15 06:42:56 +00:00
"end_time": "2026-03-15T06:40:56.833447562Z",
"start_time": "2026-03-15T06:40:56.813045993Z"
2026-03-14 11:51:56 +00:00
}
},
"cell_type": "code",
"source": [
"batch_size =10\n",
"for X,y in data_iter(batch_size, features, labels):\n",
" print(X,'\\n',y)\n",
" break"
],
"id": "f386e12d65afff2e",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2026-03-15 06:42:56 +00:00
"tensor([[-1.8455, -0.9126],\n",
" [-0.0299, 1.6530],\n",
" [-2.3513, 0.6457],\n",
" [ 0.1707, 0.8342],\n",
" [ 0.2096, -0.4362],\n",
" [ 0.6160, 1.7403],\n",
" [ 0.4242, 0.0484],\n",
" [-1.4459, 0.7434],\n",
" [ 0.5302, -0.5594],\n",
" [-0.5957, -1.5179]]) \n",
" tensor([[ 3.6056],\n",
" [-1.4872],\n",
" [-2.6965],\n",
" [ 1.6980],\n",
" [ 6.1001],\n",
" [-0.4750],\n",
" [ 4.9009],\n",
" [-1.2270],\n",
" [ 7.1486],\n",
" [ 8.1680]])\n"
2026-03-14 11:51:56 +00:00
]
}
],
"execution_count": 44
},
{
"metadata": {
"ExecuteTime": {
2026-03-15 06:42:56 +00:00
"end_time": "2026-03-15T06:40:57.028635227Z",
"start_time": "2026-03-15T06:40:56.852110495Z"
2026-03-14 11:51:56 +00:00
}
},
"cell_type": "code",
"source": [
"for epoch in range(num_epochs):\n",
" for X, y in data_iter(batch_size, features, labels):\n",
" l=loss(net(X, w, b), y)\n",
" l.sum().backward()\n",
" sgd([w,b],lr,batch_size)\n",
" with torch.no_grad():\n",
" train_l =loss(net(features, w, b), labels)\n",
" print(f'epoch {epoch+1}, train loss: {float(train_l.mean()):3f}')"
],
"id": "8888ab6adcec36f1",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2026-03-15 06:42:56 +00:00
"epoch 1, train loss: 0.030682\n",
"epoch 2, train loss: 0.000100\n",
2026-03-14 11:51:56 +00:00
"epoch 3, train loss: 0.000049\n",
2026-03-15 06:42:56 +00:00
"epoch 4, train loss: 0.000049\n",
"epoch 5, train loss: 0.000049\n",
"epoch 6, train loss: 0.000049\n",
"epoch 7, train loss: 0.000049\n",
"epoch 8, train loss: 0.000049\n",
"epoch 9, train loss: 0.000049\n",
"epoch 10, train loss: 0.000049\n",
"epoch 11, train loss: 0.000049\n",
"epoch 12, train loss: 0.000049\n",
"epoch 13, train loss: 0.000049\n",
"epoch 14, train loss: 0.000049\n",
"epoch 15, train loss: 0.000049\n",
"epoch 16, train loss: 0.000049\n",
"epoch 17, train loss: 0.000049\n",
"epoch 18, train loss: 0.000049\n",
"epoch 19, train loss: 0.000049\n",
"epoch 20, train loss: 0.000049\n"
2026-03-14 11:51:56 +00:00
]
}
],
"execution_count": 45
},
{
"metadata": {
"ExecuteTime": {
2026-03-15 06:42:56 +00:00
"end_time": "2026-03-15T06:40:57.085057122Z",
"start_time": "2026-03-15T06:40:57.045310900Z"
2026-03-14 11:51:56 +00:00
}
},
"cell_type": "code",
"source": [
"print(f'w的估计误差: {true_w - w.reshape(true_w.shape)}')\n",
"print(f'b的估计误差: {true_b - b}')"
],
"id": "8199439fa7f26309",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2026-03-15 06:42:56 +00:00
"w的估计误差: tensor([-0.0002, 0.0003], grad_fn=<SubBackward0>)\n",
"b的估计误差: tensor([0.0002], grad_fn=<RsubBackward1>)\n"
2026-03-14 11:51:56 +00:00
]
}
],
"execution_count": 46
},
{
"metadata": {
"ExecuteTime": {
2026-03-15 06:42:56 +00:00
"end_time": "2026-03-15T06:40:57.146334186Z",
"start_time": "2026-03-15T06:40:57.086581923Z"
2026-03-14 11:51:56 +00:00
}
},
"cell_type": "code",
"source": [
"from torch.utils import data\n",
"true_w = torch.tensor([2,-3.4])\n",
"true_b = 4.2\n",
"features,labels=synthetic_data(true_w, true_b, 1000)\n",
"def load_array(data_arrays,batch_size,is_train=True):\n",
" dataset = data.TensorDataset(*data_arrays)\n",
" return data.DataLoader(dataset,batch_size,shuffle=is_train)\n",
"batch_size = 10\n",
"data_iter = load_array((features,labels),batch_size)"
],
"id": "560d537dcbb5a335",
"outputs": [],
2026-03-15 06:42:56 +00:00
"execution_count": 47
2026-03-14 11:51:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-03-15 06:42:56 +00:00
"end_time": "2026-03-15T06:40:57.187559751Z",
"start_time": "2026-03-15T06:40:57.147857709Z"
2026-03-14 11:51:56 +00:00
}
},
"cell_type": "code",
"source": [
"from torch import nn\n",
"net = nn.Sequential(nn.Linear(2, 1))\n",
"net[0].weight.data.normal_(0,0.001)\n",
"net[0].bias.data.fill_(0)"
],
"id": "c54fe059d6fd20de",
"outputs": [
{
"data": {
"text/plain": [
"tensor([0.])"
]
},
2026-03-15 06:42:56 +00:00
"execution_count": 48,
2026-03-14 11:51:56 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-03-15 06:42:56 +00:00
"execution_count": 48
2026-03-14 11:51:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-03-15 06:42:56 +00:00
"end_time": "2026-03-15T06:40:57.523547926Z",
"start_time": "2026-03-15T06:40:57.191697380Z"
2026-03-14 11:51:56 +00:00
}
},
"cell_type": "code",
"source": [
"loss = nn.MSELoss()\n",
"trainer = torch.optim.SGD(net.parameters(), lr=0.01)\n",
"num_epochs = 3\n",
"for epoch in range(num_epochs):\n",
" for X, y in data_iter:\n",
" l = loss(net(X) ,y)\n",
" trainer.zero_grad()\n",
" l.backward()\n",
" trainer.step()\n",
" l = loss(net(features), labels)\n",
" print(f'epoch {epoch + 1}, loss {l:f}')\n"
],
"id": "e8a44851125b7cc6",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2026-03-15 06:42:56 +00:00
"epoch 1, loss 0.599957\n",
"epoch 2, loss 0.011503\n",
"epoch 3, loss 0.000325\n"
]
}
],
"execution_count": 49
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-15T06:40:57.637796611Z",
"start_time": "2026-03-15T06:40:57.539577276Z"
}
},
"cell_type": "code",
"source": [
"import torchvision\n",
"from torchvision import transforms\n",
"trans =transforms.ToTensor()\n",
"mnist_train = torchvision.datasets.FashionMNIST(root=\"./data\",train=True,transform=trans,download=False)\n",
"mnist_test = torchvision.datasets.FashionMNIST(root=\"./data\",train=False,transform=trans,download=False)\n"
],
"id": "bd4e8a65ccd03177",
"outputs": [],
"execution_count": 50
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-15T06:40:57.669099255Z",
"start_time": "2026-03-15T06:40:57.639295575Z"
}
},
"cell_type": "code",
"source": [
"use_svg_display()\n",
"len(mnist_train),len(mnist_test)"
],
"id": "ed2c915af7f6a76a",
"outputs": [
{
"data": {
"text/plain": [
"(60000, 10000)"
]
},
"execution_count": 51,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 51
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-15T06:40:57.717512337Z",
"start_time": "2026-03-15T06:40:57.672873756Z"
}
},
"cell_type": "code",
"source": "mnist_train[0][0].shape",
"id": "4df1cbc292aa5981",
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([1, 28, 28])"
]
},
"execution_count": 52,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 52
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-15T06:40:57.739205016Z",
"start_time": "2026-03-15T06:40:57.719179425Z"
}
},
"cell_type": "code",
"source": [
"def get_fashion_mnist_labels(labels):\n",
" text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',\n",
"'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']\n",
" return [text_labels[int(i)] for i in labels]\n"
],
"id": "332f3d6da0bafbe8",
"outputs": [],
"execution_count": 53
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-15T06:40:57.764139878Z",
"start_time": "2026-03-15T06:40:57.754725053Z"
}
},
"cell_type": "code",
"source": [
"def show_images(imgs, num_rows, num_cols, titles=None, scale=1): #@save\n",
" \"\"\"绘制图像列表\"\"\"\n",
" figsize = (num_cols * scale, num_rows * scale)\n",
" _, axes = plt.subplots(num_rows, num_cols, figsize=figsize)\n",
" axes = axes.flatten()\n",
" for i, (ax, img) in enumerate(zip(axes, imgs)):\n",
" if torch.is_tensor(img):\n",
" # 图片张量\n",
" ax.imshow(img.numpy())\n",
" else:\n",
" # PIL图片\n",
" ax.imshow(img)\n",
" ax.axes.get_xaxis().set_visible(False)\n",
" ax.axes.get_yaxis().set_visible(False)\n",
" if titles:\n",
" ax.set_title(titles[i])\n",
" return axes"
],
"id": "c83202fe9b0ab487",
"outputs": [],
"execution_count": 54
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-15T06:40:57.957479055Z",
"start_time": "2026-03-15T06:40:57.765366722Z"
}
},
"cell_type": "code",
"source": [
"X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))\n",
"print(X.shape)\n",
"show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));"
],
"id": "cf4abd8370d55416",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([18, 1, 28, 28])\n"
]
},
{
"data": {
"text/plain": [
"<Figure size 900x200 with 18 Axes>"
],
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"524.634446pt\" height=\"137.375483pt\" viewBox=\"0 0 524.634446 137.375483\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2026-03-15T14:40:57.897836</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.10.8, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 137.375483 \nL 524.634446 137.375483 \nL 524.634446 0 \nL 0 0 \nz\n\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 15.234446 69.695483 \nL 62.611804 69.695483 \nL 62.611804 22.318125 \nL 15.234446 22.318125 \nz\n\"/>\n </g>\n <g clip-path=\"url(#pba22bd51fd)\">\n <image xlink:href=\"data:image/png;base64,\niVBORw0KGgoAAAANSUhEUgAAAEIAAABCCAYAAADjVADoAAAW9klEQVR4nO2b2Y9l13Xef2vtfYY71dQTm2qSokmJjiI7sQHDgSLAQeCXAHlL8mcGeQwQIQ7gIBHiwIoiiYkUDhLZZA/VNdcdzjl7WHnY596ublKKAplqP3Chq28N9wz7O2v41rf2lb+Uf218beirvoG/L/Y1EKN9DcRo/nc+w6SFSYM5IVcKImjMkIxcKbkWJIPrEpKBbGAGwwDr7ndfwd+R/e5AHCzgtdtY44iLCgT8KqEhkxaeft/j+kx7MiAhIzEjOWNnF1jXF2D+HthvBYQ1Hqs8rk74WUQ1U0nCieHubXB3l1jjCHMPAtW6ABHnnrDn0MFopgM6AkE24mLD0AbMChBmEMwTTQFhC08yARMQgC8HzUzKK4K99BaLCoOAgUaQbDCEch8vAKFjmrD8pdcxFeKDW6QHt1i8dc39P3/CZNrzen3FXHv2J0/YbyOiIF4QwOeMM0OdoJVAhjxkMEMMxIzTvubRpiEbZBMyymf9Ic/CFDMhI6SsrGJNzIqKofLiDdq4+JSVlIVoSh89ZoKIIUA6bkiPWvzGaE8zbpNwv3yCnFy9BMTOBDBMACkIixrigP0GuzfDvRlo/0iZLYTDJrLvBm65JbfcCsVwGCJQYTgxPFCLkDGilRuX8WuSZoRwQEIxIJrjYgNXg8NMSCYkUzZDhWSHSsa9BEQubkJODsuKZUcKFdkEFUMw0t6E6KawNnAJWUV46jEPZEreAuQv9d88900gHc0IDw5xC2P2zjXVXmBx3zO77bGDhD2ISJWpLOLINERaCci4eBWj1oyK4cWopDw3xtdtmVpZxUVqMcBpufZ1almnGhXDSfHQbWioZOQlIGS87Ysw5SJMySaEXK6wDd10XZEuKiYpcr+/woXMp596zk4V//MO/zcrJIJ/OajS/oThvdeo7mWqf2ZM76+5355zr7niMk542B3RhZrTfkYXPdmUbFKA0IwK1C7iJOM148cFbF+LwxlmgiEotnt/rYlKE5UkWhdwGI0GnGQUG5//FgVDMQzBd5mA30W2YLQu4iWTJ0K6I+z5jncmxyjGycnbPFnegn9/if/RGmLxXkxhuDclHjbMvqncfe8Yd5jRWY+qsUoNx8OCZWy4Dg0huxJXmnGSUIrbVppQybQvAbF7fICOiyqJTV74m4y/T6YM2SMY0RTFynE3AIQRXIyEopJJWRlyCa0+VMVdpER6SsoT3ceRaavI/dkl/W1j/c1D6LdAeGX5J7dZ/ukd7rz5iO/90U+xSvg43eLaWo77BZ+sjxiSYx2rEuNVYOIDc98zrwYqSUxdj5dM68LomhlPIqNscoUhVJLwkjFKkozmuE4t0RzRlGRKtOfvz7ks3Gs5LpkSsqICrQt4yfS5opJEQlmFhiE5rjctfXBUPlNVkUoTz/o5jYt8c3rK2/NTfvmt+/zsL94mRsVbU0Gj+AOluWtUh4ZMM+YE2RiaDBFDyajoLmFts7hqiWcnGaclYRY3Lq9ODKG4vRnUmvCSdqUuWiaYI1imNtklymiOjDCI2wVFMh0TpGBmxKxkKdUiU0J0+1XOoagZfvSiMVvhNdFoxDWQFhUpOrx9+y20Md76B0tm3/kQ9cZP198Y03tJUq+3F8xdzzrVnA5zgildrsqNmdKliiCuuCgGtABMNDBxxVtmbsBpopFIJakAJUY2Yd+tMYQDXTPTni5XXOeWiLJMLcEcl2nCMrUll/hEMuE8zOiSJ2RHyMWjDFA12jrgfaJxibYKTF3g9ekFrUYmLjCYZ/BKWECMhrejBTSZxdEFd29dcBEmPOn2AWOv6mhcYu57bldLlqkh4BiyJwXFcnmqyQTQMXcYybY+UcoYCk4KCdt+qWQcGQTaMWne9dcc6Jq11UzSQMDRaKI3TzBPl2sqSUzcUDxmEIa8BWKbtIu3Vi7h1GhcpHGRiR/Yrza0LqKUB5AVcgVZwA97HtckrC4kZOF76slZqfljoulyxdNhj4zQjJm91YCZMHMdM+0xlEB5Kk/7PVapwUtm0IiYsc41HseSFjOh1YGZ62klcscv8SROwpyP0x0mOrDn17QS2feXCHBXV1ynmojS4wnmqIh0qWKw8nCyCcEcYtBIwJNK6KrRaOSwXlNJopGSwy6HfabHmRAFP+w7XEVpjoC577hXD2SkeIA5rmPLWZrRauCg2uAlMakCXhJHbsmhXzGY5ypN2OSak2FOlzxeMrUWzrZJFSqeVWoYsmffbwodd5lbbkUrkV9sXuOnm2/wRnPGd/2GVgIP/IqpRDoTehOuc8PjtMdgjn23KcDkii5XZAoQDuOWv2aqfckdY44pOSuz0I5WBp4Od5kdZ4ag+FQLUsOs6rjtr0vpMlfYGUYlibnraCXQamDfbcYFRrwkZjrQSqCRTAt0GjmploVDCWOiEyoppdZUqCWy5zbccmt8yDx8epfUK4+721wNezxtYbJITDRyVm1oNBHrTKxKKGUHNuYvbxDHZK3Ijtj5kYkKGRlDRuR5IhdAQ8KvArk3/LAnSAX3Z+f88eQhnw2HvL/5BkYhIa0E7lUdU+1pJDLTrpxsJEetBFoJTARu1aXutxJ5Eic8Gfb5fDgEM2auo9GIsEaA+/6Kt6tTnl4c8u/+wz/l0dMjNlrRq+f
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
}
],
"execution_count": 55
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-15T06:41:00.262791420Z",
"start_time": "2026-03-15T06:40:57.970609683Z"
}
},
"cell_type": "code",
"source": [
"batch_size = 256\n",
"def get_dataloader_workers():\n",
" \"\"\"使用4个进程来读取数据\"\"\"\n",
" return 4\n",
"\n",
"train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,\n",
"num_workers=get_dataloader_workers())\n",
"timer = Timer()\n",
"for X, y in train_iter:\n",
" continue\n",
"f'{timer.stop():.2f} sec'"
],
"id": "552769fbffc16142",
"outputs": [
{
"data": {
"text/plain": [
"'2.26 sec'"
]
},
"execution_count": 56,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 56
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-15T06:41:00.481904106Z",
"start_time": "2026-03-15T06:41:00.308644956Z"
}
},
"cell_type": "code",
"source": [
"def load_data_fashion_mnist(batch_size, resize=None):\n",
" \"\"\"下载Fashion-MNIST数据集然后将其加载到内存中\"\"\"\n",
" trans = [transforms.ToTensor()]\n",
" if resize:\n",
" trans.insert(0, transforms.Resize(resize))\n",
" trans = transforms.Compose(trans)\n",
" mnist_train = torchvision.datasets.FashionMNIST(\n",
" root=\"./data\", train=True, transform=trans, download=False)\n",
" mnist_test = torchvision.datasets.FashionMNIST(\n",
" root=\"./data\", train=False, transform=trans, download=False)\n",
" return (data.DataLoader(mnist_train, batch_size, shuffle=True,\n",
" num_workers=get_dataloader_workers()),\n",
" data.DataLoader(mnist_test, batch_size, shuffle=False,\n",
" num_workers=get_dataloader_workers()))"
],
"id": "aa81880abd86cae6",
"outputs": [],
"execution_count": 57
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-15T06:41:02.218194285Z",
"start_time": "2026-03-15T06:41:00.632922086Z"
}
},
"cell_type": "code",
"source": [
"train_iter, test_iter = load_data_fashion_mnist(32, resize=64)\n",
"for X, y in train_iter:\n",
" print(X.shape, X.dtype, y.shape, y.dtype)\n",
" break"
],
"id": "4248a103f745154",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([32, 1, 64, 64]) torch.float32 torch.Size([32]) torch.int64\n"
2026-03-14 11:51:56 +00:00
]
}
],
2026-03-15 06:42:56 +00:00
"execution_count": 58
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-15T06:41:02.319582689Z",
"start_time": "2026-03-15T06:41:02.252523696Z"
}
},
"cell_type": "code",
"source": [
"from IPython import display\n",
"batch_size = 256\n",
"train_iter, test_iter = load_data_fashion_mnist(32)\n",
"num_inputs = 784\n",
"num_outputs = 10\n",
"W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)\n",
"b = torch.zeros(num_outputs, requires_grad=True)\n",
"X = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])\n",
"X.sum(0, keepdim=True), X.sum(1, keepdim=True)\n"
],
"id": "94c52f0cca88ef48",
"outputs": [
{
"data": {
"text/plain": [
"(tensor([[5., 7., 9.]]),\n",
" tensor([[ 6.],\n",
" [15.]]))"
]
},
"execution_count": 59,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 59
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-15T06:41:02.375197243Z",
"start_time": "2026-03-15T06:41:02.320989852Z"
}
},
"cell_type": "code",
"source": [
"def softmax(X):\n",
" X_exp = torch.exp(X)\n",
" partition = X_exp.sum(1, keepdim=True)\n",
" return X_exp / partition # 这里应用了广播机制\n",
"X = torch.normal(0, 1, (2, 5))\n",
"X_prob = softmax(X)\n",
"X_prob, X_prob.sum(1)"
],
"id": "c4ab34373c5a664e",
"outputs": [
{
"data": {
"text/plain": [
"(tensor([[0.4360, 0.2195, 0.2044, 0.0829, 0.0571],\n",
" [0.0678, 0.3243, 0.2988, 0.0572, 0.2519]]),\n",
" tensor([1.0000, 1.0000]))"
]
},
"execution_count": 60,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 60
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-15T06:41:02.400574302Z",
"start_time": "2026-03-15T06:41:02.379951465Z"
}
},
"cell_type": "code",
"source": [
"def net(X):\n",
" return softmax(torch.matmul(X.reshape((-1, W.shape[0])), W) + b)"
],
"id": "6eacc53b2b9738af",
"outputs": [],
"execution_count": 61
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-15T06:41:02.433540648Z",
"start_time": "2026-03-15T06:41:02.402525751Z"
}
},
"cell_type": "code",
"source": [
"y = torch.tensor([0, 2])\n",
"y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])\n",
"y_hat[[0, 1], y]"
],
"id": "698449b4dafb545c",
"outputs": [
{
"data": {
"text/plain": [
"tensor([0.1000, 0.5000])"
]
},
"execution_count": 62,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 62
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-15T06:41:02.473364236Z",
"start_time": "2026-03-15T06:41:02.438405972Z"
}
},
"cell_type": "code",
"source": [
"def cross_entropy(y_hat, y):\n",
" return - torch.log(y_hat[range(len(y_hat)), y])\n",
"cross_entropy(y_hat, y)"
],
"id": "1720369fc8568c8c",
"outputs": [
{
"data": {
"text/plain": [
"tensor([2.3026, 0.6931])"
]
},
"execution_count": 63,
"metadata": {},
"output_type": "execute_result"
}
],
2026-03-14 11:51:56 +00:00
"execution_count": 63
},
{
"metadata": {
"ExecuteTime": {
2026-03-15 06:42:56 +00:00
"end_time": "2026-03-15T06:41:02.495413592Z",
"start_time": "2026-03-15T06:41:02.475628534Z"
2026-03-14 11:51:56 +00:00
}
},
"cell_type": "code",
2026-03-15 06:42:56 +00:00
"source": [
"def accuracy(y_hat, y): #@save\n",
" \"\"\"计算预测正确的数量\"\"\"\n",
" if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:\n",
" y_hat = y_hat.argmax(axis=1)\n",
" cmp = y_hat.type(y.dtype) == y\n",
" return float(cmp.type(y.dtype).sum())\n",
"\n",
"accuracy(y_hat, y)/len(y)"
],
"id": "e65719500a64ed87",
"outputs": [
{
"data": {
"text/plain": [
"0.5"
]
},
"execution_count": 64,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 64
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-15T06:41:02.505437086Z",
"start_time": "2026-03-15T06:41:02.497345583Z"
}
},
"cell_type": "code",
"source": [
"class Accumulator: #@save\n",
" \"\"\"在n个变量上累加\"\"\"\n",
" def __init__(self, n):\n",
" self.data = [0.0] * n\n",
" def add(self, *args):\n",
" self.data = [a + float(b) for a, b in zip(self.data, args)]\n",
" def reset(self):\n",
" self.data = [0.0] * len(self.data)\n",
" def __getitem__(self, idx):\n",
" return self.data[idx]\n"
],
"id": "f1eebb35ff2e9fea",
"outputs": [],
"execution_count": 65
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-15T06:41:02.526181989Z",
"start_time": "2026-03-15T06:41:02.506819558Z"
}
},
"cell_type": "code",
"source": [
"def evaluate_accuracy(net, data_iter): #@save\n",
" \"\"\"计算在指定数据集上模型的精度\"\"\"\n",
" if isinstance(net, torch.nn.Module):\n",
" net.eval() # 将模型设置为评估模式\n",
" metric = Accumulator(2) # 正确预测数、预测总数\n",
" with torch.no_grad():\n",
" for X, y in data_iter:\n",
" metric.add(accuracy(net(X), y), y.numel())\n",
" return metric[0] / metric[1]\n"
],
"id": "bc2beb5f2d6afe7e",
"outputs": [],
"execution_count": 66
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-15T06:41:06.805858093Z",
"start_time": "2026-03-15T06:41:02.527858694Z"
}
},
"cell_type": "code",
"source": "evaluate_accuracy(net, test_iter)",
"id": "65bcfb7e40c1a98b",
"outputs": [
{
"data": {
"text/plain": [
"0.0498"
]
},
"execution_count": 67,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 67
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-15T06:41:06.843120784Z",
"start_time": "2026-03-15T06:41:06.827746920Z"
}
},
"cell_type": "code",
"source": [
"def train_epoch_ch3(net, train_iter, loss, updater): #@save\n",
" \"\"\"训练模型一个迭代周期定义见第3章\"\"\"\n",
" # 将模型设置为训练模式\n",
" if isinstance(net, torch.nn.Module):\n",
" net.train()\n",
" # 训练损失总和、训练准确度总和、样本数\n",
" metric = Accumulator(3)\n",
" for X, y in train_iter:\n",
" # 计算梯度并更新参数\n",
" y_hat = net(X)\n",
" l = loss(y_hat, y)\n",
" if isinstance(updater, torch.optim.Optimizer):\n",
" # 使用PyTorch内置的优化器和损失函数\n",
" updater.zero_grad()\n",
" l.mean().backward()\n",
" updater.step()\n",
" else:\n",
" # 使用定制的优化器和损失函数\n",
" l.sum().backward()\n",
" updater(X.shape[0])\n",
" metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())\n",
"# 返回训练损失和训练精度\n",
" return metric[0] / metric[2], metric[1] / metric[2]"
],
"id": "2faf1dcc6c023a53",
"outputs": [],
"execution_count": 68
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-15T06:41:06.897158404Z",
"start_time": "2026-03-15T06:41:06.851017516Z"
}
},
"cell_type": "code",
"source": [
"class Animator: #@save\n",
" \"\"\"在动画中绘制数据\"\"\"\n",
" def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,\n",
" ylim=None, xscale='linear', yscale='linear',\n",
" fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,\n",
" figsize=(3.5, 2.5)):\n",
" # 增量地绘制多条线\n",
" if legend is None:\n",
" legend = []\n",
" use_svg_display()\n",
" self.fig, self.axes = plt.subplots(nrows, ncols, figsize=figsize)\n",
" if nrows * ncols == 1:\n",
" self.axes = [self.axes, ]\n",
" # 使用lambda函数捕获参数\n",
" self.config_axes = lambda: set_axes(\n",
" self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)\n",
" self.X, self.Y, self.fmts = None, None, fmts\n",
" def add(self, x, y):\n",
" # 向图表中添加多个数据点\n",
" if not hasattr(y, \"__len__\"):\n",
" y = [y]\n",
" n = len(y)\n",
" if not hasattr(x, \"__len__\"):\n",
" x = [x] * n\n",
" if not self.X:\n",
" self.X = [[] for _ in range(n)]\n",
" if not self.Y:\n",
" self.Y = [[] for _ in range(n)]\n",
" for i, (a, b) in enumerate(zip(x, y)):\n",
" if a is not None and b is not None:\n",
" self.X[i].append(a)\n",
" self.Y[i].append(b)\n",
" self.axes[0].cla()\n",
" for x, y, fmt in zip(self.X, self.Y, self.fmts):\n",
" self.axes[0].plot(x, y, fmt)\n",
" self.config_axes()\n",
" display.display(self.fig)\n",
" display.clear_output(wait=True)\n"
],
"id": "7cd5367ab43c5e5f",
"outputs": [],
"execution_count": 69
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-15T06:41:06.913685898Z",
"start_time": "2026-03-15T06:41:06.898876289Z"
}
},
"cell_type": "code",
"source": [
"def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater): #@save\n",
" \"\"\"训练模型定义见第3章\"\"\"\n",
" animator = Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0.3, 0.9],\n",
" legend=['train loss', 'train acc', 'test acc'])\n",
" for epoch in range(num_epochs):\n",
" train_metrics = train_epoch_ch3(net, train_iter, loss, updater)\n",
" test_acc = evaluate_accuracy(net, test_iter)\n",
" animator.add(epoch + 1, train_metrics + (test_acc,))\n",
" train_loss, train_acc = train_metrics\n",
" assert train_loss < 0.5, train_loss\n",
" assert train_acc <= 1 and train_acc > 0.7, train_acc\n",
" assert test_acc <= 1 and test_acc > 0.7, test_acc"
],
"id": "b02a143c75fad40",
"outputs": [],
"execution_count": 70
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-15T06:41:06.945935582Z",
"start_time": "2026-03-15T06:41:06.916434775Z"
}
},
"cell_type": "code",
"source": [
"lr = 0.1\n",
"def updater(batch_size):\n",
" return sgd([W, b], lr, batch_size)"
],
"id": "6a97b70779276b61",
"outputs": [],
"execution_count": 71
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-15T06:41:54.907637364Z",
"start_time": "2026-03-15T06:41:06.946861739Z"
}
},
"cell_type": "code",
"source": [
"num_epochs = 10\n",
"train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, updater)"
],
"id": "df3cceb72faee402",
"outputs": [
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001B[31m---------------------------------------------------------------------------\u001B[39m",
"\u001B[31mKeyboardInterrupt\u001B[39m Traceback (most recent call last)",
"\u001B[36mCell\u001B[39m\u001B[36m \u001B[39m\u001B[32mIn[72]\u001B[39m\u001B[32m, line 2\u001B[39m\n\u001B[32m 1\u001B[39m num_epochs = \u001B[32m10\u001B[39m\n\u001B[32m----> \u001B[39m\u001B[32m2\u001B[39m \u001B[43mtrain_ch3\u001B[49m\u001B[43m(\u001B[49m\u001B[43mnet\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mtrain_iter\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mtest_iter\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcross_entropy\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mnum_epochs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mupdater\u001B[49m\u001B[43m)\u001B[49m\n",
"\u001B[36mCell\u001B[39m\u001B[36m \u001B[39m\u001B[32mIn[70]\u001B[39m\u001B[32m, line 6\u001B[39m, in \u001B[36mtrain_ch3\u001B[39m\u001B[34m(net, train_iter, test_iter, loss, num_epochs, updater)\u001B[39m\n\u001B[32m 3\u001B[39m animator = Animator(xlabel=\u001B[33m'\u001B[39m\u001B[33mepoch\u001B[39m\u001B[33m'\u001B[39m, xlim=[\u001B[32m1\u001B[39m, num_epochs], ylim=[\u001B[32m0.3\u001B[39m, \u001B[32m0.9\u001B[39m],\n\u001B[32m 4\u001B[39m legend=[\u001B[33m'\u001B[39m\u001B[33mtrain loss\u001B[39m\u001B[33m'\u001B[39m, \u001B[33m'\u001B[39m\u001B[33mtrain acc\u001B[39m\u001B[33m'\u001B[39m, \u001B[33m'\u001B[39m\u001B[33mtest acc\u001B[39m\u001B[33m'\u001B[39m])\n\u001B[32m 5\u001B[39m \u001B[38;5;28;01mfor\u001B[39;00m epoch \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28mrange\u001B[39m(num_epochs):\n\u001B[32m----> \u001B[39m\u001B[32m6\u001B[39m train_metrics = \u001B[43mtrain_epoch_ch3\u001B[49m\u001B[43m(\u001B[49m\u001B[43mnet\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mtrain_iter\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mloss\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mupdater\u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m 7\u001B[39m test_acc = evaluate_accuracy(net, test_iter)\n\u001B[32m 8\u001B[39m animator.add(epoch + \u001B[32m1\u001B[39m, train_metrics + (test_acc,))\n",
"\u001B[36mCell\u001B[39m\u001B[36m \u001B[39m\u001B[32mIn[68]\u001B[39m\u001B[32m, line 10\u001B[39m, in \u001B[36mtrain_epoch_ch3\u001B[39m\u001B[34m(net, train_iter, loss, updater)\u001B[39m\n\u001B[32m 7\u001B[39m metric = Accumulator(\u001B[32m3\u001B[39m)\n\u001B[32m 8\u001B[39m \u001B[38;5;28;01mfor\u001B[39;00m X, y \u001B[38;5;129;01min\u001B[39;00m train_iter:\n\u001B[32m 9\u001B[39m \u001B[38;5;66;03m# 计算梯度并更新参数\u001B[39;00m\n\u001B[32m---> \u001B[39m\u001B[32m10\u001B[39m y_hat = \u001B[43mnet\u001B[49m\u001B[43m(\u001B[49m\u001B[43mX\u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m 11\u001B[39m l = loss(y_hat, y)\n\u001B[32m 12\u001B[39m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28misinstance\u001B[39m(updater, torch.optim.Optimizer):\n\u001B[32m 13\u001B[39m \u001B[38;5;66;03m# 使用PyTorch内置的优化器和损失函数\u001B[39;00m\n",
"\u001B[36mCell\u001B[39m\u001B[36m \u001B[39m\u001B[32mIn[61]\u001B[39m\u001B[32m, line 2\u001B[39m, in \u001B[36mnet\u001B[39m\u001B[34m(X)\u001B[39m\n\u001B[32m 1\u001B[39m \u001B[38;5;28;01mdef\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[34mnet\u001B[39m(X):\n\u001B[32m----> \u001B[39m\u001B[32m2\u001B[39m \u001B[38;5;28;01mreturn\u001B[39;00m softmax(\u001B[43mtorch\u001B[49m\u001B[43m.\u001B[49m\u001B[43mmatmul\u001B[49m\u001B[43m(\u001B[49m\u001B[43mX\u001B[49m\u001B[43m.\u001B[49m\u001B[43mreshape\u001B[49m\u001B[43m(\u001B[49m\u001B[43m(\u001B[49m\u001B[43m-\u001B[49m\u001B[32;43m1\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mW\u001B[49m\u001B[43m.\u001B[49m\u001B[43mshape\u001B[49m\u001B[43m[\u001B[49m\u001B[32;43m0\u001B[39;49m\u001B[43m]\u001B[49m\u001B[43m)\u001B[49m\u001B[43m)\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mW\u001B[49m\u001B[43m)\u001B[49m + b)\n",
"\u001B[31mKeyboardInterrupt\u001B[39m: "
]
},
{
"data": {
"text/plain": [
"<Figure size 350x250 with 1 Axes>"
],
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"238.965625pt\" height=\"183.35625pt\" viewBox=\"0 0 238.965625 183.35625\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2026-03-15T14:41:54.863039</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.10.8, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 183.35625 \nL 238.965625 183.35625 \nL 238.965625 0 \nL 0 0 \nz\n\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 30.103125 145.8 \nL 225.403125 145.8 \nL 225.403125 7.2 \nL 30.103125 7.2 \nz\n\"/>\n </g>\n <g id=\"matplotlib.axis_1\">\n <g id=\"xtick_1\">\n <g id=\"line2d_1\">\n <path d=\"M 51.803125 145.8 \nL 51.803125 7.2 \n\" clip-path=\"url(#p5a806478f0)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_2\">\n <defs>\n <path id=\"mc7bc3224be\" d=\"M 0 0 \nL 0 3.5 \n\" style=\"stroke: #ffffff; stroke-width: 0.8\"/>\n </defs>\n <g>\n <use xlink:href=\"#mc7bc3224be\" x=\"51.803125\" y=\"145.8\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_1\">\n <!-- 2 -->\n <g style=\"fill: #ffffff\" transform=\"translate(48.621875 160.398438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-32\" d=\"M 1228 531 \nL 3431 531 \nL 3431 0 \nL 469 0 \nL 469 531 \nQ 828 903 1448 1529 \nQ 2069 2156 2228 2338 \nQ 2531 2678 2651 2914 \nQ 2772 3150 2772 3378 \nQ 2772 3750 2511 3984 \nQ 2250 4219 1831 4219 \nQ 1534 4219 1204 4116 \nQ 875 4013 500 3803 \nL 500 4441 \nQ 881 4594 1212 4672 \nQ 1544 4750 1819 4750 \nQ 2544 4750 2975 4387 \nQ 3406 4025 3406 3419 \nQ 3406 3131 3298 2873 \nQ 3191 2616 2906 2266 \nQ 2828 2175 2409 1742 \nQ 1991 1309 1228 531 \nz\n\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-32\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_2\">\n <g id=\"line2d_3\">\n <path d=\"M 95.203125 145.8 \nL 95.203125 7.2 \n\" clip-path=\"url(#p5a806478f0)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_4\">\n <g>\n <use xlink:href=\"#mc7bc3224be\" x=\"95.203125\" y=\"145.8\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_2\">\n <!-- 4 -->\n <g style=\"fill: #ffffff\" transform=\"translate(92.021875 160.398438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-34\" d=\"M 2419 4116 \nL 825 1625 \nL 2419 1625 \nL 2419 4116 \nz\nM 2253 4666 \nL 3047 4666 \nL 3047 1625 \nL 3713 1625 \nL 3713 1100 \nL 3047 1100 \nL 3047 0 \nL 2419 0 \nL 2419 1100 \nL 313 1100 \nL 313 1709 \nL 2253 4666 \nz\n\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-34\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_3\">\n <g id=\"line2d_5\">\n <path d=\"M 138.603125 145.8 \nL 138.603125 7.2 \n\" clip-path=\"url(#p5a806478f0)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_6\">\n <g>\n <use xlink:href=\"#mc7bc3224be\" x=\"138.603125\" y=\"145.8\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
}
],
"execution_count": 72
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-15T06:42:03.186617720Z",
"start_time": "2026-03-15T06:42:01.659427155Z"
}
},
"cell_type": "code",
"source": [
"def predict_ch3(net, test_iter, n=6): #@save\n",
" \"\"\"预测标签定义见第3章\"\"\"\n",
" for (X, y),i in zip(test_iter,range(1)):\n",
" trues = get_fashion_mnist_labels(y)\n",
" preds = get_fashion_mnist_labels(net(X).argmax(axis=1))\n",
" titles = [true +'\\n' + pred for true, pred in zip(trues, preds)]\n",
" show_images(\n",
" X[0:n].reshape((n, 28, 28)), 1, n, titles=titles[0:n])\n",
"\n",
"predict_ch3(net, test_iter)\n"
],
"id": "94f6177bfb40eece",
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 600x100 with 6 Axes>"
],
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"357.008839pt\" height=\"90.784071pt\" viewBox=\"0 0 357.008839 90.784071\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2026-03-15T14:42:03.146128</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.10.8, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 90.784071 \nL 357.008839 90.784071 \nL 357.008839 0 \nL 0 0 \nz\n\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 15.008839 83.584071 \nL 62.837411 83.584071 \nL 62.837411 35.7555 \nL 15.008839 35.7555 \nz\n\"/>\n </g>\n <g clip-path=\"url(#p3ffb63f14c)\">\n <image xlink:href=\"data:image/png;base64,\niVBORw0KGgoAAAANSUhEUgAAAEMAAABDCAYAAADHyrhzAAAOL0lEQVR4nO2a348cx3HHP1XdPbO7d0feURQlOZbj2EBiwHEegvgxL4H/g+TvDfKSxyCAYMBxbMN2Iiu0xCN5P/bHTHdX5aFn94bUiaINI1SALWC4u3M9M93f/lbVt2ooP5F/co4GgL7rCXyT7AjGzI5gzOwIxsyOYMzsCMbMjmDM7AjGzI5gzOwIxsyOYMzsCMbM4teOEHAREJm+T+cVRFqNJwLgyOGSu+/3mR8+2yhzwV3AQXwaYI7Y19SQCqQ2LxFHcECw/f39bhZ3d5L5D6Q6Uh38TWCIIF2HhEB5tCI/PsE6YXwgWAK9yMhZoY+F024girGKI0krC830Wl4BZQ6AuVBcuSpLRou82Kx4uVmiG6X/fUC3Rv/fV8TnG7walHI/qB9H+IcTwgN4srjhLO5Y156rvKS4si2J6spYAqVqe7ZNG1AVd1j9547TTzZQ/U1ggKSEdAl7/IDyvceUlbL5UKhLCB9vkCcj2g2E1ZouFB50xjI4p9E5CxnFD+xxb/tmKNWF0SI+RG5LT31xxvXLc+KLgP08Ea+MdFuR6wwU/CvA4EnEf7JCPoKLs1ue9JXLURiGnrEGtuOSWgPDmNjlhLtQrbGwFsVN0HDN6S8qmo0ofY8rDO935IuERiMuCiFAXAViAh5n9INbaq/oQ6V2UE8KrkZ15WZYsMXI20iisqSwIjcgxCdmTjsiYAFMlEEjVZRVl3lydkunztn3IG4qi3RD/P7AQ3Y88TXZhaelZ+uB29KzKR0P/sr47ntPOTktfLy44TztuAkDfxa3FAusS6Kasi2BoQR2FriqPdmVbekoNaDnDucrLEPUB2d4FDY/vuDqR2d0ZyMnT9Z0qXCy3LGIhT5sIOwwhJGAIdzWjp1Fcgk8XS/wqvg2QBFChjBMMSb6q67aGb40Ulf46NEVq8XIxcmGD0+vWT7JvPf9NRGDbFCdH3bX/P3yGWuP/PP6Mb8rS36zfsDt5oKPHn7OP/75T3ncbTkNSi+COVQE97YB7jC6kd25tAU/H8/ZWOJZOWNTe37/rff59Nsf4VmJvuqwKPiDiD8K2CpSTxMlKrmrRL0LQvtA5y4wKhTFc6DuIl4V2wa8KDULMk5gBPCJHQBujmulmjIMiYgTopOiUYOSg+LieGwePPQdu1XP4AHrAlIU7YWwEHQluAomQnZF0AkEn1yyuaaLoGKIO6qOMn26IQFEFVclbv7yMR5Bv2OcfHBLLYH15QqpwjicEosjBmIcHgBQi2JVMFeSt0lgd1kBF1Dw4C0bTXhYEiyDq3J5e8HLYGjcH5VuNSLBUHVEnZ+lkX/tt6galiqmRvbAg27LhsS/rL9PL4UkEPaAA9WFTekorvQh02lmY4nLsmS0wKZ2ZItsNj2yLWgWYr5YQnTkbEs6yfitkLcJHxW7FsIgDYw6PWhSJlIaQAjoXq3I3RjX9ttcmDLv9Om4KC6w3YU2JjoegTRljuANHHU+j5Xf7AqLWPjo9JplzAD0oTIS+O14juIEsSnFNyumXOcl2ZXTOLIMI3kCoboyWqCaUktEiiFZiN1/fAZRKIslpSwJFki5RdqwBS2N6pYmjbEHI0xbsNces4kcFjctdP43D+Bdy+sUQay5E353vwOoDqUErChFI78vkEJFskBWQqr0JwMSGvUBrAq1BMyEXekwF/qUSangCFUEEedBN7CMA5u+Z3eSoBNi99NPISjIx/j6HF8I3UPBZWKEQ1lA7aaF78GYTb6BcRcoLYElbwxJ84wCqDcmmiBbhTJdf3d5Gz8JplIUGxUE1tu+aZdNgE0gLAvp8Q6JdsA7j5Fxm3ATqO2h2le0MzQYMVb6WDjvdpylgbw06lnCshLFHMfQmx3h81tYBjwnPAqEANpWK3v6h5kKdcFjW/CeNT5Xqd7cac9f2Y/Zx5QZAD6NdwOqTuAKXgU3aQI42N0zVBARqis6uavgmDdR58jheRKcEAzRFljdhTodRgvC6KRAxZz4X5eEp1ewWiCPHuJdory3wpYdtROsF0yF2gMqWARTp65g6CaQJjYwBVM10Cx3G70HKjTw9sDuk41XwceA6xSj9qhOQMRkSDRKbbGoJkeqYtCyw/7ZNHkeFhVRZ7kYWfSFasJQIjaJvm3tGD1gKniQOwUqu9yODBqW0Dv0LdhIFmpWNIBkaZNNgsQpbS3Bo2D7QDq50D7wynxR4i3jSEu5Mr+GxgLwxgaT5lbcLVDVkNBcjWmn937mPis7xBsr1FFph02+6iIUU7IpNvm9OMRDKnBvkT6P2MuXEBRZXxNiQBVCkAM997HDFXzV4Y9WeBco5wtsEcknSj4RPLRY49qoIXvG5HZOvAHr0bHQZtRSeAPBgzeKx+bvIVZUjbgcoc9UF7LpQV4zuYCmpgOsKIayzcq4MVygiqBqvIgrtiUh60S8LZCFKKoNLZ9yZ6l42TS0rl9JBPearJbohUGfyB9G6klg915bfO2hdj6jffvUsq+EwYLjQV7ZZTHHU3MniYZ0je4anKBOSJWgRi6Bsu0xo+3w5HMSDFywoWXFoel/JLR7aRBuhp5djSy3gdNthSJEN7tnhYd/7ljz+t/250rFNxvIEb4w9DoSb5X+ueJ9oLvo8U4ZzxNlFad0O92qgprMMlFjkyu
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
}
],
"execution_count": 73
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-15T06:42:05.791711771Z",
"start_time": "2026-03-15T06:42:05.741066616Z"
}
},
"cell_type": "code",
"source": [
"batch_size = 256\n",
"train_iter, test_iter = load_data_fashion_mnist(batch_size)"
],
"id": "4a0bfd0479ec7386",
"outputs": [],
"execution_count": 74
},
{
"metadata": {},
"cell_type": "code",
"source": [
"net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))\n",
"def init_weights(m):\n",
" if type(m) == nn.Linear:\n",
" nn.init.normal_(m.weight, std=0.01)\n",
"\n",
"net.apply(init_weights);\n",
"loss = nn.CrossEntropyLoss(reduction='none')\n",
"trainer = torch.optim.SGD(net.parameters(), lr=0.1)\n",
"num_epochs = 10\n",
"train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)"
],
"id": "b9808d88f5e6827b",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
2026-03-14 11:51:56 +00:00
"source": "",
2026-03-15 06:42:56 +00:00
"id": "c25dd146307f58e0",
2026-03-14 11:51:56 +00:00
"outputs": [],
2026-03-15 06:42:56 +00:00
"execution_count": null
2026-03-12 07:50:33 +00:00
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}