2026-03-12 07:50:33 +00:00
|
|
|
|
{
|
|
|
|
|
|
"cells": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"id": "initial_id",
|
|
|
|
|
|
"metadata": {
|
|
|
|
|
|
"collapsed": true,
|
|
|
|
|
|
"ExecuteTime": {
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"end_time": "2026-03-19T06:08:49.521494353Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:08:47.762419340Z"
|
2026-03-12 07:50:33 +00:00
|
|
|
|
}
|
|
|
|
|
|
},
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"import torch\n",
|
|
|
|
|
|
"import numpy\n",
|
2026-03-19 05:47:14 +00:00
|
|
|
|
"import pandas"
|
|
|
|
|
|
],
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [],
|
2026-03-14 11:51:56 +00:00
|
|
|
|
"execution_count": 1
|
2026-03-12 07:50:33 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:08:49.830469193Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:08:49.585762895Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-12 07:50:33 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": "torch.randn(3,4,2)",
|
|
|
|
|
|
"id": "3e141a42d342fa96",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"text/plain": [
|
|
|
|
|
|
"tensor([[[-1.0244, -0.4164],\n",
|
|
|
|
|
|
" [ 1.5765, -0.9106],\n",
|
|
|
|
|
|
" [-1.6388, -0.7727],\n",
|
|
|
|
|
|
" [-1.8594, -1.6634]],\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
" [[-0.3226, 0.6604],\n",
|
|
|
|
|
|
" [ 0.4341, -0.9600],\n",
|
|
|
|
|
|
" [ 0.2575, 2.0599],\n",
|
|
|
|
|
|
" [ 0.6960, 0.7095]],\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
" [[-0.0242, -0.5866],\n",
|
|
|
|
|
|
" [-0.8018, -0.3080],\n",
|
|
|
|
|
|
" [-1.3225, -0.0591],\n",
|
|
|
|
|
|
" [ 0.0322, 0.8251]]])"
|
|
|
|
|
|
]
|
|
|
|
|
|
},
|
|
|
|
|
|
"execution_count": 2,
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 2
|
2026-03-14 11:51:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:08:49.974163555Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:08:49.860972205Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-14 11:51:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"X = torch.arange(12, dtype=torch.float32).reshape((3,4))\n",
|
|
|
|
|
|
"Y = torch.tensor([[2.0, 1, 4, 3], [1, 2, 3, 4], [4, 3, 2, 1]])\n",
|
|
|
|
|
|
"torch.cat((X, Y), dim=0), torch.cat((X, Y), dim=1)"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "8ae20ae68abbf32f",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"text/plain": [
|
|
|
|
|
|
"(tensor([[ 0., 1., 2., 3.],\n",
|
|
|
|
|
|
" [ 4., 5., 6., 7.],\n",
|
|
|
|
|
|
" [ 8., 9., 10., 11.],\n",
|
|
|
|
|
|
" [ 2., 1., 4., 3.],\n",
|
|
|
|
|
|
" [ 1., 2., 3., 4.],\n",
|
|
|
|
|
|
" [ 4., 3., 2., 1.]]),\n",
|
|
|
|
|
|
" tensor([[ 0., 1., 2., 3., 2., 1., 4., 3.],\n",
|
|
|
|
|
|
" [ 4., 5., 6., 7., 1., 2., 3., 4.],\n",
|
|
|
|
|
|
" [ 8., 9., 10., 11., 4., 3., 2., 1.]]))"
|
|
|
|
|
|
]
|
|
|
|
|
|
},
|
|
|
|
|
|
"execution_count": 3,
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 3
|
2026-03-14 11:51:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:08:50.119655002Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:08:50.007109478Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-14 11:51:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"a = torch.arange(3).reshape((3, 1))\n",
|
|
|
|
|
|
"b = torch.arange(2).reshape((1, 2))\n",
|
|
|
|
|
|
"a, b\n",
|
|
|
|
|
|
"a+b"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "2960a1ded2cdd5a4",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"text/plain": [
|
|
|
|
|
|
"tensor([[0, 1],\n",
|
|
|
|
|
|
" [1, 2],\n",
|
|
|
|
|
|
" [2, 3]])"
|
|
|
|
|
|
]
|
|
|
|
|
|
},
|
|
|
|
|
|
"execution_count": 4,
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 4
|
2026-03-14 11:51:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:08:50.341778101Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:08:50.149819865Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-14 11:51:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": "X[-1], X[1:3]\n",
|
|
|
|
|
|
"id": "69c2ec23ab6ae97c",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"text/plain": [
|
|
|
|
|
|
"(tensor([ 8., 9., 10., 11.]),\n",
|
|
|
|
|
|
" tensor([[ 4., 5., 6., 7.],\n",
|
|
|
|
|
|
" [ 8., 9., 10., 11.]]))"
|
|
|
|
|
|
]
|
|
|
|
|
|
},
|
|
|
|
|
|
"execution_count": 5,
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 5
|
2026-03-14 11:51:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:08:51.169775176Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:08:50.471288979Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-14 11:51:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"A = X.numpy()\n",
|
|
|
|
|
|
"B = torch.tensor(A)\n",
|
|
|
|
|
|
"type(A), type(B)"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "b8d779a1bc7e4b1a",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"text/plain": [
|
|
|
|
|
|
"(numpy.ndarray, torch.Tensor)"
|
|
|
|
|
|
]
|
|
|
|
|
|
},
|
|
|
|
|
|
"execution_count": 6,
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 6
|
2026-03-12 07:50:33 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:08:51.804601549Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:08:51.491609008Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-12 07:50:33 +00:00
|
|
|
|
"cell_type": "code",
|
2026-03-14 11:51:56 +00:00
|
|
|
|
"source": [
|
|
|
|
|
|
"import os\n",
|
|
|
|
|
|
"os.makedirs(os.path.join(\"..\",\"data\"),exist_ok=True)\n",
|
|
|
|
|
|
"data_file = os.path.join(os.path.join(\"..\",\"data\",\"data.csv\"))\n",
|
|
|
|
|
|
"with open(data_file, \"w\") as f:\n",
|
|
|
|
|
|
" f.write('NumRooms,Alley,Price\\n') # 列名\n",
|
|
|
|
|
|
" f.write('NA,Pave,127500\\n') # 每行表示一个数据样本\n",
|
|
|
|
|
|
" f.write('2,NA,106000\\n')\n",
|
|
|
|
|
|
" f.write('4,NA,178100\\n')\n",
|
|
|
|
|
|
" f.write('NA,NA,140000\\n')\n",
|
|
|
|
|
|
"\n"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "82be028b0f1dd1e3",
|
2026-03-12 07:50:33 +00:00
|
|
|
|
"outputs": [],
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"execution_count": 7
|
2026-03-14 11:51:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:08:51.995470564Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:08:51.819209475Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-14 11:51:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"import pandas as pd\n",
|
|
|
|
|
|
"data = pd.read_csv(data_file)\n",
|
|
|
|
|
|
"print(data)\n"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "ddd789a2656899d1",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"name": "stdout",
|
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
|
"text": [
|
|
|
|
|
|
" NumRooms Alley Price\n",
|
|
|
|
|
|
"0 NaN Pave 127500\n",
|
|
|
|
|
|
"1 2.0 NaN 106000\n",
|
|
|
|
|
|
"2 4.0 NaN 178100\n",
|
|
|
|
|
|
"3 NaN NaN 140000\n"
|
|
|
|
|
|
]
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 8
|
2026-03-14 11:51:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:08:52.239162983Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:08:52.016744354Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-14 11:51:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"inputs, outputs = data.iloc[:, 0:2], data.iloc[:, 2]\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
"inputs = pd.get_dummies(inputs, dummy_na=True)\n",
|
|
|
|
|
|
"print(inputs)\n",
|
|
|
|
|
|
"inputs = inputs.fillna(inputs.mean())\n",
|
|
|
|
|
|
"print(inputs)\n"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "e98fcc3bd4f067cf",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"name": "stdout",
|
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
|
"text": [
|
|
|
|
|
|
" NumRooms Alley_Pave Alley_nan\n",
|
|
|
|
|
|
"0 NaN True False\n",
|
|
|
|
|
|
"1 2.0 False True\n",
|
|
|
|
|
|
"2 4.0 False True\n",
|
|
|
|
|
|
"3 NaN False True\n",
|
|
|
|
|
|
" NumRooms Alley_Pave Alley_nan\n",
|
|
|
|
|
|
"0 3.0 True False\n",
|
|
|
|
|
|
"1 2.0 False True\n",
|
|
|
|
|
|
"2 4.0 False True\n",
|
|
|
|
|
|
"3 3.0 False True\n"
|
|
|
|
|
|
]
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 9
|
2026-03-14 11:51:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:08:52.439862351Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:08:52.252020096Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-14 11:51:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"X = torch.tensor(inputs.to_numpy(dtype=float))\n",
|
|
|
|
|
|
"y = torch.tensor(outputs.to_numpy(dtype=float))\n",
|
|
|
|
|
|
"X, y\n"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "8ff0f7b40f0e4996",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"text/plain": [
|
|
|
|
|
|
"(tensor([[3., 1., 0.],\n",
|
|
|
|
|
|
" [2., 0., 1.],\n",
|
|
|
|
|
|
" [4., 0., 1.],\n",
|
|
|
|
|
|
" [3., 0., 1.]], dtype=torch.float64),\n",
|
|
|
|
|
|
" tensor([127500., 106000., 178100., 140000.], dtype=torch.float64))"
|
|
|
|
|
|
]
|
|
|
|
|
|
},
|
|
|
|
|
|
"execution_count": 10,
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 10
|
2026-03-14 11:51:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:08:52.613938268Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:08:52.458366888Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-14 11:51:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"B=torch.tensor([[1,2,3],[2,0,4],[3,4,5]])\n",
|
|
|
|
|
|
"B"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "91a6e0da442b95a0",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"text/plain": [
|
|
|
|
|
|
"tensor([[1, 2, 3],\n",
|
|
|
|
|
|
" [2, 0, 4],\n",
|
|
|
|
|
|
" [3, 4, 5]])"
|
|
|
|
|
|
]
|
|
|
|
|
|
},
|
|
|
|
|
|
"execution_count": 11,
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 11
|
2026-03-14 11:51:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:08:52.927117918Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:08:52.661665190Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-14 11:51:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": "B==B.T",
|
|
|
|
|
|
"id": "297e6a678fb19be7",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"text/plain": [
|
|
|
|
|
|
"tensor([[True, True, True],\n",
|
|
|
|
|
|
" [True, True, True],\n",
|
|
|
|
|
|
" [True, True, True]])"
|
|
|
|
|
|
]
|
|
|
|
|
|
},
|
|
|
|
|
|
"execution_count": 12,
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 12
|
2026-03-14 11:51:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:08:53.355336655Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:08:53.046923921Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-14 11:51:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"X=torch.arange(24).reshape(2,3,4)\n",
|
|
|
|
|
|
"X"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "24e864b336beb58b",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"text/plain": [
|
|
|
|
|
|
"tensor([[[ 0, 1, 2, 3],\n",
|
|
|
|
|
|
" [ 4, 5, 6, 7],\n",
|
|
|
|
|
|
" [ 8, 9, 10, 11]],\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
" [[12, 13, 14, 15],\n",
|
|
|
|
|
|
" [16, 17, 18, 19],\n",
|
|
|
|
|
|
" [20, 21, 22, 23]]])"
|
|
|
|
|
|
]
|
|
|
|
|
|
},
|
|
|
|
|
|
"execution_count": 13,
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 13
|
2026-03-14 11:51:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:08:53.538512227Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:08:53.424545812Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-14 11:51:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"A = torch.arange(20, dtype=torch.float32).reshape(5, 4)\n",
|
|
|
|
|
|
"B = A.clone() # 通过分配新内存,将A的一个副本分配给B\n",
|
|
|
|
|
|
"A, A + B\n",
|
|
|
|
|
|
"#A = torch.arange(20, dtype=torch.float32).reshape(5, 4)\n",
|
|
|
|
|
|
"#B = A # 通过分配新内存,将A的一个副本分配给B\n",
|
|
|
|
|
|
"id(A),id(B)"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "ee0905479b1dbc2b",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"text/plain": [
|
|
|
|
|
|
"(140069082123952, 140069083998928)"
|
|
|
|
|
|
]
|
|
|
|
|
|
},
|
|
|
|
|
|
"execution_count": 14,
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 14
|
2026-03-14 11:51:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
|
"source": "Hadamard乘积",
|
|
|
|
|
|
"id": "136459f5efe765cf"
|
|
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:08:53.817503675Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:08:53.581674028Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-14 11:51:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": "A*B",
|
|
|
|
|
|
"id": "f576b0df17cc0e98",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"text/plain": [
|
|
|
|
|
|
"tensor([[ 0., 1., 4., 9.],\n",
|
|
|
|
|
|
" [ 16., 25., 36., 49.],\n",
|
|
|
|
|
|
" [ 64., 81., 100., 121.],\n",
|
|
|
|
|
|
" [144., 169., 196., 225.],\n",
|
|
|
|
|
|
" [256., 289., 324., 361.]])"
|
|
|
|
|
|
]
|
|
|
|
|
|
},
|
|
|
|
|
|
"execution_count": 15,
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 15
|
2026-03-14 11:51:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:08:54.105858002Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:08:53.885857754Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-14 11:51:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"a=2\n",
|
|
|
|
|
|
"X=torch.arange(24).reshape(2,3,4)\n",
|
|
|
|
|
|
"a+X,(a*X).shape"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "b2373af1d7f2a45",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"text/plain": [
|
|
|
|
|
|
"(tensor([[[ 2, 3, 4, 5],\n",
|
|
|
|
|
|
" [ 6, 7, 8, 9],\n",
|
|
|
|
|
|
" [10, 11, 12, 13]],\n",
|
|
|
|
|
|
" \n",
|
|
|
|
|
|
" [[14, 15, 16, 17],\n",
|
|
|
|
|
|
" [18, 19, 20, 21],\n",
|
|
|
|
|
|
" [22, 23, 24, 25]]]),\n",
|
|
|
|
|
|
" torch.Size([2, 3, 4]))"
|
|
|
|
|
|
]
|
|
|
|
|
|
},
|
|
|
|
|
|
"execution_count": 16,
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 16
|
2026-03-14 11:51:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:08:54.282899541Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:08:54.174420931Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-14 11:51:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"print(A)\n",
|
|
|
|
|
|
"A_sum_axis0=A.sum(axis=0)\n",
|
|
|
|
|
|
"A_sum_axis1=A.sum(axis=1)\n",
|
|
|
|
|
|
"A_sum_axis0,A_sum_axis1"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "2b50246e1ca8a3bc",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"name": "stdout",
|
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
|
"text": [
|
|
|
|
|
|
"tensor([[ 0., 1., 2., 3.],\n",
|
|
|
|
|
|
" [ 4., 5., 6., 7.],\n",
|
|
|
|
|
|
" [ 8., 9., 10., 11.],\n",
|
|
|
|
|
|
" [12., 13., 14., 15.],\n",
|
|
|
|
|
|
" [16., 17., 18., 19.]])\n"
|
|
|
|
|
|
]
|
|
|
|
|
|
},
|
|
|
|
|
|
{
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"text/plain": [
|
|
|
|
|
|
"(tensor([40., 45., 50., 55.]), tensor([ 6., 22., 38., 54., 70.]))"
|
|
|
|
|
|
]
|
|
|
|
|
|
},
|
|
|
|
|
|
"execution_count": 17,
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 17
|
2026-03-14 11:51:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:08:54.345533969Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:08:54.290459990Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-14 11:51:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"x=torch.arange(4,dtype=torch.float32)\n",
|
|
|
|
|
|
"torch.mv(A,x)"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "3195464dfeb554ed",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"text/plain": [
|
|
|
|
|
|
"tensor([ 14., 38., 62., 86., 110.])"
|
|
|
|
|
|
]
|
|
|
|
|
|
},
|
|
|
|
|
|
"execution_count": 18,
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 18
|
2026-03-14 11:51:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:08:54.432702369Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:08:54.361198167Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-14 11:51:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"import time\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
"def showtime(func):\n",
|
|
|
|
|
|
" def wrapper():\n",
|
|
|
|
|
|
" start = time.time()\n",
|
|
|
|
|
|
" result = func() # 执行原始函数\n",
|
|
|
|
|
|
" end = time.time()\n",
|
|
|
|
|
|
" print(f\"执行时间: {end - start:.6f}秒\")\n",
|
|
|
|
|
|
" return result\n",
|
|
|
|
|
|
" return wrapper # 返回包装函数\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
"@showtime\n",
|
|
|
|
|
|
"def fun():\n",
|
|
|
|
|
|
" print(\"I am silly\")\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
"fun()\n"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "ebda8c74ead3e42b",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"name": "stdout",
|
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
|
"text": [
|
|
|
|
|
|
"I am silly\n",
|
|
|
|
|
|
"执行时间: 0.000109秒\n"
|
|
|
|
|
|
]
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 19
|
2026-03-14 11:51:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:08:54.696923466Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:08:54.436724383Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-14 11:51:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": "torch.norm(torch.ones((4, 9)))",
|
|
|
|
|
|
"id": "3343cc0c01d0161c",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"text/plain": [
|
|
|
|
|
|
"tensor(6.)"
|
|
|
|
|
|
]
|
|
|
|
|
|
},
|
|
|
|
|
|
"execution_count": 20,
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 20
|
2026-03-14 11:51:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:08:54.768902496Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:08:54.719453919Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-14 11:51:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"x =torch.arange(4.0,requires_grad=True)\n",
|
|
|
|
|
|
"x.grad"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "674e2416e9417cfe",
|
|
|
|
|
|
"outputs": [],
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"execution_count": 21
|
2026-03-14 11:51:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:08:54.891549085Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:08:54.804248422Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-14 11:51:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"y=2*torch.dot(x,x)\n",
|
|
|
|
|
|
"y"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "66c0febebcf98cde",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"text/plain": [
|
|
|
|
|
|
"tensor(28., grad_fn=<MulBackward0>)"
|
|
|
|
|
|
]
|
|
|
|
|
|
},
|
|
|
|
|
|
"execution_count": 22,
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 22
|
2026-03-14 11:51:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:08:54.993310716Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:08:54.914195086Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-14 11:51:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"y.backward()\n",
|
|
|
|
|
|
"x.grad"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "825f2ce6c46ca4a8",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"text/plain": [
|
|
|
|
|
|
"tensor([ 0., 4., 8., 12.])"
|
|
|
|
|
|
]
|
|
|
|
|
|
},
|
|
|
|
|
|
"execution_count": 23,
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 23
|
2026-03-14 11:51:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:08:55.184135089Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:08:54.995848170Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-14 11:51:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"x.grad.zero_()\n",
|
|
|
|
|
|
"y = x.sum()\n",
|
|
|
|
|
|
"y.backward()\n",
|
|
|
|
|
|
"x.grad\n"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "df399463515e9d3c",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"text/plain": [
|
|
|
|
|
|
"tensor([1., 1., 1., 1.])"
|
|
|
|
|
|
]
|
|
|
|
|
|
},
|
|
|
|
|
|
"execution_count": 24,
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 24
|
2026-03-14 11:51:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:08:55.569785595Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:08:55.211139587Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-14 11:51:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"# 对非标量调用backward需要传入一个gradient参数,该参数指定微分函数关于self的梯度。\n",
|
|
|
|
|
|
"# 本例只想求偏导数的和,所以传递一个1的梯度是合适的\n",
|
|
|
|
|
|
"x.grad.zero_()\n",
|
|
|
|
|
|
"y = x * x\n",
|
|
|
|
|
|
"# 等价于y.backward(torch.ones(len(x)))\n",
|
|
|
|
|
|
"print(y)\n",
|
|
|
|
|
|
"y.sum().backward()\n",
|
|
|
|
|
|
"x.grad"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "f9207619bd4b3de8",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"name": "stdout",
|
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
|
"text": [
|
|
|
|
|
|
"tensor([0., 1., 4., 9.], grad_fn=<MulBackward0>)\n"
|
|
|
|
|
|
]
|
|
|
|
|
|
},
|
|
|
|
|
|
{
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"text/plain": [
|
|
|
|
|
|
"tensor([0., 2., 4., 6.])"
|
|
|
|
|
|
]
|
|
|
|
|
|
},
|
|
|
|
|
|
"execution_count": 25,
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 25
|
2026-03-14 11:51:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:08:55.706654768Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:08:55.573574038Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-14 11:51:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": "torch.ones(len(x))",
|
|
|
|
|
|
"id": "409c14c230570859",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"text/plain": [
|
|
|
|
|
|
"tensor([1., 1., 1., 1.])"
|
|
|
|
|
|
]
|
|
|
|
|
|
},
|
|
|
|
|
|
"execution_count": 26,
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 26
|
2026-03-14 11:51:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:08:55.928404498Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:08:55.717433023Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-14 11:51:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"x.grad.zero_()\n",
|
|
|
|
|
|
"y=x*x\n",
|
|
|
|
|
|
"u=y.detach()\n",
|
|
|
|
|
|
"z=u*x\n",
|
|
|
|
|
|
"z.sum().backward()\n",
|
|
|
|
|
|
"x.grad==u"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "521b948fe0683b12",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"text/plain": [
|
|
|
|
|
|
"tensor([True, True, True, True])"
|
|
|
|
|
|
]
|
|
|
|
|
|
},
|
|
|
|
|
|
"execution_count": 27,
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 27
|
2026-03-14 11:51:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:08:56.100782679Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:08:55.980278337Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-14 11:51:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"x.grad.zero_()\n",
|
|
|
|
|
|
"y.sum().backward()\n",
|
|
|
|
|
|
"x.grad==2*x"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "b040beecf0632315",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"text/plain": [
|
|
|
|
|
|
"tensor([True, True, True, True])"
|
|
|
|
|
|
]
|
|
|
|
|
|
},
|
|
|
|
|
|
"execution_count": 28,
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 28
|
2026-03-14 11:51:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:08:56.391506058Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:08:56.102643728Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-14 11:51:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"from torch.distributions import multinomial\n",
|
|
|
|
|
|
"fair_probs=torch.ones([6])\n",
|
|
|
|
|
|
"fair_probs"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "4e6ec763dbea5aa3",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"text/plain": [
|
|
|
|
|
|
"tensor([1., 1., 1., 1., 1., 1.])"
|
|
|
|
|
|
]
|
|
|
|
|
|
},
|
|
|
|
|
|
"execution_count": 29,
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 29
|
2026-03-14 11:51:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:08:56.646489725Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:08:56.454875071Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-14 11:51:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": "multinomial.Multinomial(1, fair_probs).sample()",
|
|
|
|
|
|
"id": "f12d5e85bc6ab595",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"text/plain": [
|
|
|
|
|
|
"tensor([0., 0., 0., 1., 0., 0.])"
|
|
|
|
|
|
]
|
|
|
|
|
|
},
|
|
|
|
|
|
"execution_count": 30,
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 30
|
2026-03-14 11:51:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:08:56.713816600Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:08:56.653507072Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-14 11:51:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"counts = multinomial.Multinomial(10, fair_probs).sample((500,))\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
"cum_counts = counts.cumsum(dim=0)\n",
|
|
|
|
|
|
"cum_counts.size()"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "b02f43376fd6f1fe",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"text/plain": [
|
|
|
|
|
|
"torch.Size([500, 6])"
|
|
|
|
|
|
]
|
|
|
|
|
|
},
|
|
|
|
|
|
"execution_count": 31,
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 31
|
2026-03-14 11:51:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:08:57.183010802Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:08:56.761045787Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-14 11:51:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"import matplotlib.pyplot as plt\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
"# 假设 estimates 是你的数据张量\n",
|
|
|
|
|
|
"estimates = cum_counts / cum_counts.sum(dim=1, keepdims=True)\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
"# 设置图形大小 (等效于 d2l.set_figsize)\n",
|
|
|
|
|
|
"plt.figure(figsize=(6, 4.5))\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
"# 绘制每条概率曲线\n",
|
|
|
|
|
|
"for i in range(6):\n",
|
|
|
|
|
|
" plt.plot(estimates[:, i].numpy(),\n",
|
|
|
|
|
|
" label=f\"P(die={i + 1})\") # 使用 f-string 更简洁\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
"# 添加理论概率水平线\n",
|
|
|
|
|
|
"plt.axhline(y=0.167, color='black', linestyle='dashed', label='Theoretical probability')\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
"# 设置坐标轴标签\n",
|
|
|
|
|
|
"plt.xlabel('Groups of experiments')\n",
|
|
|
|
|
|
"plt.ylabel('Estimated probability')\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
"# 添加图例\n",
|
|
|
|
|
|
"plt.legend()\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
"# 显示图形\n",
|
|
|
|
|
|
"plt.show()\n",
|
|
|
|
|
|
"#plt.savefig('dice_probability.png', bbox_inches='tight')"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "8b80daa4edd0b066",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"text/plain": [
|
|
|
|
|
|
"<Figure size 600x450 with 1 Axes>"
|
|
|
|
|
|
],
|
|
|
|
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiEAAAGZCAYAAABfZuECAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAADPyElEQVR4nOzdeVhU1f/A8fdszLDvi4CgAiou4L6Uu2ZqLmmlpX0ty/ppZbZZWmpqfeur5ZJmpWmopWWWuZR7mmYuKYYrKAqKguw7M8x6f38goyOCoixq5/U893mYe88999wB5n7mrDJAQhAEQRAEoYbJa7sAgiAIgiD8O4kgRBAEQRCEWiGCEEEQBEEQaoUIQgRBEARBqBUiCBEEQRAEoVaIIEQQBEEQhFohghBBEARBEGqFsrYLcLfy9/enoKCgtoshCIIgCPccZ2dnUlJSbppOBCE34O/vT3Jycm0XQxAEQRDuWQEBATcNREQQcgOlNSABAQGiNkQQBEEQKsHZ2Znk5ORben6KIKQCBQUFIggRBEEQhGoiOqYKgiAIglArRBAiCIIgCEKtEEGIIAiCIAi1QvQJEQRBABwcHPDy8kImk9V2UQThriVJEpmZmWi12irJTwQhgiD8q8lkMkaNGkW3bt1quyiCcM/4448/iIqKQpKkO8pHBCGCIPyrjRo1iq5du7J69Wri4uIwmUy1XSRBuGsplUoaN27M0KFDAfjmm2/uLL+qKJQgCMK9yNHRkW7durF69Wp+++232i6OINwTzp07B8CwYcP44Ycf7qhpRnRMFQThX8vT0xOAuLi4Wi6JINxbSv9nvLy87igfEYQIgvCvVdoJVTTBCELllP7P3GlH7rsiCHnppZdITExEp9Nx4MAB2rZte0vnDRs2DEmS+OWXX8ocmz59OikpKWi1WrZv305oaGhVF1sQBEEQhDtQ60HI0KFDmTNnDtOnT6dVq1YcPXqUrVu34u3tXeF5wcHBfPrpp+zZs6fMsbfffptXX32VMWPG0L59e4qKiti6dStqtbq6bkMQBEEQhEqq9SDkjTfe4Ouvv2bZsmXExsYyZswYtFotzz33XLnnyOVyVq5cyfvvv09CQkKZ46+99hoffvghGzZs4Pjx44wcORJ/f38effTRaryTiqkdHXB0c0WuVNRaGQRB+HdZsWIFkyZNqjBNYmIi48ePt76WJIlBgwZVd9HK5enpSVpaGgEBAbVWBqHm1GoQolKpaN26NTt27LDukySJHTt20LFjx3LPmzp1Kunp6TccGlS/fn3q1Kljk2d+fj4HDx4sN087OzucnZ1ttqr29vrvmfHnFuqEhVR53oIg/PuUztEgSRJ6vZ74+HimTJmCQlHyRSciIoJ+/foxf/78SuXr5+fH5s2bq6PIALzwwgvs2rWLvLw8JEnC1dXV5nhWVhYrVqxg+vTp1VYG4e5Rq0GIl5cXSqWStLQ0m/1paWn4+fnd8JwHH3yQ559/nhdeeOGGx0vPq0yekyZNIj8/37olJydX9lZuymw0AqBQilHRgiBUjc2bN+Pn50dYWBizZ89m2rRpTJgwAYBx48axZs0aioqKKpVnWloaBoOhOooLlMxMu2XLFj766KNy00RFRTFixAjc3d2rrRzC3aHWm2Mqw8nJiW+//ZYXXniBrKysKsv3448/xsXFxbpVRzWg2VjSk1ihUlV53oIgVC07e02tbJWl1+tJS0sjKSmJr776ih07djBw4EDkcjmPP/44GzdutEnv7e3Nhg0b0Gq1JCQkMHz48DJ5Xt8cExgYyOrVq8nJySErK4t169YRHBxc+Tf1is8++4yZM2dy4MCBctOcOnWKlJQUBg8efNvXEe4Ntfq1PDMzE5PJhK+vr81+X19fUlNTy6QPCQmhfv36Nv9YcnlJHGU0GmnUqJH1vOvz8PX1JSYm5oblMBgM1Rr5A5ivDGdSiiBEEO5qdvYaPv57V61ce1K77hh0xbd9vk6nw9PTk4iICNzc3Dh8+LDN8WXLluHv70/37t0xGo3Mnz8fHx+fcvNTKpVs3bqV/fv307lzZ0wmE5MnT2bLli1ERERgNBoZPnw4ixYtqrBcffv2Ze/evZW6l7///pvOnTvf8Yycwt2tVoMQo9FIdHQ0PXv2ZP369UDJmOOePXvy+eefl0kfFxdHs2bNbPZ9+OGHODs7M378eC5evIjRaOTy5cv07NmTo0ePAuDs7Ez79u358ssvq/+mylFaEyIXzTGCIFSDnj178vDDD7NgwQKCg4MxmUykp6dbj4eFhdGvXz/atm1rDU6ef/75CidqGzZsGHK5nNGjR1v3jRo1itzcXLp168b27dvZsGEDBw8erLBst9PEnZKSQsuWLSt9nnBvqfUn4pw5c1i+fDmHDx/m77//5rXXXsPR0ZGoqCgAli9fTnJyMu+++y56vZ6TJ0/anJ+bmwtgs3/evHlMnjyZ+Ph4EhMT+eCDD0hJSWHdunU1dVtllPYJUapq/S0XBKECBl0xk9p1r7VrV0b//v0pKChApVIhl8tZtWoV06ZNY+DAgej1epu04eHh1i9+pU6fPk1OTk65+UdGRhIaGkpBQYHNfo1GQ0hICNu3b6ewsJDCwsJKlftW6HQ6HBwcqjxf4e5S60/EH3/8EW9vb2bMmIGfnx8xMTH06dPHGsEHBQVhsVgqleesWbNwdHRk8eLFuLm5sXfvXvr06VPmn7ImlTbHiJoQQbj73UmTSE3atWsXY8eOxWAwkJKSgtlsBkqauh0dHVGpVBivfAG6HU5OTkRHRzNixIgyxzIyMgCqrTnGw8PDeg3h/nVXPBEXLlzIwoULb3ise/eKv5GMGjXqhvvff/993n///TsuW1UpbY4RfUIEQagqRUVF1sXErlXa/61JkybWZum4uDjrtAilzTENGzascATKkSNHGDZsGOnp6WVqQ0pVV3NMs2bN+OOPPyp9nnBvuadGx9zLRE2IIAg1JTMzk+joaDp16mTdd+bMGTZv3syiRYto164drVq1YsmSJRWugLpy5UoyMzNZv349nTp1ol69enTt2pXPPvvMOoqwsLCQc+fOVbgVF1+tWfL19bU28wA0b96cyMhIm2DI3t6e1q1bs23btqp+a4S7jAhCaojoEyIIQk1asmRJmWaUUaNGkZKSwu7du1m7di2LFy+26bx6PZ1OR5cuXUhKSmLt2rXExsaydOlSNBoN+fn5t1WuMWPGEBMTw5IlSwD4888/iYmJYeDAgdY0gwYNIikpqdJNOMK9SRKb7ebs7CxJkiQ5OztXWZ7PzPlImn18v9Rx6OBavz+xiU1sJVtwcLC0YsUKKTg4uNbLUtWbRqORLly4IHXo0KHWy1LZbf/+/dJTTz1V6+UQW/lbRf87lXmGipqQGnK1JkT0CREEofoVFxczcuRIvLy8arsoleLp6cnatWv5/vvva7soQg0QbQM1xGwq6bUupm0XBKGm7N69u7aLUGlZWVl88skntV0MoYaImpAaYl07RtSECIIgCAIggpAaUzo6RqFU1HJJBEEQBOHuIIKQGmISNSGCIAiCYEMEITXEIvqECIIgCIINEYTUEFETIgiCIAi2RBBSQyylfULEZGWCIAiCAIggpMZYa0JEc4wgCIIgACIIqTGiJkQQhJq2YsUKJk2aVGGaxMRExo8fb30tSRKDBg2q7qKVKzw8nIsXL+Lg4FBrZRBqjghCaojJWBqEiD4hgiDcuaioKCRJQpIk9Ho98fHxTJkyBYWiZBqAiIgI+vXrx/z58yuVr5+fH5s3b66OIuPu7s78+fOJi4tDq9Vy4cIFPvvsM1xcXKxpYmNjOXDgAG+88Ua1lEG4u4ggpIZYa0JEc4wgCFVk8+bN+Pn5ERYWxuzZs5k2bRoTJkwAYNy4caxZs4aioqJK5ZmWlobBYKiO4uLv74+/vz9vvfUWzZo149lnn6VPnz4sXbrUJl1UVBRjx461BlTC/UsEITXk6ugYEYQIwt3OwUFdK1tl6fV60tLSSEpK4quvvmLHjh0MHDgQuVzO448/zsaNG23Se3t7s2HDBrRaLQkJCQwfPrxMntc3xwQGBrJ69WpycnLIyspi3bp1BAc
|
|
|
|
|
|
},
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
"output_type": "display_data",
|
|
|
|
|
|
"jetTransient": {
|
|
|
|
|
|
"display_id": null
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 32
|
2026-03-14 11:51:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:08:57.312713568Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:08:57.273001160Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-14 11:51:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"import numpy as np\n",
|
|
|
|
|
|
"class Timer:\n",
|
|
|
|
|
|
" \"\"\"记录多次运行时间\"\"\"\n",
|
|
|
|
|
|
" def __init__(self):\n",
|
|
|
|
|
|
" self.times = []\n",
|
|
|
|
|
|
" self.start()\n",
|
|
|
|
|
|
" def start(self):\n",
|
|
|
|
|
|
" \"\"\"启动计时器\"\"\"\n",
|
|
|
|
|
|
" self.tik = time.time()\n",
|
|
|
|
|
|
" def stop(self):\n",
|
|
|
|
|
|
" \"\"\"停止计时器并将时间记录在列表中\"\"\"\n",
|
|
|
|
|
|
" self.times.append(time.time() - self.tik)\n",
|
|
|
|
|
|
" return self.times[-1]\n",
|
|
|
|
|
|
" def avg(self):\n",
|
|
|
|
|
|
" \"\"\"返回平均时间\"\"\"\n",
|
|
|
|
|
|
" return sum(self.times) / len(self.times)\n",
|
|
|
|
|
|
" def sum(self):\n",
|
|
|
|
|
|
" \"\"\"返回时间总和\"\"\"\n",
|
|
|
|
|
|
" return sum(self.times)\n",
|
|
|
|
|
|
" def cumsum(self):\n",
|
|
|
|
|
|
" \"\"\"返回累计时间\"\"\"\n",
|
|
|
|
|
|
" return np.array(self.times).cumsum().tolist()"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "4bdbb4999907154a",
|
|
|
|
|
|
"outputs": [],
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"execution_count": 33
|
2026-03-14 11:51:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:08:57.504357761Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:08:57.314864132Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-14 11:51:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"n = 10000\n",
|
|
|
|
|
|
"a = torch.ones([n])\n",
|
|
|
|
|
|
"b = torch.ones([n])\n",
|
|
|
|
|
|
"c=torch.zeros(n)\n",
|
|
|
|
|
|
"timer = Timer()\n",
|
|
|
|
|
|
"for i in range(n):\n",
|
|
|
|
|
|
" c[i]=a[i]+b[i]\n",
|
|
|
|
|
|
"f'{timer.stop():.5f} sec'"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "c6f71622e2cc578a",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"text/plain": [
|
|
|
|
|
|
"'0.11861 sec'"
|
|
|
|
|
|
]
|
|
|
|
|
|
},
|
|
|
|
|
|
"execution_count": 34,
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 34
|
2026-03-14 11:51:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:08:57.608980105Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:08:57.525327Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-14 11:51:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"timer.start()\n",
|
|
|
|
|
|
"d=a+b\n",
|
|
|
|
|
|
"f'{timer.stop():.5f} sec'"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "2578c79b1214a79f",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"text/plain": [
|
|
|
|
|
|
"'0.00074 sec'"
|
|
|
|
|
|
]
|
|
|
|
|
|
},
|
|
|
|
|
|
"execution_count": 35,
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 35
|
2026-03-14 11:51:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:08:57.665429699Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:08:57.610507099Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-14 11:51:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"import math\n",
|
|
|
|
|
|
"def normal(x, mu, sigma):\n",
|
|
|
|
|
|
" p = 1 / math.sqrt(2 * math.pi * sigma**2)\n",
|
|
|
|
|
|
" return p * np.exp(-0.5 / sigma**2 * (x - mu)**2)"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "fd17fdbe38a5f79",
|
|
|
|
|
|
"outputs": [],
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"execution_count": 36
|
2026-03-14 11:51:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:08:57.729971931Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:08:57.669321828Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-14 11:51:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"from matplotlib_inline import backend_inline\n",
|
|
|
|
|
|
"def use_svg_display(): #@save\n",
|
|
|
|
|
|
" \"\"\"使用svg格式在Jupyter中显示绘图\"\"\"\n",
|
|
|
|
|
|
" backend_inline.set_matplotlib_formats('svg')\n",
|
|
|
|
|
|
"def set_figsize(figsize=(3.5, 2.5)): #@save\n",
|
|
|
|
|
|
" \"\"\"设置matplotlib的图表大小\"\"\"\n",
|
|
|
|
|
|
" use_svg_display()\n",
|
|
|
|
|
|
" plt.rcParams['figure.figsize'] = figsize\n",
|
|
|
|
|
|
"def set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend):\n",
|
|
|
|
|
|
" \"\"\"设置matplotlib的轴\"\"\"\n",
|
|
|
|
|
|
" axes.set_xlabel(xlabel)\n",
|
|
|
|
|
|
" axes.set_ylabel(ylabel)\n",
|
|
|
|
|
|
" axes.set_xscale(xscale)\n",
|
|
|
|
|
|
" axes.set_yscale(yscale)\n",
|
|
|
|
|
|
" axes.set_xlim(xlim)\n",
|
|
|
|
|
|
" axes.set_ylim(ylim)\n",
|
|
|
|
|
|
" if legend:\n",
|
|
|
|
|
|
" axes.legend(legend)\n",
|
|
|
|
|
|
" axes.grid()\n",
|
|
|
|
|
|
"def plot(X, Y=None, xlabel=None, ylabel=None, legend=None, xlim=None,\n",
|
|
|
|
|
|
"ylim=None, xscale='linear', yscale='linear',\n",
|
|
|
|
|
|
"fmts=('-', 'm--', 'g-.', 'r:'), figsize=(3.5, 2.5), axes=None):\n",
|
|
|
|
|
|
" \"\"\"绘制数据点\"\"\"\n",
|
|
|
|
|
|
" if legend is None:\n",
|
|
|
|
|
|
" legend = []\n",
|
|
|
|
|
|
" set_figsize(figsize)\n",
|
|
|
|
|
|
" axes = axes if axes else plt.gca()\n",
|
|
|
|
|
|
" # 如果X有一个轴,输出True\n",
|
|
|
|
|
|
" def has_one_axis(X):\n",
|
|
|
|
|
|
" return (hasattr(X, \"ndim\") and X.ndim == 1 or isinstance(X, list)\n",
|
|
|
|
|
|
"and not hasattr(X[0], \"__len__\"))\n",
|
|
|
|
|
|
" if has_one_axis(X):\n",
|
|
|
|
|
|
" X = [X]\n",
|
|
|
|
|
|
" if Y is None:\n",
|
|
|
|
|
|
" X, Y = [[]] * len(X), X\n",
|
|
|
|
|
|
" elif has_one_axis(Y):\n",
|
|
|
|
|
|
" Y = [Y]\n",
|
|
|
|
|
|
" if len(X) != len(Y):\n",
|
|
|
|
|
|
" X = X * len(Y)\n",
|
|
|
|
|
|
" axes.cla()\n",
|
|
|
|
|
|
" for x, y, fmt in zip(X, Y, fmts):\n",
|
|
|
|
|
|
" if len(x):\n",
|
|
|
|
|
|
" axes.plot(x, y, fmt)\n",
|
|
|
|
|
|
" else:\n",
|
|
|
|
|
|
" axes.plot(y, fmt)\n",
|
|
|
|
|
|
" set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend)"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "82158a69cba14da0",
|
|
|
|
|
|
"outputs": [],
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"execution_count": 37
|
2026-03-14 11:51:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:08:58.000767154Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:08:57.732355844Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-14 11:51:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"# 再次使用numpy进行可视化\n",
|
|
|
|
|
|
"x = np.arange(-7, 7, 0.01)\n",
|
|
|
|
|
|
"# 均值和标准差对\n",
|
|
|
|
|
|
"params = [(0, 1), (0, 2), (3, 1)]\n",
|
|
|
|
|
|
"plot(x, [normal(x, mu, sigma) for mu, sigma in params], xlabel='x',\n",
|
|
|
|
|
|
"ylabel='p(x)', figsize=(4.5, 2.5),\n",
|
|
|
|
|
|
"legend=[f'mean {mu}, std {sigma}' for mu, sigma in params])"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "f69ac10ebc3d13d8",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"text/plain": [
|
|
|
|
|
|
"<Figure size 450x250 with 1 Axes>"
|
|
|
|
|
|
],
|
|
|
|
|
|
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"302.08125pt\" height=\"183.35625pt\" viewBox=\"0 0 302.08125 183.35625\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2026-03-19T14:08:57.923771</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.7.2, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 183.35625 \nL 302.08125 183.35625 \nL 302.08125 0 \nL 0 0 \nz\n\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 43.78125 145.8 \nL 294.88125 145.8 \nL 294.88125 7.2 \nL 43.78125 7.2 \nz\n\"/>\n </g>\n <g id=\"matplotlib.axis_1\">\n <g id=\"xtick_1\">\n <g id=\"line2d_1\">\n <path d=\"M 71.511736 145.8 \nL 71.511736 7.2 \n\" clip-path=\"url(#p82bac51bd8)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_2\">\n <defs>\n <path id=\"m55ef5c2f68\" d=\"M 0 0 \nL 0 3.5 \n\" style=\"stroke: #ffffff; stroke-width: 0.8\"/>\n </defs>\n <g>\n <use xlink:href=\"#m55ef5c2f68\" x=\"71.511736\" y=\"145.8\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_1\">\n <!-- −6 -->\n <g style=\"fill: #ffffff\" transform=\"translate(64.140642 160.398438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-2212\" d=\"M 678 2272 \nL 4684 2272 \nL 4684 1741 \nL 678 1741 \nL 678 2272 \nz\n\" transform=\"scale(0.015625)\"/>\n <path id=\"DejaVuSans-36\" d=\"M 2113 2584 \nQ 1688 2584 1439 2293 \nQ 1191 2003 1191 1497 \nQ 1191 994 1439 701 \nQ 1688 409 2113 409 \nQ 2538 409 2786 701 \nQ 3034 994 3034 1497 \nQ 3034 2003 2786 2293 \nQ 2538 2584 2113 2584 \nz\nM 3366 4563 \nL 3366 3988 \nQ 3128 4100 2886 4159 \nQ 2644 4219 2406 4219 \nQ 1781 4219 1451 3797 \nQ 1122 3375 1075 2522 \nQ 1259 2794 1537 2939 \nQ 1816 3084 2150 3084 \nQ 2853 3084 3261 2657 \nQ 3669 2231 3669 1497 \nQ 3669 778 3244 343 \nQ 2819 -91 2113 -91 \nQ 1303 -91 875 529 \nQ 447 1150 447 2328 \nQ 447 3434 972 4092 \nQ 1497 4750 2381 4750 \nQ 2619 4750 2861 4703 \nQ 3103 4656 3366 4563 \nz\n\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-2212\"/>\n <use xlink:href=\"#DejaVuSans-36\" x=\"83.789062\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_2\">\n <g id=\"line2d_3\">\n <path d=\"M 104.145435 145.8 \nL 104.145435 7.2 \n\" clip-path=\"url(#p82bac51bd8)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_4\">\n <g>\n <use xlink:href=\"#m55ef5c2f68\" x=\"104.145435\" y=\"145.8\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_2\">\n <!-- −4 -->\n <g style=\"fill: #ffffff\" transform=\"translate(96.774342 160.398438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-34\" d=\"M 2419 4116 \nL 825 1625 \nL 2419 1625 \nL 2419 4116 \nz\nM 2253 4666 \nL 3047 4666 \nL 3047 1625 \nL 3713 1625 \nL 3713 1100 \nL 3047 1100 \nL 3047 0 \nL 2419 0 \nL 2419 1100 \nL 313 1100 \nL 313 1709 \nL 2253 4666 \nz\n\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-2212\"/>\n <use xlink:href=\"#DejaVuSans-34\" x=\"83.789062\"/>\n </g>\n </g>\n </g>\n <g id=\"xt
|
|
|
|
|
|
},
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
"output_type": "display_data",
|
|
|
|
|
|
"jetTransient": {
|
|
|
|
|
|
"display_id": null
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 38
|
2026-03-14 11:51:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:08:58.045153274Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:08:58.022867228Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-14 11:51:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"#注意一下matmul做向量乘上矩阵的时候不用考虑转置的情况\n",
|
|
|
|
|
|
"def synthetic_data(w, b, num_examples): #@save\n",
|
|
|
|
|
|
" \"\"\"生成y=Xw+b+噪声\"\"\"\n",
|
|
|
|
|
|
" X = torch.normal(0, 1, (num_examples, len(w)))\n",
|
|
|
|
|
|
" y = torch.matmul(X, w) + b\n",
|
|
|
|
|
|
" y += torch.normal(0, 0.01, y.shape)\n",
|
|
|
|
|
|
" return X, y.reshape((-1, 1))\n"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "7ed837bdd2b3a26d",
|
|
|
|
|
|
"outputs": [],
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"execution_count": 39
|
2026-03-14 11:51:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:08:58.199486802Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:08:58.048763885Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-14 11:51:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"true_w = torch.tensor([2, -3.4])\n",
|
|
|
|
|
|
"true_b = 4.2\n",
|
|
|
|
|
|
"features, labels = synthetic_data(true_w, true_b, 1000)"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "5ec2e204a6fd5cb2",
|
|
|
|
|
|
"outputs": [],
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"execution_count": 40
|
2026-03-14 11:51:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:08:58.497612094Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:08:58.215279467Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-14 11:51:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"set_figsize()\n",
|
|
|
|
|
|
"plt.scatter(features[:, (1)].detach().numpy(), labels.detach().numpy(), 1)"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "38213d46b3d9900d",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"text/plain": [
|
|
|
|
|
|
"<matplotlib.collections.PathCollection at 0x7f645fd37d50>"
|
|
|
|
|
|
]
|
|
|
|
|
|
},
|
|
|
|
|
|
"execution_count": 41,
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
|
},
|
|
|
|
|
|
{
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"text/plain": [
|
|
|
|
|
|
"<Figure size 350x250 with 1 Axes>"
|
|
|
|
|
|
],
|
|
|
|
|
|
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"229.425pt\" height=\"169.678125pt\" viewBox=\"0 0 229.425 169.678125\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2026-03-19T14:08:58.330374</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.7.2, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 169.678125 \nL 229.425 169.678125 \nL 229.425 0 \nL 0 0 \nz\n\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 26.925 145.8 \nL 222.225 145.8 \nL 222.225 7.2 \nL 26.925 7.2 \nz\n\"/>\n </g>\n <g id=\"PathCollection_1\">\n <defs>\n <path id=\"m6048dcfac9\" d=\"M 0 0.5 \nC 0.132602 0.5 0.25979 0.447317 0.353553 0.353553 \nC 0.447317 0.25979 0.5 0.132602 0.5 0 \nC 0.5 -0.132602 0.447317 -0.25979 0.353553 -0.353553 \nC 0.25979 -0.447317 0.132602 -0.5 0 -0.5 \nC -0.132602 -0.5 -0.25979 -0.447317 -0.353553 -0.353553 \nC -0.447317 -0.25979 -0.5 -0.132602 -0.5 0 \nC -0.5 0.132602 -0.447317 0.25979 -0.353553 0.353553 \nC -0.25979 0.447317 -0.132602 0.5 0 0.5 \nz\n\" style=\"stroke: #8dd3c7\"/>\n </defs>\n <g clip-path=\"url(#p567f518a09)\">\n <use xlink:href=\"#m6048dcfac9\" x=\"148.601687\" y=\"80.409012\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m6048dcfac9\" x=\"130.05288\" y=\"67.476586\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m6048dcfac9\" x=\"127.409953\" y=\"74.452727\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m6048dcfac9\" x=\"143.990382\" y=\"72.86308\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m6048dcfac9\" x=\"130.051176\" y=\"78.998008\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m6048dcfac9\" x=\"127.778786\" y=\"80.782732\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m6048dcfac9\" x=\"150.806525\" y=\"101.991179\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m6048dcfac9\" x=\"152.645613\" y=\"90.59075\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m6048dcfac9\" x=\"171.89944\" y=\"91.904379\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m6048dcfac9\" x=\"163.638463\" y=\"83.919719\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m6048dcfac9\" x=\"138.020287\" y=\"93.843834\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m6048dcfac9\" x=\"148.93979\" y=\"77.746831\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m6048dcfac9\" x=\"178.192872\" y=\"115.514121\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m6048dcfac9\" x=\"123.237041\" y=\"92.408873\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m6048dcfac9\" x=\"118.890825\" y=\"64.442598\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m6048dcfac9\" x=\"161.626012\" y=\"102.937154\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m6048dcfac9\" x=\"128.424616\" y=\"78.019978\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m6048dcfac9\" x=\"134.645774\" y=\"72.883474\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m6048dcfac9\" x=\"132.180296\" y=\"89.076328\" style=\"fill: #8dd3c7; stroke: #8dd3c7\"/>\n <use xlink:href=\"#m6048dcfac9\" x=\"113.72
|
|
|
|
|
|
},
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
"output_type": "display_data",
|
|
|
|
|
|
"jetTransient": {
|
|
|
|
|
|
"display_id": null
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 41
|
2026-03-14 11:51:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:08:58.613128054Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:08:58.566128920Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-14 11:51:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"w=torch.normal(0,0.01,size=(2,1),requires_grad=True)\n",
|
|
|
|
|
|
"b=torch.zeros(1,requires_grad=True)\n",
|
|
|
|
|
|
"def linreg(X, w, b):\n",
|
|
|
|
|
|
" return torch.matmul(X,w)+b\n",
|
|
|
|
|
|
"def squared_loss(y_hat,y):\n",
|
|
|
|
|
|
" return (y_hat-y.reshape(y_hat.shape))**2/2\n",
|
|
|
|
|
|
"def sgd(params,lr,batch_size):\n",
|
|
|
|
|
|
" with torch.no_grad():\n",
|
|
|
|
|
|
" for param in params:\n",
|
|
|
|
|
|
" param-=lr*param.grad/batch_size\n",
|
|
|
|
|
|
" param.grad.zero_()\n",
|
|
|
|
|
|
"lr = 0.03\n",
|
|
|
|
|
|
"num_epochs =20\n",
|
|
|
|
|
|
"net = linreg\n",
|
|
|
|
|
|
"loss = squared_loss"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "12166e1bc3ddd695",
|
|
|
|
|
|
"outputs": [],
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"execution_count": 42
|
2026-03-14 11:51:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:08:58.733166644Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:08:58.676875700Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-14 11:51:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"import random\n",
|
|
|
|
|
|
"def data_iter(batch_size, features, labels):\n",
|
|
|
|
|
|
" num_examples = len(features)\n",
|
|
|
|
|
|
" indices = list(range(num_examples))\n",
|
|
|
|
|
|
" # 这些样本是随机读取的,没有特定的顺序\n",
|
|
|
|
|
|
" random.shuffle(indices)\n",
|
|
|
|
|
|
" for i in range(0, num_examples, batch_size):\n",
|
|
|
|
|
|
" batch_indices = torch.tensor(\n",
|
|
|
|
|
|
" indices[i: min(i + batch_size, num_examples)])\n",
|
|
|
|
|
|
" yield features[batch_indices], labels[batch_indices]"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "f3b7ee9f326bc687",
|
|
|
|
|
|
"outputs": [],
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"execution_count": 43
|
2026-03-14 11:51:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:08:58.810488693Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:08:58.736278875Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-14 11:51:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"batch_size =10\n",
|
|
|
|
|
|
"for X,y in data_iter(batch_size, features, labels):\n",
|
|
|
|
|
|
" print(X,'\\n',y)\n",
|
|
|
|
|
|
" break"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "f386e12d65afff2e",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"name": "stdout",
|
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
|
"text": [
|
|
|
|
|
|
"tensor([[ 0.0169, -0.3729],\n",
|
|
|
|
|
|
" [ 0.2105, -1.0088],\n",
|
|
|
|
|
|
" [-1.3548, 0.9556],\n",
|
|
|
|
|
|
" [-0.0298, 0.4827],\n",
|
|
|
|
|
|
" [ 1.5137, -2.4433],\n",
|
|
|
|
|
|
" [ 0.0029, 0.6444],\n",
|
|
|
|
|
|
" [ 0.5705, 1.1589],\n",
|
|
|
|
|
|
" [ 0.3421, -0.5686],\n",
|
|
|
|
|
|
" [-0.8094, -0.8650],\n",
|
|
|
|
|
|
" [ 0.3897, -1.3542]]) \n",
|
|
|
|
|
|
" tensor([[ 5.5058],\n",
|
|
|
|
|
|
" [ 8.0436],\n",
|
|
|
|
|
|
" [-1.7624],\n",
|
|
|
|
|
|
" [ 2.4937],\n",
|
|
|
|
|
|
" [15.5411],\n",
|
|
|
|
|
|
" [ 2.0171],\n",
|
|
|
|
|
|
" [ 1.3941],\n",
|
|
|
|
|
|
" [ 6.8340],\n",
|
|
|
|
|
|
" [ 5.5124],\n",
|
|
|
|
|
|
" [ 9.5840]])\n"
|
|
|
|
|
|
]
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 44
|
2026-03-14 11:51:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:08:59.082220079Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:08:58.813592858Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-14 11:51:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"for epoch in range(num_epochs):\n",
|
|
|
|
|
|
" for X, y in data_iter(batch_size, features, labels):\n",
|
|
|
|
|
|
" l=loss(net(X, w, b), y)\n",
|
|
|
|
|
|
" l.sum().backward()\n",
|
|
|
|
|
|
" sgd([w,b],lr,batch_size)\n",
|
|
|
|
|
|
" with torch.no_grad():\n",
|
|
|
|
|
|
" train_l =loss(net(features, w, b), labels)\n",
|
|
|
|
|
|
" print(f'epoch {epoch+1}, train loss: {float(train_l.mean()):3f}')"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "8888ab6adcec36f1",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"name": "stdout",
|
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
|
"text": [
|
|
|
|
|
|
"epoch 1, train loss: 0.040479\n",
|
|
|
|
|
|
"epoch 2, train loss: 0.000158\n",
|
|
|
|
|
|
"epoch 3, train loss: 0.000052\n",
|
|
|
|
|
|
"epoch 4, train loss: 0.000052\n",
|
|
|
|
|
|
"epoch 5, train loss: 0.000052\n",
|
|
|
|
|
|
"epoch 6, train loss: 0.000052\n",
|
|
|
|
|
|
"epoch 7, train loss: 0.000052\n",
|
|
|
|
|
|
"epoch 8, train loss: 0.000052\n",
|
|
|
|
|
|
"epoch 9, train loss: 0.000052\n",
|
|
|
|
|
|
"epoch 10, train loss: 0.000052\n",
|
|
|
|
|
|
"epoch 11, train loss: 0.000052\n",
|
|
|
|
|
|
"epoch 12, train loss: 0.000052\n",
|
|
|
|
|
|
"epoch 13, train loss: 0.000052\n",
|
|
|
|
|
|
"epoch 14, train loss: 0.000052\n",
|
|
|
|
|
|
"epoch 15, train loss: 0.000052\n",
|
|
|
|
|
|
"epoch 16, train loss: 0.000052\n",
|
|
|
|
|
|
"epoch 17, train loss: 0.000052\n",
|
|
|
|
|
|
"epoch 18, train loss: 0.000052\n",
|
|
|
|
|
|
"epoch 19, train loss: 0.000052\n",
|
|
|
|
|
|
"epoch 20, train loss: 0.000052\n"
|
|
|
|
|
|
]
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 45
|
2026-03-14 11:51:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:08:59.178712680Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:08:59.105484540Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-14 11:51:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"print(f'w的估计误差: {true_w - w.reshape(true_w.shape)}')\n",
|
|
|
|
|
|
"print(f'b的估计误差: {true_b - b}')"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "8199439fa7f26309",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"name": "stdout",
|
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
|
"text": [
|
|
|
|
|
|
"w的估计误差: tensor([ 0.0010, -0.0003], grad_fn=<SubBackward0>)\n",
|
|
|
|
|
|
"b的估计误差: tensor([0.0002], grad_fn=<RsubBackward1>)\n"
|
|
|
|
|
|
]
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 46
|
2026-03-14 11:51:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:08:59.278946646Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:08:59.180914818Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-14 11:51:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"from torch.utils import data\n",
|
|
|
|
|
|
"true_w = torch.tensor([2,-3.4])\n",
|
|
|
|
|
|
"true_b = 4.2\n",
|
|
|
|
|
|
"features,labels=synthetic_data(true_w, true_b, 1000)\n",
|
|
|
|
|
|
"def load_array(data_arrays,batch_size,is_train=True):\n",
|
|
|
|
|
|
" dataset = data.TensorDataset(*data_arrays)\n",
|
|
|
|
|
|
" return data.DataLoader(dataset,batch_size,shuffle=is_train)\n",
|
|
|
|
|
|
"batch_size = 10\n",
|
|
|
|
|
|
"data_iter = load_array((features,labels),batch_size)"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "560d537dcbb5a335",
|
|
|
|
|
|
"outputs": [],
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"execution_count": 47
|
2026-03-14 11:51:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:08:59.328289201Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:08:59.284666591Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-14 11:51:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"from torch import nn\n",
|
|
|
|
|
|
"net = nn.Sequential(nn.Linear(2, 1))\n",
|
|
|
|
|
|
"net[0].weight.data.normal_(0,0.001)\n",
|
|
|
|
|
|
"net[0].bias.data.fill_(0)"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "c54fe059d6fd20de",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"text/plain": [
|
|
|
|
|
|
"tensor([0.])"
|
|
|
|
|
|
]
|
|
|
|
|
|
},
|
|
|
|
|
|
"execution_count": 48,
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 48
|
2026-03-14 11:51:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:09:00.631251025Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:08:59.342163794Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-14 11:51:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"loss = nn.MSELoss()\n",
|
|
|
|
|
|
"trainer = torch.optim.SGD(net.parameters(), lr=0.01)\n",
|
|
|
|
|
|
"num_epochs = 3\n",
|
|
|
|
|
|
"for epoch in range(num_epochs):\n",
|
|
|
|
|
|
" for X, y in data_iter:\n",
|
|
|
|
|
|
" l = loss(net(X) ,y)\n",
|
|
|
|
|
|
" trainer.zero_grad()\n",
|
|
|
|
|
|
" l.backward()\n",
|
|
|
|
|
|
" trainer.step()\n",
|
|
|
|
|
|
" l = loss(net(features), labels)\n",
|
|
|
|
|
|
" print(f'epoch {epoch + 1}, loss {l:f}')\n"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "e8a44851125b7cc6",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"name": "stdout",
|
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
|
"text": [
|
|
|
|
|
|
"epoch 1, loss 0.546318\n",
|
|
|
|
|
|
"epoch 2, loss 0.009022\n",
|
|
|
|
|
|
"epoch 3, loss 0.000254\n"
|
|
|
|
|
|
]
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 49
|
2026-03-15 06:42:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:09:00.829165029Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:09:00.673423450Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-15 06:42:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"import torchvision\n",
|
|
|
|
|
|
"from torchvision import transforms\n",
|
|
|
|
|
|
"trans =transforms.ToTensor()\n",
|
|
|
|
|
|
"mnist_train = torchvision.datasets.FashionMNIST(root=\"./data\",train=True,transform=trans,download=False)\n",
|
|
|
|
|
|
"mnist_test = torchvision.datasets.FashionMNIST(root=\"./data\",train=False,transform=trans,download=False)\n"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "bd4e8a65ccd03177",
|
|
|
|
|
|
"outputs": [],
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"execution_count": 50
|
2026-03-15 06:42:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:09:00.996874134Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:09:00.830172776Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-15 06:42:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"use_svg_display()\n",
|
|
|
|
|
|
"len(mnist_train),len(mnist_test)"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "ed2c915af7f6a76a",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"text/plain": [
|
|
|
|
|
|
"(60000, 10000)"
|
|
|
|
|
|
]
|
|
|
|
|
|
},
|
|
|
|
|
|
"execution_count": 51,
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 51
|
2026-03-15 06:42:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:09:01.048577861Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:09:01.002950237Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-15 06:42:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": "mnist_train[0][0].shape",
|
|
|
|
|
|
"id": "4df1cbc292aa5981",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"text/plain": [
|
|
|
|
|
|
"torch.Size([1, 28, 28])"
|
|
|
|
|
|
]
|
|
|
|
|
|
},
|
|
|
|
|
|
"execution_count": 52,
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 52
|
2026-03-15 06:42:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:09:01.146967321Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:09:01.075154216Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-15 06:42:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"def get_fashion_mnist_labels(labels):\n",
|
|
|
|
|
|
" text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',\n",
|
|
|
|
|
|
"'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']\n",
|
|
|
|
|
|
" return [text_labels[int(i)] for i in labels]\n"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "332f3d6da0bafbe8",
|
|
|
|
|
|
"outputs": [],
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"execution_count": 53
|
2026-03-15 06:42:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:09:01.203692216Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:09:01.148474439Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-15 06:42:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"def show_images(imgs, num_rows, num_cols, titles=None, scale=1): #@save\n",
|
|
|
|
|
|
" \"\"\"绘制图像列表\"\"\"\n",
|
|
|
|
|
|
" figsize = (num_cols * scale, num_rows * scale)\n",
|
|
|
|
|
|
" _, axes = plt.subplots(num_rows, num_cols, figsize=figsize)\n",
|
|
|
|
|
|
" axes = axes.flatten()\n",
|
|
|
|
|
|
" for i, (ax, img) in enumerate(zip(axes, imgs)):\n",
|
|
|
|
|
|
" if torch.is_tensor(img):\n",
|
|
|
|
|
|
" # 图片张量\n",
|
|
|
|
|
|
" ax.imshow(img.numpy())\n",
|
|
|
|
|
|
" else:\n",
|
|
|
|
|
|
" # PIL图片\n",
|
|
|
|
|
|
" ax.imshow(img)\n",
|
|
|
|
|
|
" ax.axes.get_xaxis().set_visible(False)\n",
|
|
|
|
|
|
" ax.axes.get_yaxis().set_visible(False)\n",
|
|
|
|
|
|
" if titles:\n",
|
|
|
|
|
|
" ax.set_title(titles[i])\n",
|
|
|
|
|
|
" return axes"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "c83202fe9b0ab487",
|
|
|
|
|
|
"outputs": [],
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"execution_count": 54
|
2026-03-15 06:42:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:09:01.696988026Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:09:01.208873844Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-15 06:42:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))\n",
|
|
|
|
|
|
"print(X.shape)\n",
|
|
|
|
|
|
"show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "cf4abd8370d55416",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"name": "stdout",
|
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
|
"text": [
|
|
|
|
|
|
"torch.Size([18, 1, 28, 28])\n"
|
|
|
|
|
|
]
|
|
|
|
|
|
},
|
|
|
|
|
|
{
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"text/plain": [
|
|
|
|
|
|
"<Figure size 900x200 with 18 Axes>"
|
|
|
|
|
|
],
|
|
|
|
|
|
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"524.634446pt\" height=\"137.375483pt\" viewBox=\"0 0 524.634446 137.375483\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2026-03-19T14:09:01.564207</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.7.2, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 137.375483 \nL 524.634446 137.375483 \nL 524.634446 0 \nL 0 0 \nz\n\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 15.234446 69.695483 \nL 62.611804 69.695483 \nL 62.611804 22.318125 \nL 15.234446 22.318125 \nz\n\"/>\n </g>\n <g clip-path=\"url(#pc37636f183)\">\n <image xlink:href=\"data:image/png;base64,\niVBORw0KGgoAAAANSUhEUgAAAEIAAABCCAYAAADjVADoAAAWB0lEQVR4nO2bW48cyZmeny8iIzMr69DVJ5IzzZnhzIhzsHYl2StrsVrBWBu+tw34Bxi+N3zhH+GfYPjavjAM+8I3K3vXWNsLwYBgaXd2vSOMOGdyyD6yWV2nPERGhC8iK7ub5MgCpBF1MR9AFJqVGZHxxhvf4f2y5B/KPw18bagX/QC/LfY1EJ19DURnya86gCoKpBiA1ogxIALWEpxHshRSA94TygqcgxDAB0Jd49frX8cafi32qwOxNcHf3MHnCXaSEgTMskWsox2n1NMEXQfyswqpHeIc4gJqtsCXZQTmt8B+KSAky5A0jTu/NQatQCmCFqr9IdWuwRmwIwFAVwbVBuxQsGNBNVBPC1QL4gLiIb0Yke1OwF+ZxznwHTDhqc//n3XXiX/q+sYS6ga8IzQ2srOuCW37FBAiv3hCEdRrt2kOtpi9mfH4+5Z02DAc1GTGsp0fcpAvMeIZaIsSjw8KjzDUNSNdY4Nm3ua0QffDHpVjDlcTQhB8EEIQFuucpjLd80DwAlaBF5AA8vTiN58CHsQLYq9cJJAfK0YPA0npGZxa9NqiPz3CnZ4+BcSXLB5RiNaIVvitgnI/ZXUg/P7bn/BGccad/IyJKtlP5uyqeNaN+P5TEzACBnCA7R5ad8955DI+aW7gOn9tg+aj6iaH1RYAHsF6zdzmNE6TKI+S65vlQxzMOk0bFNZpSmvwAVQ3z1k+RUJCsoqbYFaa4WmBzFKCc+BdB8RTTNB7u/hXbtFOM87fyainQr3ncTuW4VZFqlpmbcF7y1dRBBLlyFSLJj6oFk8mLVo8RhxGXD+2kstzsPYpF23RgRP/v3SGTLcoQn/txFTxXp5l7IZ9MztgbnNarRgkNo7ZAecPYDYaErxw4SLD8u8ekD05YOfnDemfvUdo2+cwYnuLxd0xq5sK9w9mfPvGIQeDGS+lF1y4AV+U28yaAed1QdUafEdtgER5RAKZbvsdTFXbf7dZjJIQ70NQdNeLxyhH0oE30PYaqFdB7DetG++hmtKGy0zg6pg38wV+Xxjqhm8UxwD8+eO7fLGYcprvc/tH6RUgREhevY3bmTB/c8Tjbyqaqed2UZIoR+lSju2EVZuxaDMaF28z2qEIiIRu0Q4lnlS7Kwu7vgAlHt1R3IXL86wl9EA5FLVPUBKw6J5pioC/4ig2bLNBk4inDYrGaXxQLNsMHwTVjVsnCQ/1NkYcRWK5MVxyb2cP7r6GLpsIhKQpF999mcff1FSv1/yjb/2Uka45t0Nql3BSj1ivduIZbKMzK0zDILEUScM4qTHKMdQ1SgKFajDi+uNhg6bypn/4zVFwnVO9aAfYoPFBokMNUIZ4fet1xyiHEY9HqF0EaaAblATaoDHK0TrFymY0XjNbD2haTaI9mWkx2vG4GpIlLa8WT7g7POHDN/Y5/v4UZSGRLEMNcuotRb3nGE1LDrInGHHM25yaOKmSy53f0FsRSDraJuJQEq4tVHXfmchXXFDkyl6juQ2aTBlMcDgUPrT4IB0wChTxk+hAr7KoDRoVAtZrXHdEPdJHIu8VQV9npA9CohyFrskyix1JB8Tbr+Nyw+xteOd3H5Bry8+WL3cOK/47yGcMtKV0htNmROs1jd/soKLxCa3X8cGuOLWBbih0ZEehmo4p7TWwAMYqOsSpXlOomioYFm6AQ7F0OTZoLtoBS5eRqkBmSnwQzpshpTPdkUh6PyESKLKGNHGkSUuetAwSy0ExI1MtI11Te4PRjtUIVAtJuz2gzTXtTst3pl8wawuOyjEAW2mMEqOkZi9ZslA5jrjwWTOg6WjrgoCA7f7e7MyGOajIjg0IRlq0BHSXTeU6evpbyQVTVbIKhpmqsSEhl4YqpNigqX1CojyFaiJjkGsgtF71cxoVfVGWtGS6pUgapqakUM2lX1IBb2J+kjQTQ5sLYlqUBMZJRVK4y0UAlTccNlv9GTfakeaRwiNdU+imp7kLiuN6QtkaEvFk3lJhqL3Biu5pninLWFXkyrKv56TiOHFj7jW3GKqaqV6Ri2VqYo5yI5kzT3M8isobbEjQ4qm8oe4YufEfQB+GE+XQeDLVsmeWZMpixJFKS+sU+ZmgG0jqicJlgkrj4se64qaZA7DoaDlvc0qXkqmWqVljxJF3A+4kS3b0sqfz2mecNiMql5AoR+0NDsfapygCS5dhg2YrKdFJwEjLrl6Ri+P9+oD3Vwe8nM341qBkqCpeSeYUAlWYUwXFIhiO2i2qYNjSqx6YKhh8UNigUeLZTxYUqsYH1SdtMdfxTFRFLhbnFcWxJ6kDicsEl0KatuyZRT+YC6o/xyNdk6mWTLXsJCuUeHKxXYSoycWSiiMXSxUMx+nk0jl1R8RIDLWFbrBeM9IVO8mSlc/4d+ffZ2FzPp7vcb4esFOUfDrZY5jU3ErnFLpmrCqGqu6dMMSQS+gSuRDQEsO3JqCIn4jvU/HNdxuzrWayDujKkzQTwadwc2vBd/L73Lc7vF/e7tlR6Jqb5oJC1aTiGOuyQzcOmIslF0uhLDuqxXf/d5RNOWymPKrjkSpUQ6Zs7xdumQveMGf8cPG7/PCP/y7Dh8SzCjwebfNo5yWCBpcFSALZjTUvb19wMJzxg+lHDFWNxqMFPC2+q5my
|
|
|
|
|
|
},
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
"output_type": "display_data",
|
|
|
|
|
|
"jetTransient": {
|
|
|
|
|
|
"display_id": null
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 55
|
2026-03-19 05:47:14 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:09:03.457306914Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:09:01.729639427Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-15 06:42:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"batch_size = 256\n",
|
|
|
|
|
|
"def get_dataloader_workers():\n",
|
|
|
|
|
|
" \"\"\"使用4个进程来读取数据\"\"\"\n",
|
|
|
|
|
|
" return 4\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
"train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,\n",
|
|
|
|
|
|
"num_workers=get_dataloader_workers())\n",
|
|
|
|
|
|
"timer = Timer()\n",
|
|
|
|
|
|
"for X, y in train_iter:\n",
|
|
|
|
|
|
" continue\n",
|
|
|
|
|
|
"f'{timer.stop():.2f} sec'"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "552769fbffc16142",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"text/plain": [
|
|
|
|
|
|
"'1.59 sec'"
|
|
|
|
|
|
]
|
|
|
|
|
|
},
|
|
|
|
|
|
"execution_count": 56,
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 56
|
2026-03-15 06:42:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:09:03.736459982Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:09:03.597245687Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-15 06:42:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"def load_data_fashion_mnist(batch_size, resize=None):\n",
|
|
|
|
|
|
" \"\"\"下载Fashion-MNIST数据集,然后将其加载到内存中\"\"\"\n",
|
|
|
|
|
|
" trans = [transforms.ToTensor()]\n",
|
|
|
|
|
|
" if resize:\n",
|
|
|
|
|
|
" trans.insert(0, transforms.Resize(resize))\n",
|
|
|
|
|
|
" trans = transforms.Compose(trans)\n",
|
|
|
|
|
|
" mnist_train = torchvision.datasets.FashionMNIST(\n",
|
|
|
|
|
|
" root=\"./data\", train=True, transform=trans, download=False)\n",
|
|
|
|
|
|
" mnist_test = torchvision.datasets.FashionMNIST(\n",
|
|
|
|
|
|
" root=\"./data\", train=False, transform=trans, download=False)\n",
|
|
|
|
|
|
" return (data.DataLoader(mnist_train, batch_size, shuffle=True,\n",
|
|
|
|
|
|
" num_workers=get_dataloader_workers()),\n",
|
|
|
|
|
|
" data.DataLoader(mnist_test, batch_size, shuffle=False,\n",
|
|
|
|
|
|
" num_workers=get_dataloader_workers()))"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "aa81880abd86cae6",
|
|
|
|
|
|
"outputs": [],
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"execution_count": 57
|
2026-03-15 06:42:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:09:04.172516352Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:09:03.782776079Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-15 06:42:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"train_iter, test_iter = load_data_fashion_mnist(32, resize=64)\n",
|
|
|
|
|
|
"for X, y in train_iter:\n",
|
|
|
|
|
|
" print(X.shape, X.dtype, y.shape, y.dtype)\n",
|
|
|
|
|
|
" break"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "4248a103f745154",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"name": "stdout",
|
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
|
"text": [
|
|
|
|
|
|
"torch.Size([32, 1, 64, 64]) torch.float32 torch.Size([32]) torch.int64\n"
|
|
|
|
|
|
]
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 58
|
2026-03-15 06:42:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:09:04.306546718Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:09:04.202047232Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-15 06:42:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"from IPython import display\n",
|
|
|
|
|
|
"batch_size = 256\n",
|
|
|
|
|
|
"train_iter, test_iter = load_data_fashion_mnist(32)\n",
|
|
|
|
|
|
"num_inputs = 784\n",
|
|
|
|
|
|
"num_outputs = 10\n",
|
|
|
|
|
|
"W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)\n",
|
|
|
|
|
|
"b = torch.zeros(num_outputs, requires_grad=True)\n",
|
|
|
|
|
|
"X = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])\n",
|
|
|
|
|
|
"X.sum(0, keepdim=True), X.sum(1, keepdim=True)\n"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "94c52f0cca88ef48",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"text/plain": [
|
|
|
|
|
|
"(tensor([[5., 7., 9.]]),\n",
|
|
|
|
|
|
" tensor([[ 6.],\n",
|
|
|
|
|
|
" [15.]]))"
|
|
|
|
|
|
]
|
|
|
|
|
|
},
|
|
|
|
|
|
"execution_count": 59,
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 59
|
2026-03-15 06:42:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:09:04.394766349Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:09:04.309113610Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-15 06:42:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"def softmax(X):\n",
|
|
|
|
|
|
" X_exp = torch.exp(X)\n",
|
|
|
|
|
|
" partition = X_exp.sum(1, keepdim=True)\n",
|
|
|
|
|
|
" return X_exp / partition # 这里应用了广播机制\n",
|
|
|
|
|
|
"X = torch.normal(0, 1, (2, 5))\n",
|
|
|
|
|
|
"X_prob = softmax(X)\n",
|
|
|
|
|
|
"X_prob, X_prob.sum(1)"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "c4ab34373c5a664e",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"text/plain": [
|
|
|
|
|
|
"(tensor([[0.1969, 0.2421, 0.2949, 0.1017, 0.1645],\n",
|
|
|
|
|
|
" [0.2729, 0.1267, 0.0365, 0.2715, 0.2924]]),\n",
|
|
|
|
|
|
" tensor([1., 1.]))"
|
|
|
|
|
|
]
|
|
|
|
|
|
},
|
|
|
|
|
|
"execution_count": 60,
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 60
|
2026-03-15 06:42:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:09:04.445839278Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:09:04.430535547Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-15 06:42:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"def net(X):\n",
|
|
|
|
|
|
" return softmax(torch.matmul(X.reshape((-1, W.shape[0])), W) + b)"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "6eacc53b2b9738af",
|
|
|
|
|
|
"outputs": [],
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"execution_count": 61
|
2026-03-15 06:42:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:09:04.567489691Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:09:04.449587260Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-15 06:42:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"y = torch.tensor([0, 2])\n",
|
|
|
|
|
|
"y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])\n",
|
|
|
|
|
|
"y_hat[[0, 1], y]"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "698449b4dafb545c",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"text/plain": [
|
|
|
|
|
|
"tensor([0.1000, 0.5000])"
|
|
|
|
|
|
]
|
|
|
|
|
|
},
|
|
|
|
|
|
"execution_count": 62,
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 62
|
2026-03-15 06:42:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:09:04.745616835Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:09:04.613175566Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-15 06:42:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"def cross_entropy(y_hat, y):\n",
|
|
|
|
|
|
" return - torch.log(y_hat[range(len(y_hat)), y])\n",
|
|
|
|
|
|
"cross_entropy(y_hat, y)"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "1720369fc8568c8c",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"text/plain": [
|
|
|
|
|
|
"tensor([2.3026, 0.6931])"
|
|
|
|
|
|
]
|
|
|
|
|
|
},
|
|
|
|
|
|
"execution_count": 63,
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 63
|
2026-03-14 11:51:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:09:04.855442356Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:09:04.774151992Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-14 11:51:56 +00:00
|
|
|
|
"cell_type": "code",
|
2026-03-15 06:42:56 +00:00
|
|
|
|
"source": [
|
|
|
|
|
|
"def accuracy(y_hat, y): #@save\n",
|
|
|
|
|
|
" \"\"\"计算预测正确的数量\"\"\"\n",
|
|
|
|
|
|
" if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:\n",
|
|
|
|
|
|
" y_hat = y_hat.argmax(axis=1)\n",
|
|
|
|
|
|
" cmp = y_hat.type(y.dtype) == y\n",
|
|
|
|
|
|
" return float(cmp.type(y.dtype).sum())\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
"accuracy(y_hat, y)/len(y)"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "e65719500a64ed87",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"text/plain": [
|
|
|
|
|
|
"0.5"
|
|
|
|
|
|
]
|
|
|
|
|
|
},
|
|
|
|
|
|
"execution_count": 64,
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 64
|
2026-03-15 06:42:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:09:04.919559483Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:09:04.858709428Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-15 06:42:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"class Accumulator: #@save\n",
|
|
|
|
|
|
" \"\"\"在n个变量上累加\"\"\"\n",
|
|
|
|
|
|
" def __init__(self, n):\n",
|
|
|
|
|
|
" self.data = [0.0] * n\n",
|
|
|
|
|
|
" def add(self, *args):\n",
|
|
|
|
|
|
" self.data = [a + float(b) for a, b in zip(self.data, args)]\n",
|
|
|
|
|
|
" def reset(self):\n",
|
|
|
|
|
|
" self.data = [0.0] * len(self.data)\n",
|
|
|
|
|
|
" def __getitem__(self, idx):\n",
|
|
|
|
|
|
" return self.data[idx]\n"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "f1eebb35ff2e9fea",
|
|
|
|
|
|
"outputs": [],
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"execution_count": 65
|
2026-03-15 06:42:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:09:04.987417283Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:09:04.924473660Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-15 06:42:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"def evaluate_accuracy(net, data_iter): #@save\n",
|
|
|
|
|
|
" \"\"\"计算在指定数据集上模型的精度\"\"\"\n",
|
|
|
|
|
|
" if isinstance(net, torch.nn.Module):\n",
|
|
|
|
|
|
" net.eval() # 将模型设置为评估模式\n",
|
|
|
|
|
|
" metric = Accumulator(2) # 正确预测数、预测总数\n",
|
|
|
|
|
|
" with torch.no_grad():\n",
|
|
|
|
|
|
" for X, y in data_iter:\n",
|
|
|
|
|
|
" metric.add(accuracy(net(X), y), y.numel())\n",
|
|
|
|
|
|
" return metric[0] / metric[1]\n"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "bc2beb5f2d6afe7e",
|
|
|
|
|
|
"outputs": [],
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"execution_count": 66
|
2026-03-15 06:42:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:09:09.640811163Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:09:04.992201693Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-15 06:42:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": "evaluate_accuracy(net, test_iter)",
|
|
|
|
|
|
"id": "65bcfb7e40c1a98b",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"text/plain": [
|
|
|
|
|
|
"0.1045"
|
|
|
|
|
|
]
|
|
|
|
|
|
},
|
|
|
|
|
|
"execution_count": 67,
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 67
|
2026-03-15 06:42:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:09:10.226880910Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:09:09.774018681Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-15 06:42:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"def train_epoch_ch3(net, train_iter, loss, updater): #@save\n",
|
|
|
|
|
|
" \"\"\"训练模型一个迭代周期(定义见第3章)\"\"\"\n",
|
|
|
|
|
|
" # 将模型设置为训练模式\n",
|
|
|
|
|
|
" if isinstance(net, torch.nn.Module):\n",
|
|
|
|
|
|
" net.train()\n",
|
|
|
|
|
|
" # 训练损失总和、训练准确度总和、样本数\n",
|
|
|
|
|
|
" metric = Accumulator(3)\n",
|
|
|
|
|
|
" for X, y in train_iter:\n",
|
|
|
|
|
|
" # 计算梯度并更新参数\n",
|
|
|
|
|
|
" y_hat = net(X)\n",
|
|
|
|
|
|
" l = loss(y_hat, y)\n",
|
|
|
|
|
|
" if isinstance(updater, torch.optim.Optimizer):\n",
|
|
|
|
|
|
" # 使用PyTorch内置的优化器和损失函数\n",
|
|
|
|
|
|
" updater.zero_grad()\n",
|
|
|
|
|
|
" l.mean().backward()\n",
|
|
|
|
|
|
" updater.step()\n",
|
|
|
|
|
|
" else:\n",
|
|
|
|
|
|
" # 使用定制的优化器和损失函数\n",
|
|
|
|
|
|
" l.sum().backward()\n",
|
|
|
|
|
|
" updater(X.shape[0])\n",
|
|
|
|
|
|
" metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())\n",
|
|
|
|
|
|
"# 返回训练损失和训练精度\n",
|
|
|
|
|
|
" return metric[0] / metric[2], metric[1] / metric[2]"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "2faf1dcc6c023a53",
|
|
|
|
|
|
"outputs": [],
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"execution_count": 68
|
2026-03-15 06:42:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:09:10.673837338Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:09:10.455642059Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-15 06:42:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"class Animator: #@save\n",
|
|
|
|
|
|
" \"\"\"在动画中绘制数据\"\"\"\n",
|
|
|
|
|
|
" def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,\n",
|
|
|
|
|
|
" ylim=None, xscale='linear', yscale='linear',\n",
|
|
|
|
|
|
" fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,\n",
|
|
|
|
|
|
" figsize=(3.5, 2.5)):\n",
|
|
|
|
|
|
" # 增量地绘制多条线\n",
|
|
|
|
|
|
" if legend is None:\n",
|
|
|
|
|
|
" legend = []\n",
|
|
|
|
|
|
" use_svg_display()\n",
|
|
|
|
|
|
" self.fig, self.axes = plt.subplots(nrows, ncols, figsize=figsize)\n",
|
|
|
|
|
|
" if nrows * ncols == 1:\n",
|
|
|
|
|
|
" self.axes = [self.axes, ]\n",
|
|
|
|
|
|
" # 使用lambda函数捕获参数\n",
|
|
|
|
|
|
" self.config_axes = lambda: set_axes(\n",
|
|
|
|
|
|
" self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)\n",
|
|
|
|
|
|
" self.X, self.Y, self.fmts = None, None, fmts\n",
|
|
|
|
|
|
" def add(self, x, y):\n",
|
|
|
|
|
|
" # 向图表中添加多个数据点\n",
|
|
|
|
|
|
" if not hasattr(y, \"__len__\"):\n",
|
|
|
|
|
|
" y = [y]\n",
|
|
|
|
|
|
" n = len(y)\n",
|
|
|
|
|
|
" if not hasattr(x, \"__len__\"):\n",
|
|
|
|
|
|
" x = [x] * n\n",
|
|
|
|
|
|
" if not self.X:\n",
|
|
|
|
|
|
" self.X = [[] for _ in range(n)]\n",
|
|
|
|
|
|
" if not self.Y:\n",
|
|
|
|
|
|
" self.Y = [[] for _ in range(n)]\n",
|
|
|
|
|
|
" for i, (a, b) in enumerate(zip(x, y)):\n",
|
|
|
|
|
|
" if a is not None and b is not None:\n",
|
|
|
|
|
|
" self.X[i].append(a)\n",
|
|
|
|
|
|
" self.Y[i].append(b)\n",
|
|
|
|
|
|
" self.axes[0].cla()\n",
|
|
|
|
|
|
" for x, y, fmt in zip(self.X, self.Y, self.fmts):\n",
|
|
|
|
|
|
" self.axes[0].plot(x, y, fmt)\n",
|
|
|
|
|
|
" self.config_axes()\n",
|
|
|
|
|
|
" display.display(self.fig)\n",
|
|
|
|
|
|
" display.clear_output(wait=True)\n"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "7cd5367ab43c5e5f",
|
|
|
|
|
|
"outputs": [],
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"execution_count": 69
|
2026-03-15 06:42:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:09:10.768738929Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:09:10.687413593Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-15 06:42:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater): #@save\n",
|
|
|
|
|
|
" \"\"\"训练模型(定义见第3章)\"\"\"\n",
|
|
|
|
|
|
" animator = Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0.3, 0.9],\n",
|
|
|
|
|
|
" legend=['train loss', 'train acc', 'test acc'])\n",
|
|
|
|
|
|
" for epoch in range(num_epochs):\n",
|
|
|
|
|
|
" train_metrics = train_epoch_ch3(net, train_iter, loss, updater)\n",
|
|
|
|
|
|
" test_acc = evaluate_accuracy(net, test_iter)\n",
|
|
|
|
|
|
" animator.add(epoch + 1, train_metrics + (test_acc,))\n",
|
|
|
|
|
|
" train_loss, train_acc = train_metrics\n",
|
|
|
|
|
|
" assert train_loss < 0.5, train_loss\n",
|
|
|
|
|
|
" assert train_acc <= 1 and train_acc > 0.7, train_acc\n",
|
|
|
|
|
|
" assert test_acc <= 1 and test_acc > 0.7, test_acc"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "b02a143c75fad40",
|
|
|
|
|
|
"outputs": [],
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"execution_count": 70
|
2026-03-15 06:42:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:09:10.914319299Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:09:10.840221929Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-15 06:42:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"lr = 0.1\n",
|
|
|
|
|
|
"def updater(batch_size):\n",
|
|
|
|
|
|
" return sgd([W, b], lr, batch_size)"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "6a97b70779276b61",
|
|
|
|
|
|
"outputs": [],
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"execution_count": 71
|
2026-03-15 06:42:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:09:11.044553254Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:09:10.940821708Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-15 06:42:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"num_epochs = 10\n",
|
2026-03-19 05:47:14 +00:00
|
|
|
|
"#train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, updater)"
|
2026-03-15 06:42:56 +00:00
|
|
|
|
],
|
|
|
|
|
|
"id": "df3cceb72faee402",
|
2026-03-19 05:47:14 +00:00
|
|
|
|
"outputs": [],
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"execution_count": 72
|
2026-03-15 06:42:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:09:11.628097076Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:09:11.067755016Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-15 06:42:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"def predict_ch3(net, test_iter, n=6): #@save\n",
|
|
|
|
|
|
" \"\"\"预测标签(定义见第3章)\"\"\"\n",
|
|
|
|
|
|
" for (X, y),i in zip(test_iter,range(1)):\n",
|
|
|
|
|
|
" trues = get_fashion_mnist_labels(y)\n",
|
|
|
|
|
|
" preds = get_fashion_mnist_labels(net(X).argmax(axis=1))\n",
|
|
|
|
|
|
" titles = [true +'\\n' + pred for true, pred in zip(trues, preds)]\n",
|
|
|
|
|
|
" show_images(\n",
|
|
|
|
|
|
" X[0:n].reshape((n, 28, 28)), 1, n, titles=titles[0:n])\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
"predict_ch3(net, test_iter)\n"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "94f6177bfb40eece",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"text/plain": [
|
|
|
|
|
|
"<Figure size 600x100 with 6 Axes>"
|
|
|
|
|
|
],
|
|
|
|
|
|
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"357.008839pt\" height=\"90.784071pt\" viewBox=\"0 0 357.008839 90.784071\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2026-03-19T14:09:11.541460</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.7.2, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 90.784071 \nL 357.008839 90.784071 \nL 357.008839 0 \nL 0 0 \nz\n\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 15.008839 83.584071 \nL 62.837411 83.584071 \nL 62.837411 35.7555 \nL 15.008839 35.7555 \nz\n\"/>\n </g>\n <g clip-path=\"url(#p15736a100e)\">\n <image xlink:href=\"data:image/png;base64,\niVBORw0KGgoAAAANSUhEUgAAAEMAAABDCAYAAADHyrhzAAAOAElEQVR4nO2a2ZIkV1KGP/dzIiKX2rsltTZkwwi0GNhwwyVzBW/Ak/BevACGwTUYBjPGYswIBCYJqdVdXV1bZmREnOPOxYlcqjeJYYYWZvmbZVd2RFREnN/df19OyR/Lnzp7AKCv+wV+SNiTsYM9GTvYk7GDPRk72JOxgz0ZO9iTsYM9GTvYk7GDPRk72JOxg/i9rhIBUUTLTwAJCqrb82vo/4Bfs/LTffPds4EbnnM5/ipoQCdNef7uc9f3Xd/7Fc/2bPjQA99BhlQ1UkX03hn5zWOsifQnNblR2jNlOBIsQq4BBascD5TPqzhxEAftQbIQF1AtnGoB86974jJRfXFO/vZxIcXyC28TfudHfPMnb9IfQZo5VoEm0F4QA0kgVo5J3j4XL8fFnNNf9tR/+0sYhleQIYLUFVLX5PvHLD44YJgJiwdKnkL7/sDRm7ccND1vzBZMwsBpvWQaBg5Cx3FclhemWCZTvMdcya50Hvl6dcIi1Xx+dY/HF4f405r+Fw31Vc3Z7SHy9BJ68JeQ0b9zxPKPbvnxm+d8ePiYB/U1T4Y536yO6S1w1U/pUqQdKlZDxF3IWXGHnAJmgsuUt/9pCp0SpWkQEeS9t8lnB1gdSPOIByFNFYvC6lTozgSroT82vHbiwUBQI5tyuZqiMuHJao6KozgijooT1e4sIEomqqGydd953dMdtiyCcWVTYit0J2dMfv8Ei+W5GFStowlCb+jgXHxc8dGDz/nk6CE/njziXrjlsprxoLlisMhVnpJMuc0Ni9TQW+C6n5BdWQ4VQw4sjqdwfIj0E6IeHiBVxcUfvsXFp0I6cOStFVWVOZh2NDFxpEYYF2BeLNylSHahGyKPFgdYVoZVhKQwCNopCHh02JEUrw2dJqo68e7ZFafNkremN7w7v6QSY/rhgIqxSA2DK787f8RPD/6VG5vwV1ef8m13xFe3JzxZzPjR2QV/9t5f8E684UShEcXcGdbeOOrFymFAeJynfNY/YGEN58Mht7nhz989o3vvhNBlosxneF3RHwrDseGzzOGsowqZed1ThTy6t2AuZFOSKe0Q6VOk7yNDF/GksFJkULQXdKCQEQTfIcMymAR6E266BoBpHJjFHsLAiQ5ENaZhAOC36nPeD7esdMl/zR5yFFdUmmlC4r3ZJZWkEjLuxX12SNgNroBTSUYxAkYlmUoyro4ruApx+fFbeBQW7wr1gyUpBW7O55CFp50ieRSjXFbkUkRI8laUapPNMdbe7yMZcfs7ANYreRBc4fHtKY8VqAyJhlbGbNYR1YiheOPP6vf4y9mnVGJEzeNiheOmpcuRv779lEYHKskbfVpr1G2eMHhgpj2VZJZWcz4ckCzQ5orOImERiG2PdonYH0csQjowTmcd14sJ1ga0F6prRXvQ9cLZZglNoyFkJ3PI9pr1Bx8z70bJR08RCF0oBEXFY8lGt31AohGCIepcxwmPF3OaKvH2/JpZLB5Ta2Zw5cvVKWHUKZWtPiUPXPZTkivz2DMPPZ1F2lxhLvQWSBZKphkykox4+NkNRKU/OOByuIc41GNqCivQoaTKXN9dsMWSnhC2YbBzzqPjAl5tnKScC2C1F09KILZmcH2P8sVdwJ1hCKQUaLuaPgWqYHRDpE+BOmbuHywIshXkwQLtUG30zF2Y1AOTmDbXqDjHdcss9uTGSQc12mei/+xfkFhxXz6iuZqTJkJ/VKxacjGkGVh11wvW+XpDxnoRIwFWjbFY+XjNuFoFouEmSBvQ5FsyxvttargxFVofQJyurRBxbFGhS2U5M7r7kRi36tB1FUNb4SaQpXjkJBOrjAajqjJVyBzXLYfVCpsYw2FAeyXijudMuFkxPa9J04BmxYJsXX1tVymW3YUH8Lp4ATqSsVkMSCphIYzCoeDlH3ay6+Z6z1KiT8oDPSvrAxKtvMp4HwAzwWysisVx3yk6pRhCg6PBUHXcBR+TgY2WdRVQH4suy/h/fEnzTcNkOmV+dozVkeFsQpoGrBZSU8jJ9UhSBAtCmkFuCilWAerFwiaolTDbxInsECzgwbeCbEAWvNdyzWjVTXwFR4MTq0yXy2KoDMtKAkK4G7OijjaGiDOZ9kzrgWxCnyLmwipXtLlCXLbZZG0UW61gtUJXHapKaGq8CeBgg6CD4gFSKi9rlWDBQYU0E9zHsKBkFmyddeSOtoiMJbuAiyD6jHu4lJxolBfd8QJRQ9XQ6ORosGNpd0GedTVxRG1TAK6rYAeyKb2FzXsCxE2Ajr5l/QCXV0gIxGVLrCoYmzJXgTDGSVBcFTuo6U4brBa6o0CuIc2FYV48wOriAfiapFKDFI9wPAgWtl6yzlAl5AyCI5UhwamqTAjGZNqTax37O8G9GKksQ9CqZBVLiqEsUmApExhJEXUu4ow2VcSFEhcDoTeihIDbaEp3sIwtFmXB19d8F3Q2Y356Ak3N5MExaRZp36gQU3IDfV1CSFIxieSSlovGCGaOjKILJbu4OoTykWiEytBQ0m0Qp6lLK9CnwKJtcBvDZoy5EKxowqBgUkLOpAh3k5HgtH1FykpohbgcU6vndQHxq225+pDwZQspEc8DYVIRVhOa64rcKKsTxSpKZzktVrcxOCVDyOuKbIzdAKKCySiWo/u7
|
|
|
|
|
|
},
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
"output_type": "display_data",
|
|
|
|
|
|
"jetTransient": {
|
|
|
|
|
|
"display_id": null
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 73
|
2026-03-15 06:42:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:09:11.756853549Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:09:11.685653252Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-15 06:42:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"batch_size = 256\n",
|
|
|
|
|
|
"train_iter, test_iter = load_data_fashion_mnist(batch_size)"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "4a0bfd0479ec7386",
|
|
|
|
|
|
"outputs": [],
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"execution_count": 74
|
2026-03-15 06:42:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:09:42.908237720Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:09:42.860595602Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-15 06:42:56 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))\n",
|
|
|
|
|
|
"def init_weights(m):\n",
|
|
|
|
|
|
" if type(m) == nn.Linear:\n",
|
|
|
|
|
|
" nn.init.normal_(m.weight, std=0.01)\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
"net.apply(init_weights);\n",
|
|
|
|
|
|
"loss = nn.CrossEntropyLoss(reduction='none')\n",
|
|
|
|
|
|
"trainer = torch.optim.SGD(net.parameters(), lr=0.1)\n",
|
|
|
|
|
|
"num_epochs = 10\n",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"#train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)"
|
2026-03-15 06:42:56 +00:00
|
|
|
|
],
|
|
|
|
|
|
"id": "b9808d88f5e6827b",
|
|
|
|
|
|
"outputs": [],
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"execution_count": 76
|
2026-03-15 06:42:56 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:09:46.372721439Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:09:46.108696173Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-15 06:42:56 +00:00
|
|
|
|
"cell_type": "code",
|
2026-03-19 05:47:14 +00:00
|
|
|
|
"source": [
|
|
|
|
|
|
"x = torch.arange(-8.0, 8.0, 0.1, requires_grad=True)\n",
|
|
|
|
|
|
"y = torch.relu(x)\n",
|
|
|
|
|
|
"plot(x.detach(), y.detach(), 'x', 'relu(x)', figsize=(5, 2.5))"
|
|
|
|
|
|
],
|
2026-03-15 06:42:56 +00:00
|
|
|
|
"id": "c25dd146307f58e0",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"text/plain": [
|
|
|
|
|
|
"<Figure size 500x250 with 1 Axes>"
|
|
|
|
|
|
],
|
|
|
|
|
|
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"320.440625pt\" height=\"183.35625pt\" viewBox=\"0 0 320.440625 183.35625\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2026-03-19T14:09:46.256394</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.7.2, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 183.35625 \nL 320.440625 183.35625 \nL 320.440625 0 \nL 0 0 \nz\n\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 34.240625 145.8 \nL 313.240625 145.8 \nL 313.240625 7.2 \nL 34.240625 7.2 \nz\n\"/>\n </g>\n <g id=\"matplotlib.axis_1\">\n <g id=\"xtick_1\">\n <g id=\"line2d_1\">\n <path d=\"M 46.922443 145.8 \nL 46.922443 7.2 \n\" clip-path=\"url(#pee650b9715)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_2\">\n <defs>\n <path id=\"m5bb76aae00\" d=\"M 0 0 \nL 0 3.5 \n\" style=\"stroke: #ffffff; stroke-width: 0.8\"/>\n </defs>\n <g>\n <use xlink:href=\"#m5bb76aae00\" x=\"46.922443\" y=\"145.8\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_1\">\n <!-- −8 -->\n <g style=\"fill: #ffffff\" transform=\"translate(39.551349 160.398438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-2212\" d=\"M 678 2272 \nL 4684 2272 \nL 4684 1741 \nL 678 1741 \nL 678 2272 \nz\n\" transform=\"scale(0.015625)\"/>\n <path id=\"DejaVuSans-38\" d=\"M 2034 2216 \nQ 1584 2216 1326 1975 \nQ 1069 1734 1069 1313 \nQ 1069 891 1326 650 \nQ 1584 409 2034 409 \nQ 2484 409 2743 651 \nQ 3003 894 3003 1313 \nQ 3003 1734 2745 1975 \nQ 2488 2216 2034 2216 \nz\nM 1403 2484 \nQ 997 2584 770 2862 \nQ 544 3141 544 3541 \nQ 544 4100 942 4425 \nQ 1341 4750 2034 4750 \nQ 2731 4750 3128 4425 \nQ 3525 4100 3525 3541 \nQ 3525 3141 3298 2862 \nQ 3072 2584 2669 2484 \nQ 3125 2378 3379 2068 \nQ 3634 1759 3634 1313 \nQ 3634 634 3220 271 \nQ 2806 -91 2034 -91 \nQ 1263 -91 848 271 \nQ 434 634 434 1313 \nQ 434 1759 690 2068 \nQ 947 2378 1403 2484 \nz\nM 1172 3481 \nQ 1172 3119 1398 2916 \nQ 1625 2713 2034 2713 \nQ 2441 2713 2670 2916 \nQ 2900 3119 2900 3481 \nQ 2900 3844 2670 4047 \nQ 2441 4250 2034 4250 \nQ 1625 4250 1398 4047 \nQ 1172 3844 1172 3481 \nz\n\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-2212\"/>\n <use xlink:href=\"#DejaVuSans-38\" x=\"83.789062\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_2\">\n <g id=\"line2d_3\">\n <path d=\"M 78.826389 145.8 \nL 78.826389 7.2 \n\" clip-path=\"url(#pee650b9715)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_4\">\n <g>\n <use xlink:href=\"#m5bb76aae00\" x=\"78.826389\" y=\"145.8\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_2\">\n <!-- −6 -->\n <g style=\"fill: #ffffff\" transform=\"translate(71.455295 160.398438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-36\" d=\"M 2113 2584 \nQ 1688 2584 1439 2293 \nQ 1191 2003 1191 1497 \nQ 1191 994 1439 701 \nQ 1688 409 2113 409 \nQ 2538 409 2786 701 \nQ 3034 994 3034 1497 \nQ 3034 2003 2786 2293 \nQ 2538 2584 2113 2584 \nz\nM 3366 4563 \nL 3366 3988 \nQ 3128
|
|
|
|
|
|
},
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
"output_type": "display_data",
|
|
|
|
|
|
"jetTransient": {
|
|
|
|
|
|
"display_id": null
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 77
|
2026-03-19 05:47:14 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:09:46.713437268Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:09:46.494861144Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-19 05:47:14 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"y.backward(torch.ones_like(x), retain_graph=True)\n",
|
|
|
|
|
|
"plot(x.detach(), x.grad, 'x', 'grad of relu', figsize=(5, 2.5))"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "f96acd2015dccb38",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"text/plain": [
|
|
|
|
|
|
"<Figure size 500x250 with 1 Axes>"
|
|
|
|
|
|
],
|
|
|
|
|
|
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"329.98125pt\" height=\"183.35625pt\" viewBox=\"0 0 329.98125 183.35625\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2026-03-19T14:09:46.640099</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.7.2, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 183.35625 \nL 329.98125 183.35625 \nL 329.98125 0 \nL 0 0 \nz\n\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 43.78125 145.8 \nL 322.78125 145.8 \nL 322.78125 7.2 \nL 43.78125 7.2 \nz\n\"/>\n </g>\n <g id=\"matplotlib.axis_1\">\n <g id=\"xtick_1\">\n <g id=\"line2d_1\">\n <path d=\"M 56.463068 145.8 \nL 56.463068 7.2 \n\" clip-path=\"url(#p50e2fe8e16)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_2\">\n <defs>\n <path id=\"m08c14e3804\" d=\"M 0 0 \nL 0 3.5 \n\" style=\"stroke: #ffffff; stroke-width: 0.8\"/>\n </defs>\n <g>\n <use xlink:href=\"#m08c14e3804\" x=\"56.463068\" y=\"145.8\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_1\">\n <!-- −8 -->\n <g style=\"fill: #ffffff\" transform=\"translate(49.091974 160.398438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-2212\" d=\"M 678 2272 \nL 4684 2272 \nL 4684 1741 \nL 678 1741 \nL 678 2272 \nz\n\" transform=\"scale(0.015625)\"/>\n <path id=\"DejaVuSans-38\" d=\"M 2034 2216 \nQ 1584 2216 1326 1975 \nQ 1069 1734 1069 1313 \nQ 1069 891 1326 650 \nQ 1584 409 2034 409 \nQ 2484 409 2743 651 \nQ 3003 894 3003 1313 \nQ 3003 1734 2745 1975 \nQ 2488 2216 2034 2216 \nz\nM 1403 2484 \nQ 997 2584 770 2862 \nQ 544 3141 544 3541 \nQ 544 4100 942 4425 \nQ 1341 4750 2034 4750 \nQ 2731 4750 3128 4425 \nQ 3525 4100 3525 3541 \nQ 3525 3141 3298 2862 \nQ 3072 2584 2669 2484 \nQ 3125 2378 3379 2068 \nQ 3634 1759 3634 1313 \nQ 3634 634 3220 271 \nQ 2806 -91 2034 -91 \nQ 1263 -91 848 271 \nQ 434 634 434 1313 \nQ 434 1759 690 2068 \nQ 947 2378 1403 2484 \nz\nM 1172 3481 \nQ 1172 3119 1398 2916 \nQ 1625 2713 2034 2713 \nQ 2441 2713 2670 2916 \nQ 2900 3119 2900 3481 \nQ 2900 3844 2670 4047 \nQ 2441 4250 2034 4250 \nQ 1625 4250 1398 4047 \nQ 1172 3844 1172 3481 \nz\n\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-2212\"/>\n <use xlink:href=\"#DejaVuSans-38\" x=\"83.789062\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_2\">\n <g id=\"line2d_3\">\n <path d=\"M 88.367014 145.8 \nL 88.367014 7.2 \n\" clip-path=\"url(#p50e2fe8e16)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_4\">\n <g>\n <use xlink:href=\"#m08c14e3804\" x=\"88.367014\" y=\"145.8\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_2\">\n <!-- −6 -->\n <g style=\"fill: #ffffff\" transform=\"translate(80.99592 160.398438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-36\" d=\"M 2113 2584 \nQ 1688 2584 1439 2293 \nQ 1191 2003 1191 1497 \nQ 1191 994 1439 701 \nQ 1688 409 2113 409 \nQ 2538 409 2786 701 \nQ 3034 994 3034 1497 \nQ 3034 2003 2786 2293 \nQ 2538 2584 2113 2584 \nz\nM 3366 4563 \nL 3366 3988 \nQ 3128 4100 288
|
|
|
|
|
|
},
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
"output_type": "display_data",
|
|
|
|
|
|
"jetTransient": {
|
|
|
|
|
|
"display_id": null
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 78
|
2026-03-19 05:47:14 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:09:47.267818729Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:09:47.018071215Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-19 05:47:14 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"y = torch.sigmoid(x)\n",
|
|
|
|
|
|
"plot(x.detach(), y.detach(), 'x', 'sigmoid(x)', figsize=(5, 2.5))"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "74013cea59cd8be3",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"text/plain": [
|
|
|
|
|
|
"<Figure size 500x250 with 1 Axes>"
|
|
|
|
|
|
],
|
|
|
|
|
|
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"329.98125pt\" height=\"183.35625pt\" viewBox=\"0 0 329.98125 183.35625\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2026-03-19T14:09:47.185521</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.7.2, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 183.35625 \nL 329.98125 183.35625 \nL 329.98125 0 \nL 0 0 \nz\n\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 43.78125 145.8 \nL 322.78125 145.8 \nL 322.78125 7.2 \nL 43.78125 7.2 \nz\n\"/>\n </g>\n <g id=\"matplotlib.axis_1\">\n <g id=\"xtick_1\">\n <g id=\"line2d_1\">\n <path d=\"M 56.463068 145.8 \nL 56.463068 7.2 \n\" clip-path=\"url(#p8269f525f0)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_2\">\n <defs>\n <path id=\"ma5cdbfd291\" d=\"M 0 0 \nL 0 3.5 \n\" style=\"stroke: #ffffff; stroke-width: 0.8\"/>\n </defs>\n <g>\n <use xlink:href=\"#ma5cdbfd291\" x=\"56.463068\" y=\"145.8\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_1\">\n <!-- −8 -->\n <g style=\"fill: #ffffff\" transform=\"translate(49.091974 160.398438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-2212\" d=\"M 678 2272 \nL 4684 2272 \nL 4684 1741 \nL 678 1741 \nL 678 2272 \nz\n\" transform=\"scale(0.015625)\"/>\n <path id=\"DejaVuSans-38\" d=\"M 2034 2216 \nQ 1584 2216 1326 1975 \nQ 1069 1734 1069 1313 \nQ 1069 891 1326 650 \nQ 1584 409 2034 409 \nQ 2484 409 2743 651 \nQ 3003 894 3003 1313 \nQ 3003 1734 2745 1975 \nQ 2488 2216 2034 2216 \nz\nM 1403 2484 \nQ 997 2584 770 2862 \nQ 544 3141 544 3541 \nQ 544 4100 942 4425 \nQ 1341 4750 2034 4750 \nQ 2731 4750 3128 4425 \nQ 3525 4100 3525 3541 \nQ 3525 3141 3298 2862 \nQ 3072 2584 2669 2484 \nQ 3125 2378 3379 2068 \nQ 3634 1759 3634 1313 \nQ 3634 634 3220 271 \nQ 2806 -91 2034 -91 \nQ 1263 -91 848 271 \nQ 434 634 434 1313 \nQ 434 1759 690 2068 \nQ 947 2378 1403 2484 \nz\nM 1172 3481 \nQ 1172 3119 1398 2916 \nQ 1625 2713 2034 2713 \nQ 2441 2713 2670 2916 \nQ 2900 3119 2900 3481 \nQ 2900 3844 2670 4047 \nQ 2441 4250 2034 4250 \nQ 1625 4250 1398 4047 \nQ 1172 3844 1172 3481 \nz\n\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-2212\"/>\n <use xlink:href=\"#DejaVuSans-38\" x=\"83.789062\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_2\">\n <g id=\"line2d_3\">\n <path d=\"M 88.367014 145.8 \nL 88.367014 7.2 \n\" clip-path=\"url(#p8269f525f0)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_4\">\n <g>\n <use xlink:href=\"#ma5cdbfd291\" x=\"88.367014\" y=\"145.8\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_2\">\n <!-- −6 -->\n <g style=\"fill: #ffffff\" transform=\"translate(80.99592 160.398438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-36\" d=\"M 2113 2584 \nQ 1688 2584 1439 2293 \nQ 1191 2003 1191 1497 \nQ 1191 994 1439 701 \nQ 1688 409 2113 409 \nQ 2538 409 2786 701 \nQ 3034 994 3034 1497 \nQ 3034 2003 2786 2293 \nQ 2538 2584 2113 2584 \nz\nM 3366 4563 \nL 3366 3988 \nQ 3128 4100 288
|
|
|
|
|
|
},
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
"output_type": "display_data",
|
|
|
|
|
|
"jetTransient": {
|
|
|
|
|
|
"display_id": null
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 79
|
2026-03-19 05:47:14 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:09:47.710436882Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:09:47.473515430Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-19 05:47:14 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"x.grad.data.zero_()\n",
|
|
|
|
|
|
"y.backward(torch.ones_like(x),retain_graph=True)\n",
|
|
|
|
|
|
"plot(x.detach(), x.grad, 'x', 'grad of sigmoid', figsize=(5, 2.5))"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "6a0b4f529bf9cc5c",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"text/plain": [
|
|
|
|
|
|
"<Figure size 500x250 with 1 Axes>"
|
|
|
|
|
|
],
|
|
|
|
|
|
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"336.34375pt\" height=\"183.35625pt\" viewBox=\"0 0 336.34375 183.35625\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2026-03-19T14:09:47.635692</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.7.2, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 183.35625 \nL 336.34375 183.35625 \nL 336.34375 0 \nL 0 0 \nz\n\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 50.14375 145.8 \nL 329.14375 145.8 \nL 329.14375 7.2 \nL 50.14375 7.2 \nz\n\"/>\n </g>\n <g id=\"matplotlib.axis_1\">\n <g id=\"xtick_1\">\n <g id=\"line2d_1\">\n <path d=\"M 62.825568 145.8 \nL 62.825568 7.2 \n\" clip-path=\"url(#p1199fb60d9)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_2\">\n <defs>\n <path id=\"m63c8d8144b\" d=\"M 0 0 \nL 0 3.5 \n\" style=\"stroke: #ffffff; stroke-width: 0.8\"/>\n </defs>\n <g>\n <use xlink:href=\"#m63c8d8144b\" x=\"62.825568\" y=\"145.8\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_1\">\n <!-- −8 -->\n <g style=\"fill: #ffffff\" transform=\"translate(55.454474 160.398438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-2212\" d=\"M 678 2272 \nL 4684 2272 \nL 4684 1741 \nL 678 1741 \nL 678 2272 \nz\n\" transform=\"scale(0.015625)\"/>\n <path id=\"DejaVuSans-38\" d=\"M 2034 2216 \nQ 1584 2216 1326 1975 \nQ 1069 1734 1069 1313 \nQ 1069 891 1326 650 \nQ 1584 409 2034 409 \nQ 2484 409 2743 651 \nQ 3003 894 3003 1313 \nQ 3003 1734 2745 1975 \nQ 2488 2216 2034 2216 \nz\nM 1403 2484 \nQ 997 2584 770 2862 \nQ 544 3141 544 3541 \nQ 544 4100 942 4425 \nQ 1341 4750 2034 4750 \nQ 2731 4750 3128 4425 \nQ 3525 4100 3525 3541 \nQ 3525 3141 3298 2862 \nQ 3072 2584 2669 2484 \nQ 3125 2378 3379 2068 \nQ 3634 1759 3634 1313 \nQ 3634 634 3220 271 \nQ 2806 -91 2034 -91 \nQ 1263 -91 848 271 \nQ 434 634 434 1313 \nQ 434 1759 690 2068 \nQ 947 2378 1403 2484 \nz\nM 1172 3481 \nQ 1172 3119 1398 2916 \nQ 1625 2713 2034 2713 \nQ 2441 2713 2670 2916 \nQ 2900 3119 2900 3481 \nQ 2900 3844 2670 4047 \nQ 2441 4250 2034 4250 \nQ 1625 4250 1398 4047 \nQ 1172 3844 1172 3481 \nz\n\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-2212\"/>\n <use xlink:href=\"#DejaVuSans-38\" x=\"83.789062\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_2\">\n <g id=\"line2d_3\">\n <path d=\"M 94.729514 145.8 \nL 94.729514 7.2 \n\" clip-path=\"url(#p1199fb60d9)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_4\">\n <g>\n <use xlink:href=\"#m63c8d8144b\" x=\"94.729514\" y=\"145.8\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_2\">\n <!-- −6 -->\n <g style=\"fill: #ffffff\" transform=\"translate(87.35842 160.398438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-36\" d=\"M 2113 2584 \nQ 1688 2584 1439 2293 \nQ 1191 2003 1191 1497 \nQ 1191 994 1439 701 \nQ 1688 409 2113 409 \nQ 2538 409 2786 701 \nQ 3034 994 3034 1497 \nQ 3034 2003 2786 2293 \nQ 2538 2584 2113 2584 \nz\nM 3366 4563 \nL 3366 3988 \nQ 3128 4100 288
|
|
|
|
|
|
},
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
"output_type": "display_data",
|
|
|
|
|
|
"jetTransient": {
|
|
|
|
|
|
"display_id": null
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 80
|
2026-03-19 05:47:14 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:09:48.096574814Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:09:48.002692660Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-19 05:47:14 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"batch_size = 256\n",
|
|
|
|
|
|
"train_iter, test_iter = load_data_fashion_mnist(batch_size)"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "f1de998439b1b9f",
|
|
|
|
|
|
"outputs": [],
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"execution_count": 81
|
2026-03-19 05:47:14 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:09:51.282745933Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:09:51.217670982Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-19 05:47:14 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"num_inputs,num_outputs,num_hiddens = 784, 10, 256\n",
|
|
|
|
|
|
"W1=nn.Parameter(torch.randn(num_inputs,num_hiddens,requires_grad=True)*0.01)\n",
|
|
|
|
|
|
"b1=nn.Parameter(torch.zeros(num_hiddens,requires_grad=True))\n",
|
|
|
|
|
|
"W2 = nn.Parameter(torch.randn(num_hiddens,num_outputs,requires_grad=True)*0.01)\n",
|
|
|
|
|
|
"b2=nn.Parameter(torch.zeros(num_outputs,requires_grad=True))\n",
|
|
|
|
|
|
"params=[W1,b1,W2,b2]\n"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "adcea8cd4ee792a8",
|
|
|
|
|
|
"outputs": [],
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"execution_count": 82
|
2026-03-19 05:47:14 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:09:53.068840701Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:09:52.988897182Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-19 05:47:14 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"def relu(X):\n",
|
|
|
|
|
|
" a = torch.zeros_like(X)\n",
|
|
|
|
|
|
" return torch.max(X,a)\n",
|
|
|
|
|
|
"def net(X):\n",
|
|
|
|
|
|
" X = X.reshape((-1,num_inputs))\n",
|
|
|
|
|
|
" H = relu(X@W1+b1)\n",
|
|
|
|
|
|
" return (H@W2+b2)\n",
|
|
|
|
|
|
"loss = nn.CrossEntropyLoss(reduction='none')\n",
|
|
|
|
|
|
"num_epochs,lr=10,0.05\n",
|
|
|
|
|
|
"updater=torch.optim.SGD(params,lr=lr)\n",
|
|
|
|
|
|
"#train_ch3(net,train_iter,test_iter,loss,num_epochs,updater)\n"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "cfd81a0f16d0c573",
|
|
|
|
|
|
"outputs": [],
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"execution_count": 83
|
2026-03-19 05:47:14 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:09:54.454079900Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:09:53.918996975Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-19 05:47:14 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": "predict_ch3(net, test_iter)",
|
|
|
|
|
|
"id": "f2ed2e9cee14c28a",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"text/plain": [
|
|
|
|
|
|
"<Figure size 600x100 with 6 Axes>"
|
|
|
|
|
|
],
|
|
|
|
|
|
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"357.008839pt\" height=\"90.784071pt\" viewBox=\"0 0 357.008839 90.784071\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2026-03-19T14:09:54.360205</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.7.2, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 90.784071 \nL 357.008839 90.784071 \nL 357.008839 0 \nL 0 0 \nz\n\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 15.008839 83.584071 \nL 62.837411 83.584071 \nL 62.837411 35.7555 \nL 15.008839 35.7555 \nz\n\"/>\n </g>\n <g clip-path=\"url(#p9beae8b09f)\">\n <image xlink:href=\"data:image/png;base64,\niVBORw0KGgoAAAANSUhEUgAAAEMAAABDCAYAAADHyrhzAAAOAElEQVR4nO2a2ZIkV1KGP/dzIiKX2rsltTZkwwi0GNhwwyVzBW/Ak/BevACGwTUYBjPGYswIBCYJqdVdXV1bZmREnOPOxYlcqjeJYYYWZvmbZVd2RFREnN/df19OyR/Lnzp7AKCv+wV+SNiTsYM9GTvYk7GDPRk72JOxgz0ZO9iTsYM9GTvYk7GDPRk72JOxg/i9rhIBUUTLTwAJCqrb82vo/4Bfs/LTffPds4EbnnM5/ipoQCdNef7uc9f3Xd/7Fc/2bPjQA99BhlQ1UkX03hn5zWOsifQnNblR2jNlOBIsQq4BBascD5TPqzhxEAftQbIQF1AtnGoB86974jJRfXFO/vZxIcXyC28TfudHfPMnb9IfQZo5VoEm0F4QA0kgVo5J3j4XL8fFnNNf9tR/+0sYhleQIYLUFVLX5PvHLD44YJgJiwdKnkL7/sDRm7ccND1vzBZMwsBpvWQaBg5Cx3FclhemWCZTvMdcya50Hvl6dcIi1Xx+dY/HF4f405r+Fw31Vc3Z7SHy9BJ68JeQ0b9zxPKPbvnxm+d8ePiYB/U1T4Y536yO6S1w1U/pUqQdKlZDxF3IWXGHnAJmgsuUt/9pCp0SpWkQEeS9t8lnB1gdSPOIByFNFYvC6lTozgSroT82vHbiwUBQI5tyuZqiMuHJao6KozgijooT1e4sIEomqqGydd953dMdtiyCcWVTYit0J2dMfv8Ei+W5GFStowlCb+jgXHxc8dGDz/nk6CE/njziXrjlsprxoLlisMhVnpJMuc0Ni9TQW+C6n5BdWQ4VQw4sjqdwfIj0E6IeHiBVxcUfvsXFp0I6cOStFVWVOZh2NDFxpEYYF2BeLNylSHahGyKPFgdYVoZVhKQwCNopCHh02JEUrw2dJqo68e7ZFafNkremN7w7v6QSY/rhgIqxSA2DK787f8RPD/6VG5vwV1ef8m13xFe3JzxZzPjR2QV/9t5f8E684UShEcXcGdbeOOrFymFAeJynfNY/YGEN58Mht7nhz989o3vvhNBlosxneF3RHwrDseGzzOGsowqZed1ThTy6t2AuZFOSKe0Q6VOk7yNDF/GksFJkULQXdKCQEQTfIcMymAR6E266BoBpHJjFHsLAiQ5ENaZhAOC36nPeD7esdMl/zR5yFFdUmmlC4r3ZJZWkEjLuxX12SNgNroBTSUYxAkYlmUoyro4ruApx+fFbeBQW7wr1gyUpBW7O55CFp50ieRSjXFbkUkRI8laUapPNMdbe7yMZcfs7ANYreRBc4fHtKY8VqAyJhlbGbNYR1YiheOPP6vf4y9mnVGJEzeNiheOmpcuRv779lEYHKskbfVpr1G2eMHhgpj2VZJZWcz4ckCzQ5orOImERiG2PdonYH0csQjowTmcd14sJ1ga0F6prRXvQ9cLZZglNoyFkJ3PI9pr1Bx8z70bJR08RCF0oBEXFY8lGt31AohGCIepcxwmPF3OaKvH2/JpZLB5Ta2Zw5cvVKWHUKZWtPiUPXPZTkivz2DMPPZ1F2lxhLvQWSBZKphkykox4+NkNRKU/OOByuIc41GNqCivQoaTKXN9dsMWSnhC2YbBzzqPjAl5tnKScC2C1F09KILZmcH2P8sVdwJ1hCKQUaLuaPgWqYHRDpE+BOmbuHywIshXkwQLtUG30zF2Y1AOTmDbXqDjHdcss9uTGSQc12mei/+xfkFhxXz6iuZqTJkJ/VKxacjGkGVh11wvW+XpDxnoRIwFWjbFY+XjNuFoFouEmSBvQ5FsyxvttargxFVofQJyurRBxbFGhS2U5M7r7kRi36tB1FUNb4SaQpXjkJBOrjAajqjJVyBzXLYfVCpsYw2FAeyXijudMuFkxPa9J04BmxYJsXX1tVymW3YUH8Lp4ATqSsVkMSCphIYzCoeDlH3ay6+Z6z1KiT8oDPSvrAxKtvMp4HwAzwWysisVx3yk6pRhCg6PBUHXcBR+TgY2WdRVQH4suy/h/fEnzTcNkOmV+dozVkeFsQpoGrBZSU8jJ9UhSBAtCmkFuCilWAerFwiaolTDbxInsECzgwbeCbEAWvNdyzWjVTXwFR4MTq0yXy2KoDMtKAkK4G7OijjaGiDOZ9kzrgWxCnyLmwipXtLlCXLbZZG0UW61gtUJXHapKaGq8CeBgg6CD4gFSKi9rlWDBQYU0E9zHsKBkFmyddeSOtoiMJbuAiyD6jHu4lJxolBfd8QJRQ9XQ6ORosGNpd0GedTVxRG1TAK6rYAeyKb2FzXsCxE2Ajr5l/QCXV0gIxGVLrCoYmzJXgTDGSVBcFTuo6U4brBa6o0CuIc2FYV48wOriAfiapFKDFI9wPAgWtl6yzlAl5AyCI5UhwamqTAjGZNqTax37O8G9GKksQ9CqZBVLiqEsUmApExhJEXUu4ow2VcSFEhcDoTeihIDbaEp3sIwtFmXB19d8F3Q2Y356Ak3N5MExaRZp36gQU3IDfV1CSFIxieSSlovGCGaOjKILJbu4OoTykWiEytBQ0m0Qp6lLK9CnwKJtcBvDZoy5EKxowqBgUkLOpAh3k5HgtH1FykpohbgcU6vndQHxq225+pDwZQspEc8DYVIRVhOa64rcKKsTxSpKZzktVrcxOCVDyOuKbIzdAKKCySiWo/u7
|
|
|
|
|
|
},
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
"output_type": "display_data",
|
|
|
|
|
|
"jetTransient": {
|
|
|
|
|
|
"display_id": null
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 84
|
2026-03-19 05:47:14 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:09:55.272193107Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:09:55.113735528Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-19 05:47:14 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"net = nn.Sequential(nn.Flatten(),\n",
|
|
|
|
|
|
" nn.Linear(784,256),\n",
|
|
|
|
|
|
" nn.ReLU(),\n",
|
|
|
|
|
|
" nn.Linear(256,10))\n",
|
|
|
|
|
|
"def init_weights(m):\n",
|
|
|
|
|
|
" if type(m) == nn.Linear:\n",
|
|
|
|
|
|
" nn.init.normal_(m.weight,std=0.01)\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
"net.apply(init_weights)\n"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "9e4ed6d103380bc7",
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"text/plain": [
|
|
|
|
|
|
"Sequential(\n",
|
|
|
|
|
|
" (0): Flatten(start_dim=1, end_dim=-1)\n",
|
|
|
|
|
|
" (1): Linear(in_features=784, out_features=256, bias=True)\n",
|
|
|
|
|
|
" (2): ReLU()\n",
|
|
|
|
|
|
" (3): Linear(in_features=256, out_features=10, bias=True)\n",
|
|
|
|
|
|
")"
|
|
|
|
|
|
]
|
|
|
|
|
|
},
|
|
|
|
|
|
"execution_count": 85,
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 85
|
2026-03-19 05:47:14 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2026-03-19T06:09:56.028892806Z",
|
|
|
|
|
|
"start_time": "2026-03-19T06:09:55.915661103Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2026-03-19 05:47:14 +00:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"batch_size,lr, num_epochs=256,0.1,10\n",
|
|
|
|
|
|
"loss = nn.CrossEntropyLoss(reduction='none')\n",
|
|
|
|
|
|
"trainer = torch.optim.SGD(net.parameters(),lr=lr)\n",
|
|
|
|
|
|
"train_iter, test_iter = load_data_fashion_mnist(batch_size)\n",
|
|
|
|
|
|
"#train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "52d71c77c4f51e90",
|
|
|
|
|
|
"outputs": [],
|
2026-03-22 08:28:55 +00:00
|
|
|
|
"execution_count": 86
|
2026-03-19 05:47:14 +00:00
|
|
|
|
},
|
|
|
|
|
|
{
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"source": "",
|
|
|
|
|
|
"id": "94706c936b4be3e1",
|
|
|
|
|
|
"outputs": [],
|
|
|
|
|
|
"execution_count": null
|
2026-03-12 07:50:33 +00:00
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"metadata": {
|
|
|
|
|
|
"kernelspec": {
|
|
|
|
|
|
"display_name": "Python 3",
|
|
|
|
|
|
"language": "python",
|
|
|
|
|
|
"name": "python3"
|
|
|
|
|
|
},
|
|
|
|
|
|
"language_info": {
|
|
|
|
|
|
"codemirror_mode": {
|
|
|
|
|
|
"name": "ipython",
|
|
|
|
|
|
"version": 2
|
|
|
|
|
|
},
|
|
|
|
|
|
"file_extension": ".py",
|
|
|
|
|
|
"mimetype": "text/x-python",
|
|
|
|
|
|
"name": "python",
|
|
|
|
|
|
"nbconvert_exporter": "python",
|
|
|
|
|
|
"pygments_lexer": "ipython2",
|
|
|
|
|
|
"version": "2.7.6"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
|
|
|
|
|
"nbformat": 4,
|
|
|
|
|
|
"nbformat_minor": 5
|
|
|
|
|
|
}
|