nn/chapter1-6.ipynb

2542 lines
634 KiB
Text
Raw Normal View History

2026-03-12 07:50:33 +00:00
{
"cells": [
{
"cell_type": "code",
"id": "initial_id",
"metadata": {
"collapsed": true,
"ExecuteTime": {
"end_time": "2026-03-19T06:08:49.521494353Z",
"start_time": "2026-03-19T06:08:47.762419340Z"
2026-03-12 07:50:33 +00:00
}
},
"source": [
"import torch\n",
"import numpy\n",
2026-03-19 05:47:14 +00:00
"import pandas"
],
"outputs": [],
2026-03-14 11:51:56 +00:00
"execution_count": 1
2026-03-12 07:50:33 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:08:49.830469193Z",
"start_time": "2026-03-19T06:08:49.585762895Z"
}
},
2026-03-12 07:50:33 +00:00
"cell_type": "code",
"source": "torch.randn(3,4,2)",
"id": "3e141a42d342fa96",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[-1.0244, -0.4164],\n",
" [ 1.5765, -0.9106],\n",
" [-1.6388, -0.7727],\n",
" [-1.8594, -1.6634]],\n",
"\n",
" [[-0.3226, 0.6604],\n",
" [ 0.4341, -0.9600],\n",
" [ 0.2575, 2.0599],\n",
" [ 0.6960, 0.7095]],\n",
"\n",
" [[-0.0242, -0.5866],\n",
" [-0.8018, -0.3080],\n",
" [-1.3225, -0.0591],\n",
" [ 0.0322, 0.8251]]])"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 2
2026-03-14 11:51:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:08:49.974163555Z",
"start_time": "2026-03-19T06:08:49.860972205Z"
}
},
2026-03-14 11:51:56 +00:00
"cell_type": "code",
"source": [
"X = torch.arange(12, dtype=torch.float32).reshape((3,4))\n",
"Y = torch.tensor([[2.0, 1, 4, 3], [1, 2, 3, 4], [4, 3, 2, 1]])\n",
"torch.cat((X, Y), dim=0), torch.cat((X, Y), dim=1)"
],
"id": "8ae20ae68abbf32f",
"outputs": [
{
"data": {
"text/plain": [
"(tensor([[ 0., 1., 2., 3.],\n",
" [ 4., 5., 6., 7.],\n",
" [ 8., 9., 10., 11.],\n",
" [ 2., 1., 4., 3.],\n",
" [ 1., 2., 3., 4.],\n",
" [ 4., 3., 2., 1.]]),\n",
" tensor([[ 0., 1., 2., 3., 2., 1., 4., 3.],\n",
" [ 4., 5., 6., 7., 1., 2., 3., 4.],\n",
" [ 8., 9., 10., 11., 4., 3., 2., 1.]]))"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 3
2026-03-14 11:51:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:08:50.119655002Z",
"start_time": "2026-03-19T06:08:50.007109478Z"
}
},
2026-03-14 11:51:56 +00:00
"cell_type": "code",
"source": [
"a = torch.arange(3).reshape((3, 1))\n",
"b = torch.arange(2).reshape((1, 2))\n",
"a, b\n",
"a+b"
],
"id": "2960a1ded2cdd5a4",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0, 1],\n",
" [1, 2],\n",
" [2, 3]])"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 4
2026-03-14 11:51:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:08:50.341778101Z",
"start_time": "2026-03-19T06:08:50.149819865Z"
}
},
2026-03-14 11:51:56 +00:00
"cell_type": "code",
"source": "X[-1], X[1:3]\n",
"id": "69c2ec23ab6ae97c",
"outputs": [
{
"data": {
"text/plain": [
"(tensor([ 8., 9., 10., 11.]),\n",
" tensor([[ 4., 5., 6., 7.],\n",
" [ 8., 9., 10., 11.]]))"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 5
2026-03-14 11:51:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:08:51.169775176Z",
"start_time": "2026-03-19T06:08:50.471288979Z"
}
},
2026-03-14 11:51:56 +00:00
"cell_type": "code",
"source": [
"A = X.numpy()\n",
"B = torch.tensor(A)\n",
"type(A), type(B)"
],
"id": "b8d779a1bc7e4b1a",
"outputs": [
{
"data": {
"text/plain": [
"(numpy.ndarray, torch.Tensor)"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 6
2026-03-12 07:50:33 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:08:51.804601549Z",
"start_time": "2026-03-19T06:08:51.491609008Z"
}
},
2026-03-12 07:50:33 +00:00
"cell_type": "code",
2026-03-14 11:51:56 +00:00
"source": [
"import os\n",
"os.makedirs(os.path.join(\"..\",\"data\"),exist_ok=True)\n",
"data_file = os.path.join(os.path.join(\"..\",\"data\",\"data.csv\"))\n",
"with open(data_file, \"w\") as f:\n",
" f.write('NumRooms,Alley,Price\\n') # 列名\n",
" f.write('NA,Pave,127500\\n') # 每行表示一个数据样本\n",
" f.write('2,NA,106000\\n')\n",
" f.write('4,NA,178100\\n')\n",
" f.write('NA,NA,140000\\n')\n",
"\n"
],
"id": "82be028b0f1dd1e3",
2026-03-12 07:50:33 +00:00
"outputs": [],
"execution_count": 7
2026-03-14 11:51:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:08:51.995470564Z",
"start_time": "2026-03-19T06:08:51.819209475Z"
}
},
2026-03-14 11:51:56 +00:00
"cell_type": "code",
"source": [
"import pandas as pd\n",
"data = pd.read_csv(data_file)\n",
"print(data)\n"
],
"id": "ddd789a2656899d1",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" NumRooms Alley Price\n",
"0 NaN Pave 127500\n",
"1 2.0 NaN 106000\n",
"2 4.0 NaN 178100\n",
"3 NaN NaN 140000\n"
]
}
],
"execution_count": 8
2026-03-14 11:51:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:08:52.239162983Z",
"start_time": "2026-03-19T06:08:52.016744354Z"
}
},
2026-03-14 11:51:56 +00:00
"cell_type": "code",
"source": [
"inputs, outputs = data.iloc[:, 0:2], data.iloc[:, 2]\n",
"\n",
"\n",
"inputs = pd.get_dummies(inputs, dummy_na=True)\n",
"print(inputs)\n",
"inputs = inputs.fillna(inputs.mean())\n",
"print(inputs)\n"
],
"id": "e98fcc3bd4f067cf",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" NumRooms Alley_Pave Alley_nan\n",
"0 NaN True False\n",
"1 2.0 False True\n",
"2 4.0 False True\n",
"3 NaN False True\n",
" NumRooms Alley_Pave Alley_nan\n",
"0 3.0 True False\n",
"1 2.0 False True\n",
"2 4.0 False True\n",
"3 3.0 False True\n"
]
}
],
"execution_count": 9
2026-03-14 11:51:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:08:52.439862351Z",
"start_time": "2026-03-19T06:08:52.252020096Z"
}
},
2026-03-14 11:51:56 +00:00
"cell_type": "code",
"source": [
"X = torch.tensor(inputs.to_numpy(dtype=float))\n",
"y = torch.tensor(outputs.to_numpy(dtype=float))\n",
"X, y\n"
],
"id": "8ff0f7b40f0e4996",
"outputs": [
{
"data": {
"text/plain": [
"(tensor([[3., 1., 0.],\n",
" [2., 0., 1.],\n",
" [4., 0., 1.],\n",
" [3., 0., 1.]], dtype=torch.float64),\n",
" tensor([127500., 106000., 178100., 140000.], dtype=torch.float64))"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 10
2026-03-14 11:51:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:08:52.613938268Z",
"start_time": "2026-03-19T06:08:52.458366888Z"
}
},
2026-03-14 11:51:56 +00:00
"cell_type": "code",
"source": [
"B=torch.tensor([[1,2,3],[2,0,4],[3,4,5]])\n",
"B"
],
"id": "91a6e0da442b95a0",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[1, 2, 3],\n",
" [2, 0, 4],\n",
" [3, 4, 5]])"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 11
2026-03-14 11:51:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:08:52.927117918Z",
"start_time": "2026-03-19T06:08:52.661665190Z"
}
},
2026-03-14 11:51:56 +00:00
"cell_type": "code",
"source": "B==B.T",
"id": "297e6a678fb19be7",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[True, True, True],\n",
" [True, True, True],\n",
" [True, True, True]])"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 12
2026-03-14 11:51:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:08:53.355336655Z",
"start_time": "2026-03-19T06:08:53.046923921Z"
}
},
2026-03-14 11:51:56 +00:00
"cell_type": "code",
"source": [
"X=torch.arange(24).reshape(2,3,4)\n",
"X"
],
"id": "24e864b336beb58b",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[ 0, 1, 2, 3],\n",
" [ 4, 5, 6, 7],\n",
" [ 8, 9, 10, 11]],\n",
"\n",
" [[12, 13, 14, 15],\n",
" [16, 17, 18, 19],\n",
" [20, 21, 22, 23]]])"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 13
2026-03-14 11:51:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:08:53.538512227Z",
"start_time": "2026-03-19T06:08:53.424545812Z"
}
},
2026-03-14 11:51:56 +00:00
"cell_type": "code",
"source": [
"A = torch.arange(20, dtype=torch.float32).reshape(5, 4)\n",
"B = A.clone() # 通过分配新内存将A的一个副本分配给B\n",
"A, A + B\n",
"#A = torch.arange(20, dtype=torch.float32).reshape(5, 4)\n",
"#B = A # 通过分配新内存将A的一个副本分配给B\n",
"id(A),id(B)"
],
"id": "ee0905479b1dbc2b",
"outputs": [
{
"data": {
"text/plain": [
"(140069082123952, 140069083998928)"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 14
2026-03-14 11:51:56 +00:00
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Hadamard乘积",
"id": "136459f5efe765cf"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:08:53.817503675Z",
"start_time": "2026-03-19T06:08:53.581674028Z"
}
},
2026-03-14 11:51:56 +00:00
"cell_type": "code",
"source": "A*B",
"id": "f576b0df17cc0e98",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 0., 1., 4., 9.],\n",
" [ 16., 25., 36., 49.],\n",
" [ 64., 81., 100., 121.],\n",
" [144., 169., 196., 225.],\n",
" [256., 289., 324., 361.]])"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 15
2026-03-14 11:51:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:08:54.105858002Z",
"start_time": "2026-03-19T06:08:53.885857754Z"
}
},
2026-03-14 11:51:56 +00:00
"cell_type": "code",
"source": [
"a=2\n",
"X=torch.arange(24).reshape(2,3,4)\n",
"a+X,(a*X).shape"
],
"id": "b2373af1d7f2a45",
"outputs": [
{
"data": {
"text/plain": [
"(tensor([[[ 2, 3, 4, 5],\n",
" [ 6, 7, 8, 9],\n",
" [10, 11, 12, 13]],\n",
" \n",
" [[14, 15, 16, 17],\n",
" [18, 19, 20, 21],\n",
" [22, 23, 24, 25]]]),\n",
" torch.Size([2, 3, 4]))"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 16
2026-03-14 11:51:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:08:54.282899541Z",
"start_time": "2026-03-19T06:08:54.174420931Z"
}
},
2026-03-14 11:51:56 +00:00
"cell_type": "code",
"source": [
"print(A)\n",
"A_sum_axis0=A.sum(axis=0)\n",
"A_sum_axis1=A.sum(axis=1)\n",
"A_sum_axis0,A_sum_axis1"
],
"id": "2b50246e1ca8a3bc",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[ 0., 1., 2., 3.],\n",
" [ 4., 5., 6., 7.],\n",
" [ 8., 9., 10., 11.],\n",
" [12., 13., 14., 15.],\n",
" [16., 17., 18., 19.]])\n"
]
},
{
"data": {
"text/plain": [
"(tensor([40., 45., 50., 55.]), tensor([ 6., 22., 38., 54., 70.]))"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 17
2026-03-14 11:51:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:08:54.345533969Z",
"start_time": "2026-03-19T06:08:54.290459990Z"
}
},
2026-03-14 11:51:56 +00:00
"cell_type": "code",
"source": [
"x=torch.arange(4,dtype=torch.float32)\n",
"torch.mv(A,x)"
],
"id": "3195464dfeb554ed",
"outputs": [
{
"data": {
"text/plain": [
"tensor([ 14., 38., 62., 86., 110.])"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 18
2026-03-14 11:51:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:08:54.432702369Z",
"start_time": "2026-03-19T06:08:54.361198167Z"
}
},
2026-03-14 11:51:56 +00:00
"cell_type": "code",
"source": [
"import time\n",
"\n",
"def showtime(func):\n",
" def wrapper():\n",
" start = time.time()\n",
" result = func() # 执行原始函数\n",
" end = time.time()\n",
" print(f\"执行时间: {end - start:.6f}秒\")\n",
" return result\n",
" return wrapper # 返回包装函数\n",
"\n",
"@showtime\n",
"def fun():\n",
" print(\"I am silly\")\n",
"\n",
"fun()\n"
],
"id": "ebda8c74ead3e42b",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"I am silly\n",
"执行时间: 0.000109秒\n"
]
}
],
"execution_count": 19
2026-03-14 11:51:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:08:54.696923466Z",
"start_time": "2026-03-19T06:08:54.436724383Z"
}
},
2026-03-14 11:51:56 +00:00
"cell_type": "code",
"source": "torch.norm(torch.ones((4, 9)))",
"id": "3343cc0c01d0161c",
"outputs": [
{
"data": {
"text/plain": [
"tensor(6.)"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 20
2026-03-14 11:51:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:08:54.768902496Z",
"start_time": "2026-03-19T06:08:54.719453919Z"
}
},
2026-03-14 11:51:56 +00:00
"cell_type": "code",
"source": [
"x =torch.arange(4.0,requires_grad=True)\n",
"x.grad"
],
"id": "674e2416e9417cfe",
"outputs": [],
"execution_count": 21
2026-03-14 11:51:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:08:54.891549085Z",
"start_time": "2026-03-19T06:08:54.804248422Z"
}
},
2026-03-14 11:51:56 +00:00
"cell_type": "code",
"source": [
"y=2*torch.dot(x,x)\n",
"y"
],
"id": "66c0febebcf98cde",
"outputs": [
{
"data": {
"text/plain": [
"tensor(28., grad_fn=<MulBackward0>)"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 22
2026-03-14 11:51:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:08:54.993310716Z",
"start_time": "2026-03-19T06:08:54.914195086Z"
}
},
2026-03-14 11:51:56 +00:00
"cell_type": "code",
"source": [
"y.backward()\n",
"x.grad"
],
"id": "825f2ce6c46ca4a8",
"outputs": [
{
"data": {
"text/plain": [
"tensor([ 0., 4., 8., 12.])"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 23
2026-03-14 11:51:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:08:55.184135089Z",
"start_time": "2026-03-19T06:08:54.995848170Z"
}
},
2026-03-14 11:51:56 +00:00
"cell_type": "code",
"source": [
"x.grad.zero_()\n",
"y = x.sum()\n",
"y.backward()\n",
"x.grad\n"
],
"id": "df399463515e9d3c",
"outputs": [
{
"data": {
"text/plain": [
"tensor([1., 1., 1., 1.])"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 24
2026-03-14 11:51:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:08:55.569785595Z",
"start_time": "2026-03-19T06:08:55.211139587Z"
}
},
2026-03-14 11:51:56 +00:00
"cell_type": "code",
"source": [
"# 对非标量调用backward需要传入一个gradient参数该参数指定微分函数关于self的梯度。\n",
"# 本例只想求偏导数的和所以传递一个1的梯度是合适的\n",
"x.grad.zero_()\n",
"y = x * x\n",
"# 等价于y.backward(torch.ones(len(x)))\n",
"print(y)\n",
"y.sum().backward()\n",
"x.grad"
],
"id": "f9207619bd4b3de8",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([0., 1., 4., 9.], grad_fn=<MulBackward0>)\n"
]
},
{
"data": {
"text/plain": [
"tensor([0., 2., 4., 6.])"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 25
2026-03-14 11:51:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:08:55.706654768Z",
"start_time": "2026-03-19T06:08:55.573574038Z"
}
},
2026-03-14 11:51:56 +00:00
"cell_type": "code",
"source": "torch.ones(len(x))",
"id": "409c14c230570859",
"outputs": [
{
"data": {
"text/plain": [
"tensor([1., 1., 1., 1.])"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 26
2026-03-14 11:51:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:08:55.928404498Z",
"start_time": "2026-03-19T06:08:55.717433023Z"
}
},
2026-03-14 11:51:56 +00:00
"cell_type": "code",
"source": [
"x.grad.zero_()\n",
"y=x*x\n",
"u=y.detach()\n",
"z=u*x\n",
"z.sum().backward()\n",
"x.grad==u"
],
"id": "521b948fe0683b12",
"outputs": [
{
"data": {
"text/plain": [
"tensor([True, True, True, True])"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 27
2026-03-14 11:51:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:08:56.100782679Z",
"start_time": "2026-03-19T06:08:55.980278337Z"
}
},
2026-03-14 11:51:56 +00:00
"cell_type": "code",
"source": [
"x.grad.zero_()\n",
"y.sum().backward()\n",
"x.grad==2*x"
],
"id": "b040beecf0632315",
"outputs": [
{
"data": {
"text/plain": [
"tensor([True, True, True, True])"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 28
2026-03-14 11:51:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:08:56.391506058Z",
"start_time": "2026-03-19T06:08:56.102643728Z"
}
},
2026-03-14 11:51:56 +00:00
"cell_type": "code",
"source": [
"from torch.distributions import multinomial\n",
"fair_probs=torch.ones([6])\n",
"fair_probs"
],
"id": "4e6ec763dbea5aa3",
"outputs": [
{
"data": {
"text/plain": [
"tensor([1., 1., 1., 1., 1., 1.])"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 29
2026-03-14 11:51:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:08:56.646489725Z",
"start_time": "2026-03-19T06:08:56.454875071Z"
}
},
2026-03-14 11:51:56 +00:00
"cell_type": "code",
"source": "multinomial.Multinomial(1, fair_probs).sample()",
"id": "f12d5e85bc6ab595",
"outputs": [
{
"data": {
"text/plain": [
"tensor([0., 0., 0., 1., 0., 0.])"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 30
2026-03-14 11:51:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:08:56.713816600Z",
"start_time": "2026-03-19T06:08:56.653507072Z"
}
},
2026-03-14 11:51:56 +00:00
"cell_type": "code",
"source": [
"counts = multinomial.Multinomial(10, fair_probs).sample((500,))\n",
"\n",
"cum_counts = counts.cumsum(dim=0)\n",
"cum_counts.size()"
],
"id": "b02f43376fd6f1fe",
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([500, 6])"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 31
2026-03-14 11:51:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:08:57.183010802Z",
"start_time": "2026-03-19T06:08:56.761045787Z"
}
},
2026-03-14 11:51:56 +00:00
"cell_type": "code",
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"# 假设 estimates 是你的数据张量\n",
"estimates = cum_counts / cum_counts.sum(dim=1, keepdims=True)\n",
"\n",
"# 设置图形大小 (等效于 d2l.set_figsize)\n",
"plt.figure(figsize=(6, 4.5))\n",
"\n",
"# 绘制每条概率曲线\n",
"for i in range(6):\n",
" plt.plot(estimates[:, i].numpy(),\n",
" label=f\"P(die={i + 1})\") # 使用 f-string 更简洁\n",
"\n",
"# 添加理论概率水平线\n",
"plt.axhline(y=0.167, color='black', linestyle='dashed', label='Theoretical probability')\n",
"\n",
"# 设置坐标轴标签\n",
"plt.xlabel('Groups of experiments')\n",
"plt.ylabel('Estimated probability')\n",
"\n",
"# 添加图例\n",
"plt.legend()\n",
"\n",
"# 显示图形\n",
"plt.show()\n",
"#plt.savefig('dice_probability.png', bbox_inches='tight')"
],
"id": "8b80daa4edd0b066",
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 600x450 with 1 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiEAAAGZCAYAAABfZuECAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAADPyElEQVR4nOzdeVhU1f/A8fdszLDvi4CgAiou4L6Uu2ZqLmmlpX0ty/ppZbZZWmpqfeur5ZJmpWmopWWWuZR7mmYuKYYrKAqKguw7M8x6f38goyOCoixq5/U893mYe88999wB5n7mrDJAQhAEQRAEoYbJa7sAgiAIgiD8O4kgRBAEQRCEWiGCEEEQBEEQaoUIQgRBEARBqBUiCBEEQRAEoVaIIEQQBEEQhFohghBBEARBEGqFsrYLcLfy9/enoKCgtoshCIIgCPccZ2dnUlJSbppOBCE34O/vT3Jycm0XQxAEQRDuWQEBATcNREQQcgOlNSABAQGiNkQQBEEQKsHZ2Znk5ORben6KIKQCBQUFIggRBEEQhGoiOqYKgiAIglArRBAiCIIgCEKtEEGIIAiCIAi1QvQJEQRBABwcHPDy8kImk9V2UQThriVJEpmZmWi12irJTwQhgiD8q8lkMkaNGkW3bt1quyiCcM/4448/iIqKQpKkO8pHBCGCIPyrjRo1iq5du7J69Wri4uIwmUy1XSRBuGsplUoaN27M0KFDAfjmm2/uLL+qKJQgCMK9yNHRkW7durF69Wp+++232i6OINwTzp07B8CwYcP44Ycf7qhpRnRMFQThX8vT0xOAuLi4Wi6JINxbSv9nvLy87igfEYQIgvCvVdoJVTTBCELllP7P3GlH7rsiCHnppZdITExEp9Nx4MAB2rZte0vnDRs2DEmS+OWXX8ocmz59OikpKWi1WrZv305oaGhVF1sQBEEQhDtQ60HI0KFDmTNnDtOnT6dVq1YcPXqUrVu34u3tXeF5wcHBfPrpp+zZs6fMsbfffptXX32VMWPG0L59e4qKiti6dStqtbq6bkMQBEEQhEqq9SDkjTfe4Ouvv2bZsmXExsYyZswYtFotzz33XLnnyOVyVq5cyfvvv09CQkKZ46+99hoffvghGzZs4Pjx44wcORJ/f38effTRaryTiqkdHXB0c0WuVNRaGQRB+HdZsWIFkyZNqjBNYmIi48ePt76WJIlBgwZVd9HK5enpSVpaGgEBAbVWBqHm1GoQolKpaN26NTt27LDukySJHTt20LFjx3LPmzp1Kunp6TccGlS/fn3q1Kljk2d+fj4HDx4sN087OzucnZ1ttqr29vrvmfHnFuqEhVR53oIg/PuUztEgSRJ6vZ74+HimTJmCQlHyRSciIoJ+/foxf/78SuXr5+fH5s2bq6PIALzwwgvs2rWLvLw8JEnC1dXV5nhWVhYrVqxg+vTp1VYG4e5Rq0GIl5cXSqWStLQ0m/1paWn4+fnd8JwHH3yQ559/nhdeeOGGx0vPq0yekyZNIj8/37olJydX9lZuymw0AqBQilHRgiBUjc2bN+Pn50dYWBizZ89m2rRpTJgwAYBx48axZs0aioqKKpVnWloaBoOhOooLlMxMu2XLFj766KNy00RFRTFixAjc3d2rrRzC3aHWm2Mqw8nJiW+//ZYXXniBrKysKsv3448/xsXFxbpVRzWg2VjSk1ihUlV53oIgVC07e02tbJWl1+tJS0sjKSmJr776ih07djBw4EDkcjmPP/44GzdutEnv7e3Nhg0b0Gq1JCQkMHz48DJ5Xt8cExgYyOrVq8nJySErK4t169YRHBxc+Tf1is8++4yZM2dy4MCBctOcOnWKlJQUBg8efNvXEe4Ntfq1PDMzE5PJhK+vr81+X19fUlNTy6QPCQmhfv36Nv9YcnlJHGU0GmnUqJH1vOvz8PX1JSYm5oblMBgM1Rr5A5ivDGdSiiBEEO5qdvYaPv57V61ce1K77hh0xbd9vk6nw9PTk4iICNzc3Dh8+LDN8WXLluHv70/37t0xGo3Mnz8fHx+fcvNTKpVs3bqV/fv307lzZ0wmE5MnT2bLli1ERERgNBoZPnw4ixYtqrBcffv2Ze/evZW6l7///pvOnTvf8Yycwt2tVoMQo9FIdHQ0PXv2ZP369UDJmOOePXvy+eefl0kfFxdHs2bNbPZ9+OGHODs7M378eC5evIjRaOTy5cv07NmTo0ePAuDs7Ez79u358ssvq/+mylFaEyIXzTGCIFSDnj178vDDD7NgwQKCg4MxmUykp6dbj4eFhdGvXz/atm1rDU6ef/75CidqGzZsGHK5nNGjR1v3jRo1itzcXLp168b27dvZsGEDBw8erLBst9PEnZKSQsuWLSt9nnBvqfUn4pw5c1i+fDmHDx/m77//5rXXXsPR0ZGoqCgAli9fTnJyMu+++y56vZ6TJ0/anJ+bmwtgs3/evHlMnjyZ+Ph4EhMT+eCDD0hJSWHdunU1dVtllPYJUapq/S0XBKECBl0xk9p1r7VrV0b//v0pKChApVIhl8tZtWoV06ZNY+DAgej1epu04eHh1i9+pU6fPk1OTk65+UdGRhIaGkpBQYHNfo1GQ0hICNu3b6ewsJDCwsJKlftW6HQ6HBwcqjxf4e5S60/EH3/8EW9vb2bMmIGfnx8xMTH06dPHGsEHBQVhsVgqleesWbNwdHRk8eLFuLm5sXfvXvr06VPmn7ImlTbHiJoQQbj73UmTSE3atWsXY8eOxWAwkJKSgtlsBkqauh0dHVGpVBivfAG6HU5OTkRHRzNixIgyxzIyMgCqrTnGw8PDeg3h/nVXPBEXLlzIwoULb3ise/eKv5GMGjXqhvvff/993n///TsuW1UpbY4RfUIEQagqRUVF1sXErlXa/61JkybWZum4uDjrtAilzTENGzascATKkSNHGDZsGOnp6WVqQ0pVV3NMs2bN+OOPPyp9nnBvuadGx9zLRE2IIAg1JTMzk+joaDp16mTdd+bMGTZv3syiRYto164drVq1YsmSJRWugLpy5UoyMzNZv349nTp1ol69enTt2pXPPvvMOoqwsLCQc+fOVbgVF1+tWfL19bU28wA0b96cyMhIm2DI3t6e1q1bs23btqp+a4S7jAhCaojoEyIIQk1asmRJmWaUUaNGkZKSwu7du1m7di2LFy+26bx6PZ1OR5cuXUhKSmLt2rXExsaydOlSNBoN+fn5t1WuMWPGEBMTw5IlSwD4888/iYmJYeDAgdY0gwYNIikpqdJNOMK9SRKb7ebs7CxJkiQ5OztXWZ7PzPlImn18v9Rx6OBavz+xiU1sJVtwcLC0YsUKKTg4uNbLUtWbRqORLly4IHXo0KHWy1LZbf/+/dJTTz1V6+UQW/lbRf87lXmGipqQGnK1JkT0CREEofoVFxczcuRIvLy8arsoleLp6cnatWv5/vvva7soQg0QbQM1xGwq6bUupm0XBKGm7N69u7aLUGlZWVl88skntV0MoYaImpAaYl07RtSECIIgCAIggpAaUzo6RqFU1HJJBEEQBOHuIIKQGmISNSGCIAiCYEMEITXEIvqECIIgCIINEYTUEFETIgiCIAi2RBBSQyylfULEZGWCIAiCAIggpMZYa0JEc4wgCIIgACIIqTGiJkQQhJq2YsUKJk2aVGGaxMRExo8fb30tSRKDBg2q7qKVKzw8nIsXL+Lg4FBrZRBqjghCaojJWBqEiD4hgiDcuaioKCRJQpIk9Ho98fHxTJkyBYWiZBqAiIgI+vXrx/z58yuVr5+fH5s3b66OIuPu7s78+fOJi4tDq9Vy4cIFPvvsM1xcXKxpYmNjOXDgAG+88Ua1lEG4u4ggpIZYa0JEc4wgCFVk8+bN+Pn5ERYWxuzZs5k2bRoTJkwAYNy4caxZs4aioqJK5ZmWlobBYKiO4uLv74+/vz9vvfUWzZo149lnn6VPnz4sXbrUJl1UVBRjx461BlTC/UsEITXk6ugYEYQIwt3OwUFdK1tl6fV60tLSSEpK4quvvmLHjh0MHDgQuVzO448/zsaNG23Se3t7s2HDBrRaLQkJCQwfPrxMntc3xwQGBrJ69WpycnLIyspi3bp1BAc
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
}
],
"execution_count": 32
2026-03-14 11:51:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:08:57.312713568Z",
"start_time": "2026-03-19T06:08:57.273001160Z"
}
},
2026-03-14 11:51:56 +00:00
"cell_type": "code",
"source": [
"import numpy as np\n",
"class Timer:\n",
" \"\"\"记录多次运行时间\"\"\"\n",
" def __init__(self):\n",
" self.times = []\n",
" self.start()\n",
" def start(self):\n",
" \"\"\"启动计时器\"\"\"\n",
" self.tik = time.time()\n",
" def stop(self):\n",
" \"\"\"停止计时器并将时间记录在列表中\"\"\"\n",
" self.times.append(time.time() - self.tik)\n",
" return self.times[-1]\n",
" def avg(self):\n",
" \"\"\"返回平均时间\"\"\"\n",
" return sum(self.times) / len(self.times)\n",
" def sum(self):\n",
" \"\"\"返回时间总和\"\"\"\n",
" return sum(self.times)\n",
" def cumsum(self):\n",
" \"\"\"返回累计时间\"\"\"\n",
" return np.array(self.times).cumsum().tolist()"
],
"id": "4bdbb4999907154a",
"outputs": [],
"execution_count": 33
2026-03-14 11:51:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:08:57.504357761Z",
"start_time": "2026-03-19T06:08:57.314864132Z"
}
},
2026-03-14 11:51:56 +00:00
"cell_type": "code",
"source": [
"n = 10000\n",
"a = torch.ones([n])\n",
"b = torch.ones([n])\n",
"c=torch.zeros(n)\n",
"timer = Timer()\n",
"for i in range(n):\n",
" c[i]=a[i]+b[i]\n",
"f'{timer.stop():.5f} sec'"
],
"id": "c6f71622e2cc578a",
"outputs": [
{
"data": {
"text/plain": [
"'0.11861 sec'"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 34
2026-03-14 11:51:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:08:57.608980105Z",
"start_time": "2026-03-19T06:08:57.525327Z"
}
},
2026-03-14 11:51:56 +00:00
"cell_type": "code",
"source": [
"timer.start()\n",
"d=a+b\n",
"f'{timer.stop():.5f} sec'"
],
"id": "2578c79b1214a79f",
"outputs": [
{
"data": {
"text/plain": [
"'0.00074 sec'"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 35
2026-03-14 11:51:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:08:57.665429699Z",
"start_time": "2026-03-19T06:08:57.610507099Z"
}
},
2026-03-14 11:51:56 +00:00
"cell_type": "code",
"source": [
"import math\n",
"def normal(x, mu, sigma):\n",
" p = 1 / math.sqrt(2 * math.pi * sigma**2)\n",
" return p * np.exp(-0.5 / sigma**2 * (x - mu)**2)"
],
"id": "fd17fdbe38a5f79",
"outputs": [],
"execution_count": 36
2026-03-14 11:51:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:08:57.729971931Z",
"start_time": "2026-03-19T06:08:57.669321828Z"
}
},
2026-03-14 11:51:56 +00:00
"cell_type": "code",
"source": [
"from matplotlib_inline import backend_inline\n",
"def use_svg_display(): #@save\n",
" \"\"\"使用svg格式在Jupyter中显示绘图\"\"\"\n",
" backend_inline.set_matplotlib_formats('svg')\n",
"def set_figsize(figsize=(3.5, 2.5)): #@save\n",
" \"\"\"设置matplotlib的图表大小\"\"\"\n",
" use_svg_display()\n",
" plt.rcParams['figure.figsize'] = figsize\n",
"def set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend):\n",
" \"\"\"设置matplotlib的轴\"\"\"\n",
" axes.set_xlabel(xlabel)\n",
" axes.set_ylabel(ylabel)\n",
" axes.set_xscale(xscale)\n",
" axes.set_yscale(yscale)\n",
" axes.set_xlim(xlim)\n",
" axes.set_ylim(ylim)\n",
" if legend:\n",
" axes.legend(legend)\n",
" axes.grid()\n",
"def plot(X, Y=None, xlabel=None, ylabel=None, legend=None, xlim=None,\n",
"ylim=None, xscale='linear', yscale='linear',\n",
"fmts=('-', 'm--', 'g-.', 'r:'), figsize=(3.5, 2.5), axes=None):\n",
" \"\"\"绘制数据点\"\"\"\n",
" if legend is None:\n",
" legend = []\n",
" set_figsize(figsize)\n",
" axes = axes if axes else plt.gca()\n",
" # 如果X有一个轴输出True\n",
" def has_one_axis(X):\n",
" return (hasattr(X, \"ndim\") and X.ndim == 1 or isinstance(X, list)\n",
"and not hasattr(X[0], \"__len__\"))\n",
" if has_one_axis(X):\n",
" X = [X]\n",
" if Y is None:\n",
" X, Y = [[]] * len(X), X\n",
" elif has_one_axis(Y):\n",
" Y = [Y]\n",
" if len(X) != len(Y):\n",
" X = X * len(Y)\n",
" axes.cla()\n",
" for x, y, fmt in zip(X, Y, fmts):\n",
" if len(x):\n",
" axes.plot(x, y, fmt)\n",
" else:\n",
" axes.plot(y, fmt)\n",
" set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend)"
],
"id": "82158a69cba14da0",
"outputs": [],
"execution_count": 37
2026-03-14 11:51:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:08:58.000767154Z",
"start_time": "2026-03-19T06:08:57.732355844Z"
}
},
2026-03-14 11:51:56 +00:00
"cell_type": "code",
"source": [
"# 再次使用numpy进行可视化\n",
"x = np.arange(-7, 7, 0.01)\n",
"# 均值和标准差对\n",
"params = [(0, 1), (0, 2), (3, 1)]\n",
"plot(x, [normal(x, mu, sigma) for mu, sigma in params], xlabel='x',\n",
"ylabel='p(x)', figsize=(4.5, 2.5),\n",
"legend=[f'mean {mu}, std {sigma}' for mu, sigma in params])"
],
"id": "f69ac10ebc3d13d8",
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 450x250 with 1 Axes>"
],
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"302.08125pt\" height=\"183.35625pt\" viewBox=\"0 0 302.08125 183.35625\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2026-03-19T14:08:57.923771</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.7.2, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 183.35625 \nL 302.08125 183.35625 \nL 302.08125 0 \nL 0 0 \nz\n\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 43.78125 145.8 \nL 294.88125 145.8 \nL 294.88125 7.2 \nL 43.78125 7.2 \nz\n\"/>\n </g>\n <g id=\"matplotlib.axis_1\">\n <g id=\"xtick_1\">\n <g id=\"line2d_1\">\n <path d=\"M 71.511736 145.8 \nL 71.511736 7.2 \n\" clip-path=\"url(#p82bac51bd8)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_2\">\n <defs>\n <path id=\"m55ef5c2f68\" d=\"M 0 0 \nL 0 3.5 \n\" style=\"stroke: #ffffff; stroke-width: 0.8\"/>\n </defs>\n <g>\n <use xlink:href=\"#m55ef5c2f68\" x=\"71.511736\" y=\"145.8\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_1\">\n <!-- 6 -->\n <g style=\"fill: #ffffff\" transform=\"translate(64.140642 160.398438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-2212\" d=\"M 678 2272 \nL 4684 2272 \nL 4684 1741 \nL 678 1741 \nL 678 2272 \nz\n\" transform=\"scale(0.015625)\"/>\n <path id=\"DejaVuSans-36\" d=\"M 2113 2584 \nQ 1688 2584 1439 2293 \nQ 1191 2003 1191 1497 \nQ 1191 994 1439 701 \nQ 1688 409 2113 409 \nQ 2538 409 2786 701 \nQ 3034 994 3034 1497 \nQ 3034 2003 2786 2293 \nQ 2538 2584 2113 2584 \nz\nM 3366 4563 \nL 3366 3988 \nQ 3128 4100 2886 4159 \nQ 2644 4219 2406 4219 \nQ 1781 4219 1451 3797 \nQ 1122 3375 1075 2522 \nQ 1259 2794 1537 2939 \nQ 1816 3084 2150 3084 \nQ 2853 3084 3261 2657 \nQ 3669 2231 3669 1497 \nQ 3669 778 3244 343 \nQ 2819 -91 2113 -91 \nQ 1303 -91 875 529 \nQ 447 1150 447 2328 \nQ 447 3434 972 4092 \nQ 1497 4750 2381 4750 \nQ 2619 4750 2861 4703 \nQ 3103 4656 3366 4563 \nz\n\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-2212\"/>\n <use xlink:href=\"#DejaVuSans-36\" x=\"83.789062\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_2\">\n <g id=\"line2d_3\">\n <path d=\"M 104.145435 145.8 \nL 104.145435 7.2 \n\" clip-path=\"url(#p82bac51bd8)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_4\">\n <g>\n <use xlink:href=\"#m55ef5c2f68\" x=\"104.145435\" y=\"145.8\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_2\">\n <!-- 4 -->\n <g style=\"fill: #ffffff\" transform=\"translate(96.774342 160.398438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-34\" d=\"M 2419 4116 \nL 825 1625 \nL 2419 1625 \nL 2419 4116 \nz\nM 2253 4666 \nL 3047 4666 \nL 3047 1625 \nL 3713 1625 \nL 3713 1100 \nL 3047 1100 \nL 3047 0 \nL 2419 0 \nL 2419 1100 \nL 313 1100 \nL 313 1709 \nL 2253 4666 \nz\n\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-2212\"/>\n <use xlink:href=\"#DejaVuSans-34\" x=\"83.789062\"/>\n </g>\n </g>\n </g>\n <g id=\"xt
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
}
],
"execution_count": 38
2026-03-14 11:51:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:08:58.045153274Z",
"start_time": "2026-03-19T06:08:58.022867228Z"
}
},
2026-03-14 11:51:56 +00:00
"cell_type": "code",
"source": [
"#注意一下matmul做向量乘上矩阵的时候不用考虑转置的情况\n",
"def synthetic_data(w, b, num_examples): #@save\n",
" \"\"\"生成y=Xw+b+噪声\"\"\"\n",
" X = torch.normal(0, 1, (num_examples, len(w)))\n",
" y = torch.matmul(X, w) + b\n",
" y += torch.normal(0, 0.01, y.shape)\n",
" return X, y.reshape((-1, 1))\n"
],
"id": "7ed837bdd2b3a26d",
"outputs": [],
"execution_count": 39
2026-03-14 11:51:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:08:58.199486802Z",
"start_time": "2026-03-19T06:08:58.048763885Z"
}
},
2026-03-14 11:51:56 +00:00
"cell_type": "code",
"source": [
"true_w = torch.tensor([2, -3.4])\n",
"true_b = 4.2\n",
"features, labels = synthetic_data(true_w, true_b, 1000)"
],
"id": "5ec2e204a6fd5cb2",
"outputs": [],
"execution_count": 40
2026-03-14 11:51:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:08:58.497612094Z",
"start_time": "2026-03-19T06:08:58.215279467Z"
}
},
2026-03-14 11:51:56 +00:00
"cell_type": "code",
"source": [
"set_figsize()\n",
"plt.scatter(features[:, (1)].detach().numpy(), labels.detach().numpy(), 1)"
],
"id": "38213d46b3d9900d",
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.collections.PathCollection at 0x7f645fd37d50>"
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"text/plain": [
"<Figure size 350x250 with 1 Axes>"
],
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"229.425pt\" height=\"169.678125pt\" viewBox=\"0 0 229.425 169.678125\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2026-03-19T14:08:58.330374</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.7.2, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 169.678125 \nL 229.425 169.678125 \nL 229.425 0 \nL 0 0 \nz\n\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 26.925 145.8 \nL 222.225 145.8 \nL 222.225 7.2 \nL 26.925 7.2 \nz\n\"/>\n </g>\n <g id=\"PathCollection_1\">\n <defs>\n <path id=\"m6048dcfac9\" d=\"M 0 0.5 \nC 0.132602 0.5 0.25979 0.447317 0.353553 0.353553 \nC 0.447317 0.25979 0.5 0.132602 0.5 0 \nC 0.5 -0.132602 0.447317 -0.25979 0.353553 -0.353553 \nC 0.25979 -0.447317 0.132602 -0.5 0 -0.5 \nC -0.132602 -0.5 -0.25979 -0.447317 -0.353553 -0.353553 \nC -0.447317 -0.25979 -0.5 -0.132602 -0.5 0 \nC -0.5 0.132602 -0.447317 0.25979 -0.353553 0.353553 \nC -0.25979 0.447317 -0.132602 0.5 0 0.5 \nz\n\" style=\"stroke: #8dd3c7\"/>\n </defs>\n <g clip-path=\"url(#p567f518a09)\">\n <use xlink:href=\"#m6048dcfac9\" x=\"148.601687\" y=\"80.409012\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m6048dcfac9\" x=\"130.05288\" y=\"67.476586\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m6048dcfac9\" x=\"127.409953\" y=\"74.452727\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m6048dcfac9\" x=\"143.990382\" y=\"72.86308\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m6048dcfac9\" x=\"130.051176\" y=\"78.998008\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m6048dcfac9\" x=\"127.778786\" y=\"80.782732\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m6048dcfac9\" x=\"150.806525\" y=\"101.991179\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m6048dcfac9\" x=\"152.645613\" y=\"90.59075\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m6048dcfac9\" x=\"171.89944\" y=\"91.904379\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m6048dcfac9\" x=\"163.638463\" y=\"83.919719\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m6048dcfac9\" x=\"138.020287\" y=\"93.843834\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m6048dcfac9\" x=\"148.93979\" y=\"77.746831\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m6048dcfac9\" x=\"178.192872\" y=\"115.514121\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m6048dcfac9\" x=\"123.237041\" y=\"92.408873\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m6048dcfac9\" x=\"118.890825\" y=\"64.442598\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m6048dcfac9\" x=\"161.626012\" y=\"102.937154\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m6048dcfac9\" x=\"128.424616\" y=\"78.019978\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m6048dcfac9\" x=\"134.645774\" y=\"72.883474\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m6048dcfac9\" x=\"132.180296\" y=\"89.076328\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m6048dcfac9\" x=\"113.72
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
}
],
"execution_count": 41
2026-03-14 11:51:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:08:58.613128054Z",
"start_time": "2026-03-19T06:08:58.566128920Z"
}
},
2026-03-14 11:51:56 +00:00
"cell_type": "code",
"source": [
"w=torch.normal(0,0.01,size=(2,1),requires_grad=True)\n",
"b=torch.zeros(1,requires_grad=True)\n",
"def linreg(X, w, b):\n",
" return torch.matmul(X,w)+b\n",
"def squared_loss(y_hat,y):\n",
" return (y_hat-y.reshape(y_hat.shape))**2/2\n",
"def sgd(params,lr,batch_size):\n",
" with torch.no_grad():\n",
" for param in params:\n",
" param-=lr*param.grad/batch_size\n",
" param.grad.zero_()\n",
"lr = 0.03\n",
"num_epochs =20\n",
"net = linreg\n",
"loss = squared_loss"
],
"id": "12166e1bc3ddd695",
"outputs": [],
"execution_count": 42
2026-03-14 11:51:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:08:58.733166644Z",
"start_time": "2026-03-19T06:08:58.676875700Z"
}
},
2026-03-14 11:51:56 +00:00
"cell_type": "code",
"source": [
"import random\n",
"def data_iter(batch_size, features, labels):\n",
" num_examples = len(features)\n",
" indices = list(range(num_examples))\n",
" # 这些样本是随机读取的,没有特定的顺序\n",
" random.shuffle(indices)\n",
" for i in range(0, num_examples, batch_size):\n",
" batch_indices = torch.tensor(\n",
" indices[i: min(i + batch_size, num_examples)])\n",
" yield features[batch_indices], labels[batch_indices]"
],
"id": "f3b7ee9f326bc687",
"outputs": [],
"execution_count": 43
2026-03-14 11:51:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:08:58.810488693Z",
"start_time": "2026-03-19T06:08:58.736278875Z"
}
},
2026-03-14 11:51:56 +00:00
"cell_type": "code",
"source": [
"batch_size =10\n",
"for X,y in data_iter(batch_size, features, labels):\n",
" print(X,'\\n',y)\n",
" break"
],
"id": "f386e12d65afff2e",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[ 0.0169, -0.3729],\n",
" [ 0.2105, -1.0088],\n",
" [-1.3548, 0.9556],\n",
" [-0.0298, 0.4827],\n",
" [ 1.5137, -2.4433],\n",
" [ 0.0029, 0.6444],\n",
" [ 0.5705, 1.1589],\n",
" [ 0.3421, -0.5686],\n",
" [-0.8094, -0.8650],\n",
" [ 0.3897, -1.3542]]) \n",
" tensor([[ 5.5058],\n",
" [ 8.0436],\n",
" [-1.7624],\n",
" [ 2.4937],\n",
" [15.5411],\n",
" [ 2.0171],\n",
" [ 1.3941],\n",
" [ 6.8340],\n",
" [ 5.5124],\n",
" [ 9.5840]])\n"
]
}
],
"execution_count": 44
2026-03-14 11:51:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:08:59.082220079Z",
"start_time": "2026-03-19T06:08:58.813592858Z"
}
},
2026-03-14 11:51:56 +00:00
"cell_type": "code",
"source": [
"for epoch in range(num_epochs):\n",
" for X, y in data_iter(batch_size, features, labels):\n",
" l=loss(net(X, w, b), y)\n",
" l.sum().backward()\n",
" sgd([w,b],lr,batch_size)\n",
" with torch.no_grad():\n",
" train_l =loss(net(features, w, b), labels)\n",
" print(f'epoch {epoch+1}, train loss: {float(train_l.mean()):3f}')"
],
"id": "8888ab6adcec36f1",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 1, train loss: 0.040479\n",
"epoch 2, train loss: 0.000158\n",
"epoch 3, train loss: 0.000052\n",
"epoch 4, train loss: 0.000052\n",
"epoch 5, train loss: 0.000052\n",
"epoch 6, train loss: 0.000052\n",
"epoch 7, train loss: 0.000052\n",
"epoch 8, train loss: 0.000052\n",
"epoch 9, train loss: 0.000052\n",
"epoch 10, train loss: 0.000052\n",
"epoch 11, train loss: 0.000052\n",
"epoch 12, train loss: 0.000052\n",
"epoch 13, train loss: 0.000052\n",
"epoch 14, train loss: 0.000052\n",
"epoch 15, train loss: 0.000052\n",
"epoch 16, train loss: 0.000052\n",
"epoch 17, train loss: 0.000052\n",
"epoch 18, train loss: 0.000052\n",
"epoch 19, train loss: 0.000052\n",
"epoch 20, train loss: 0.000052\n"
]
}
],
"execution_count": 45
2026-03-14 11:51:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:08:59.178712680Z",
"start_time": "2026-03-19T06:08:59.105484540Z"
}
},
2026-03-14 11:51:56 +00:00
"cell_type": "code",
"source": [
"print(f'w的估计误差: {true_w - w.reshape(true_w.shape)}')\n",
"print(f'b的估计误差: {true_b - b}')"
],
"id": "8199439fa7f26309",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"w的估计误差: tensor([ 0.0010, -0.0003], grad_fn=<SubBackward0>)\n",
"b的估计误差: tensor([0.0002], grad_fn=<RsubBackward1>)\n"
]
}
],
"execution_count": 46
2026-03-14 11:51:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:08:59.278946646Z",
"start_time": "2026-03-19T06:08:59.180914818Z"
}
},
2026-03-14 11:51:56 +00:00
"cell_type": "code",
"source": [
"from torch.utils import data\n",
"true_w = torch.tensor([2,-3.4])\n",
"true_b = 4.2\n",
"features,labels=synthetic_data(true_w, true_b, 1000)\n",
"def load_array(data_arrays,batch_size,is_train=True):\n",
" dataset = data.TensorDataset(*data_arrays)\n",
" return data.DataLoader(dataset,batch_size,shuffle=is_train)\n",
"batch_size = 10\n",
"data_iter = load_array((features,labels),batch_size)"
],
"id": "560d537dcbb5a335",
"outputs": [],
"execution_count": 47
2026-03-14 11:51:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:08:59.328289201Z",
"start_time": "2026-03-19T06:08:59.284666591Z"
}
},
2026-03-14 11:51:56 +00:00
"cell_type": "code",
"source": [
"from torch import nn\n",
"net = nn.Sequential(nn.Linear(2, 1))\n",
"net[0].weight.data.normal_(0,0.001)\n",
"net[0].bias.data.fill_(0)"
],
"id": "c54fe059d6fd20de",
"outputs": [
{
"data": {
"text/plain": [
"tensor([0.])"
]
},
"execution_count": 48,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 48
2026-03-14 11:51:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:09:00.631251025Z",
"start_time": "2026-03-19T06:08:59.342163794Z"
}
},
2026-03-14 11:51:56 +00:00
"cell_type": "code",
"source": [
"loss = nn.MSELoss()\n",
"trainer = torch.optim.SGD(net.parameters(), lr=0.01)\n",
"num_epochs = 3\n",
"for epoch in range(num_epochs):\n",
" for X, y in data_iter:\n",
" l = loss(net(X) ,y)\n",
" trainer.zero_grad()\n",
" l.backward()\n",
" trainer.step()\n",
" l = loss(net(features), labels)\n",
" print(f'epoch {epoch + 1}, loss {l:f}')\n"
],
"id": "e8a44851125b7cc6",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 1, loss 0.546318\n",
"epoch 2, loss 0.009022\n",
"epoch 3, loss 0.000254\n"
]
}
],
"execution_count": 49
2026-03-15 06:42:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:09:00.829165029Z",
"start_time": "2026-03-19T06:09:00.673423450Z"
}
},
2026-03-15 06:42:56 +00:00
"cell_type": "code",
"source": [
"import torchvision\n",
"from torchvision import transforms\n",
"trans =transforms.ToTensor()\n",
"mnist_train = torchvision.datasets.FashionMNIST(root=\"./data\",train=True,transform=trans,download=False)\n",
"mnist_test = torchvision.datasets.FashionMNIST(root=\"./data\",train=False,transform=trans,download=False)\n"
],
"id": "bd4e8a65ccd03177",
"outputs": [],
"execution_count": 50
2026-03-15 06:42:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:09:00.996874134Z",
"start_time": "2026-03-19T06:09:00.830172776Z"
}
},
2026-03-15 06:42:56 +00:00
"cell_type": "code",
"source": [
"use_svg_display()\n",
"len(mnist_train),len(mnist_test)"
],
"id": "ed2c915af7f6a76a",
"outputs": [
{
"data": {
"text/plain": [
"(60000, 10000)"
]
},
"execution_count": 51,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 51
2026-03-15 06:42:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:09:01.048577861Z",
"start_time": "2026-03-19T06:09:01.002950237Z"
}
},
2026-03-15 06:42:56 +00:00
"cell_type": "code",
"source": "mnist_train[0][0].shape",
"id": "4df1cbc292aa5981",
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([1, 28, 28])"
]
},
"execution_count": 52,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 52
2026-03-15 06:42:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:09:01.146967321Z",
"start_time": "2026-03-19T06:09:01.075154216Z"
}
},
2026-03-15 06:42:56 +00:00
"cell_type": "code",
"source": [
"def get_fashion_mnist_labels(labels):\n",
" text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',\n",
"'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']\n",
" return [text_labels[int(i)] for i in labels]\n"
],
"id": "332f3d6da0bafbe8",
"outputs": [],
"execution_count": 53
2026-03-15 06:42:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:09:01.203692216Z",
"start_time": "2026-03-19T06:09:01.148474439Z"
}
},
2026-03-15 06:42:56 +00:00
"cell_type": "code",
"source": [
"def show_images(imgs, num_rows, num_cols, titles=None, scale=1): #@save\n",
" \"\"\"绘制图像列表\"\"\"\n",
" figsize = (num_cols * scale, num_rows * scale)\n",
" _, axes = plt.subplots(num_rows, num_cols, figsize=figsize)\n",
" axes = axes.flatten()\n",
" for i, (ax, img) in enumerate(zip(axes, imgs)):\n",
" if torch.is_tensor(img):\n",
" # 图片张量\n",
" ax.imshow(img.numpy())\n",
" else:\n",
" # PIL图片\n",
" ax.imshow(img)\n",
" ax.axes.get_xaxis().set_visible(False)\n",
" ax.axes.get_yaxis().set_visible(False)\n",
" if titles:\n",
" ax.set_title(titles[i])\n",
" return axes"
],
"id": "c83202fe9b0ab487",
"outputs": [],
"execution_count": 54
2026-03-15 06:42:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:09:01.696988026Z",
"start_time": "2026-03-19T06:09:01.208873844Z"
}
},
2026-03-15 06:42:56 +00:00
"cell_type": "code",
"source": [
"X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))\n",
"print(X.shape)\n",
"show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));"
],
"id": "cf4abd8370d55416",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([18, 1, 28, 28])\n"
]
},
{
"data": {
"text/plain": [
"<Figure size 900x200 with 18 Axes>"
],
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"524.634446pt\" height=\"137.375483pt\" viewBox=\"0 0 524.634446 137.375483\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2026-03-19T14:09:01.564207</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.7.2, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 137.375483 \nL 524.634446 137.375483 \nL 524.634446 0 \nL 0 0 \nz\n\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 15.234446 69.695483 \nL 62.611804 69.695483 \nL 62.611804 22.318125 \nL 15.234446 22.318125 \nz\n\"/>\n </g>\n <g clip-path=\"url(#pc37636f183)\">\n <image xlink:href=\"data:image/png;base64,\niVBORw0KGgoAAAANSUhEUgAAAEIAAABCCAYAAADjVADoAAAWB0lEQVR4nO2bW48cyZmeny8iIzMr69DVJ5IzzZnhzIhzsHYl2StrsVrBWBu+tw34Bxi+N3zhH+GfYPjavjAM+8I3K3vXWNsLwYBgaXd2vSOMOGdyyD6yWV2nPERGhC8iK7ub5MgCpBF1MR9AFJqVGZHxxhvf4f2y5B/KPw18bagX/QC/LfY1EJ19DURnya86gCoKpBiA1ogxIALWEpxHshRSA94TygqcgxDAB0Jd49frX8cafi32qwOxNcHf3MHnCXaSEgTMskWsox2n1NMEXQfyswqpHeIc4gJqtsCXZQTmt8B+KSAky5A0jTu/NQatQCmCFqr9IdWuwRmwIwFAVwbVBuxQsGNBNVBPC1QL4gLiIb0Yke1OwF+ZxznwHTDhqc//n3XXiX/q+sYS6ga8IzQ2srOuCW37FBAiv3hCEdRrt2kOtpi9mfH4+5Z02DAc1GTGsp0fcpAvMeIZaIsSjw8KjzDUNSNdY4Nm3ua0QffDHpVjDlcTQhB8EEIQFuucpjLd80DwAlaBF5AA8vTiN58CHsQLYq9cJJAfK0YPA0npGZxa9NqiPz3CnZ4+BcSXLB5RiNaIVvitgnI/ZXUg/P7bn/BGccad/IyJKtlP5uyqeNaN+P5TEzACBnCA7R5ad8955DI+aW7gOn9tg+aj6iaH1RYAHsF6zdzmNE6TKI+S65vlQxzMOk0bFNZpSmvwAVQ3z1k+RUJCsoqbYFaa4WmBzFKCc+BdB8RTTNB7u/hXbtFOM87fyainQr3ncTuW4VZFqlpmbcF7y1dRBBLlyFSLJj6oFk8mLVo8RhxGXD+2kstzsPYpF23RgRP/v3SGTLcoQn/txFTxXp5l7IZ9MztgbnNarRgkNo7ZAecPYDYaErxw4SLD8u8ekD05YOfnDemfvUdo2+cwYnuLxd0xq5sK9w9mfPvGIQeDGS+lF1y4AV+U28yaAed1QdUafEdtgER5RAKZbvsdTFXbf7dZjJIQ70NQdNeLxyhH0oE30PYaqFdB7DetG++hmtKGy0zg6pg38wV+Xxjqhm8UxwD8+eO7fLGYcprvc/tH6RUgREhevY3bmTB/c8Tjbyqaqed2UZIoR+lSju2EVZuxaDMaF28z2qEIiIRu0Q4lnlS7Kwu7vgAlHt1R3IXL86wl9EA5FLVPUBKw6J5pioC/4ig2bLNBk4inDYrGaXxQLNsMHwTVjVsnCQ/1NkYcRWK5MVxyb2cP7r6GLpsIhKQpF999mcff1FSv1/yjb/2Uka45t0Nql3BSj1ivduIZbKMzK0zDILEUScM4qTHKMdQ1SgKFajDi+uNhg6bypn/4zVFwnVO9aAfYoPFBokMNUIZ4fet1xyiHEY9HqF0EaaAblATaoDHK0TrFymY0XjNbD2haTaI9mWkx2vG4GpIlLa8WT7g7POHDN/Y5/v4UZSGRLEMNcuotRb3nGE1LDrInGHHM25yaOKmSy53f0FsRSDraJuJQEq4tVHXfmchXXFDkyl6juQ2aTBlMcDgUPrT4IB0wChTxk+hAr7KoDRoVAtZrXHdEPdJHIu8VQV9npA9CohyFrskyix1JB8Tbr+Nyw+xteOd3H5Bry8+WL3cOK/47yGcMtKV0htNmROs1jd/soKLxCa3X8cGuOLWBbih0ZEehmo4p7TWwAMYqOsSpXlOomioYFm6AQ7F0OTZoLtoBS5eRqkBmSnwQzpshpTPdkUh6PyESKLKGNHGkSUuetAwSy0ExI1MtI11Te4PRjtUIVAtJuz2gzTXtTst3pl8wawuOyjEAW2mMEqOkZi9ZslA5jrjwWTOg6WjrgoCA7f7e7MyGOajIjg0IRlq0BHSXTeU6evpbyQVTVbIKhpmqsSEhl4YqpNigqX1CojyFaiJjkGsgtF71cxoVfVGWtGS6pUgapqakUM2lX1IBb2J+kjQTQ5sLYlqUBMZJRVK4y0UAlTccNlv9GTfakeaRwiNdU+imp7kLiuN6QtkaEvFk3lJhqL3Biu5pninLWFXkyrKv56TiOHFj7jW3GKqaqV6Ri2VqYo5yI5kzT3M8isobbEjQ4qm8oe4YufEfQB+GE+XQeDLVsmeWZMpixJFKS+sU+ZmgG0jqicJlgkrj4se64qaZA7DoaDlvc0qXkqmWqVljxJF3A+4kS3b0sqfz2mecNiMql5AoR+0NDsfapygCS5dhg2YrKdFJwEjLrl6Ri+P9+oD3Vwe8nM341qBkqCpeSeYUAlWYUwXFIhiO2i2qYNjSqx6YKhh8UNigUeLZTxYUqsYH1SdtMdfxTFRFLhbnFcWxJ6kDicsEl0KatuyZRT+YC6o/xyNdk6mWTLXsJCuUeHKxXYSoycWSiiMXSxUMx+nk0jl1R8RIDLWFbrBeM9IVO8mSlc/4d+ffZ2FzPp7vcb4esFOUfDrZY5jU3ErnFLpmrCqGqu6dMMSQS+gSuRDQEsO3JqCIn4jvU/HNdxuzrWayDujKkzQTwadwc2vBd/L73Lc7vF/e7tlR6Jqb5oJC1aTiGOuyQzcOmIslF0uhLDuqxXf/d5RNOWymPKrjkSpUQ6Zs7xdumQveMGf8cPG7/PCP/y7Dh8SzCjwebfNo5yWCBpcFSALZjTUvb19wMJzxg+lHDFWNxqMFPC2+q5my
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
}
],
"execution_count": 55
2026-03-19 05:47:14 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:09:03.457306914Z",
"start_time": "2026-03-19T06:09:01.729639427Z"
}
},
2026-03-15 06:42:56 +00:00
"cell_type": "code",
"source": [
"batch_size = 256\n",
"def get_dataloader_workers():\n",
" \"\"\"使用4个进程来读取数据\"\"\"\n",
" return 4\n",
"\n",
"train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,\n",
"num_workers=get_dataloader_workers())\n",
"timer = Timer()\n",
"for X, y in train_iter:\n",
" continue\n",
"f'{timer.stop():.2f} sec'"
],
"id": "552769fbffc16142",
"outputs": [
{
"data": {
"text/plain": [
"'1.59 sec'"
]
},
"execution_count": 56,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 56
2026-03-15 06:42:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:09:03.736459982Z",
"start_time": "2026-03-19T06:09:03.597245687Z"
}
},
2026-03-15 06:42:56 +00:00
"cell_type": "code",
"source": [
"def load_data_fashion_mnist(batch_size, resize=None):\n",
" \"\"\"下载Fashion-MNIST数据集然后将其加载到内存中\"\"\"\n",
" trans = [transforms.ToTensor()]\n",
" if resize:\n",
" trans.insert(0, transforms.Resize(resize))\n",
" trans = transforms.Compose(trans)\n",
" mnist_train = torchvision.datasets.FashionMNIST(\n",
" root=\"./data\", train=True, transform=trans, download=False)\n",
" mnist_test = torchvision.datasets.FashionMNIST(\n",
" root=\"./data\", train=False, transform=trans, download=False)\n",
" return (data.DataLoader(mnist_train, batch_size, shuffle=True,\n",
" num_workers=get_dataloader_workers()),\n",
" data.DataLoader(mnist_test, batch_size, shuffle=False,\n",
" num_workers=get_dataloader_workers()))"
],
"id": "aa81880abd86cae6",
"outputs": [],
"execution_count": 57
2026-03-15 06:42:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:09:04.172516352Z",
"start_time": "2026-03-19T06:09:03.782776079Z"
}
},
2026-03-15 06:42:56 +00:00
"cell_type": "code",
"source": [
"train_iter, test_iter = load_data_fashion_mnist(32, resize=64)\n",
"for X, y in train_iter:\n",
" print(X.shape, X.dtype, y.shape, y.dtype)\n",
" break"
],
"id": "4248a103f745154",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([32, 1, 64, 64]) torch.float32 torch.Size([32]) torch.int64\n"
]
}
],
"execution_count": 58
2026-03-15 06:42:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:09:04.306546718Z",
"start_time": "2026-03-19T06:09:04.202047232Z"
}
},
2026-03-15 06:42:56 +00:00
"cell_type": "code",
"source": [
"from IPython import display\n",
"batch_size = 256\n",
"train_iter, test_iter = load_data_fashion_mnist(32)\n",
"num_inputs = 784\n",
"num_outputs = 10\n",
"W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)\n",
"b = torch.zeros(num_outputs, requires_grad=True)\n",
"X = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])\n",
"X.sum(0, keepdim=True), X.sum(1, keepdim=True)\n"
],
"id": "94c52f0cca88ef48",
"outputs": [
{
"data": {
"text/plain": [
"(tensor([[5., 7., 9.]]),\n",
" tensor([[ 6.],\n",
" [15.]]))"
]
},
"execution_count": 59,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 59
2026-03-15 06:42:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:09:04.394766349Z",
"start_time": "2026-03-19T06:09:04.309113610Z"
}
},
2026-03-15 06:42:56 +00:00
"cell_type": "code",
"source": [
"def softmax(X):\n",
" X_exp = torch.exp(X)\n",
" partition = X_exp.sum(1, keepdim=True)\n",
" return X_exp / partition # 这里应用了广播机制\n",
"X = torch.normal(0, 1, (2, 5))\n",
"X_prob = softmax(X)\n",
"X_prob, X_prob.sum(1)"
],
"id": "c4ab34373c5a664e",
"outputs": [
{
"data": {
"text/plain": [
"(tensor([[0.1969, 0.2421, 0.2949, 0.1017, 0.1645],\n",
" [0.2729, 0.1267, 0.0365, 0.2715, 0.2924]]),\n",
" tensor([1., 1.]))"
]
},
"execution_count": 60,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 60
2026-03-15 06:42:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:09:04.445839278Z",
"start_time": "2026-03-19T06:09:04.430535547Z"
}
},
2026-03-15 06:42:56 +00:00
"cell_type": "code",
"source": [
"def net(X):\n",
" return softmax(torch.matmul(X.reshape((-1, W.shape[0])), W) + b)"
],
"id": "6eacc53b2b9738af",
"outputs": [],
"execution_count": 61
2026-03-15 06:42:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:09:04.567489691Z",
"start_time": "2026-03-19T06:09:04.449587260Z"
}
},
2026-03-15 06:42:56 +00:00
"cell_type": "code",
"source": [
"y = torch.tensor([0, 2])\n",
"y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])\n",
"y_hat[[0, 1], y]"
],
"id": "698449b4dafb545c",
"outputs": [
{
"data": {
"text/plain": [
"tensor([0.1000, 0.5000])"
]
},
"execution_count": 62,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 62
2026-03-15 06:42:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:09:04.745616835Z",
"start_time": "2026-03-19T06:09:04.613175566Z"
}
},
2026-03-15 06:42:56 +00:00
"cell_type": "code",
"source": [
"def cross_entropy(y_hat, y):\n",
" return - torch.log(y_hat[range(len(y_hat)), y])\n",
"cross_entropy(y_hat, y)"
],
"id": "1720369fc8568c8c",
"outputs": [
{
"data": {
"text/plain": [
"tensor([2.3026, 0.6931])"
]
},
"execution_count": 63,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 63
2026-03-14 11:51:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:09:04.855442356Z",
"start_time": "2026-03-19T06:09:04.774151992Z"
}
},
2026-03-14 11:51:56 +00:00
"cell_type": "code",
2026-03-15 06:42:56 +00:00
"source": [
"def accuracy(y_hat, y): #@save\n",
" \"\"\"计算预测正确的数量\"\"\"\n",
" if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:\n",
" y_hat = y_hat.argmax(axis=1)\n",
" cmp = y_hat.type(y.dtype) == y\n",
" return float(cmp.type(y.dtype).sum())\n",
"\n",
"accuracy(y_hat, y)/len(y)"
],
"id": "e65719500a64ed87",
"outputs": [
{
"data": {
"text/plain": [
"0.5"
]
},
"execution_count": 64,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 64
2026-03-15 06:42:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:09:04.919559483Z",
"start_time": "2026-03-19T06:09:04.858709428Z"
}
},
2026-03-15 06:42:56 +00:00
"cell_type": "code",
"source": [
"class Accumulator: #@save\n",
" \"\"\"在n个变量上累加\"\"\"\n",
" def __init__(self, n):\n",
" self.data = [0.0] * n\n",
" def add(self, *args):\n",
" self.data = [a + float(b) for a, b in zip(self.data, args)]\n",
" def reset(self):\n",
" self.data = [0.0] * len(self.data)\n",
" def __getitem__(self, idx):\n",
" return self.data[idx]\n"
],
"id": "f1eebb35ff2e9fea",
"outputs": [],
"execution_count": 65
2026-03-15 06:42:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:09:04.987417283Z",
"start_time": "2026-03-19T06:09:04.924473660Z"
}
},
2026-03-15 06:42:56 +00:00
"cell_type": "code",
"source": [
"def evaluate_accuracy(net, data_iter): #@save\n",
" \"\"\"计算在指定数据集上模型的精度\"\"\"\n",
" if isinstance(net, torch.nn.Module):\n",
" net.eval() # 将模型设置为评估模式\n",
" metric = Accumulator(2) # 正确预测数、预测总数\n",
" with torch.no_grad():\n",
" for X, y in data_iter:\n",
" metric.add(accuracy(net(X), y), y.numel())\n",
" return metric[0] / metric[1]\n"
],
"id": "bc2beb5f2d6afe7e",
"outputs": [],
"execution_count": 66
2026-03-15 06:42:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:09:09.640811163Z",
"start_time": "2026-03-19T06:09:04.992201693Z"
}
},
2026-03-15 06:42:56 +00:00
"cell_type": "code",
"source": "evaluate_accuracy(net, test_iter)",
"id": "65bcfb7e40c1a98b",
"outputs": [
{
"data": {
"text/plain": [
"0.1045"
]
},
"execution_count": 67,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 67
2026-03-15 06:42:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:09:10.226880910Z",
"start_time": "2026-03-19T06:09:09.774018681Z"
}
},
2026-03-15 06:42:56 +00:00
"cell_type": "code",
"source": [
"def train_epoch_ch3(net, train_iter, loss, updater): #@save\n",
" \"\"\"训练模型一个迭代周期定义见第3章\"\"\"\n",
" # 将模型设置为训练模式\n",
" if isinstance(net, torch.nn.Module):\n",
" net.train()\n",
" # 训练损失总和、训练准确度总和、样本数\n",
" metric = Accumulator(3)\n",
" for X, y in train_iter:\n",
" # 计算梯度并更新参数\n",
" y_hat = net(X)\n",
" l = loss(y_hat, y)\n",
" if isinstance(updater, torch.optim.Optimizer):\n",
" # 使用PyTorch内置的优化器和损失函数\n",
" updater.zero_grad()\n",
" l.mean().backward()\n",
" updater.step()\n",
" else:\n",
" # 使用定制的优化器和损失函数\n",
" l.sum().backward()\n",
" updater(X.shape[0])\n",
" metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())\n",
"# 返回训练损失和训练精度\n",
" return metric[0] / metric[2], metric[1] / metric[2]"
],
"id": "2faf1dcc6c023a53",
"outputs": [],
"execution_count": 68
2026-03-15 06:42:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:09:10.673837338Z",
"start_time": "2026-03-19T06:09:10.455642059Z"
}
},
2026-03-15 06:42:56 +00:00
"cell_type": "code",
"source": [
"class Animator: #@save\n",
" \"\"\"在动画中绘制数据\"\"\"\n",
" def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,\n",
" ylim=None, xscale='linear', yscale='linear',\n",
" fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,\n",
" figsize=(3.5, 2.5)):\n",
" # 增量地绘制多条线\n",
" if legend is None:\n",
" legend = []\n",
" use_svg_display()\n",
" self.fig, self.axes = plt.subplots(nrows, ncols, figsize=figsize)\n",
" if nrows * ncols == 1:\n",
" self.axes = [self.axes, ]\n",
" # 使用lambda函数捕获参数\n",
" self.config_axes = lambda: set_axes(\n",
" self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)\n",
" self.X, self.Y, self.fmts = None, None, fmts\n",
" def add(self, x, y):\n",
" # 向图表中添加多个数据点\n",
" if not hasattr(y, \"__len__\"):\n",
" y = [y]\n",
" n = len(y)\n",
" if not hasattr(x, \"__len__\"):\n",
" x = [x] * n\n",
" if not self.X:\n",
" self.X = [[] for _ in range(n)]\n",
" if not self.Y:\n",
" self.Y = [[] for _ in range(n)]\n",
" for i, (a, b) in enumerate(zip(x, y)):\n",
" if a is not None and b is not None:\n",
" self.X[i].append(a)\n",
" self.Y[i].append(b)\n",
" self.axes[0].cla()\n",
" for x, y, fmt in zip(self.X, self.Y, self.fmts):\n",
" self.axes[0].plot(x, y, fmt)\n",
" self.config_axes()\n",
" display.display(self.fig)\n",
" display.clear_output(wait=True)\n"
],
"id": "7cd5367ab43c5e5f",
"outputs": [],
"execution_count": 69
2026-03-15 06:42:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:09:10.768738929Z",
"start_time": "2026-03-19T06:09:10.687413593Z"
}
},
2026-03-15 06:42:56 +00:00
"cell_type": "code",
"source": [
"def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater): #@save\n",
" \"\"\"训练模型定义见第3章\"\"\"\n",
" animator = Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0.3, 0.9],\n",
" legend=['train loss', 'train acc', 'test acc'])\n",
" for epoch in range(num_epochs):\n",
" train_metrics = train_epoch_ch3(net, train_iter, loss, updater)\n",
" test_acc = evaluate_accuracy(net, test_iter)\n",
" animator.add(epoch + 1, train_metrics + (test_acc,))\n",
" train_loss, train_acc = train_metrics\n",
" assert train_loss < 0.5, train_loss\n",
" assert train_acc <= 1 and train_acc > 0.7, train_acc\n",
" assert test_acc <= 1 and test_acc > 0.7, test_acc"
],
"id": "b02a143c75fad40",
"outputs": [],
"execution_count": 70
2026-03-15 06:42:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:09:10.914319299Z",
"start_time": "2026-03-19T06:09:10.840221929Z"
}
},
2026-03-15 06:42:56 +00:00
"cell_type": "code",
"source": [
"lr = 0.1\n",
"def updater(batch_size):\n",
" return sgd([W, b], lr, batch_size)"
],
"id": "6a97b70779276b61",
"outputs": [],
"execution_count": 71
2026-03-15 06:42:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:09:11.044553254Z",
"start_time": "2026-03-19T06:09:10.940821708Z"
}
},
2026-03-15 06:42:56 +00:00
"cell_type": "code",
"source": [
"num_epochs = 10\n",
2026-03-19 05:47:14 +00:00
"#train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, updater)"
2026-03-15 06:42:56 +00:00
],
"id": "df3cceb72faee402",
2026-03-19 05:47:14 +00:00
"outputs": [],
"execution_count": 72
2026-03-15 06:42:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:09:11.628097076Z",
"start_time": "2026-03-19T06:09:11.067755016Z"
}
},
2026-03-15 06:42:56 +00:00
"cell_type": "code",
"source": [
"def predict_ch3(net, test_iter, n=6): #@save\n",
" \"\"\"预测标签定义见第3章\"\"\"\n",
" for (X, y),i in zip(test_iter,range(1)):\n",
" trues = get_fashion_mnist_labels(y)\n",
" preds = get_fashion_mnist_labels(net(X).argmax(axis=1))\n",
" titles = [true +'\\n' + pred for true, pred in zip(trues, preds)]\n",
" show_images(\n",
" X[0:n].reshape((n, 28, 28)), 1, n, titles=titles[0:n])\n",
"\n",
"predict_ch3(net, test_iter)\n"
],
"id": "94f6177bfb40eece",
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 600x100 with 6 Axes>"
],
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"357.008839pt\" height=\"90.784071pt\" viewBox=\"0 0 357.008839 90.784071\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2026-03-19T14:09:11.541460</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.7.2, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 90.784071 \nL 357.008839 90.784071 \nL 357.008839 0 \nL 0 0 \nz\n\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 15.008839 83.584071 \nL 62.837411 83.584071 \nL 62.837411 35.7555 \nL 15.008839 35.7555 \nz\n\"/>\n </g>\n <g clip-path=\"url(#p15736a100e)\">\n <image xlink:href=\"data:image/png;base64,\niVBORw0KGgoAAAANSUhEUgAAAEMAAABDCAYAAADHyrhzAAAOAElEQVR4nO2a2ZIkV1KGP/dzIiKX2rsltTZkwwi0GNhwwyVzBW/Ak/BevACGwTUYBjPGYswIBCYJqdVdXV1bZmREnOPOxYlcqjeJYYYWZvmbZVd2RFREnN/df19OyR/Lnzp7AKCv+wV+SNiTsYM9GTvYk7GDPRk72JOxgz0ZO9iTsYM9GTvYk7GDPRk72JOxg/i9rhIBUUTLTwAJCqrb82vo/4Bfs/LTffPds4EbnnM5/ipoQCdNef7uc9f3Xd/7Fc/2bPjQA99BhlQ1UkX03hn5zWOsifQnNblR2jNlOBIsQq4BBascD5TPqzhxEAftQbIQF1AtnGoB86974jJRfXFO/vZxIcXyC28TfudHfPMnb9IfQZo5VoEm0F4QA0kgVo5J3j4XL8fFnNNf9tR/+0sYhleQIYLUFVLX5PvHLD44YJgJiwdKnkL7/sDRm7ccND1vzBZMwsBpvWQaBg5Cx3FclhemWCZTvMdcya50Hvl6dcIi1Xx+dY/HF4f405r+Fw31Vc3Z7SHy9BJ68JeQ0b9zxPKPbvnxm+d8ePiYB/U1T4Y536yO6S1w1U/pUqQdKlZDxF3IWXGHnAJmgsuUt/9pCp0SpWkQEeS9t8lnB1gdSPOIByFNFYvC6lTozgSroT82vHbiwUBQI5tyuZqiMuHJao6KozgijooT1e4sIEomqqGydd953dMdtiyCcWVTYit0J2dMfv8Ei+W5GFStowlCb+jgXHxc8dGDz/nk6CE/njziXrjlsprxoLlisMhVnpJMuc0Ni9TQW+C6n5BdWQ4VQw4sjqdwfIj0E6IeHiBVxcUfvsXFp0I6cOStFVWVOZh2NDFxpEYYF2BeLNylSHahGyKPFgdYVoZVhKQwCNopCHh02JEUrw2dJqo68e7ZFafNkremN7w7v6QSY/rhgIqxSA2DK787f8RPD/6VG5vwV1ef8m13xFe3JzxZzPjR2QV/9t5f8E684UShEcXcGdbeOOrFymFAeJynfNY/YGEN58Mht7nhz989o3vvhNBlosxneF3RHwrDseGzzOGsowqZed1ThTy6t2AuZFOSKe0Q6VOk7yNDF/GksFJkULQXdKCQEQTfIcMymAR6E266BoBpHJjFHsLAiQ5ENaZhAOC36nPeD7esdMl/zR5yFFdUmmlC4r3ZJZWkEjLuxX12SNgNroBTSUYxAkYlmUoyro4ruApx+fFbeBQW7wr1gyUpBW7O55CFp50ieRSjXFbkUkRI8laUapPNMdbe7yMZcfs7ANYreRBc4fHtKY8VqAyJhlbGbNYR1YiheOPP6vf4y9mnVGJEzeNiheOmpcuRv779lEYHKskbfVpr1G2eMHhgpj2VZJZWcz4ckCzQ5orOImERiG2PdonYH0csQjowTmcd14sJ1ga0F6prRXvQ9cLZZglNoyFkJ3PI9pr1Bx8z70bJR08RCF0oBEXFY8lGt31AohGCIepcxwmPF3OaKvH2/JpZLB5Ta2Zw5cvVKWHUKZWtPiUPXPZTkivz2DMPPZ1F2lxhLvQWSBZKphkykox4+NkNRKU/OOByuIc41GNqCivQoaTKXN9dsMWSnhC2YbBzzqPjAl5tnKScC2C1F09KILZmcH2P8sVdwJ1hCKQUaLuaPgWqYHRDpE+BOmbuHywIshXkwQLtUG30zF2Y1AOTmDbXqDjHdcss9uTGSQc12mei/+xfkFhxXz6iuZqTJkJ/VKxacjGkGVh11wvW+XpDxnoRIwFWjbFY+XjNuFoFouEmSBvQ5FsyxvttargxFVofQJyurRBxbFGhS2U5M7r7kRi36tB1FUNb4SaQpXjkJBOrjAajqjJVyBzXLYfVCpsYw2FAeyXijudMuFkxPa9J04BmxYJsXX1tVymW3YUH8Lp4ATqSsVkMSCphIYzCoeDlH3ay6+Z6z1KiT8oDPSvrAxKtvMp4HwAzwWysisVx3yk6pRhCg6PBUHXcBR+TgY2WdRVQH4suy/h/fEnzTcNkOmV+dozVkeFsQpoGrBZSU8jJ9UhSBAtCmkFuCilWAerFwiaolTDbxInsECzgwbeCbEAWvNdyzWjVTXwFR4MTq0yXy2KoDMtKAkK4G7OijjaGiDOZ9kzrgWxCnyLmwipXtLlCXLbZZG0UW61gtUJXHapKaGq8CeBgg6CD4gFSKi9rlWDBQYU0E9zHsKBkFmyddeSOtoiMJbuAiyD6jHu4lJxolBfd8QJRQ9XQ6ORosGNpd0GedTVxRG1TAK6rYAeyKb2FzXsCxE2Ajr5l/QCXV0gIxGVLrCoYmzJXgTDGSVBcFTuo6U4brBa6o0CuIc2FYV48wOriAfiapFKDFI9wPAgWtl6yzlAl5AyCI5UhwamqTAjGZNqTax37O8G9GKksQ9CqZBVLiqEsUmApExhJEXUu4ow2VcSFEhcDoTeihIDbaEp3sIwtFmXB19d8F3Q2Y356Ak3N5MExaRZp36gQU3IDfV1CSFIxieSSlovGCGaOjKILJbu4OoTykWiEytBQ0m0Qp6lLK9CnwKJtcBvDZoy5EKxowqBgUkLOpAh3k5HgtH1FykpohbgcU6vndQHxq225+pDwZQspEc8DYVIRVhOa64rcKKsTxSpKZzktVrcxOCVDyOuKbIzdAKKCySiWo/u7
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
}
],
"execution_count": 73
2026-03-15 06:42:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:09:11.756853549Z",
"start_time": "2026-03-19T06:09:11.685653252Z"
}
},
2026-03-15 06:42:56 +00:00
"cell_type": "code",
"source": [
"batch_size = 256\n",
"train_iter, test_iter = load_data_fashion_mnist(batch_size)"
],
"id": "4a0bfd0479ec7386",
"outputs": [],
"execution_count": 74
2026-03-15 06:42:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:09:42.908237720Z",
"start_time": "2026-03-19T06:09:42.860595602Z"
}
},
2026-03-15 06:42:56 +00:00
"cell_type": "code",
"source": [
"net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))\n",
"def init_weights(m):\n",
" if type(m) == nn.Linear:\n",
" nn.init.normal_(m.weight, std=0.01)\n",
"\n",
"net.apply(init_weights);\n",
"loss = nn.CrossEntropyLoss(reduction='none')\n",
"trainer = torch.optim.SGD(net.parameters(), lr=0.1)\n",
"num_epochs = 10\n",
"#train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)"
2026-03-15 06:42:56 +00:00
],
"id": "b9808d88f5e6827b",
"outputs": [],
"execution_count": 76
2026-03-15 06:42:56 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:09:46.372721439Z",
"start_time": "2026-03-19T06:09:46.108696173Z"
}
},
2026-03-15 06:42:56 +00:00
"cell_type": "code",
2026-03-19 05:47:14 +00:00
"source": [
"x = torch.arange(-8.0, 8.0, 0.1, requires_grad=True)\n",
"y = torch.relu(x)\n",
"plot(x.detach(), y.detach(), 'x', 'relu(x)', figsize=(5, 2.5))"
],
2026-03-15 06:42:56 +00:00
"id": "c25dd146307f58e0",
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 500x250 with 1 Axes>"
],
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"320.440625pt\" height=\"183.35625pt\" viewBox=\"0 0 320.440625 183.35625\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2026-03-19T14:09:46.256394</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.7.2, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 183.35625 \nL 320.440625 183.35625 \nL 320.440625 0 \nL 0 0 \nz\n\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 34.240625 145.8 \nL 313.240625 145.8 \nL 313.240625 7.2 \nL 34.240625 7.2 \nz\n\"/>\n </g>\n <g id=\"matplotlib.axis_1\">\n <g id=\"xtick_1\">\n <g id=\"line2d_1\">\n <path d=\"M 46.922443 145.8 \nL 46.922443 7.2 \n\" clip-path=\"url(#pee650b9715)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_2\">\n <defs>\n <path id=\"m5bb76aae00\" d=\"M 0 0 \nL 0 3.5 \n\" style=\"stroke: #ffffff; stroke-width: 0.8\"/>\n </defs>\n <g>\n <use xlink:href=\"#m5bb76aae00\" x=\"46.922443\" y=\"145.8\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_1\">\n <!-- 8 -->\n <g style=\"fill: #ffffff\" transform=\"translate(39.551349 160.398438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-2212\" d=\"M 678 2272 \nL 4684 2272 \nL 4684 1741 \nL 678 1741 \nL 678 2272 \nz\n\" transform=\"scale(0.015625)\"/>\n <path id=\"DejaVuSans-38\" d=\"M 2034 2216 \nQ 1584 2216 1326 1975 \nQ 1069 1734 1069 1313 \nQ 1069 891 1326 650 \nQ 1584 409 2034 409 \nQ 2484 409 2743 651 \nQ 3003 894 3003 1313 \nQ 3003 1734 2745 1975 \nQ 2488 2216 2034 2216 \nz\nM 1403 2484 \nQ 997 2584 770 2862 \nQ 544 3141 544 3541 \nQ 544 4100 942 4425 \nQ 1341 4750 2034 4750 \nQ 2731 4750 3128 4425 \nQ 3525 4100 3525 3541 \nQ 3525 3141 3298 2862 \nQ 3072 2584 2669 2484 \nQ 3125 2378 3379 2068 \nQ 3634 1759 3634 1313 \nQ 3634 634 3220 271 \nQ 2806 -91 2034 -91 \nQ 1263 -91 848 271 \nQ 434 634 434 1313 \nQ 434 1759 690 2068 \nQ 947 2378 1403 2484 \nz\nM 1172 3481 \nQ 1172 3119 1398 2916 \nQ 1625 2713 2034 2713 \nQ 2441 2713 2670 2916 \nQ 2900 3119 2900 3481 \nQ 2900 3844 2670 4047 \nQ 2441 4250 2034 4250 \nQ 1625 4250 1398 4047 \nQ 1172 3844 1172 3481 \nz\n\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-2212\"/>\n <use xlink:href=\"#DejaVuSans-38\" x=\"83.789062\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_2\">\n <g id=\"line2d_3\">\n <path d=\"M 78.826389 145.8 \nL 78.826389 7.2 \n\" clip-path=\"url(#pee650b9715)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_4\">\n <g>\n <use xlink:href=\"#m5bb76aae00\" x=\"78.826389\" y=\"145.8\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_2\">\n <!-- 6 -->\n <g style=\"fill: #ffffff\" transform=\"translate(71.455295 160.398438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-36\" d=\"M 2113 2584 \nQ 1688 2584 1439 2293 \nQ 1191 2003 1191 1497 \nQ 1191 994 1439 701 \nQ 1688 409 2113 409 \nQ 2538 409 2786 701 \nQ 3034 994 3034 1497 \nQ 3034 2003 2786 2293 \nQ 2538 2584 2113 2584 \nz\nM 3366 4563 \nL 3366 3988 \nQ 3128
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
}
],
"execution_count": 77
2026-03-19 05:47:14 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:09:46.713437268Z",
"start_time": "2026-03-19T06:09:46.494861144Z"
}
},
2026-03-19 05:47:14 +00:00
"cell_type": "code",
"source": [
"y.backward(torch.ones_like(x), retain_graph=True)\n",
"plot(x.detach(), x.grad, 'x', 'grad of relu', figsize=(5, 2.5))"
],
"id": "f96acd2015dccb38",
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 500x250 with 1 Axes>"
],
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"329.98125pt\" height=\"183.35625pt\" viewBox=\"0 0 329.98125 183.35625\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2026-03-19T14:09:46.640099</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.7.2, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 183.35625 \nL 329.98125 183.35625 \nL 329.98125 0 \nL 0 0 \nz\n\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 43.78125 145.8 \nL 322.78125 145.8 \nL 322.78125 7.2 \nL 43.78125 7.2 \nz\n\"/>\n </g>\n <g id=\"matplotlib.axis_1\">\n <g id=\"xtick_1\">\n <g id=\"line2d_1\">\n <path d=\"M 56.463068 145.8 \nL 56.463068 7.2 \n\" clip-path=\"url(#p50e2fe8e16)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_2\">\n <defs>\n <path id=\"m08c14e3804\" d=\"M 0 0 \nL 0 3.5 \n\" style=\"stroke: #ffffff; stroke-width: 0.8\"/>\n </defs>\n <g>\n <use xlink:href=\"#m08c14e3804\" x=\"56.463068\" y=\"145.8\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_1\">\n <!-- 8 -->\n <g style=\"fill: #ffffff\" transform=\"translate(49.091974 160.398438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-2212\" d=\"M 678 2272 \nL 4684 2272 \nL 4684 1741 \nL 678 1741 \nL 678 2272 \nz\n\" transform=\"scale(0.015625)\"/>\n <path id=\"DejaVuSans-38\" d=\"M 2034 2216 \nQ 1584 2216 1326 1975 \nQ 1069 1734 1069 1313 \nQ 1069 891 1326 650 \nQ 1584 409 2034 409 \nQ 2484 409 2743 651 \nQ 3003 894 3003 1313 \nQ 3003 1734 2745 1975 \nQ 2488 2216 2034 2216 \nz\nM 1403 2484 \nQ 997 2584 770 2862 \nQ 544 3141 544 3541 \nQ 544 4100 942 4425 \nQ 1341 4750 2034 4750 \nQ 2731 4750 3128 4425 \nQ 3525 4100 3525 3541 \nQ 3525 3141 3298 2862 \nQ 3072 2584 2669 2484 \nQ 3125 2378 3379 2068 \nQ 3634 1759 3634 1313 \nQ 3634 634 3220 271 \nQ 2806 -91 2034 -91 \nQ 1263 -91 848 271 \nQ 434 634 434 1313 \nQ 434 1759 690 2068 \nQ 947 2378 1403 2484 \nz\nM 1172 3481 \nQ 1172 3119 1398 2916 \nQ 1625 2713 2034 2713 \nQ 2441 2713 2670 2916 \nQ 2900 3119 2900 3481 \nQ 2900 3844 2670 4047 \nQ 2441 4250 2034 4250 \nQ 1625 4250 1398 4047 \nQ 1172 3844 1172 3481 \nz\n\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-2212\"/>\n <use xlink:href=\"#DejaVuSans-38\" x=\"83.789062\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_2\">\n <g id=\"line2d_3\">\n <path d=\"M 88.367014 145.8 \nL 88.367014 7.2 \n\" clip-path=\"url(#p50e2fe8e16)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_4\">\n <g>\n <use xlink:href=\"#m08c14e3804\" x=\"88.367014\" y=\"145.8\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_2\">\n <!-- 6 -->\n <g style=\"fill: #ffffff\" transform=\"translate(80.99592 160.398438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-36\" d=\"M 2113 2584 \nQ 1688 2584 1439 2293 \nQ 1191 2003 1191 1497 \nQ 1191 994 1439 701 \nQ 1688 409 2113 409 \nQ 2538 409 2786 701 \nQ 3034 994 3034 1497 \nQ 3034 2003 2786 2293 \nQ 2538 2584 2113 2584 \nz\nM 3366 4563 \nL 3366 3988 \nQ 3128 4100 288
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
}
],
"execution_count": 78
2026-03-19 05:47:14 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:09:47.267818729Z",
"start_time": "2026-03-19T06:09:47.018071215Z"
}
},
2026-03-19 05:47:14 +00:00
"cell_type": "code",
"source": [
"y = torch.sigmoid(x)\n",
"plot(x.detach(), y.detach(), 'x', 'sigmoid(x)', figsize=(5, 2.5))"
],
"id": "74013cea59cd8be3",
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 500x250 with 1 Axes>"
],
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"329.98125pt\" height=\"183.35625pt\" viewBox=\"0 0 329.98125 183.35625\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2026-03-19T14:09:47.185521</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.7.2, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 183.35625 \nL 329.98125 183.35625 \nL 329.98125 0 \nL 0 0 \nz\n\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 43.78125 145.8 \nL 322.78125 145.8 \nL 322.78125 7.2 \nL 43.78125 7.2 \nz\n\"/>\n </g>\n <g id=\"matplotlib.axis_1\">\n <g id=\"xtick_1\">\n <g id=\"line2d_1\">\n <path d=\"M 56.463068 145.8 \nL 56.463068 7.2 \n\" clip-path=\"url(#p8269f525f0)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_2\">\n <defs>\n <path id=\"ma5cdbfd291\" d=\"M 0 0 \nL 0 3.5 \n\" style=\"stroke: #ffffff; stroke-width: 0.8\"/>\n </defs>\n <g>\n <use xlink:href=\"#ma5cdbfd291\" x=\"56.463068\" y=\"145.8\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_1\">\n <!-- 8 -->\n <g style=\"fill: #ffffff\" transform=\"translate(49.091974 160.398438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-2212\" d=\"M 678 2272 \nL 4684 2272 \nL 4684 1741 \nL 678 1741 \nL 678 2272 \nz\n\" transform=\"scale(0.015625)\"/>\n <path id=\"DejaVuSans-38\" d=\"M 2034 2216 \nQ 1584 2216 1326 1975 \nQ 1069 1734 1069 1313 \nQ 1069 891 1326 650 \nQ 1584 409 2034 409 \nQ 2484 409 2743 651 \nQ 3003 894 3003 1313 \nQ 3003 1734 2745 1975 \nQ 2488 2216 2034 2216 \nz\nM 1403 2484 \nQ 997 2584 770 2862 \nQ 544 3141 544 3541 \nQ 544 4100 942 4425 \nQ 1341 4750 2034 4750 \nQ 2731 4750 3128 4425 \nQ 3525 4100 3525 3541 \nQ 3525 3141 3298 2862 \nQ 3072 2584 2669 2484 \nQ 3125 2378 3379 2068 \nQ 3634 1759 3634 1313 \nQ 3634 634 3220 271 \nQ 2806 -91 2034 -91 \nQ 1263 -91 848 271 \nQ 434 634 434 1313 \nQ 434 1759 690 2068 \nQ 947 2378 1403 2484 \nz\nM 1172 3481 \nQ 1172 3119 1398 2916 \nQ 1625 2713 2034 2713 \nQ 2441 2713 2670 2916 \nQ 2900 3119 2900 3481 \nQ 2900 3844 2670 4047 \nQ 2441 4250 2034 4250 \nQ 1625 4250 1398 4047 \nQ 1172 3844 1172 3481 \nz\n\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-2212\"/>\n <use xlink:href=\"#DejaVuSans-38\" x=\"83.789062\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_2\">\n <g id=\"line2d_3\">\n <path d=\"M 88.367014 145.8 \nL 88.367014 7.2 \n\" clip-path=\"url(#p8269f525f0)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_4\">\n <g>\n <use xlink:href=\"#ma5cdbfd291\" x=\"88.367014\" y=\"145.8\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_2\">\n <!-- 6 -->\n <g style=\"fill: #ffffff\" transform=\"translate(80.99592 160.398438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-36\" d=\"M 2113 2584 \nQ 1688 2584 1439 2293 \nQ 1191 2003 1191 1497 \nQ 1191 994 1439 701 \nQ 1688 409 2113 409 \nQ 2538 409 2786 701 \nQ 3034 994 3034 1497 \nQ 3034 2003 2786 2293 \nQ 2538 2584 2113 2584 \nz\nM 3366 4563 \nL 3366 3988 \nQ 3128 4100 288
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
}
],
"execution_count": 79
2026-03-19 05:47:14 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:09:47.710436882Z",
"start_time": "2026-03-19T06:09:47.473515430Z"
}
},
2026-03-19 05:47:14 +00:00
"cell_type": "code",
"source": [
"x.grad.data.zero_()\n",
"y.backward(torch.ones_like(x),retain_graph=True)\n",
"plot(x.detach(), x.grad, 'x', 'grad of sigmoid', figsize=(5, 2.5))"
],
"id": "6a0b4f529bf9cc5c",
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 500x250 with 1 Axes>"
],
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"336.34375pt\" height=\"183.35625pt\" viewBox=\"0 0 336.34375 183.35625\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2026-03-19T14:09:47.635692</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.7.2, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 183.35625 \nL 336.34375 183.35625 \nL 336.34375 0 \nL 0 0 \nz\n\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 50.14375 145.8 \nL 329.14375 145.8 \nL 329.14375 7.2 \nL 50.14375 7.2 \nz\n\"/>\n </g>\n <g id=\"matplotlib.axis_1\">\n <g id=\"xtick_1\">\n <g id=\"line2d_1\">\n <path d=\"M 62.825568 145.8 \nL 62.825568 7.2 \n\" clip-path=\"url(#p1199fb60d9)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_2\">\n <defs>\n <path id=\"m63c8d8144b\" d=\"M 0 0 \nL 0 3.5 \n\" style=\"stroke: #ffffff; stroke-width: 0.8\"/>\n </defs>\n <g>\n <use xlink:href=\"#m63c8d8144b\" x=\"62.825568\" y=\"145.8\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_1\">\n <!-- 8 -->\n <g style=\"fill: #ffffff\" transform=\"translate(55.454474 160.398438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-2212\" d=\"M 678 2272 \nL 4684 2272 \nL 4684 1741 \nL 678 1741 \nL 678 2272 \nz\n\" transform=\"scale(0.015625)\"/>\n <path id=\"DejaVuSans-38\" d=\"M 2034 2216 \nQ 1584 2216 1326 1975 \nQ 1069 1734 1069 1313 \nQ 1069 891 1326 650 \nQ 1584 409 2034 409 \nQ 2484 409 2743 651 \nQ 3003 894 3003 1313 \nQ 3003 1734 2745 1975 \nQ 2488 2216 2034 2216 \nz\nM 1403 2484 \nQ 997 2584 770 2862 \nQ 544 3141 544 3541 \nQ 544 4100 942 4425 \nQ 1341 4750 2034 4750 \nQ 2731 4750 3128 4425 \nQ 3525 4100 3525 3541 \nQ 3525 3141 3298 2862 \nQ 3072 2584 2669 2484 \nQ 3125 2378 3379 2068 \nQ 3634 1759 3634 1313 \nQ 3634 634 3220 271 \nQ 2806 -91 2034 -91 \nQ 1263 -91 848 271 \nQ 434 634 434 1313 \nQ 434 1759 690 2068 \nQ 947 2378 1403 2484 \nz\nM 1172 3481 \nQ 1172 3119 1398 2916 \nQ 1625 2713 2034 2713 \nQ 2441 2713 2670 2916 \nQ 2900 3119 2900 3481 \nQ 2900 3844 2670 4047 \nQ 2441 4250 2034 4250 \nQ 1625 4250 1398 4047 \nQ 1172 3844 1172 3481 \nz\n\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-2212\"/>\n <use xlink:href=\"#DejaVuSans-38\" x=\"83.789062\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_2\">\n <g id=\"line2d_3\">\n <path d=\"M 94.729514 145.8 \nL 94.729514 7.2 \n\" clip-path=\"url(#p1199fb60d9)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_4\">\n <g>\n <use xlink:href=\"#m63c8d8144b\" x=\"94.729514\" y=\"145.8\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_2\">\n <!-- 6 -->\n <g style=\"fill: #ffffff\" transform=\"translate(87.35842 160.398438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-36\" d=\"M 2113 2584 \nQ 1688 2584 1439 2293 \nQ 1191 2003 1191 1497 \nQ 1191 994 1439 701 \nQ 1688 409 2113 409 \nQ 2538 409 2786 701 \nQ 3034 994 3034 1497 \nQ 3034 2003 2786 2293 \nQ 2538 2584 2113 2584 \nz\nM 3366 4563 \nL 3366 3988 \nQ 3128 4100 288
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
}
],
"execution_count": 80
2026-03-19 05:47:14 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:09:48.096574814Z",
"start_time": "2026-03-19T06:09:48.002692660Z"
}
},
2026-03-19 05:47:14 +00:00
"cell_type": "code",
"source": [
"batch_size = 256\n",
"train_iter, test_iter = load_data_fashion_mnist(batch_size)"
],
"id": "f1de998439b1b9f",
"outputs": [],
"execution_count": 81
2026-03-19 05:47:14 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:09:51.282745933Z",
"start_time": "2026-03-19T06:09:51.217670982Z"
}
},
2026-03-19 05:47:14 +00:00
"cell_type": "code",
"source": [
"num_inputs,num_outputs,num_hiddens = 784, 10, 256\n",
"W1=nn.Parameter(torch.randn(num_inputs,num_hiddens,requires_grad=True)*0.01)\n",
"b1=nn.Parameter(torch.zeros(num_hiddens,requires_grad=True))\n",
"W2 = nn.Parameter(torch.randn(num_hiddens,num_outputs,requires_grad=True)*0.01)\n",
"b2=nn.Parameter(torch.zeros(num_outputs,requires_grad=True))\n",
"params=[W1,b1,W2,b2]\n"
],
"id": "adcea8cd4ee792a8",
"outputs": [],
"execution_count": 82
2026-03-19 05:47:14 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:09:53.068840701Z",
"start_time": "2026-03-19T06:09:52.988897182Z"
}
},
2026-03-19 05:47:14 +00:00
"cell_type": "code",
"source": [
"def relu(X):\n",
" a = torch.zeros_like(X)\n",
" return torch.max(X,a)\n",
"def net(X):\n",
" X = X.reshape((-1,num_inputs))\n",
" H = relu(X@W1+b1)\n",
" return (H@W2+b2)\n",
"loss = nn.CrossEntropyLoss(reduction='none')\n",
"num_epochs,lr=10,0.05\n",
"updater=torch.optim.SGD(params,lr=lr)\n",
"#train_ch3(net,train_iter,test_iter,loss,num_epochs,updater)\n"
],
"id": "cfd81a0f16d0c573",
"outputs": [],
"execution_count": 83
2026-03-19 05:47:14 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:09:54.454079900Z",
"start_time": "2026-03-19T06:09:53.918996975Z"
}
},
2026-03-19 05:47:14 +00:00
"cell_type": "code",
"source": "predict_ch3(net, test_iter)",
"id": "f2ed2e9cee14c28a",
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 600x100 with 6 Axes>"
],
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"357.008839pt\" height=\"90.784071pt\" viewBox=\"0 0 357.008839 90.784071\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2026-03-19T14:09:54.360205</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.7.2, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 90.784071 \nL 357.008839 90.784071 \nL 357.008839 0 \nL 0 0 \nz\n\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 15.008839 83.584071 \nL 62.837411 83.584071 \nL 62.837411 35.7555 \nL 15.008839 35.7555 \nz\n\"/>\n </g>\n <g clip-path=\"url(#p9beae8b09f)\">\n <image xlink:href=\"data:image/png;base64,\niVBORw0KGgoAAAANSUhEUgAAAEMAAABDCAYAAADHyrhzAAAOAElEQVR4nO2a2ZIkV1KGP/dzIiKX2rsltTZkwwi0GNhwwyVzBW/Ak/BevACGwTUYBjPGYswIBCYJqdVdXV1bZmREnOPOxYlcqjeJYYYWZvmbZVd2RFREnN/df19OyR/Lnzp7AKCv+wV+SNiTsYM9GTvYk7GDPRk72JOxgz0ZO9iTsYM9GTvYk7GDPRk72JOxg/i9rhIBUUTLTwAJCqrb82vo/4Bfs/LTffPds4EbnnM5/ipoQCdNef7uc9f3Xd/7Fc/2bPjQA99BhlQ1UkX03hn5zWOsifQnNblR2jNlOBIsQq4BBascD5TPqzhxEAftQbIQF1AtnGoB86974jJRfXFO/vZxIcXyC28TfudHfPMnb9IfQZo5VoEm0F4QA0kgVo5J3j4XL8fFnNNf9tR/+0sYhleQIYLUFVLX5PvHLD44YJgJiwdKnkL7/sDRm7ccND1vzBZMwsBpvWQaBg5Cx3FclhemWCZTvMdcya50Hvl6dcIi1Xx+dY/HF4f405r+Fw31Vc3Z7SHy9BJ68JeQ0b9zxPKPbvnxm+d8ePiYB/U1T4Y536yO6S1w1U/pUqQdKlZDxF3IWXGHnAJmgsuUt/9pCp0SpWkQEeS9t8lnB1gdSPOIByFNFYvC6lTozgSroT82vHbiwUBQI5tyuZqiMuHJao6KozgijooT1e4sIEomqqGydd953dMdtiyCcWVTYit0J2dMfv8Ei+W5GFStowlCb+jgXHxc8dGDz/nk6CE/njziXrjlsprxoLlisMhVnpJMuc0Ni9TQW+C6n5BdWQ4VQw4sjqdwfIj0E6IeHiBVxcUfvsXFp0I6cOStFVWVOZh2NDFxpEYYF2BeLNylSHahGyKPFgdYVoZVhKQwCNopCHh02JEUrw2dJqo68e7ZFafNkremN7w7v6QSY/rhgIqxSA2DK787f8RPD/6VG5vwV1ef8m13xFe3JzxZzPjR2QV/9t5f8E684UShEcXcGdbeOOrFymFAeJynfNY/YGEN58Mht7nhz989o3vvhNBlosxneF3RHwrDseGzzOGsowqZed1ThTy6t2AuZFOSKe0Q6VOk7yNDF/GksFJkULQXdKCQEQTfIcMymAR6E266BoBpHJjFHsLAiQ5ENaZhAOC36nPeD7esdMl/zR5yFFdUmmlC4r3ZJZWkEjLuxX12SNgNroBTSUYxAkYlmUoyro4ruApx+fFbeBQW7wr1gyUpBW7O55CFp50ieRSjXFbkUkRI8laUapPNMdbe7yMZcfs7ANYreRBc4fHtKY8VqAyJhlbGbNYR1YiheOPP6vf4y9mnVGJEzeNiheOmpcuRv779lEYHKskbfVpr1G2eMHhgpj2VZJZWcz4ckCzQ5orOImERiG2PdonYH0csQjowTmcd14sJ1ga0F6prRXvQ9cLZZglNoyFkJ3PI9pr1Bx8z70bJR08RCF0oBEXFY8lGt31AohGCIepcxwmPF3OaKvH2/JpZLB5Ta2Zw5cvVKWHUKZWtPiUPXPZTkivz2DMPPZ1F2lxhLvQWSBZKphkykox4+NkNRKU/OOByuIc41GNqCivQoaTKXN9dsMWSnhC2YbBzzqPjAl5tnKScC2C1F09KILZmcH2P8sVdwJ1hCKQUaLuaPgWqYHRDpE+BOmbuHywIshXkwQLtUG30zF2Y1AOTmDbXqDjHdcss9uTGSQc12mei/+xfkFhxXz6iuZqTJkJ/VKxacjGkGVh11wvW+XpDxnoRIwFWjbFY+XjNuFoFouEmSBvQ5FsyxvttargxFVofQJyurRBxbFGhS2U5M7r7kRi36tB1FUNb4SaQpXjkJBOrjAajqjJVyBzXLYfVCpsYw2FAeyXijudMuFkxPa9J04BmxYJsXX1tVymW3YUH8Lp4ATqSsVkMSCphIYzCoeDlH3ay6+Z6z1KiT8oDPSvrAxKtvMp4HwAzwWysisVx3yk6pRhCg6PBUHXcBR+TgY2WdRVQH4suy/h/fEnzTcNkOmV+dozVkeFsQpoGrBZSU8jJ9UhSBAtCmkFuCilWAerFwiaolTDbxInsECzgwbeCbEAWvNdyzWjVTXwFR4MTq0yXy2KoDMtKAkK4G7OijjaGiDOZ9kzrgWxCnyLmwipXtLlCXLbZZG0UW61gtUJXHapKaGq8CeBgg6CD4gFSKi9rlWDBQYU0E9zHsKBkFmyddeSOtoiMJbuAiyD6jHu4lJxolBfd8QJRQ9XQ6ORosGNpd0GedTVxRG1TAK6rYAeyKb2FzXsCxE2Ajr5l/QCXV0gIxGVLrCoYmzJXgTDGSVBcFTuo6U4brBa6o0CuIc2FYV48wOriAfiapFKDFI9wPAgWtl6yzlAl5AyCI5UhwamqTAjGZNqTax37O8G9GKksQ9CqZBVLiqEsUmApExhJEXUu4ow2VcSFEhcDoTeihIDbaEp3sIwtFmXB19d8F3Q2Y356Ak3N5MExaRZp36gQU3IDfV1CSFIxieSSlovGCGaOjKILJbu4OoTykWiEytBQ0m0Qp6lLK9CnwKJtcBvDZoy5EKxowqBgUkLOpAh3k5HgtH1FykpohbgcU6vndQHxq225+pDwZQspEc8DYVIRVhOa64rcKKsTxSpKZzktVrcxOCVDyOuKbIzdAKKCySiWo/u7
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
}
],
"execution_count": 84
2026-03-19 05:47:14 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:09:55.272193107Z",
"start_time": "2026-03-19T06:09:55.113735528Z"
}
},
2026-03-19 05:47:14 +00:00
"cell_type": "code",
"source": [
"net = nn.Sequential(nn.Flatten(),\n",
" nn.Linear(784,256),\n",
" nn.ReLU(),\n",
" nn.Linear(256,10))\n",
"def init_weights(m):\n",
" if type(m) == nn.Linear:\n",
" nn.init.normal_(m.weight,std=0.01)\n",
"\n",
"net.apply(init_weights)\n"
],
"id": "9e4ed6d103380bc7",
"outputs": [
{
"data": {
"text/plain": [
"Sequential(\n",
" (0): Flatten(start_dim=1, end_dim=-1)\n",
" (1): Linear(in_features=784, out_features=256, bias=True)\n",
" (2): ReLU()\n",
" (3): Linear(in_features=256, out_features=10, bias=True)\n",
")"
]
},
"execution_count": 85,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 85
2026-03-19 05:47:14 +00:00
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-19T06:09:56.028892806Z",
"start_time": "2026-03-19T06:09:55.915661103Z"
}
},
2026-03-19 05:47:14 +00:00
"cell_type": "code",
"source": [
"batch_size,lr, num_epochs=256,0.1,10\n",
"loss = nn.CrossEntropyLoss(reduction='none')\n",
"trainer = torch.optim.SGD(net.parameters(),lr=lr)\n",
"train_iter, test_iter = load_data_fashion_mnist(batch_size)\n",
"#train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)"
],
"id": "52d71c77c4f51e90",
"outputs": [],
"execution_count": 86
2026-03-19 05:47:14 +00:00
},
{
"metadata": {},
"cell_type": "code",
"source": "",
"id": "94706c936b4be3e1",
"outputs": [],
"execution_count": null
2026-03-12 07:50:33 +00:00
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}