From 45deb26b74a377638b7a387e7154d97b2dbeb1b1 Mon Sep 17 00:00:00 2001 From: yukun-hh Date: Tue, 30 Jun 2026 11:57:35 +0800 Subject: [PATCH] =?UTF-8?q?transformer=20encode=20=E5=9D=97=E5=AE=8C?= =?UTF-8?q?=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- chapter10.ipynb | 697 +++++++++++++++++++++++++++++------------------- 1 file changed, 419 insertions(+), 278 deletions(-) diff --git a/chapter10.ipynb b/chapter10.ipynb index cd1ba10..4c660b1 100644 --- a/chapter10.ipynb +++ b/chapter10.ipynb @@ -4,27 +4,36 @@ "cell_type": "code", "id": "initial_id", "metadata": { - "collapsed": true, "ExecuteTime": { - "end_time": "2026-05-29T15:19:42.541639835Z", - "start_time": "2026-05-29T15:19:40.265547407Z" + "end_time": "2026-06-30T02:59:03.101034703Z", + "start_time": "2026-06-30T02:58:59.676757575Z" } }, "source": [ "import torch\n", "from d2l import torch as d2l\n" ], - "outputs": [], - "execution_count": 2 + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/yukun/.conda/envs/nn/lib/python3.11/site-packages/torch/cuda/__init__.py:1007: UserWarning: Can't initialize NVML\n", + " raw_cnt = _raw_device_count_nvml()\n" + ] + } + ], + "execution_count": 1 }, { + "cell_type": "code", + "id": "a0e3f725b7764f08", "metadata": { "ExecuteTime": { - "end_time": "2026-05-29T15:19:42.593785744Z", - "start_time": "2026-05-29T15:19:42.545273168Z" + "end_time": "2026-06-30T02:59:03.137220451Z", + "start_time": "2026-06-30T02:59:03.118596105Z" } }, - "cell_type": "code", "source": [ "def show_heatmaps(matrices,xlabel,ylabel,titles=None,figsize=(2.5,2.5),cmap='Reds'):\n", " d2l.use_svg_display()\n", @@ -41,31 +50,30 @@ " ax.set_title(titles[j])\n", " fig.colorbar(pcm,ax=axes,shrink=0.6)" ], - "id": "a0e3f725b7764f08", "outputs": [], - "execution_count": 3 + "execution_count": 2 }, { + "cell_type": "code", + "id": "bb798de34fc4d0fa", "metadata": { "ExecuteTime": { - "end_time": "2026-05-29T15:19:42.832298248Z", - "start_time": "2026-05-29T15:19:42.594979091Z" + "end_time": "2026-06-30T02:59:03.336399820Z", + "start_time": "2026-06-30T02:59:03.139187972Z" } }, - "cell_type": "code", "source": [ "attention_weights = torch.eye(10).reshape((1, 1, 10, 10))\n", "\n", "show_heatmaps(attention_weights, xlabel='Keys', ylabel='Queries')" ], - "id": "bb798de34fc4d0fa", "outputs": [ { "data": { "text/plain": [ "
" ], - "image/svg+xml": "\n\n\n \n \n \n \n 2026-05-29T23:19:42.792345\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" + "image/svg+xml": "\n\n\n \n \n \n \n 2026-06-30T10:59:03.270032\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" }, "metadata": {}, "output_type": "display_data", @@ -74,48 +82,49 @@ } } ], - "execution_count": 4 + "execution_count": 3 }, { + "cell_type": "code", + "id": "a4070c75847fb887", "metadata": { "ExecuteTime": { - "end_time": "2026-05-29T15:19:42.925242334Z", - "start_time": "2026-05-29T15:19:42.844425757Z" + "end_time": "2026-06-30T02:59:03.393286045Z", + "start_time": "2026-06-30T02:59:03.339609950Z" } }, - "cell_type": "code", "source": [ "n_train = 50\n", "x_train,_ = torch.sort(torch.rand(n_train)*5)" ], - "id": "a4070c75847fb887", "outputs": [], - "execution_count": 5 + "execution_count": 4 }, { + "cell_type": "code", + "id": "2aca6952876d9cf5", "metadata": { "ExecuteTime": { - "end_time": "2026-05-29T15:19:42.978395734Z", - "start_time": "2026-05-29T15:19:42.926738710Z" + "end_time": "2026-06-30T02:59:03.465395938Z", + "start_time": "2026-06-30T02:59:03.397781049Z" } }, - "cell_type": "code", "source": [ "def f(x):\n", " return 2*torch.sin(x)+x**0.8" ], - "id": "2aca6952876d9cf5", "outputs": [], - "execution_count": 6 + "execution_count": 5 }, { + "cell_type": "code", + "id": "ea7b94585dcde934", "metadata": { "ExecuteTime": { - "end_time": "2026-05-29T15:19:43.056742866Z", - "start_time": "2026-05-29T15:19:42.982949395Z" + "end_time": "2026-06-30T02:59:03.554188488Z", + "start_time": "2026-06-30T02:59:03.467856816Z" } }, - "cell_type": "code", "source": [ "y_train = f(x_train)+torch.normal(0.0,0.5,(n_train,))\n", "x_test = torch.arange(0,5,0.1)\n", @@ -123,7 +132,6 @@ "n_test = len(x_test)\n", "n_test" ], - "id": "ea7b94585dcde934", "outputs": [ { "data": { @@ -131,51 +139,51 @@ "50" ] }, - "execution_count": 7, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], - "execution_count": 7 + "execution_count": 6 }, { + "cell_type": "code", + "id": "35c2403bbb5509cf", "metadata": { "ExecuteTime": { - "end_time": "2026-05-29T15:19:43.117629637Z", - "start_time": "2026-05-29T15:19:43.058324467Z" + "end_time": "2026-06-30T02:59:03.578849203Z", + "start_time": "2026-06-30T02:59:03.563506811Z" } }, - "cell_type": "code", "source": [ "def plot_kernel_reg(y_hat):\n", " d2l.plot(x_test, [y_truth, y_hat], 'x', 'y', legend=['Truth', 'Pred'],\n", " xlim=[0, 5], ylim=[-1, 5])\n", " d2l.plt.plot(x_train, y_train, 'o', alpha=0.5);" ], - "id": "35c2403bbb5509cf", "outputs": [], - "execution_count": 8 + "execution_count": 7 }, { + "cell_type": "code", + "id": "f27ecb9c68aed894", "metadata": { "ExecuteTime": { - "end_time": "2026-05-29T15:19:43.224202029Z", - "start_time": "2026-05-29T15:19:43.119223163Z" + "end_time": "2026-06-30T02:59:03.728510578Z", + "start_time": "2026-06-30T02:59:03.580633476Z" } }, - "cell_type": "code", "source": [ "y_hat = torch.repeat_interleave(y_train.mean(),n_test)\n", "plot_kernel_reg(y_hat)" ], - "id": "f27ecb9c68aed894", "outputs": [ { "data": { "text/plain": [ "
" ], - "image/svg+xml": "\n\n\n \n \n \n \n 2026-05-29T23:19:43.188981\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" + "image/svg+xml": "\n\n\n \n \n \n \n 2026-06-30T10:59:03.678092\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" }, "metadata": {}, "output_type": "display_data", @@ -184,16 +192,17 @@ } } ], - "execution_count": 9 + "execution_count": 8 }, { + "cell_type": "code", + "id": "fa7579d7cd8bfcb7", "metadata": { "ExecuteTime": { - "end_time": "2026-05-29T15:19:43.317103476Z", - "start_time": "2026-05-29T15:19:43.225907041Z" + "end_time": "2026-06-30T02:59:03.807801559Z", + "start_time": "2026-06-30T02:59:03.732916546Z" } }, - "cell_type": "code", "source": [ "from torch import nn\n", "X_repeat = x_test.repeat_interleave(n_train).reshape((-1, n_train))\n", @@ -201,20 +210,16 @@ "y_hat = torch.matmul(attention_weights,y_train)\n", "x_train,X_repeat,attention_weights" ], - "id": "fa7579d7cd8bfcb7", "outputs": [ { "data": { "text/plain": [ - "(tensor([9.4801e-04, 5.2621e-02, 2.3643e-01, 2.9541e-01, 3.2529e-01, 3.5294e-01,\n", - " 3.6399e-01, 4.1190e-01, 5.8751e-01, 6.8494e-01, 7.2698e-01, 8.0101e-01,\n", - " 9.2743e-01, 9.7225e-01, 1.1656e+00, 1.4397e+00, 1.4788e+00, 1.6088e+00,\n", - " 1.7147e+00, 1.8255e+00, 1.8951e+00, 1.9322e+00, 2.3849e+00, 2.4183e+00,\n", - " 2.4902e+00, 2.5441e+00, 2.6483e+00, 2.8371e+00, 2.9705e+00, 3.1383e+00,\n", - " 3.2709e+00, 3.3751e+00, 3.3849e+00, 3.4902e+00, 3.5679e+00, 3.5782e+00,\n", - " 3.5821e+00, 3.8648e+00, 4.1552e+00, 4.3914e+00, 4.4348e+00, 4.4414e+00,\n", - " 4.6038e+00, 4.6099e+00, 4.7481e+00, 4.7528e+00, 4.8005e+00, 4.8134e+00,\n", - " 4.8934e+00, 4.9044e+00]),\n", + "(tensor([0.0153, 0.1180, 0.1875, 0.4215, 0.5745, 0.6163, 0.6557, 0.6917, 0.7939,\n", + " 0.9230, 1.0240, 1.2932, 1.3246, 1.5167, 1.6193, 1.6496, 1.6613, 1.7104,\n", + " 1.7485, 1.7840, 1.7889, 2.0187, 2.0398, 2.0413, 2.2665, 2.4287, 2.4440,\n", + " 2.6026, 2.6988, 2.7922, 2.8129, 2.9356, 3.0271, 3.0442, 3.0646, 3.1996,\n", + " 3.2037, 3.3222, 3.4918, 3.5313, 3.7002, 3.7714, 3.7949, 3.9609, 4.3272,\n", + " 4.4440, 4.6667, 4.7213, 4.8963, 4.9971]),\n", " tensor([[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", " [0.1000, 0.1000, 0.1000, ..., 0.1000, 0.1000, 0.1000],\n", " [0.2000, 0.2000, 0.2000, ..., 0.2000, 0.2000, 0.2000],\n", @@ -222,45 +227,78 @@ " [4.7000, 4.7000, 4.7000, ..., 4.7000, 4.7000, 4.7000],\n", " [4.8000, 4.8000, 4.8000, ..., 4.8000, 4.8000, 4.8000],\n", " [4.9000, 4.9000, 4.9000, ..., 4.9000, 4.9000, 4.9000]]),\n", - " tensor([[6.8662e-02, 6.8567e-02, 6.6769e-02, ..., 6.3916e-07, 4.3363e-07,\n", - " 4.1091e-07],\n", - " [6.4257e-02, 6.4501e-02, 6.3975e-02, ..., 9.6787e-07, 6.6191e-07,\n", - " 6.2792e-07],\n", - " [5.9935e-02, 6.0473e-02, 6.1093e-02, ..., 1.4608e-06, 1.0070e-06,\n", - " 9.5634e-07],\n", + " tensor([[7.9004e-02, 7.8465e-02, 7.7637e-02, ..., 1.1413e-06, 4.9195e-07,\n", + " 2.9875e-07],\n", + " [7.2616e-02, 7.2865e-02, 7.2599e-02, ..., 1.6794e-06, 7.3669e-07,\n", + " 4.5190e-07],\n", + " [6.6459e-02, 6.7376e-02, 6.7597e-02, ..., 2.4606e-06, 1.0985e-06,\n", + " 6.8065e-07],\n", " ...,\n", - " [9.4878e-07, 1.2079e-06, 2.7906e-06, ..., 5.8779e-02, 5.8062e-02,\n", - " 5.7935e-02],\n", - " [6.1962e-07, 7.9294e-07, 1.8658e-06, ..., 6.2113e-02, 6.1848e-02,\n", - " 6.1781e-02],\n", - " [4.0301e-07, 5.1842e-07, 1.2425e-06, ..., 6.5370e-02, 6.5614e-02,\n", - " 6.5615e-02]]))" + " [1.3743e-06, 2.2120e-06, 3.0334e-06, ..., 8.0086e-02, 7.8576e-02,\n", + " 7.6646e-02],\n", + " [9.2161e-07, 1.4986e-06, 2.0694e-06, ..., 8.5977e-02, 8.5846e-02,\n", + " 8.4585e-02],\n", + " [6.1460e-07, 1.0097e-06, 1.4040e-06, ..., 9.1792e-02, 9.3269e-02,\n", + " 9.2831e-02]]))" ] }, - "execution_count": 10, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], - "execution_count": 10 + "execution_count": 9 }, { + "cell_type": "code", + "id": "f76573b2eb74ba7", "metadata": { "ExecuteTime": { - "end_time": "2026-05-29T15:19:43.435846223Z", - "start_time": "2026-05-29T15:19:43.319732830Z" + "end_time": "2026-06-30T02:59:03.923768674Z", + "start_time": "2026-06-30T02:59:03.813863850Z" } }, - "cell_type": "code", - "source": "plot_kernel_reg(y_hat)", - "id": "f76573b2eb74ba7", + "source": [ + "plot_kernel_reg(y_hat)" + ], "outputs": [ { "data": { "text/plain": [ "
" ], - "image/svg+xml": "\n\n\n \n \n \n \n 2026-05-29T23:19:43.394543\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" + "image/svg+xml": "\n\n\n \n \n \n \n 2026-06-30T10:59:03.882515\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" + }, + "metadata": {}, + "output_type": "display_data", + "jetTransient": { + "display_id": null + } + } + ], + "execution_count": 10 + }, + { + "cell_type": "code", + "id": "1bbe3af9e8a0f412", + "metadata": { + "ExecuteTime": { + "end_time": "2026-06-30T02:59:04.041592329Z", + "start_time": "2026-06-30T02:59:03.938730755Z" + } + }, + "source": [ + "d2l.show_heatmaps(attention_weights.unsqueeze(0).unsqueeze(0),\n", + " xlabel='Sorted training inputs',\n", + " ylabel='Sorted testing inputs')" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "
" + ], + "image/svg+xml": "\n\n\n \n \n \n \n 2026-06-30T10:59:04.008946\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" }, "metadata": {}, "output_type": "display_data", @@ -272,44 +310,14 @@ "execution_count": 11 }, { + "cell_type": "code", + "id": "aeb79a3bd64c826", "metadata": { "ExecuteTime": { - "end_time": "2026-05-29T15:19:43.558539946Z", - "start_time": "2026-05-29T15:19:43.438360863Z" + "end_time": "2026-06-30T02:59:04.105924659Z", + "start_time": "2026-06-30T02:59:04.055849961Z" } }, - "cell_type": "code", - "source": [ - "d2l.show_heatmaps(attention_weights.unsqueeze(0).unsqueeze(0),\n", - " xlabel='Sorted training inputs',\n", - " ylabel='Sorted testing inputs')" - ], - "id": "1bbe3af9e8a0f412", - "outputs": [ - { - "data": { - "text/plain": [ - "
" - ], - "image/svg+xml": "\n\n\n \n \n \n \n 2026-05-29T23:19:43.512317\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" - }, - "metadata": {}, - "output_type": "display_data", - "jetTransient": { - "display_id": null - } - } - ], - "execution_count": 12 - }, - { - "metadata": { - "ExecuteTime": { - "end_time": "2026-05-29T15:19:43.610782245Z", - "start_time": "2026-05-29T15:19:43.560607937Z" - } - }, - "cell_type": "code", "source": [ "class NWKernelRegression(nn.Module):\n", " def __init__(self,**kwargs):\n", @@ -321,18 +329,18 @@ " -((queries - keys) * self.w)**2 / 2, dim=1)\n", " return torch.bmm(self.attention_weights.unsqueeze(1),values.unsqueeze(-1)).reshape(-1)" ], - "id": "aeb79a3bd64c826", "outputs": [], - "execution_count": 13 + "execution_count": 12 }, { + "cell_type": "code", + "id": "7061ed7f139bc78e", "metadata": { "ExecuteTime": { - "end_time": "2026-05-29T15:19:43.661276280Z", - "start_time": "2026-05-29T15:19:43.611752619Z" + "end_time": "2026-06-30T02:59:04.158366434Z", + "start_time": "2026-06-30T02:59:04.107607867Z" } }, - "cell_type": "code", "source": [ "# X_tile的形状:(n_train,n_train),每一行都包含着相同的训练输入\n", "X_tile = x_train.repeat((n_train, 1))\n", @@ -343,18 +351,18 @@ "# values的形状:('n_train','n_train'-1)\n", "values = Y_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))" ], - "id": "7061ed7f139bc78e", "outputs": [], - "execution_count": 14 + "execution_count": 13 }, { + "cell_type": "code", + "id": "3dca11d4e962d859", "metadata": { "ExecuteTime": { - "end_time": "2026-05-29T15:19:44.080669795Z", - "start_time": "2026-05-29T15:19:43.662206275Z" + "end_time": "2026-06-30T02:59:04.750466831Z", + "start_time": "2026-06-30T02:59:04.162845978Z" } }, - "cell_type": "code", "source": [ "net = NWKernelRegression()\n", "loss = nn.MSELoss(reduction='none')\n", @@ -368,14 +376,47 @@ " print(f'epoch {epoch + 1}, loss {float(l.sum()):.6f}')\n", " animator.add(epoch + 1, float(l.sum()))" ], - "id": "3dca11d4e962d859", "outputs": [ { "data": { "text/plain": [ "
" ], - "image/svg+xml": "\n\n\n \n \n \n \n 2026-05-29T23:19:44.050719\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" + "image/svg+xml": "\n\n\n \n \n \n \n 2026-06-30T10:59:04.709978\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" + }, + "metadata": {}, + "output_type": "display_data", + "jetTransient": { + "display_id": null + } + } + ], + "execution_count": 14 + }, + { + "cell_type": "code", + "id": "4ff00f0b4983b55e", + "metadata": { + "ExecuteTime": { + "end_time": "2026-06-30T02:59:04.976760199Z", + "start_time": "2026-06-30T02:59:04.799699382Z" + } + }, + "source": [ + "# keys的形状:(n_test,n_train),每一行包含着相同的训练输入(例如,相同的键)\n", + "keys = x_train.repeat((n_test, 1))\n", + "# value的形状:(n_test,n_train)\n", + "values = y_train.repeat((n_test, 1))\n", + "y_hat = net(x_test, keys, values).unsqueeze(1).detach()\n", + "plot_kernel_reg(y_hat)" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "
" + ], + "image/svg+xml": "\n\n\n \n \n \n \n 2026-06-30T10:59:04.872782\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" }, "metadata": {}, "output_type": "display_data", @@ -387,29 +428,26 @@ "execution_count": 15 }, { + "cell_type": "code", + "id": "97a690478767c3c3", "metadata": { "ExecuteTime": { - "end_time": "2026-05-29T15:19:44.196985129Z", - "start_time": "2026-05-29T15:19:44.082458369Z" + "end_time": "2026-06-30T02:59:05.088839245Z", + "start_time": "2026-06-30T02:59:04.978734877Z" } }, - "cell_type": "code", "source": [ - "# keys的形状:(n_test,n_train),每一行包含着相同的训练输入(例如,相同的键)\n", - "keys = x_train.repeat((n_test, 1))\n", - "# value的形状:(n_test,n_train)\n", - "values = y_train.repeat((n_test, 1))\n", - "y_hat = net(x_test, keys, values).unsqueeze(1).detach()\n", - "plot_kernel_reg(y_hat)" + "d2l.show_heatmaps(net.attention_weights.unsqueeze(0).unsqueeze(0),\n", + "xlabel='Sorted training inputs',\n", + "ylabel='Sorted testing inputs')" ], - "id": "4ff00f0b4983b55e", "outputs": [ { "data": { "text/plain": [ - "
" + "
" ], - "image/svg+xml": "\n\n\n \n \n \n \n 2026-05-29T23:19:44.158025\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" + "image/svg+xml": "\n\n\n \n \n \n \n 2026-06-30T10:59:05.049223\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" }, "metadata": {}, "output_type": "display_data", @@ -421,44 +459,14 @@ "execution_count": 16 }, { + "cell_type": "code", + "id": "8d03d0c7f735b755", "metadata": { "ExecuteTime": { - "end_time": "2026-05-29T15:19:44.303651198Z", - "start_time": "2026-05-29T15:19:44.199584310Z" + "end_time": "2026-06-30T02:59:05.142764987Z", + "start_time": "2026-06-30T02:59:05.093810141Z" } }, - "cell_type": "code", - "source": [ - "d2l.show_heatmaps(net.attention_weights.unsqueeze(0).unsqueeze(0),\n", - "xlabel='Sorted training inputs',\n", - "ylabel='Sorted testing inputs')" - ], - "id": "97a690478767c3c3", - "outputs": [ - { - "data": { - "text/plain": [ - "
" - ], - "image/svg+xml": "\n\n\n \n \n \n \n 2026-05-29T23:19:44.256961\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" - }, - "metadata": {}, - "output_type": "display_data", - "jetTransient": { - "display_id": null - } - } - ], - "execution_count": 17 - }, - { - "metadata": { - "ExecuteTime": { - "end_time": "2026-05-29T15:19:44.358514818Z", - "start_time": "2026-05-29T15:19:44.306386112Z" - } - }, - "cell_type": "code", "source": [ "def masked_softmax(X,valid_lens):\n", " # X:3D valid_len 1D or 2D\n", @@ -474,46 +482,48 @@ " X = d2l.sequence_mask(X.reshape(-1,shape[-1]),valid_lens,value=-1e6)\n", " return nn.functional.softmax(X.reshape(shape),dim=-1)" ], - "id": "8d03d0c7f735b755", "outputs": [], - "execution_count": 18 + "execution_count": 17 }, { + "cell_type": "code", + "id": "27f8c24ec572b267", "metadata": { "ExecuteTime": { - "end_time": "2026-05-29T15:19:44.420710191Z", - "start_time": "2026-05-29T15:19:44.359351187Z" + "end_time": "2026-06-30T02:59:05.208569156Z", + "start_time": "2026-06-30T02:59:05.144491160Z" } }, - "cell_type": "code", - "source": "masked_softmax(torch.rand(2,2,4),torch.tensor([2,3]))", - "id": "27f8c24ec572b267", + "source": [ + "masked_softmax(torch.rand(2,2,4),torch.tensor([2,3]))" + ], "outputs": [ { "data": { "text/plain": [ - "tensor([[[0.6937, 0.3063, 0.0000, 0.0000],\n", - " [0.6165, 0.3835, 0.0000, 0.0000]],\n", + "tensor([[[0.5584, 0.4416, 0.0000, 0.0000],\n", + " [0.6205, 0.3795, 0.0000, 0.0000]],\n", "\n", - " [[0.3526, 0.3291, 0.3183, 0.0000],\n", - " [0.2735, 0.2359, 0.4906, 0.0000]]])" + " [[0.4421, 0.3398, 0.2181, 0.0000],\n", + " [0.3929, 0.3625, 0.2446, 0.0000]]])" ] }, - "execution_count": 19, + "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], - "execution_count": 19 + "execution_count": 18 }, { + "cell_type": "code", + "id": "c7e5254395b14b1", "metadata": { "ExecuteTime": { - "end_time": "2026-05-29T15:19:44.470391356Z", - "start_time": "2026-05-29T15:19:44.422576814Z" + "end_time": "2026-06-30T02:59:05.262719443Z", + "start_time": "2026-06-30T02:59:05.210360072Z" } }, - "cell_type": "code", "source": [ "class AdditiveAttention(nn.Module):\n", " def __init__(self,key_size,query_size,num_hiddens,dropout,**kwargs):\n", @@ -535,18 +545,18 @@ "\n", " return torch.bmm(self.dropout(self.attention_weights), value)" ], - "id": "c7e5254395b14b1", "outputs": [], - "execution_count": 20 + "execution_count": 19 }, { + "cell_type": "code", + "id": "77e325a032e70d98", "metadata": { "ExecuteTime": { - "end_time": "2026-05-29T15:19:44.536311739Z", - "start_time": "2026-05-29T15:19:44.471725853Z" + "end_time": "2026-06-30T02:59:05.332247710Z", + "start_time": "2026-06-30T02:59:05.264012533Z" } }, - "cell_type": "code", "source": [ "queries,keys = torch.normal(0,1,(2,1,20)) , torch.ones((2,10,2))\n", "values = torch.arange(40,dtype=torch.float32).reshape(1,10,4).repeat(2,1,1)\n", @@ -556,7 +566,6 @@ "attention.eval()\n", "attention(queries,keys,values,valid_lens)" ], - "id": "77e325a032e70d98", "outputs": [ { "data": { @@ -566,33 +575,33 @@ " [[10.0000, 11.0000, 12.0000, 13.0000]]], grad_fn=)" ] }, - "execution_count": 21, + "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], - "execution_count": 21 + "execution_count": 20 }, { + "cell_type": "code", + "id": "93ceedfd6e03982f", "metadata": { "ExecuteTime": { - "end_time": "2026-05-29T15:19:44.618227172Z", - "start_time": "2026-05-29T15:19:44.538570102Z" + "end_time": "2026-06-30T02:59:05.426180301Z", + "start_time": "2026-06-30T02:59:05.338806296Z" } }, - "cell_type": "code", "source": [ "d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),\n", " xlabel='Keys', ylabel='Queries')" ], - "id": "93ceedfd6e03982f", "outputs": [ { "data": { "text/plain": [ "
" ], - "image/svg+xml": "\n\n\n \n \n \n \n 2026-05-29T23:19:44.586919\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" + "image/svg+xml": "\n\n\n \n \n \n \n 2026-06-30T10:59:05.393045\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" }, "metadata": {}, "output_type": "display_data", @@ -601,16 +610,17 @@ } } ], - "execution_count": 22 + "execution_count": 21 }, { + "cell_type": "code", + "id": "8453c76623a5b435", "metadata": { "ExecuteTime": { - "end_time": "2026-05-29T15:19:44.668038302Z", - "start_time": "2026-05-29T15:19:44.620075434Z" + "end_time": "2026-06-30T02:59:05.479115341Z", + "start_time": "2026-06-30T02:59:05.429811516Z" } }, - "cell_type": "code", "source": [ "import math\n", "class DotProductAttention(nn.Module):\n", @@ -630,25 +640,24 @@ "\n", " return torch.bmm(self.dropout(self.attention_weights), values)" ], - "id": "8453c76623a5b435", "outputs": [], - "execution_count": 23 + "execution_count": 22 }, { + "cell_type": "code", + "id": "8af6b28944977a62", "metadata": { "ExecuteTime": { - "end_time": "2026-05-29T15:19:44.754268677Z", - "start_time": "2026-05-29T15:19:44.668929892Z" + "end_time": "2026-06-30T02:59:05.547775309Z", + "start_time": "2026-06-30T02:59:05.480327680Z" } }, - "cell_type": "code", "source": [ "queries = torch.normal(0, 1, (2, 1, 2))\n", "attention = DotProductAttention(dropout=0.5)\n", "attention.eval()\n", "attention(queries, keys, values, valid_lens)" ], - "id": "8af6b28944977a62", "outputs": [ { "data": { @@ -658,33 +667,33 @@ " [[10.0000, 11.0000, 12.0000, 13.0000]]])" ] }, - "execution_count": 24, + "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], - "execution_count": 24 + "execution_count": 23 }, { + "cell_type": "code", + "id": "28f7e16771076e33", "metadata": { "ExecuteTime": { - "end_time": "2026-05-29T15:19:44.822905172Z", - "start_time": "2026-05-29T15:19:44.756082157Z" + "end_time": "2026-06-30T02:59:05.663459211Z", + "start_time": "2026-06-30T02:59:05.549703930Z" } }, - "cell_type": "code", "source": [ "d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),\n", "xlabel='Keys', ylabel='Queries')" ], - "id": "28f7e16771076e33", "outputs": [ { "data": { "text/plain": [ "
" ], - "image/svg+xml": "\n\n\n \n \n \n \n 2026-05-29T23:19:44.802392\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" + "image/svg+xml": "\n\n\n \n \n \n \n 2026-06-30T10:59:05.624810\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" }, "metadata": {}, "output_type": "display_data", @@ -693,16 +702,17 @@ } } ], - "execution_count": 25 + "execution_count": 24 }, { + "cell_type": "code", + "id": "33507bd2e1917a29", "metadata": { "ExecuteTime": { - "end_time": "2026-05-29T15:19:44.874393721Z", - "start_time": "2026-05-29T15:19:44.824430337Z" + "end_time": "2026-06-30T02:59:05.730333551Z", + "start_time": "2026-06-30T02:59:05.680995441Z" } }, - "cell_type": "code", "source": [ "class AttentionDecoder(d2l.Decoder):\n", " def __init__(self,**kwargs):\n", @@ -711,18 +721,18 @@ " def attention_weight(self):\n", " raise NotImplementedError" ], - "id": "33507bd2e1917a29", "outputs": [], - "execution_count": 26 + "execution_count": 25 }, { + "cell_type": "code", + "id": "a8a497c9041910a1", "metadata": { "ExecuteTime": { - "end_time": "2026-05-29T15:19:44.926627628Z", - "start_time": "2026-05-29T15:19:44.875567483Z" + "end_time": "2026-06-30T02:59:05.781440228Z", + "start_time": "2026-06-30T02:59:05.731688196Z" } }, - "cell_type": "code", "source": [ "class Seq2SeqAttentionDecoder(AttentionDecoder):\n", " def __init__(self,vocab_size,embed_size,num_hiddens,num_layers,dropout=0,**kwargs):\n", @@ -772,18 +782,18 @@ " def attention_weights(self):\n", " return self._attention_weights" ], - "id": "a8a497c9041910a1", "outputs": [], - "execution_count": 27 + "execution_count": 26 }, { + "cell_type": "code", + "id": "b2a6bc735743bd0f", "metadata": { "ExecuteTime": { - "end_time": "2026-05-29T15:19:44.996664037Z", - "start_time": "2026-05-29T15:19:44.927689328Z" + "end_time": "2026-06-30T02:59:05.842620719Z", + "start_time": "2026-06-30T02:59:05.782837732Z" } }, - "cell_type": "code", "source": [ "encoder = d2l.Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16,\n", "num_layers=2)\n", @@ -796,7 +806,6 @@ "output, state = decoder(X, state)\n", "output.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape" ], - "id": "b2a6bc735743bd0f", "outputs": [ { "data": { @@ -804,34 +813,35 @@ "(torch.Size([4, 7, 10]), 3, torch.Size([4, 7, 16]), 2, torch.Size([4, 16]))" ] }, - "execution_count": 28, + "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], - "execution_count": 28 + "execution_count": 27 }, { - "metadata": { - "ExecuteTime": { - "end_time": "2026-05-29T15:19:45.047118759Z", - "start_time": "2026-05-29T15:19:44.998053092Z" - } - }, "cell_type": "code", - "source": "", "id": "4c7bf86095f15c98", - "outputs": [], - "execution_count": 28 - }, - { "metadata": { "ExecuteTime": { - "end_time": "2026-05-29T15:20:14.678359202Z", - "start_time": "2026-05-29T15:19:45.048442204Z" + "end_time": "2026-06-30T02:59:05.903307506Z", + "start_time": "2026-06-30T02:59:05.854726928Z" } }, + "source": [], + "outputs": [], + "execution_count": 27 + }, + { "cell_type": "code", + "id": "5f5cedf74b97bd12", + "metadata": { + "ExecuteTime": { + "end_time": "2026-06-30T03:22:39.836556973Z", + "start_time": "2026-06-30T02:59:05.904535279Z" + } + }, "source": [ "embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1\n", "batch_size, num_steps = 64, 10\n", @@ -844,13 +854,26 @@ "net = d2l.EncoderDecoder(encoder, decoder)\n", "d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)" ], - "id": "5f5cedf74b97bd12", "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "loss 0.019, 17862.0 tokens/sec on cuda:0\n" + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001B[31m---------------------------------------------------------------------------\u001B[39m", + "\u001B[31mKeyboardInterrupt\u001B[39m Traceback (most recent call last)", + "\u001B[36mCell\u001B[39m\u001B[36m \u001B[39m\u001B[32mIn[28]\u001B[39m\u001B[32m, line 10\u001B[39m\n\u001B[32m 7\u001B[39m decoder = Seq2SeqAttentionDecoder(\n\u001B[32m 8\u001B[39m \u001B[38;5;28mlen\u001B[39m(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)\n\u001B[32m 9\u001B[39m net = d2l.EncoderDecoder(encoder, decoder)\n\u001B[32m---> \u001B[39m\u001B[32m10\u001B[39m \u001B[43md2l\u001B[49m\u001B[43m.\u001B[49m\u001B[43mtrain_seq2seq\u001B[49m\u001B[43m(\u001B[49m\u001B[43mnet\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mtrain_iter\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mlr\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mnum_epochs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mtgt_vocab\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mdevice\u001B[49m\u001B[43m)\u001B[49m\n", + "\u001B[36mFile \u001B[39m\u001B[32m~/.conda/envs/nn/lib/python3.11/site-packages/d2l/torch.py:3421\u001B[39m, in \u001B[36mtrain_seq2seq\u001B[39m\u001B[34m(net, data_iter, lr, num_epochs, tgt_vocab, device)\u001B[39m\n\u001B[32m 3418\u001B[39m bos = torch.tensor([tgt_vocab[\u001B[33m'\u001B[39m\u001B[33m\u001B[39m\u001B[33m'\u001B[39m]] * Y.shape[\u001B[32m0\u001B[39m],\n\u001B[32m 3419\u001B[39m device=device).reshape(-\u001B[32m1\u001B[39m, \u001B[32m1\u001B[39m)\n\u001B[32m 3420\u001B[39m dec_input = d2l.concat([bos, Y[:, :-\u001B[32m1\u001B[39m]], \u001B[32m1\u001B[39m) \u001B[38;5;66;03m# Teacher forcing\u001B[39;00m\n\u001B[32m-> \u001B[39m\u001B[32m3421\u001B[39m Y_hat, _ = \u001B[43mnet\u001B[49m\u001B[43m(\u001B[49m\u001B[43mX\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mdec_input\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mX_valid_len\u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m 3422\u001B[39m l = loss(Y_hat, Y, Y_valid_len)\n\u001B[32m 3423\u001B[39m l.sum().backward() \u001B[38;5;66;03m# Make the loss scalar for `backward`\u001B[39;00m\n", + "\u001B[36mFile \u001B[39m\u001B[32m~/.conda/envs/nn/lib/python3.11/site-packages/torch/nn/modules/module.py:1776\u001B[39m, in \u001B[36mModule._wrapped_call_impl\u001B[39m\u001B[34m(self, *args, **kwargs)\u001B[39m\n\u001B[32m 1774\u001B[39m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m._compiled_call_impl(*args, **kwargs) \u001B[38;5;66;03m# type: ignore[misc]\u001B[39;00m\n\u001B[32m 1775\u001B[39m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[32m-> \u001B[39m\u001B[32m1776\u001B[39m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[43m.\u001B[49m\u001B[43m_call_impl\u001B[49m\u001B[43m(\u001B[49m\u001B[43m*\u001B[49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43m*\u001B[49m\u001B[43m*\u001B[49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n", + "\u001B[36mFile \u001B[39m\u001B[32m~/.conda/envs/nn/lib/python3.11/site-packages/torch/nn/modules/module.py:1787\u001B[39m, in \u001B[36mModule._call_impl\u001B[39m\u001B[34m(self, *args, **kwargs)\u001B[39m\n\u001B[32m 1782\u001B[39m \u001B[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001B[39;00m\n\u001B[32m 1783\u001B[39m \u001B[38;5;66;03m# this function, and just call forward.\u001B[39;00m\n\u001B[32m 1784\u001B[39m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m (\u001B[38;5;28mself\u001B[39m._backward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m._backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m._forward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m._forward_pre_hooks\n\u001B[32m 1785\u001B[39m \u001B[38;5;129;01mor\u001B[39;00m _global_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_backward_hooks\n\u001B[32m 1786\u001B[39m \u001B[38;5;129;01mor\u001B[39;00m _global_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_forward_pre_hooks):\n\u001B[32m-> \u001B[39m\u001B[32m1787\u001B[39m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mforward_call\u001B[49m\u001B[43m(\u001B[49m\u001B[43m*\u001B[49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43m*\u001B[49m\u001B[43m*\u001B[49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m 1789\u001B[39m result = \u001B[38;5;28;01mNone\u001B[39;00m\n\u001B[32m 1790\u001B[39m called_always_called_hooks = \u001B[38;5;28mset\u001B[39m()\n", + "\u001B[36mFile \u001B[39m\u001B[32m~/.conda/envs/nn/lib/python3.11/site-packages/d2l/torch.py:964\u001B[39m, in \u001B[36mEncoderDecoder.forward\u001B[39m\u001B[34m(self, enc_X, dec_X, *args)\u001B[39m\n\u001B[32m 962\u001B[39m dec_state = \u001B[38;5;28mself\u001B[39m.decoder.init_state(enc_all_outputs, *args)\n\u001B[32m 963\u001B[39m \u001B[38;5;66;03m# Return decoder output only\u001B[39;00m\n\u001B[32m--> \u001B[39m\u001B[32m964\u001B[39m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[43m.\u001B[49m\u001B[43mdecoder\u001B[49m\u001B[43m(\u001B[49m\u001B[43mdec_X\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mdec_state\u001B[49m\u001B[43m)\u001B[49m\n", + "\u001B[36mFile \u001B[39m\u001B[32m~/.conda/envs/nn/lib/python3.11/site-packages/torch/nn/modules/module.py:1776\u001B[39m, in \u001B[36mModule._wrapped_call_impl\u001B[39m\u001B[34m(self, *args, **kwargs)\u001B[39m\n\u001B[32m 1774\u001B[39m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m._compiled_call_impl(*args, **kwargs) \u001B[38;5;66;03m# type: ignore[misc]\u001B[39;00m\n\u001B[32m 1775\u001B[39m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[32m-> \u001B[39m\u001B[32m1776\u001B[39m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[43m.\u001B[49m\u001B[43m_call_impl\u001B[49m\u001B[43m(\u001B[49m\u001B[43m*\u001B[49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43m*\u001B[49m\u001B[43m*\u001B[49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n", + "\u001B[36mFile \u001B[39m\u001B[32m~/.conda/envs/nn/lib/python3.11/site-packages/torch/nn/modules/module.py:1787\u001B[39m, in \u001B[36mModule._call_impl\u001B[39m\u001B[34m(self, *args, **kwargs)\u001B[39m\n\u001B[32m 1782\u001B[39m \u001B[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001B[39;00m\n\u001B[32m 1783\u001B[39m \u001B[38;5;66;03m# this function, and just call forward.\u001B[39;00m\n\u001B[32m 1784\u001B[39m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m (\u001B[38;5;28mself\u001B[39m._backward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m._backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m._forward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m._forward_pre_hooks\n\u001B[32m 1785\u001B[39m \u001B[38;5;129;01mor\u001B[39;00m _global_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_backward_hooks\n\u001B[32m 1786\u001B[39m \u001B[38;5;129;01mor\u001B[39;00m _global_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_forward_pre_hooks):\n\u001B[32m-> \u001B[39m\u001B[32m1787\u001B[39m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mforward_call\u001B[49m\u001B[43m(\u001B[49m\u001B[43m*\u001B[49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43m*\u001B[49m\u001B[43m*\u001B[49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m 1789\u001B[39m result = \u001B[38;5;28;01mNone\u001B[39;00m\n\u001B[32m 1790\u001B[39m called_always_called_hooks = \u001B[38;5;28mset\u001B[39m()\n", + "\u001B[36mCell\u001B[39m\u001B[36m \u001B[39m\u001B[32mIn[26]\u001B[39m\u001B[32m, line 37\u001B[39m, in \u001B[36mSeq2SeqAttentionDecoder.forward\u001B[39m\u001B[34m(self, X, state)\u001B[39m\n\u001B[32m 35\u001B[39m x = torch.cat((context, torch.unsqueeze(x, dim=\u001B[32m1\u001B[39m)), dim=-\u001B[32m1\u001B[39m)\n\u001B[32m 36\u001B[39m \u001B[38;5;66;03m# 将x变形为(1,batch_size,embed_size+num_hiddens)\u001B[39;00m\n\u001B[32m---> \u001B[39m\u001B[32m37\u001B[39m out, hidden_state = \u001B[38;5;28;43mself\u001B[39;49m\u001B[43m.\u001B[49m\u001B[43mrnn\u001B[49m\u001B[43m(\u001B[49m\u001B[43mx\u001B[49m\u001B[43m.\u001B[49m\u001B[43mpermute\u001B[49m\u001B[43m(\u001B[49m\u001B[32;43m1\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[32;43m0\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[32;43m2\u001B[39;49m\u001B[43m)\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mhidden_state\u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m 38\u001B[39m outputs.append(out)\n\u001B[32m 39\u001B[39m \u001B[38;5;28mself\u001B[39m._attention_weights.append(\u001B[38;5;28mself\u001B[39m.attention.attention_weights)\n", + "\u001B[36mFile \u001B[39m\u001B[32m~/.conda/envs/nn/lib/python3.11/site-packages/torch/nn/modules/module.py:1776\u001B[39m, in \u001B[36mModule._wrapped_call_impl\u001B[39m\u001B[34m(self, *args, **kwargs)\u001B[39m\n\u001B[32m 1774\u001B[39m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m._compiled_call_impl(*args, **kwargs) \u001B[38;5;66;03m# type: ignore[misc]\u001B[39;00m\n\u001B[32m 1775\u001B[39m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[32m-> \u001B[39m\u001B[32m1776\u001B[39m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[43m.\u001B[49m\u001B[43m_call_impl\u001B[49m\u001B[43m(\u001B[49m\u001B[43m*\u001B[49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43m*\u001B[49m\u001B[43m*\u001B[49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n", + "\u001B[36mFile \u001B[39m\u001B[32m~/.conda/envs/nn/lib/python3.11/site-packages/torch/nn/modules/module.py:1787\u001B[39m, in \u001B[36mModule._call_impl\u001B[39m\u001B[34m(self, *args, **kwargs)\u001B[39m\n\u001B[32m 1782\u001B[39m \u001B[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001B[39;00m\n\u001B[32m 1783\u001B[39m \u001B[38;5;66;03m# this function, and just call forward.\u001B[39;00m\n\u001B[32m 1784\u001B[39m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m (\u001B[38;5;28mself\u001B[39m._backward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m._backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m._forward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m._forward_pre_hooks\n\u001B[32m 1785\u001B[39m \u001B[38;5;129;01mor\u001B[39;00m _global_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_backward_hooks\n\u001B[32m 1786\u001B[39m \u001B[38;5;129;01mor\u001B[39;00m _global_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_forward_pre_hooks):\n\u001B[32m-> \u001B[39m\u001B[32m1787\u001B[39m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mforward_call\u001B[49m\u001B[43m(\u001B[49m\u001B[43m*\u001B[49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43m*\u001B[49m\u001B[43m*\u001B[49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m 1789\u001B[39m result = \u001B[38;5;28;01mNone\u001B[39;00m\n\u001B[32m 1790\u001B[39m called_always_called_hooks = \u001B[38;5;28mset\u001B[39m()\n", + "\u001B[36mFile \u001B[39m\u001B[32m~/.conda/envs/nn/lib/python3.11/site-packages/torch/nn/modules/rnn.py:1415\u001B[39m, in \u001B[36mGRU.forward\u001B[39m\u001B[34m(self, input, hx)\u001B[39m\n\u001B[32m 1413\u001B[39m \u001B[38;5;28mself\u001B[39m.check_forward_args(\u001B[38;5;28minput\u001B[39m, hx, batch_sizes)\n\u001B[32m 1414\u001B[39m \u001B[38;5;28;01mif\u001B[39;00m batch_sizes \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[32m-> \u001B[39m\u001B[32m1415\u001B[39m result = \u001B[43m_VF\u001B[49m\u001B[43m.\u001B[49m\u001B[43mgru\u001B[49m\u001B[43m(\u001B[49m\n\u001B[32m 1416\u001B[39m \u001B[43m \u001B[49m\u001B[38;5;28;43minput\u001B[39;49m\u001B[43m,\u001B[49m\n\u001B[32m 1417\u001B[39m \u001B[43m \u001B[49m\u001B[43mhx\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 1418\u001B[39m \u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m.\u001B[49m\u001B[43m_flat_weights\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;66;43;03m# type: ignore[arg-type]\u001B[39;49;00m\n\u001B[32m 1419\u001B[39m \u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m.\u001B[49m\u001B[43mbias\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 1420\u001B[39m \u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m.\u001B[49m\u001B[43mnum_layers\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 1421\u001B[39m \u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m.\u001B[49m\u001B[43mdropout\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 1422\u001B[39m \u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m.\u001B[49m\u001B[43mtraining\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 1423\u001B[39m \u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m.\u001B[49m\u001B[43mbidirectional\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 1424\u001B[39m \u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m.\u001B[49m\u001B[43mbatch_first\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 1425\u001B[39m \u001B[43m \u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m 1426\u001B[39m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[32m 1427\u001B[39m result = _VF.gru(\n\u001B[32m 1428\u001B[39m \u001B[38;5;28minput\u001B[39m,\n\u001B[32m 1429\u001B[39m batch_sizes,\n\u001B[32m (...)\u001B[39m\u001B[32m 1436\u001B[39m \u001B[38;5;28mself\u001B[39m.bidirectional,\n\u001B[32m 1437\u001B[39m )\n", + "\u001B[31mKeyboardInterrupt\u001B[39m: " ] }, { @@ -858,7 +881,7 @@ "text/plain": [ "
" ], - "image/svg+xml": "\n\n\n \n \n \n \n 2026-05-29T23:20:14.649299\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" + "image/svg+xml": "\n\n\n \n \n \n \n 2026-06-30T11:22:39.515794\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" }, "metadata": {}, "output_type": "display_data", @@ -867,16 +890,17 @@ } } ], - "execution_count": 29 + "execution_count": 28 }, { + "cell_type": "code", + "id": "ddaed8bd284ee01b", "metadata": { "ExecuteTime": { - "end_time": "2026-05-29T15:20:14.794039857Z", - "start_time": "2026-05-29T15:20:14.727727180Z" + "end_time": "2026-06-30T03:22:50.427691735Z", + "start_time": "2026-06-30T03:22:50.314718050Z" } }, - "cell_type": "code", "source": [ "engs = ['go .', \"i lost .\", 'he\\'s calm .', 'i\\'m home .']\n", "fras = ['va !', 'j\\'ai perdu .', 'il est calme .', 'je suis chez moi .']\n", @@ -886,7 +910,6 @@ " print(f'{eng} => {translation}, ',\n", " f'bleu {d2l.bleu(translation, fra, k=2):.3f}')" ], - "id": "ddaed8bd284ee01b", "outputs": [ { "name": "stdout", @@ -894,21 +917,22 @@ "text": [ "go . => va !, bleu 1.000\n", "i lost . => j'ai perdu ., bleu 1.000\n", - "he's calm . => il est ., bleu 0.658\n", - "i'm home . => je suis chez moi ., bleu 1.000\n" + "he's calm . => il est riche ., bleu 0.658\n", + "i'm home . => je suis calme ., bleu 0.512\n" ] } ], - "execution_count": 30 + "execution_count": 29 }, { + "cell_type": "code", + "id": "13d19f8b5048c74d", "metadata": { "ExecuteTime": { - "end_time": "2026-05-29T15:20:38.342057516Z", - "start_time": "2026-05-29T15:20:38.288567783Z" + "end_time": "2026-06-30T03:22:57.445168273Z", + "start_time": "2026-06-30T03:22:57.376381899Z" } }, - "cell_type": "code", "source": [ "class MultiHeadAttention(nn.Module):\n", " def __init__(self, key_size, query_size, value_size, num_hiddens,\n", @@ -955,25 +979,24 @@ " X = X.permute(0, 2, 1, 3)\n", " return X.reshape(X.shape[0], X.shape[1], -1)\n" ], - "id": "13d19f8b5048c74d", "outputs": [], - "execution_count": 32 + "execution_count": 30 }, { + "cell_type": "code", + "id": "b25c661a511d7763", "metadata": { "ExecuteTime": { - "end_time": "2026-05-29T15:20:39.251763154Z", - "start_time": "2026-05-29T15:20:39.201715213Z" + "end_time": "2026-06-30T03:22:58.271805155Z", + "start_time": "2026-06-30T03:22:58.200822898Z" } }, - "cell_type": "code", "source": [ "num_hiddens, num_heads = 100, 5\n", "attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,\n", "num_hiddens, num_heads, 0.5)\n", "attention.eval()" ], - "id": "b25c661a511d7763", "outputs": [ { "data": { @@ -989,21 +1012,22 @@ ")" ] }, - "execution_count": 33, + "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], - "execution_count": 33 + "execution_count": 31 }, { + "cell_type": "code", + "id": "49877c929741ec94", "metadata": { "ExecuteTime": { - "end_time": "2026-05-29T15:20:40.455647106Z", - "start_time": "2026-05-29T15:20:40.307682855Z" + "end_time": "2026-06-30T03:22:58.957157613Z", + "start_time": "2026-06-30T03:22:58.850698625Z" } }, - "cell_type": "code", "source": [ "batch_size, num_queries = 2, 4\n", "num_kvpairs, valid_lens = 6, torch.tensor([3, 2])\n", @@ -1011,7 +1035,6 @@ "Y = torch.ones((batch_size, num_kvpairs, num_hiddens))\n", "attention(X, Y, Y, valid_lens).shape" ], - "id": "49877c929741ec94", "outputs": [ { "data": { @@ -1019,25 +1042,143 @@ "torch.Size([2, 4, 100])" ] }, - "execution_count": 34, + "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], - "execution_count": 34 + "execution_count": 32 + }, + { + "cell_type": "code", + "id": "b8a10576c7789750", + "metadata": { + "ExecuteTime": { + "end_time": "2026-06-30T03:24:09.856237679Z", + "start_time": "2026-06-30T03:24:09.793761501Z" + } + }, + "source": [ + "class PositionalEncoding(nn.Module):\n", + " def __init__(self,num_hiddens,dropout,max_len=1000):\n", + " super().__init__()\n", + " self.dropout = nn.Dropout(dropout)\n", + " self.dropout = nn.Dropout(dropout)\n", + " X = torch.arange(max_len, dtype=torch.float32).reshape(\n", + " -1, 1) / torch.pow(10000, torch.arange(\n", + " 0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)\n", + " self.P = torch.zeros((1,max_len,num_hiddens))\n", + " self.P[:,:,0::2] = torch.sin(X)\n", + " self.P[:,:,1::2] = torch.cos(X)\n", + " def forward(self,X):\n", + " X = X + self.P[:,:X.shape[1],:].to(X.device)\n", + " return self.dropout(X)" + ], + "outputs": [], + "execution_count": 35 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2026-06-30T03:24:10.924435124Z", + "start_time": "2026-06-30T03:24:10.759631840Z" + } + }, + "cell_type": "code", + "source": [ + "encoding_dim, num_steps = 32, 60\n", + "pos_encoding = PositionalEncoding(encoding_dim, 0)\n", + "pos_encoding.eval()\n", + "X = pos_encoding(torch.zeros((1, num_steps, encoding_dim)))\n", + "P = pos_encoding.P[:, :X.shape[1], :]\n", + "d2l.plot(torch.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)',\n", + " figsize=(6, 2.5), legend=[\"Col %d\" % d for d in torch.arange(6, 10)])" + ], + "id": "fe6cd8d94205977f", + "outputs": [ + { + "data": { + "text/plain": [ + "
" + ], + "image/svg+xml": "\n\n\n \n \n \n \n 2026-06-30T11:24:10.850286\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" + }, + "metadata": {}, + "output_type": "display_data", + "jetTransient": { + "display_id": null + } + } + ], + "execution_count": 36 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2026-06-30T03:29:41.868213373Z", + "start_time": "2026-06-30T03:29:41.818146905Z" + } + }, + "cell_type": "code", + "source": [ + "class PositionWiseFFN(nn.Module):\n", + " def __init__(self,ffn_num_input,ffn_num_hiddens,ffn_num_outputs,**kwargs):\n", + " super().__init__(**kwargs)\n", + " self.dense1 = nn.Linear(ffn_num_input,ffn_num_hiddens)\n", + " self.relu = nn.ReLU()\n", + " self.dense2 = nn.Linear(ffn_num_hiddens,ffn_num_outputs)\n", + " def forward(self,X):\n", + " return self.dense2(self.relu(self.dense1(X)))\n" + ], + "id": "b764f19d408049b", + "outputs": [], + "execution_count": 38 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2026-06-30T03:35:03.713494687Z", + "start_time": "2026-06-30T03:35:03.630606833Z" + } + }, + "cell_type": "code", + "source": [ + "class AddNorm(nn.Module):\n", + " def __init__(self,normalized_shape,dropout,**kwargs):\n", + " super().__init__()\n", + " self.dropout=nn.Dropout()\n", + " self.ln = nn.LayerNorm(normalized_shape)\n", + " def forward(self,X,Y):\n", + " return self.ln(self.dropout(Y)+X)\n" + ], + "id": "377f4d9eea428068", + "outputs": [], + "execution_count": 39 }, { "metadata": {}, "cell_type": "code", "outputs": [], "execution_count": null, - "source": "", - "id": "b8a10576c7789750" + "source": [ + "class EncoderBlock(nn.Moudle):\n", + " def __init(self,key_size,query_size,value_size,num_hiddens,norm_shape,ffn_num_input,ffn_num_hiddens,num_heads,dropout,use_bias=False,**kwargs):\n", + " super(EncoderBlock,self).__init__(**kwargs)\n", + " self.attention = MultiHeadAttention(key_size,query_size,value_size,num_hiddens,num_heads,dropout,bias=use_bias)\n", + " self.addnorm1 = AddNorm(norm_shape,dropout)\n", + " self.ffn = PositionWiseFFN(\n", + " ffn_num_input, ffn_num_hiddens, num_hiddens)\n", + " self.addnorm2 = AddNorm(norm_shape, dropout)\n", + " def forward(self, X, valid_lens):\n", + " Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))\n", + " return self.addnorm2(Y, self.ffn(Y))" + ], + "id": "c54dc4e20e9ffb90" } ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" },