nn/test.ipynb

1563 lines
258 KiB
Text
Raw Normal View History

2026-03-12 07:50:33 +00:00
{
"cells": [
{
"cell_type": "code",
"id": "initial_id",
"metadata": {
"collapsed": true,
"ExecuteTime": {
2026-03-14 11:51:56 +00:00
"end_time": "2026-03-14T11:37:46.716891442Z",
"start_time": "2026-03-14T11:37:45.891202979Z"
2026-03-12 07:50:33 +00:00
}
},
"source": [
"import torch\n",
"import numpy\n",
"import pandas\n",
2026-03-14 11:51:56 +00:00
"from sympy.physics.control.control_plots import matplotlib\n",
"from torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook import batched_powerSGD_hook\n",
2026-03-12 07:50:33 +00:00
"\n"
],
"outputs": [],
2026-03-14 11:51:56 +00:00
"execution_count": 1
2026-03-12 07:50:33 +00:00
},
{
"metadata": {
"ExecuteTime": {
2026-03-14 11:51:56 +00:00
"end_time": "2026-03-14T11:37:46.748820470Z",
"start_time": "2026-03-14T11:37:46.717673396Z"
2026-03-12 07:50:33 +00:00
}
},
"cell_type": "code",
"source": "torch.randn(3,4,2)",
"id": "3e141a42d342fa96",
"outputs": [
{
"data": {
"text/plain": [
2026-03-14 11:51:56 +00:00
"tensor([[[ 1.1696, -0.5395],\n",
" [-1.2794, -1.0168],\n",
" [ 3.2351, 0.6066],\n",
" [ 1.5116, -0.1253]],\n",
2026-03-12 07:50:33 +00:00
"\n",
2026-03-14 11:51:56 +00:00
" [[-0.1823, 0.1887],\n",
" [ 0.0186, -1.5205],\n",
" [-0.3032, 0.1184],\n",
" [-0.1708, 1.2866]],\n",
2026-03-12 07:50:33 +00:00
"\n",
2026-03-14 11:51:56 +00:00
" [[ 0.1142, 0.0435],\n",
" [-0.4102, -0.4663],\n",
" [ 0.2203, 0.3123],\n",
" [ 1.9645, 1.8992]]])"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 2
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T11:37:46.795608911Z",
"start_time": "2026-03-14T11:37:46.749770547Z"
}
},
"cell_type": "code",
"source": [
"X = torch.arange(12, dtype=torch.float32).reshape((3,4))\n",
"Y = torch.tensor([[2.0, 1, 4, 3], [1, 2, 3, 4], [4, 3, 2, 1]])\n",
"torch.cat((X, Y), dim=0), torch.cat((X, Y), dim=1)"
],
"id": "8ae20ae68abbf32f",
"outputs": [
{
"data": {
"text/plain": [
"(tensor([[ 0., 1., 2., 3.],\n",
" [ 4., 5., 6., 7.],\n",
" [ 8., 9., 10., 11.],\n",
" [ 2., 1., 4., 3.],\n",
" [ 1., 2., 3., 4.],\n",
" [ 4., 3., 2., 1.]]),\n",
" tensor([[ 0., 1., 2., 3., 2., 1., 4., 3.],\n",
" [ 4., 5., 6., 7., 1., 2., 3., 4.],\n",
" [ 8., 9., 10., 11., 4., 3., 2., 1.]]))"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 3
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T11:37:46.812288160Z",
"start_time": "2026-03-14T11:37:46.803470930Z"
}
},
"cell_type": "code",
"source": [
"a = torch.arange(3).reshape((3, 1))\n",
"b = torch.arange(2).reshape((1, 2))\n",
"a, b\n",
"a+b"
],
"id": "2960a1ded2cdd5a4",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0, 1],\n",
" [1, 2],\n",
" [2, 3]])"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 4
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T11:37:46.908394350Z",
"start_time": "2026-03-14T11:37:46.858309741Z"
}
},
"cell_type": "code",
"source": "X[-1], X[1:3]\n",
"id": "69c2ec23ab6ae97c",
"outputs": [
{
"data": {
"text/plain": [
"(tensor([ 8., 9., 10., 11.]),\n",
" tensor([[ 4., 5., 6., 7.],\n",
" [ 8., 9., 10., 11.]]))"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 5
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T11:37:46.990018992Z",
"start_time": "2026-03-14T11:37:46.944214131Z"
}
},
"cell_type": "code",
"source": [
"A = X.numpy()\n",
"B = torch.tensor(A)\n",
"type(A), type(B)"
],
"id": "b8d779a1bc7e4b1a",
"outputs": [
{
"data": {
"text/plain": [
"(numpy.ndarray, torch.Tensor)"
2026-03-12 07:50:33 +00:00
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 6
},
{
"metadata": {
"ExecuteTime": {
2026-03-14 11:51:56 +00:00
"end_time": "2026-03-14T11:37:47.097040164Z",
"start_time": "2026-03-14T11:37:46.993944041Z"
2026-03-12 07:50:33 +00:00
}
},
"cell_type": "code",
2026-03-14 11:51:56 +00:00
"source": [
"import os\n",
"os.makedirs(os.path.join(\"..\",\"data\"),exist_ok=True)\n",
"data_file = os.path.join(os.path.join(\"..\",\"data\",\"data.csv\"))\n",
"with open(data_file, \"w\") as f:\n",
" f.write('NumRooms,Alley,Price\\n') # 列名\n",
" f.write('NA,Pave,127500\\n') # 每行表示一个数据样本\n",
" f.write('2,NA,106000\\n')\n",
" f.write('4,NA,178100\\n')\n",
" f.write('NA,NA,140000\\n')\n",
"\n"
],
"id": "82be028b0f1dd1e3",
2026-03-12 07:50:33 +00:00
"outputs": [],
2026-03-14 11:51:56 +00:00
"execution_count": 7
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T11:37:47.138922799Z",
"start_time": "2026-03-14T11:37:47.109432980Z"
}
},
"cell_type": "code",
"source": [
"import pandas as pd\n",
"data = pd.read_csv(data_file)\n",
"print(data)\n"
],
"id": "ddd789a2656899d1",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" NumRooms Alley Price\n",
"0 NaN Pave 127500\n",
"1 2.0 NaN 106000\n",
"2 4.0 NaN 178100\n",
"3 NaN NaN 140000\n"
]
}
],
"execution_count": 8
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T11:37:47.162491839Z",
"start_time": "2026-03-14T11:37:47.139968456Z"
}
},
"cell_type": "code",
"source": [
"inputs, outputs = data.iloc[:, 0:2], data.iloc[:, 2]\n",
"\n",
"\n",
"inputs = pd.get_dummies(inputs, dummy_na=True)\n",
"print(inputs)\n",
"inputs = inputs.fillna(inputs.mean())\n",
"print(inputs)\n"
],
"id": "e98fcc3bd4f067cf",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" NumRooms Alley_Pave Alley_nan\n",
"0 NaN True False\n",
"1 2.0 False True\n",
"2 4.0 False True\n",
"3 NaN False True\n",
" NumRooms Alley_Pave Alley_nan\n",
"0 3.0 True False\n",
"1 2.0 False True\n",
"2 4.0 False True\n",
"3 3.0 False True\n"
]
}
],
"execution_count": 9
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T11:37:47.186754928Z",
"start_time": "2026-03-14T11:37:47.168374155Z"
}
},
"cell_type": "code",
"source": [
"X = torch.tensor(inputs.to_numpy(dtype=float))\n",
"y = torch.tensor(outputs.to_numpy(dtype=float))\n",
"X, y\n"
],
"id": "8ff0f7b40f0e4996",
"outputs": [
{
"data": {
"text/plain": [
"(tensor([[3., 1., 0.],\n",
" [2., 0., 1.],\n",
" [4., 0., 1.],\n",
" [3., 0., 1.]], dtype=torch.float64),\n",
" tensor([127500., 106000., 178100., 140000.], dtype=torch.float64))"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 10
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T11:37:47.214268957Z",
"start_time": "2026-03-14T11:37:47.192209367Z"
}
},
"cell_type": "code",
"source": [
"B=torch.tensor([[1,2,3],[2,0,4],[3,4,5]])\n",
"B"
],
"id": "91a6e0da442b95a0",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[1, 2, 3],\n",
" [2, 0, 4],\n",
" [3, 4, 5]])"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 11
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T11:37:47.309990671Z",
"start_time": "2026-03-14T11:37:47.241157425Z"
}
},
"cell_type": "code",
"source": "B==B.T",
"id": "297e6a678fb19be7",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[True, True, True],\n",
" [True, True, True],\n",
" [True, True, True]])"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 12
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T11:37:47.535195222Z",
"start_time": "2026-03-14T11:37:47.409682869Z"
}
},
"cell_type": "code",
"source": [
"X=torch.arange(24).reshape(2,3,4)\n",
"X"
],
"id": "24e864b336beb58b",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[ 0, 1, 2, 3],\n",
" [ 4, 5, 6, 7],\n",
" [ 8, 9, 10, 11]],\n",
"\n",
" [[12, 13, 14, 15],\n",
" [16, 17, 18, 19],\n",
" [20, 21, 22, 23]]])"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 13
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T11:37:47.648470258Z",
"start_time": "2026-03-14T11:37:47.558566189Z"
}
},
"cell_type": "code",
"source": [
"A = torch.arange(20, dtype=torch.float32).reshape(5, 4)\n",
"B = A.clone() # 通过分配新内存将A的一个副本分配给B\n",
"A, A + B\n",
"#A = torch.arange(20, dtype=torch.float32).reshape(5, 4)\n",
"#B = A # 通过分配新内存将A的一个副本分配给B\n",
"id(A),id(B)"
],
"id": "ee0905479b1dbc2b",
"outputs": [
{
"data": {
"text/plain": [
"(140539332541136, 140539333492432)"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 14
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Hadamard乘积",
"id": "136459f5efe765cf"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T11:37:47.690032963Z",
"start_time": "2026-03-14T11:37:47.651693108Z"
}
},
"cell_type": "code",
"source": "A*B",
"id": "f576b0df17cc0e98",
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 0., 1., 4., 9.],\n",
" [ 16., 25., 36., 49.],\n",
" [ 64., 81., 100., 121.],\n",
" [144., 169., 196., 225.],\n",
" [256., 289., 324., 361.]])"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 15
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T11:37:47.778795013Z",
"start_time": "2026-03-14T11:37:47.722390317Z"
}
},
"cell_type": "code",
"source": [
"a=2\n",
"X=torch.arange(24).reshape(2,3,4)\n",
"a+X,(a*X).shape"
],
"id": "b2373af1d7f2a45",
"outputs": [
{
"data": {
"text/plain": [
"(tensor([[[ 2, 3, 4, 5],\n",
" [ 6, 7, 8, 9],\n",
" [10, 11, 12, 13]],\n",
" \n",
" [[14, 15, 16, 17],\n",
" [18, 19, 20, 21],\n",
" [22, 23, 24, 25]]]),\n",
" torch.Size([2, 3, 4]))"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 16
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T11:37:47.852744522Z",
"start_time": "2026-03-14T11:37:47.810917322Z"
}
},
"cell_type": "code",
"source": [
"print(A)\n",
"A_sum_axis0=A.sum(axis=0)\n",
"A_sum_axis1=A.sum(axis=1)\n",
"A_sum_axis0,A_sum_axis1"
],
"id": "2b50246e1ca8a3bc",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[ 0., 1., 2., 3.],\n",
" [ 4., 5., 6., 7.],\n",
" [ 8., 9., 10., 11.],\n",
" [12., 13., 14., 15.],\n",
" [16., 17., 18., 19.]])\n"
]
},
{
"data": {
"text/plain": [
"(tensor([40., 45., 50., 55.]), tensor([ 6., 22., 38., 54., 70.]))"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 17
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T11:37:48.039859489Z",
"start_time": "2026-03-14T11:37:47.861768656Z"
}
},
"cell_type": "code",
"source": [
"x=torch.arange(4,dtype=torch.float32)\n",
"torch.mv(A,x)"
],
"id": "3195464dfeb554ed",
"outputs": [
{
"data": {
"text/plain": [
"tensor([ 14., 38., 62., 86., 110.])"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 18
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T11:37:48.070161220Z",
"start_time": "2026-03-14T11:37:48.042514455Z"
}
},
"cell_type": "code",
"source": [
"import time\n",
"\n",
"def showtime(func):\n",
" def wrapper():\n",
" start = time.time()\n",
" result = func() # 执行原始函数\n",
" end = time.time()\n",
" print(f\"执行时间: {end - start:.6f}秒\")\n",
" return result\n",
" return wrapper # 返回包装函数\n",
"\n",
"@showtime\n",
"def fun():\n",
" print(\"I am silly\")\n",
"\n",
"fun()\n"
],
"id": "ebda8c74ead3e42b",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"I am silly\n",
"执行时间: 0.000630秒\n"
]
}
],
"execution_count": 19
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T11:37:48.092871646Z",
"start_time": "2026-03-14T11:37:48.071687740Z"
}
},
"cell_type": "code",
"source": "torch.norm(torch.ones((4, 9)))",
"id": "3343cc0c01d0161c",
"outputs": [
{
"data": {
"text/plain": [
"tensor(6.)"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 20
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T11:37:48.100506481Z",
"start_time": "2026-03-14T11:37:48.094144409Z"
}
},
"cell_type": "code",
"source": [
"x =torch.arange(4.0,requires_grad=True)\n",
"x.grad"
],
"id": "674e2416e9417cfe",
"outputs": [],
"execution_count": 21
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T11:37:48.121524857Z",
"start_time": "2026-03-14T11:37:48.101329494Z"
}
},
"cell_type": "code",
"source": [
"y=2*torch.dot(x,x)\n",
"y"
],
"id": "66c0febebcf98cde",
"outputs": [
{
"data": {
"text/plain": [
"tensor(28., grad_fn=<MulBackward0>)"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 22
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T11:37:48.148136131Z",
"start_time": "2026-03-14T11:37:48.129950999Z"
}
},
"cell_type": "code",
"source": [
"y.backward()\n",
"x.grad"
],
"id": "825f2ce6c46ca4a8",
"outputs": [
{
"data": {
"text/plain": [
"tensor([ 0., 4., 8., 12.])"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 23
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T11:37:48.175519287Z",
"start_time": "2026-03-14T11:37:48.153141505Z"
}
},
"cell_type": "code",
"source": [
"x.grad.zero_()\n",
"y = x.sum()\n",
"y.backward()\n",
"x.grad\n"
],
"id": "df399463515e9d3c",
"outputs": [
{
"data": {
"text/plain": [
"tensor([1., 1., 1., 1.])"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 24
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T11:37:48.190165115Z",
"start_time": "2026-03-14T11:37:48.176741284Z"
}
},
"cell_type": "code",
"source": [
"# 对非标量调用backward需要传入一个gradient参数该参数指定微分函数关于self的梯度。\n",
"# 本例只想求偏导数的和所以传递一个1的梯度是合适的\n",
"x.grad.zero_()\n",
"y = x * x\n",
"# 等价于y.backward(torch.ones(len(x)))\n",
"print(y)\n",
"y.sum().backward()\n",
"x.grad"
],
"id": "f9207619bd4b3de8",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([0., 1., 4., 9.], grad_fn=<MulBackward0>)\n"
]
},
{
"data": {
"text/plain": [
"tensor([0., 2., 4., 6.])"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 25
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T11:37:48.210632850Z",
"start_time": "2026-03-14T11:37:48.200923142Z"
}
},
"cell_type": "code",
"source": "torch.ones(len(x))",
"id": "409c14c230570859",
"outputs": [
{
"data": {
"text/plain": [
"tensor([1., 1., 1., 1.])"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 26
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T11:37:48.229459200Z",
"start_time": "2026-03-14T11:37:48.218856646Z"
}
},
"cell_type": "code",
"source": [
"x.grad.zero_()\n",
"y=x*x\n",
"u=y.detach()\n",
"z=u*x\n",
"z.sum().backward()\n",
"x.grad==u"
],
"id": "521b948fe0683b12",
"outputs": [
{
"data": {
"text/plain": [
"tensor([True, True, True, True])"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 27
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T11:37:48.269115256Z",
"start_time": "2026-03-14T11:37:48.244909724Z"
}
},
"cell_type": "code",
"source": [
"x.grad.zero_()\n",
"y.sum().backward()\n",
"x.grad==2*x"
],
"id": "b040beecf0632315",
"outputs": [
{
"data": {
"text/plain": [
"tensor([True, True, True, True])"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 28
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T11:37:48.300440123Z",
"start_time": "2026-03-14T11:37:48.279382427Z"
}
},
"cell_type": "code",
"source": [
"from torch.distributions import multinomial\n",
"fair_probs=torch.ones([6])\n",
"fair_probs"
],
"id": "4e6ec763dbea5aa3",
"outputs": [
{
"data": {
"text/plain": [
"tensor([1., 1., 1., 1., 1., 1.])"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 29
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T11:37:48.318972768Z",
"start_time": "2026-03-14T11:37:48.301409242Z"
}
},
"cell_type": "code",
"source": "multinomial.Multinomial(1, fair_probs).sample()",
"id": "f12d5e85bc6ab595",
"outputs": [
{
"data": {
"text/plain": [
"tensor([0., 0., 1., 0., 0., 0.])"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 30
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T11:37:48.337802576Z",
"start_time": "2026-03-14T11:37:48.320185416Z"
}
},
"cell_type": "code",
"source": [
"counts = multinomial.Multinomial(10, fair_probs).sample((500,))\n",
"\n",
"cum_counts = counts.cumsum(dim=0)\n",
"cum_counts.size()"
],
"id": "b02f43376fd6f1fe",
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([500, 6])"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 31
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T11:37:48.463214605Z",
"start_time": "2026-03-14T11:37:48.358849572Z"
}
},
"cell_type": "code",
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"# 假设 estimates 是你的数据张量\n",
"estimates = cum_counts / cum_counts.sum(dim=1, keepdims=True)\n",
"\n",
"# 设置图形大小 (等效于 d2l.set_figsize)\n",
"plt.figure(figsize=(6, 4.5))\n",
"\n",
"# 绘制每条概率曲线\n",
"for i in range(6):\n",
" plt.plot(estimates[:, i].numpy(),\n",
" label=f\"P(die={i + 1})\") # 使用 f-string 更简洁\n",
"\n",
"# 添加理论概率水平线\n",
"plt.axhline(y=0.167, color='black', linestyle='dashed', label='Theoretical probability')\n",
"\n",
"# 设置坐标轴标签\n",
"plt.xlabel('Groups of experiments')\n",
"plt.ylabel('Estimated probability')\n",
"\n",
"# 添加图例\n",
"plt.legend()\n",
"\n",
"# 显示图形\n",
"plt.show()\n",
"#plt.savefig('dice_probability.png', bbox_inches='tight')"
],
"id": "8b80daa4edd0b066",
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 600x450 with 1 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiEAAAGZCAYAAABfZuECAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAzkxJREFUeJzs3Xd8U9X7wPHPzWhW0713oZRZ9gbZU0WGAm7FjRPcuBCcOFB/7gXi+CIoCqIiW0E2hbJn6d47HWnm/f0RmlLaAlWWcN6vV17KzcnNSdrmPjnnOc+RABlBEARBEITzTHGhOyAIgiAIwuVJBCGCIAiCIFwQIggRBEEQBOGCEEGIIAiCIAgXhAhCBEEQBEG4IEQQIgiCIAjCBSGCEEEQBEEQLgjVhe7AxSosLIzy8vIL3Q1BEARB+M8xGo1kZ2eftp0IQhoQFhZGVlbWhe6GIAiCIPxnhYeHnzYQEUFIA2pGQMLDw8VoiCAIgiA0gdFoJCsr64yunyIIOYXy8nIRhAiCIAjCOSISUwVBEARBuCBEECIIgiAIwgUhghBBEARBEC4IEYQIgiAIgnBBiCBEEARBEIQLQgQhgiAIgiBcECIIEQRBEAThghBBiCAIgiAIF4QIQgRBEARBuCAuiiDk/vvvJyUlBbPZzObNm+nWrVujbceOHcu2bdsoKSmhoqKCnTt3cvPNN9drN2PGDLKzs6mqqmLlypXExcWdy5cgCIIgCMI/IF/I24QJE+Tq6mr59ttvl1u3bi1/+umncnFxsRwYGNhg+/79+8tjxoyRW7VqJTdr1kx++OGHZZvNJg8bNszd5sknn5RLSkrka665Rk5ISJAXL14sJycnyxqN5oz6ZDQaZVmWZaPReEHfG3ETN3ETN3ETt//arYnX0Avb2c2bN8vvv/+++9+SJMmZmZnyU089dcbnSExMlGfOnOn+d3Z2tvzYY4+5/+3l5SWbzWZ54sSJ5+INPKObd3CI7BsaLus8vS74L4i4iZu4iZu4idu5ujXlGnpBp2PUajVdunRh1apV7mOyLLNq1Sp69ep1RucYNGgQLVu2ZN26dQDExsYSGhpa55wmk4ktW7Y0ek4PDw+MRmOd29k29fPfmPLJMp7+7m+ue2zWWT+/IAiCIPzXXNAgJCAgAJVKRV5eXp3jeXl5hISENPo4Ly8vysvLsVqt/Pbbbzz00EPuoKPmcU0557Rp0zCZTO5bVlbWv3lZjZDc/xfbvvGcF0EQBEG4XFwUialNVV5eTseOHenWrRvPPvsss2fPpn///v/4fK+99hpeXl7uW3h4+FnsrcsrE/qgCzcDICmUZ/38giAIgvBfo7qQT15YWIjdbic4OLjO8eDgYHJzcxt9nCzLJCcnA7Br1y5at27NtGnT+Ouvv9yPO/kcwcHBJCUlNXg+q9WK1Wr9l6/m1GzVZiRJBkCh+E/GfoIgCIJwVl3Qq6HNZiMxMZHBgwe7j0mSxODBg9m0adMZn0ehUKDRaABISUkhJyenzjmNRiM9evRo0jnPBafsBMRIiCAIgiDABR4JAZg9ezbz5s1j+/btbN26lSlTpmAwGJg7dy4A8+bNIysri2eeeQaAp59+mu3bt5OcnIxGo+HKK6/klltuYfLkye5zvvvuuzz33HMcOXKElJQUXnrpJbKzs1m8ePGFeIlustM1EiJJ0mlaCoIgCMKl74IHIQsXLiQwMJCZM2cSEhJCUlISI0aMID8/H4CoqCicTqe7vcFg4KOPPiIiIgKz2czBgwe5+eabWbhwobvNG2+8gcFg4LPPPsPHx4e///6bESNGYLFYzvvrO5HsdACgECMhgiAIgoCEa62ucAKj0YjJZHKvwjlbXtn4J/YCP+w2Ky9d1/WsnVcQBEEQLhZNuYaKDMnzSHbW5ISIt10QBEEQxNXwPHLKrukYkRMiCIIgCCIIOa9ku8gJEQRBEIQaIgg5j05MsBVTMoIgCMLlTlwJzyP5hCBEFCwTBEEQLnfiSngeOZ129/9LknjrBUEQhMubuBKeR07HidMxIi9EEARBuLyJIOQ8EtMxgiAIglBLXAnPI6fjhOkYEYQIgiAIlzlxJTyPajawAxGECIIgCIK4Ep5HdadjRE6IIAiCcHkTQch55HQ4qNmqR4yECIIgCJc7cSU8j1xBiIsYCREEQRAudyIIOY+cDqdr32LE/jGCIAiCIIKQ80h2ijohgiAIglBDBCHnUd3pGPHWC4IgCJc3cSU8j5zOE6ZjlGIkRBAEQbi8iSDkPJJPLNsuckIEQRCEy5wIQs4jp9NBTewhVscIgiAIlzsRhJxHTqcTUSdEEARBEFzElfA8OnE6RoyECIIgCJc7EYScR06nozYxVYyECIIgCJc5cSU8j5z22iW6IggRBEEQLnfiSngeOetsYCfeekEQBOHyJq6E55HsOHE6RuSECIIgCJc3EYScR2IkRBAEQRBqiSvheeR0OGsGQpAk8dYLgiAIlzdxJTyPZKcTJFEnRBAEQRBABCHnVd0N7EROiCAIgnB5E0HIeVRnAzsxEiIIgiBc5sSV8DyST0hMFUGIIAiCcLkTV8LzqO50jHjrBUEQhMubuBKeR05RJ0QQBEEQ3EQQch45T9jATqlSX8CeCIIgCMKFJ4KQ80h2Otx1QsY/PosOA6+5oP0RBEEQhAtJBCHnkdNRuzoGYNyUly9cZwRBEAThAhNByHnkKtsuX+huCIIgCMJFQQQh55F8wuoYQRAEQbjciSDkPDqxWJkgCIIgXO5EEHIenbg6RhAEQRAudyIIOY9kZ/3pGLVGdwF6IgiCIAgXnghCziOnw4l00nSMwdv3wnRGEARBEC6wiyIIuf/++0lJScFsNrN582a6devWaNu77rqLdevWUVxcTHFxMStXrqzXfu7cuciyXOe2bNmyc/0yTsvZQGKq3svn/HdEEARBEC4CFzwImTBhArNnz2bGjBl07tyZXbt2sXz5cgIDAxtsP2DAAObPn8/AgQPp1asXGRkZrFixgrCwsDrtli1bRkhIiPt2ww03nI+Xc0oNJabqvfwuTGcEQRAE4QK74EHIo48+yueff85XX33FgQMHuO+++6iqquKOO+5osP3NN9/Mxx9/zK5duzh06BB33XUXCoWCwYMH12lnsVjIy8tz30pLSxvtg4eHB0ajsc7tXHCNhNStEyKmYwRBEITL1QUNQtRqNV26dGHVqlXuY7Iss2rVKnr16nVG59Dr9ajVaoqLi+scHzBgAHl5eRw8eJCPPvoIP7/GRxymTZuGyWRy37Kysv7ZCzoN2Vl/dYynj/85eS5BEARBuNhd0CAkICAAlUpFXl5eneN5eXmEhISc0TlmzZpFdnZ2nUDmjz/+4NZbb2Xw4ME89dRT9O/fn2XLlqFQNPxyX3vtNby8vNy38PDwf/6iTuHksu0AXv7B5+S5BEEQBOFip7rQHfg3nnrqKa6//noGDBiAxWJxH1+wYIH7//fu3cvu3bs5duwYAwYMYM2aNfXOY7VasVqt57y/DS3R9QoQQYggCIJwebqgIyGFhYXY7XaCg+teiIODg8nNzT3lYx977DGefvpphg0bxp49e07ZNiUlhYKCAuLi4v51n/+NhoqVeQec2YiPIAiCIFxqLmgQYrPZSExMrJNUKkkSgwcPZtOmTY0+7oknnuD5559nxIgRJCYmnvZ5wsPD8ff3Jycn56z0+59yOuvXCfESQYggCIJwmbrgq2Nmz57N3Xffza233kqrVq34+OOPMRgMzJ07F4B58+bx6quvuts/+eSTvPTSS9xxxx2kpqYSHBxMcHAwBoMBAIPBwBtvvEGPHj2Ijo5m0KBBLFmyhKNHj7J8+fIL8hprNLSBndE3AKVKfQF6IwiCIAgX1gXPCVm4cCGBgYHMnDmTkJAQkpKSGDFiBPn5+QBERUW56mscN3nyZDQaDYsWLapznhdffJEZM2bgcDho3749t912Gz4+PmRnZ7NixQqef/7585L3cSrOBlbHABj9gijNPzcrcgRBEAThYiVxcuEKAaPRiMlkwsvLi/Ly8rN23naD+nHjC+9gK6s78jFn2iTS9p9+WkkQBEEQLnZNuYZe8OmYy4k
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
}
],
"execution_count": 32
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T11:37:48.474923890Z",
"start_time": "2026-03-14T11:37:48.467148045Z"
}
},
"cell_type": "code",
"source": [
"import numpy as np\n",
"class Timer:\n",
" \"\"\"记录多次运行时间\"\"\"\n",
" def __init__(self):\n",
" self.times = []\n",
" self.start()\n",
" def start(self):\n",
" \"\"\"启动计时器\"\"\"\n",
" self.tik = time.time()\n",
" def stop(self):\n",
" \"\"\"停止计时器并将时间记录在列表中\"\"\"\n",
" self.times.append(time.time() - self.tik)\n",
" return self.times[-1]\n",
" def avg(self):\n",
" \"\"\"返回平均时间\"\"\"\n",
" return sum(self.times) / len(self.times)\n",
" def sum(self):\n",
" \"\"\"返回时间总和\"\"\"\n",
" return sum(self.times)\n",
" def cumsum(self):\n",
" \"\"\"返回累计时间\"\"\"\n",
" return np.array(self.times).cumsum().tolist()"
],
"id": "4bdbb4999907154a",
"outputs": [],
"execution_count": 33
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T11:37:48.638326094Z",
"start_time": "2026-03-14T11:37:48.475453591Z"
}
},
"cell_type": "code",
"source": [
"n = 10000\n",
"a = torch.ones([n])\n",
"b = torch.ones([n])\n",
"c=torch.zeros(n)\n",
"timer = Timer()\n",
"for i in range(n):\n",
" c[i]=a[i]+b[i]\n",
"f'{timer.stop():.5f} sec'"
],
"id": "c6f71622e2cc578a",
"outputs": [
{
"data": {
"text/plain": [
"'0.05042 sec'"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 34
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T11:37:48.678544906Z",
"start_time": "2026-03-14T11:37:48.640377978Z"
}
},
"cell_type": "code",
"source": [
"timer.start()\n",
"d=a+b\n",
"f'{timer.stop():.5f} sec'"
],
"id": "2578c79b1214a79f",
"outputs": [
{
"data": {
"text/plain": [
"'0.00046 sec'"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 35
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T11:37:48.687129040Z",
"start_time": "2026-03-14T11:37:48.680082631Z"
}
},
"cell_type": "code",
"source": [
"import math\n",
"def normal(x, mu, sigma):\n",
" p = 1 / math.sqrt(2 * math.pi * sigma**2)\n",
" return p * np.exp(-0.5 / sigma**2 * (x - mu)**2)"
],
"id": "fd17fdbe38a5f79",
"outputs": [],
"execution_count": 36
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T11:37:48.705878968Z",
"start_time": "2026-03-14T11:37:48.687875003Z"
}
},
"cell_type": "code",
"source": [
"from matplotlib_inline import backend_inline\n",
"def use_svg_display(): #@save\n",
" \"\"\"使用svg格式在Jupyter中显示绘图\"\"\"\n",
" backend_inline.set_matplotlib_formats('svg')\n",
"def set_figsize(figsize=(3.5, 2.5)): #@save\n",
" \"\"\"设置matplotlib的图表大小\"\"\"\n",
" use_svg_display()\n",
" plt.rcParams['figure.figsize'] = figsize\n",
"def set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend):\n",
" \"\"\"设置matplotlib的轴\"\"\"\n",
" axes.set_xlabel(xlabel)\n",
" axes.set_ylabel(ylabel)\n",
" axes.set_xscale(xscale)\n",
" axes.set_yscale(yscale)\n",
" axes.set_xlim(xlim)\n",
" axes.set_ylim(ylim)\n",
" if legend:\n",
" axes.legend(legend)\n",
" axes.grid()\n",
"def plot(X, Y=None, xlabel=None, ylabel=None, legend=None, xlim=None,\n",
"ylim=None, xscale='linear', yscale='linear',\n",
"fmts=('-', 'm--', 'g-.', 'r:'), figsize=(3.5, 2.5), axes=None):\n",
" \"\"\"绘制数据点\"\"\"\n",
" if legend is None:\n",
" legend = []\n",
" set_figsize(figsize)\n",
" axes = axes if axes else plt.gca()\n",
" # 如果X有一个轴输出True\n",
" def has_one_axis(X):\n",
" return (hasattr(X, \"ndim\") and X.ndim == 1 or isinstance(X, list)\n",
"and not hasattr(X[0], \"__len__\"))\n",
" if has_one_axis(X):\n",
" X = [X]\n",
" if Y is None:\n",
" X, Y = [[]] * len(X), X\n",
" elif has_one_axis(Y):\n",
" Y = [Y]\n",
" if len(X) != len(Y):\n",
" X = X * len(Y)\n",
" axes.cla()\n",
" for x, y, fmt in zip(X, Y, fmts):\n",
" if len(x):\n",
" axes.plot(x, y, fmt)\n",
" else:\n",
" axes.plot(y, fmt)\n",
" set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend)"
],
"id": "82158a69cba14da0",
"outputs": [],
"execution_count": 37
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T11:37:48.820167717Z",
"start_time": "2026-03-14T11:37:48.707125239Z"
}
},
"cell_type": "code",
"source": [
"# 再次使用numpy进行可视化\n",
"x = np.arange(-7, 7, 0.01)\n",
"# 均值和标准差对\n",
"params = [(0, 1), (0, 2), (3, 1)]\n",
"plot(x, [normal(x, mu, sigma) for mu, sigma in params], xlabel='x',\n",
"ylabel='p(x)', figsize=(4.5, 2.5),\n",
"legend=[f'mean {mu}, std {sigma}' for mu, sigma in params])"
],
"id": "f69ac10ebc3d13d8",
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 450x250 with 1 Axes>"
],
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"302.08125pt\" height=\"183.35625pt\" viewBox=\"0 0 302.08125 183.35625\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2026-03-14T19:37:48.782465</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.10.8, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 183.35625 \nL 302.08125 183.35625 \nL 302.08125 0 \nL 0 0 \nz\n\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 43.78125 145.8 \nL 294.88125 145.8 \nL 294.88125 7.2 \nL 43.78125 7.2 \nz\n\"/>\n </g>\n <g id=\"matplotlib.axis_1\">\n <g id=\"xtick_1\">\n <g id=\"line2d_1\">\n <path d=\"M 71.511736 145.8 \nL 71.511736 7.2 \n\" clip-path=\"url(#p81920ab880)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_2\">\n <defs>\n <path id=\"me49776c533\" d=\"M 0 0 \nL 0 3.5 \n\" style=\"stroke: #ffffff; stroke-width: 0.8\"/>\n </defs>\n <g>\n <use xlink:href=\"#me49776c533\" x=\"71.511736\" y=\"145.8\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_1\">\n <!-- 6 -->\n <g style=\"fill: #ffffff\" transform=\"translate(64.140642 160.398438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-2212\" d=\"M 678 2272 \nL 4684 2272 \nL 4684 1741 \nL 678 1741 \nL 678 2272 \nz\n\" transform=\"scale(0.015625)\"/>\n <path id=\"DejaVuSans-36\" d=\"M 2113 2584 \nQ 1688 2584 1439 2293 \nQ 1191 2003 1191 1497 \nQ 1191 994 1439 701 \nQ 1688 409 2113 409 \nQ 2538 409 2786 701 \nQ 3034 994 3034 1497 \nQ 3034 2003 2786 2293 \nQ 2538 2584 2113 2584 \nz\nM 3366 4563 \nL 3366 3988 \nQ 3128 4100 2886 4159 \nQ 2644 4219 2406 4219 \nQ 1781 4219 1451 3797 \nQ 1122 3375 1075 2522 \nQ 1259 2794 1537 2939 \nQ 1816 3084 2150 3084 \nQ 2853 3084 3261 2657 \nQ 3669 2231 3669 1497 \nQ 3669 778 3244 343 \nQ 2819 -91 2113 -91 \nQ 1303 -91 875 529 \nQ 447 1150 447 2328 \nQ 447 3434 972 4092 \nQ 1497 4750 2381 4750 \nQ 2619 4750 2861 4703 \nQ 3103 4656 3366 4563 \nz\n\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-2212\"/>\n <use xlink:href=\"#DejaVuSans-36\" transform=\"translate(83.789062 0)\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_2\">\n <g id=\"line2d_3\">\n <path d=\"M 104.145435 145.8 \nL 104.145435 7.2 \n\" clip-path=\"url(#p81920ab880)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_4\">\n <g>\n <use xlink:href=\"#me49776c533\" x=\"104.145435\" y=\"145.8\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_2\">\n <!-- 4 -->\n <g style=\"fill: #ffffff\" transform=\"translate(96.774342 160.398438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-34\" d=\"M 2419 4116 \nL 825 1625 \nL 2419 1625 \nL 2419 4116 \nz\nM 2253 4666 \nL 3047 4666 \nL 3047 1625 \nL 3713 1625 \nL 3713 1100 \nL 3047 1100 \nL 3047 0 \nL 2419 0 \nL 2419 1100 \nL 313 1100 \nL 313 1709 \nL 2253 4666 \nz\n\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-2212\"/>\n <use xlink:href=\"#DejaVuSans-34\" transform=\"translate(83.789062 0)\"/>\n
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
}
],
"execution_count": 38
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T11:37:48.849464685Z",
"start_time": "2026-03-14T11:37:48.839470354Z"
}
},
"cell_type": "code",
"source": [
"#注意一下matmul做向量乘上矩阵的时候不用考虑转置的情况\n",
"def synthetic_data(w, b, num_examples): #@save\n",
" \"\"\"生成y=Xw+b+噪声\"\"\"\n",
" X = torch.normal(0, 1, (num_examples, len(w)))\n",
" y = torch.matmul(X, w) + b\n",
" y += torch.normal(0, 0.01, y.shape)\n",
" return X, y.reshape((-1, 1))\n"
],
"id": "7ed837bdd2b3a26d",
"outputs": [],
"execution_count": 39
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T11:37:49.046153840Z",
"start_time": "2026-03-14T11:37:48.850521592Z"
}
},
"cell_type": "code",
"source": [
"true_w = torch.tensor([2, -3.4])\n",
"true_b = 4.2\n",
"features, labels = synthetic_data(true_w, true_b, 1000)"
],
"id": "5ec2e204a6fd5cb2",
"outputs": [],
"execution_count": 40
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T11:37:49.114421568Z",
"start_time": "2026-03-14T11:37:49.049220234Z"
}
},
"cell_type": "code",
"source": [
"set_figsize()\n",
"plt.scatter(features[:, (1)].detach().numpy(), labels.detach().numpy(), 1)"
],
"id": "38213d46b3d9900d",
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.collections.PathCollection at 0x7fd1dc5d8050>"
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"text/plain": [
"<Figure size 350x250 with 1 Axes>"
],
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"231.442187pt\" height=\"169.678125pt\" viewBox=\"0 0 231.442187 169.678125\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2026-03-14T19:37:49.085396</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.10.8, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 169.678125 \nL 231.442187 169.678125 \nL 231.442187 0 \nL 0 0 \nz\n\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 28.942188 145.8 \nL 224.242188 145.8 \nL 224.242188 7.2 \nL 28.942188 7.2 \nz\n\"/>\n </g>\n <g id=\"PathCollection_1\">\n <defs>\n <path id=\"mec311cd7aa\" d=\"M 0 0.5 \nC 0.132602 0.5 0.25979 0.447317 0.353553 0.353553 \nC 0.447317 0.25979 0.5 0.132602 0.5 0 \nC 0.5 -0.132602 0.447317 -0.25979 0.353553 -0.353553 \nC 0.25979 -0.447317 0.132602 -0.5 0 -0.5 \nC -0.132602 -0.5 -0.25979 -0.447317 -0.353553 -0.353553 \nC -0.447317 -0.25979 -0.5 -0.132602 -0.5 0 \nC -0.5 0.132602 -0.447317 0.25979 -0.353553 0.353553 \nC -0.25979 0.447317 -0.132602 0.5 0 0.5 \nz\n\" style=\"stroke: #8dd3c7\"/>\n </defs>\n <g clip-path=\"url(#p1de2f470cd)\">\n <use xlink:href=\"#mec311cd7aa\" x=\"134.486182\" y=\"94.084937\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#mec311cd7aa\" x=\"143.877044\" y=\"104.463934\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#mec311cd7aa\" x=\"100.666573\" y=\"76.987414\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#mec311cd7aa\" x=\"80.296119\" y=\"56.031343\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#mec311cd7aa\" x=\"110.760197\" y=\"78.583035\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#mec311cd7aa\" x=\"122.557872\" y=\"63.64687\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#mec311cd7aa\" x=\"128.714432\" y=\"82.836413\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#mec311cd7aa\" x=\"118.272401\" y=\"82.090179\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#mec311cd7aa\" x=\"89.042918\" y=\"55.266539\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#mec311cd7aa\" x=\"167.665828\" y=\"100.420188\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#mec311cd7aa\" x=\"139.287651\" y=\"75.376865\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#mec311cd7aa\" x=\"120.513179\" y=\"66.434633\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#mec311cd7aa\" x=\"120.160009\" y=\"63.938541\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#mec311cd7aa\" x=\"100.307909\" y=\"72.202839\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#mec311cd7aa\" x=\"109.388936\" y=\"79.87903\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#mec311cd7aa\" x=\"127.666158\" y=\"77.090934\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#mec311cd7aa\" x=\"127.250946\" y=\"96.440363\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#mec311cd7aa\" x=\"131.189483\" y=\"73.420853\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#mec311cd7aa\" x=\"140.363667\" y=\"89.033501\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
}
],
"execution_count": 41
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T11:37:49.133683224Z",
"start_time": "2026-03-14T11:37:49.125021829Z"
}
},
"cell_type": "code",
"source": [
"w=torch.normal(0,0.01,size=(2,1),requires_grad=True)\n",
"b=torch.zeros(1,requires_grad=True)\n",
"def linreg(X, w, b):\n",
" return torch.matmul(X,w)+b\n",
"def squared_loss(y_hat,y):\n",
" return (y_hat-y.reshape(y_hat.shape))**2/2\n",
"def sgd(params,lr,batch_size):\n",
" with torch.no_grad():\n",
" for param in params:\n",
" param-=lr*param.grad/batch_size\n",
" param.grad.zero_()\n",
"lr = 0.03\n",
"num_epochs =20\n",
"net = linreg\n",
"loss = squared_loss"
],
"id": "12166e1bc3ddd695",
"outputs": [],
"execution_count": 42
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T11:37:49.152287703Z",
"start_time": "2026-03-14T11:37:49.134387652Z"
}
},
"cell_type": "code",
"source": [
"import random\n",
"def data_iter(batch_size, features, labels):\n",
" num_examples = len(features)\n",
" indices = list(range(num_examples))\n",
" # 这些样本是随机读取的,没有特定的顺序\n",
" random.shuffle(indices)\n",
" for i in range(0, num_examples, batch_size):\n",
" batch_indices = torch.tensor(\n",
" indices[i: min(i + batch_size, num_examples)])\n",
" yield features[batch_indices], labels[batch_indices]"
],
"id": "f3b7ee9f326bc687",
"outputs": [],
"execution_count": 43
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T11:37:49.183095346Z",
"start_time": "2026-03-14T11:37:49.153518090Z"
}
},
"cell_type": "code",
"source": [
"batch_size =10\n",
"for X,y in data_iter(batch_size, features, labels):\n",
" print(X,'\\n',y)\n",
" break"
],
"id": "f386e12d65afff2e",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[-0.3577, 0.6754],\n",
" [-0.1904, -0.6314],\n",
" [-1.5305, -0.2903],\n",
" [ 2.0532, -0.3528],\n",
" [ 0.4056, -0.7645],\n",
" [-0.7985, 1.3492],\n",
" [-0.4550, 0.1608],\n",
" [ 1.1672, -0.5057],\n",
" [ 0.3912, -2.4489],\n",
" [ 1.9930, 1.6857]]) \n",
" tensor([[ 1.1766],\n",
" [ 5.9634],\n",
" [ 2.1265],\n",
" [ 9.5015],\n",
" [ 7.6220],\n",
" [-1.9910],\n",
" [ 2.7378],\n",
" [ 8.2718],\n",
" [13.2983],\n",
" [ 2.4553]])\n"
]
}
],
"execution_count": 44
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T11:37:49.350363398Z",
"start_time": "2026-03-14T11:37:49.184018282Z"
}
},
"cell_type": "code",
"source": [
"for epoch in range(num_epochs):\n",
" for X, y in data_iter(batch_size, features, labels):\n",
" l=loss(net(X, w, b), y)\n",
" l.sum().backward()\n",
" sgd([w,b],lr,batch_size)\n",
" with torch.no_grad():\n",
" train_l =loss(net(features, w, b), labels)\n",
" print(f'epoch {epoch+1}, train loss: {float(train_l.mean()):3f}')"
],
"id": "8888ab6adcec36f1",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 1, train loss: 0.052015\n",
"epoch 2, train loss: 0.000228\n",
"epoch 3, train loss: 0.000049\n",
"epoch 4, train loss: 0.000048\n",
"epoch 5, train loss: 0.000048\n",
"epoch 6, train loss: 0.000048\n",
"epoch 7, train loss: 0.000048\n",
"epoch 8, train loss: 0.000048\n",
"epoch 9, train loss: 0.000048\n",
"epoch 10, train loss: 0.000048\n",
"epoch 11, train loss: 0.000048\n",
"epoch 12, train loss: 0.000048\n",
"epoch 13, train loss: 0.000048\n",
"epoch 14, train loss: 0.000048\n",
"epoch 15, train loss: 0.000048\n",
"epoch 16, train loss: 0.000048\n",
"epoch 17, train loss: 0.000048\n",
"epoch 18, train loss: 0.000048\n",
"epoch 19, train loss: 0.000048\n",
"epoch 20, train loss: 0.000048\n"
]
}
],
"execution_count": 45
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T11:37:49.378897333Z",
"start_time": "2026-03-14T11:37:49.353047845Z"
}
},
"cell_type": "code",
"source": [
"print(f'w的估计误差: {true_w - w.reshape(true_w.shape)}')\n",
"print(f'b的估计误差: {true_b - b}')"
],
"id": "8199439fa7f26309",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"w的估计误差: tensor([-0.0004, 0.0002], grad_fn=<SubBackward0>)\n",
"b的估计误差: tensor([-0.0005], grad_fn=<RsubBackward1>)\n"
]
}
],
"execution_count": 46
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T11:40:57.656654704Z",
"start_time": "2026-03-14T11:40:57.637331612Z"
}
},
"cell_type": "code",
"source": [
"from torch.utils import data\n",
"true_w = torch.tensor([2,-3.4])\n",
"true_b = 4.2\n",
"features,labels=synthetic_data(true_w, true_b, 1000)\n",
"def load_array(data_arrays,batch_size,is_train=True):\n",
" dataset = data.TensorDataset(*data_arrays)\n",
" return data.DataLoader(dataset,batch_size,shuffle=is_train)\n",
"batch_size = 10\n",
"data_iter = load_array((features,labels),batch_size)"
],
"id": "560d537dcbb5a335",
"outputs": [],
"execution_count": 51
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T11:40:58.976005043Z",
"start_time": "2026-03-14T11:40:58.927183259Z"
}
},
"cell_type": "code",
"source": [
"from torch import nn\n",
"net = nn.Sequential(nn.Linear(2, 1))\n",
"net[0].weight.data.normal_(0,0.001)\n",
"net[0].bias.data.fill_(0)"
],
"id": "c54fe059d6fd20de",
"outputs": [
{
"data": {
"text/plain": [
"tensor([0.])"
]
},
"execution_count": 52,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 52
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T11:41:52.186040634Z",
"start_time": "2026-03-14T11:41:52.126046372Z"
}
},
"cell_type": "code",
"source": [
"loss = nn.MSELoss()\n",
"trainer = torch.optim.SGD(net.parameters(), lr=0.01)\n",
"num_epochs = 3\n",
"for epoch in range(num_epochs):\n",
" for X, y in data_iter:\n",
" l = loss(net(X) ,y)\n",
" trainer.zero_grad()\n",
" l.backward()\n",
" trainer.step()\n",
" l = loss(net(features), labels)\n",
" print(f'epoch {epoch + 1}, loss {l:f}')\n"
],
"id": "e8a44851125b7cc6",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 1, loss 0.000091\n",
"epoch 2, loss 0.000091\n",
"epoch 3, loss 0.000091\n"
]
}
],
"execution_count": 63
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T11:37:49.859894220Z",
"start_time": "2026-03-14T11:37:49.844955750Z"
}
},
"cell_type": "code",
"source": "",
"id": "bd4e8a65ccd03177",
"outputs": [],
"execution_count": 49
2026-03-12 07:50:33 +00:00
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}