nn/chapter10.ipynb

1201 lines
307 KiB
Text
Raw Permalink Normal View History

2026-06-28 14:39:18 +00:00
{
"cells": [
{
"cell_type": "code",
"id": "initial_id",
"metadata": {
"ExecuteTime": {
2026-06-30 03:57:35 +00:00
"end_time": "2026-06-30T02:59:03.101034703Z",
"start_time": "2026-06-30T02:58:59.676757575Z"
2026-06-28 14:39:18 +00:00
}
},
"source": [
"import torch\n",
"from d2l import torch as d2l\n"
],
2026-06-30 03:57:35 +00:00
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/yukun/.conda/envs/nn/lib/python3.11/site-packages/torch/cuda/__init__.py:1007: UserWarning: Can't initialize NVML\n",
" raw_cnt = _raw_device_count_nvml()\n"
]
}
],
"execution_count": 1
2026-06-28 14:39:18 +00:00
},
{
2026-06-30 03:57:35 +00:00
"cell_type": "code",
"id": "a0e3f725b7764f08",
2026-06-28 14:39:18 +00:00
"metadata": {
"ExecuteTime": {
2026-06-30 03:57:35 +00:00
"end_time": "2026-06-30T02:59:03.137220451Z",
"start_time": "2026-06-30T02:59:03.118596105Z"
2026-06-28 14:39:18 +00:00
}
},
"source": [
"def show_heatmaps(matrices,xlabel,ylabel,titles=None,figsize=(2.5,2.5),cmap='Reds'):\n",
" d2l.use_svg_display()\n",
" num_rows,num_cols = matrices.shape[0],matrices.shape[1]\n",
" fig,axes = d2l.plt.subplots(num_rows,num_cols,figsize=figsize,sharex=True,squeeze=False)\n",
" for i,(row_axes,row_matrices) in enumerate(zip(axes,matrices)):\n",
" for j,(ax,matrix) in enumerate(zip(row_axes,row_matrices)):\n",
" pcm = ax.imshow(matrix.detach().numpy(),cmap=cmap)\n",
" if i == num_rows - 1:\n",
" ax.set_xlabel(xlabel)\n",
" if j == 0:\n",
" ax.set_ylabel(ylabel)\n",
" if titles:\n",
" ax.set_title(titles[j])\n",
" fig.colorbar(pcm,ax=axes,shrink=0.6)"
],
"outputs": [],
2026-06-30 03:57:35 +00:00
"execution_count": 2
2026-06-28 14:39:18 +00:00
},
{
2026-06-30 03:57:35 +00:00
"cell_type": "code",
"id": "bb798de34fc4d0fa",
2026-06-28 14:39:18 +00:00
"metadata": {
"ExecuteTime": {
2026-06-30 03:57:35 +00:00
"end_time": "2026-06-30T02:59:03.336399820Z",
"start_time": "2026-06-30T02:59:03.139187972Z"
2026-06-28 14:39:18 +00:00
}
},
"source": [
"attention_weights = torch.eye(10).reshape((1, 1, 10, 10))\n",
"\n",
"show_heatmaps(attention_weights, xlabel='Keys', ylabel='Queries')"
],
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 250x250 with 2 Axes>"
],
2026-06-30 03:57:35 +00:00
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"193.43925pt\" height=\"156.35625pt\" viewBox=\"0 0 193.43925 156.35625\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2026-06-30T10:59:03.270032</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.7.2, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M -0 156.35625 \nL 193.43925 156.35625 \nL 193.43925 0 \nL -0 0 \nz\n\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 34.240625 118.8 \nL 145.840625 118.8 \nL 145.840625 7.2 \nL 34.240625 7.2 \nz\n\"/>\n </g>\n <g clip-path=\"url(#pa25cfb4ac9)\">\n <image xlink:href=\"data:image/png;base64,\niVBORw0KGgoAAAANSUhEUgAAAJsAAACbCAYAAAB1YemMAAAB+ElEQVR4nO3csY3CQBBAUUBkXAn0ByVRIAWQEPtKuDOL/srye7nlDb4mGs1xeb+WA/zD/XId+v70pXfAn8RGRmxkxEZGbGTERkZsZMRGRmxkxEZGbGTERkZsZMRGRmxkzrMfQGtkJ+3xfg7922QjIzYyYiMjNjJiIyM2MmIjIzYyYiMjNjJiIyM2MmIjIzYyVow2ZvRs1eia0AiTjYzYyIiNjNjIiI2M2MiIjYzYyIiNjNjIiI2M2MiIjYzYyIiNjH22CWaerZrJZCMjNjJiIyM2MmIjIzYyYiMjNjJiIyM2MmIjIzYyYiMjNjJWjD6w5bNVM5lsZMRGRmxkxEZGbGTERkZsZMRGRmxkxEZGbGTERkZsZMRGRmxkdrvPttezVTOZbGTERkZsZMRGRmxkxEZGbGTERkZsZMRGRmxkxEZGbGTERmazK0bOVm2PyUZGbGTERkZsZMRGRmxkxEZGbGTERkZsZMRGRmxkxEZGbGTERmbqPpuzVftispERGxmxkREbGbGRERsZsZERGxmxkREbGbGRERsZsZERG5mhFSNnq1jDZCMjNjJiIyM2MmIjIzYyYiMjNjJiIyM2MmIjIzYyYiMjNjJiI3O8HX6WTz+2j8YaJhsZsZERGxmxkREbGbGRERsZsZERGxmxkREbGbGRERsZsZH5BQ6wIZohMfZYAAAAAElFTkSuQmCC\" id=\"image0da7a678b4\" transform=\"scale(1 -1) translate(0 -111.6)\" x=\"34.240625\" y=\"-7.2\" width=\"111.6\" height=\"111.6\"/>\n </g>\n <g id=\"matplotlib.axis_1\">\n <g id=\"xtick_1\">\n <g id=\"line2d_1\">\n <defs>\n <path id=\"m132af9a215\" d=\"M 0 0 \nL 0 3.5 \n\" style=\"stroke: #ffffff; stroke-width: 0.8\"/>\n </defs>\n <g>\n <use xlink:href=\"#m132af9a215\" x=\"39.820625\" y=\"118.8\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_1\">\n <!-- 0 -->\n <g style=\"fill: #ffffff\" transform=\"translate(36.639375 133.398438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-30\" d=\"M 2034 4250 \nQ 1547 4250 1301 3770 \nQ 1056 3291 1056 2328 \nQ 1056 1369 1301 889 \nQ 1547 409 2034 409 \nQ 2525 409 2770 889 \nQ 3016 1369 3016 2328 \nQ 3016 3291 2770 3770 \nQ 2525 4250 2034 4250 \nz\nM 2034 4750 \nQ 2819 4750 3233 4129 \nQ 3647 3509 3647 2328 \nQ 3647 1150 3233 529 \nQ 2819 -91 2034 -91 \nQ 1250 -91 836 529 \nQ 422 1150 422 2328 \nQ 422 3509 836 4129 \nQ 1250 4750 2034 4750 \nz\n\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-30\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_2\">\n <g id=\"line2d_2\">\n <g>\n <use xlink:href=\"#m132af9a215\" x=\"95.620625\" y=\"118.8\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_2\">\n <!-- 5 -->\n <g style=\"fill: #ffffff\" transform=\"translate(92.439375 133.398438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-35\" d=\"M 691 4666 \nL 3169 4666 \nL 3169 4134 \nL 1269 4134 \nL 1269 2991 \nQ 1406 3038 1543 3061 \nQ 1681 3084 1819 3084 \nQ 2600 3084 3056 2656 \nQ 3513 2228 3513 1497 \nQ 3513 744 3044 326 \nQ 2575 -91 1722 -91 \nQ 1428 -91 1123 -41 \nQ 819 9 494 109 \nL 494 744 \nQ 775 591 1075 516 \nQ
2026-06-28 14:39:18 +00:00
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
}
],
2026-06-30 03:57:35 +00:00
"execution_count": 3
2026-06-28 14:39:18 +00:00
},
{
2026-06-30 03:57:35 +00:00
"cell_type": "code",
"id": "a4070c75847fb887",
2026-06-28 14:39:18 +00:00
"metadata": {
"ExecuteTime": {
2026-06-30 03:57:35 +00:00
"end_time": "2026-06-30T02:59:03.393286045Z",
"start_time": "2026-06-30T02:59:03.339609950Z"
2026-06-28 14:39:18 +00:00
}
},
"source": [
"n_train = 50\n",
"x_train,_ = torch.sort(torch.rand(n_train)*5)"
],
"outputs": [],
2026-06-30 03:57:35 +00:00
"execution_count": 4
2026-06-28 14:39:18 +00:00
},
{
2026-06-30 03:57:35 +00:00
"cell_type": "code",
"id": "2aca6952876d9cf5",
2026-06-28 14:39:18 +00:00
"metadata": {
"ExecuteTime": {
2026-06-30 03:57:35 +00:00
"end_time": "2026-06-30T02:59:03.465395938Z",
"start_time": "2026-06-30T02:59:03.397781049Z"
2026-06-28 14:39:18 +00:00
}
},
"source": [
"def f(x):\n",
" return 2*torch.sin(x)+x**0.8"
],
"outputs": [],
2026-06-30 03:57:35 +00:00
"execution_count": 5
2026-06-28 14:39:18 +00:00
},
{
2026-06-30 03:57:35 +00:00
"cell_type": "code",
"id": "ea7b94585dcde934",
2026-06-28 14:39:18 +00:00
"metadata": {
"ExecuteTime": {
2026-06-30 03:57:35 +00:00
"end_time": "2026-06-30T02:59:03.554188488Z",
"start_time": "2026-06-30T02:59:03.467856816Z"
2026-06-28 14:39:18 +00:00
}
},
"source": [
"y_train = f(x_train)+torch.normal(0.0,0.5,(n_train,))\n",
"x_test = torch.arange(0,5,0.1)\n",
"y_truth = f(x_test)\n",
"n_test = len(x_test)\n",
"n_test"
],
"outputs": [
{
"data": {
"text/plain": [
"50"
]
},
2026-06-30 03:57:35 +00:00
"execution_count": 6,
2026-06-28 14:39:18 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-06-30 03:57:35 +00:00
"execution_count": 6
2026-06-28 14:39:18 +00:00
},
{
2026-06-30 03:57:35 +00:00
"cell_type": "code",
"id": "35c2403bbb5509cf",
2026-06-28 14:39:18 +00:00
"metadata": {
"ExecuteTime": {
2026-06-30 03:57:35 +00:00
"end_time": "2026-06-30T02:59:03.578849203Z",
"start_time": "2026-06-30T02:59:03.563506811Z"
2026-06-28 14:39:18 +00:00
}
},
"source": [
"def plot_kernel_reg(y_hat):\n",
" d2l.plot(x_test, [y_truth, y_hat], 'x', 'y', legend=['Truth', 'Pred'],\n",
" xlim=[0, 5], ylim=[-1, 5])\n",
" d2l.plt.plot(x_train, y_train, 'o', alpha=0.5);"
],
"outputs": [],
2026-06-30 03:57:35 +00:00
"execution_count": 7
2026-06-28 14:39:18 +00:00
},
{
2026-06-30 03:57:35 +00:00
"cell_type": "code",
"id": "f27ecb9c68aed894",
2026-06-28 14:39:18 +00:00
"metadata": {
"ExecuteTime": {
2026-06-30 03:57:35 +00:00
"end_time": "2026-06-30T02:59:03.728510578Z",
"start_time": "2026-06-30T02:59:03.580633476Z"
2026-06-28 14:39:18 +00:00
}
},
"source": [
"y_hat = torch.repeat_interleave(y_train.mean(),n_test)\n",
"plot_kernel_reg(y_hat)"
],
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 350x250 with 1 Axes>"
],
2026-06-30 03:57:35 +00:00
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"248.301563pt\" height=\"187.155469pt\" viewBox=\"0 0 248.301563 187.155469\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2026-06-30T10:59:03.678092</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.7.2, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 187.155469 \nL 248.301563 187.155469 \nL 248.301563 0 \nL 0 0 \nz\n\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 42.620312 149.599219 \nL 237.920313 149.599219 \nL 237.920313 10.999219 \nL 42.620312 10.999219 \nz\n\"/>\n </g>\n <g id=\"matplotlib.axis_1\">\n <g id=\"xtick_1\">\n <g id=\"line2d_1\">\n <path d=\"M 42.620312 149.599219 \nL 42.620312 10.999219 \n\" clip-path=\"url(#p4d701e4fe2)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_2\">\n <defs>\n <path id=\"mae512846f2\" d=\"M 0 0 \nL 0 3.5 \n\" style=\"stroke: #ffffff; stroke-width: 0.8\"/>\n </defs>\n <g>\n <use xlink:href=\"#mae512846f2\" x=\"42.620312\" y=\"149.599219\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_1\">\n <!-- 0 -->\n <g style=\"fill: #ffffff\" transform=\"translate(39.439062 164.197656) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-30\" d=\"M 2034 4250 \nQ 1547 4250 1301 3770 \nQ 1056 3291 1056 2328 \nQ 1056 1369 1301 889 \nQ 1547 409 2034 409 \nQ 2525 409 2770 889 \nQ 3016 1369 3016 2328 \nQ 3016 3291 2770 3770 \nQ 2525 4250 2034 4250 \nz\nM 2034 4750 \nQ 2819 4750 3233 4129 \nQ 3647 3509 3647 2328 \nQ 3647 1150 3233 529 \nQ 2819 -91 2034 -91 \nQ 1250 -91 836 529 \nQ 422 1150 422 2328 \nQ 422 3509 836 4129 \nQ 1250 4750 2034 4750 \nz\n\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-30\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_2\">\n <g id=\"line2d_3\">\n <path d=\"M 81.680312 149.599219 \nL 81.680312 10.999219 \n\" clip-path=\"url(#p4d701e4fe2)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_4\">\n <g>\n <use xlink:href=\"#mae512846f2\" x=\"81.680312\" y=\"149.599219\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_2\">\n <!-- 1 -->\n <g style=\"fill: #ffffff\" transform=\"translate(78.499062 164.197656) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-31\" d=\"M 794 531 \nL 1825 531 \nL 1825 4091 \nL 703 3866 \nL 703 4441 \nL 1819 4666 \nL 2450 4666 \nL 2450 531 \nL 3481 531 \nL 3481 0 \nL 794 0 \nL 794 531 \nz\n\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-31\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_3\">\n <g id=\"line2d_5\">\n <path d=\"M 120.740313 149.599219 \nL 120.740313 10.999219 \n\" clip-path=\"url(#p4d701e4fe2)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_6\">\n <g>\n <use xlink:href=\"#mae512846f2\" x=\"120.740313\" y=\"149.599219\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_3\">\n <!-- 2 -->\n <g styl
2026-06-28 14:39:18 +00:00
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
}
],
2026-06-30 03:57:35 +00:00
"execution_count": 8
2026-06-28 14:39:18 +00:00
},
{
2026-06-30 03:57:35 +00:00
"cell_type": "code",
"id": "fa7579d7cd8bfcb7",
2026-06-28 14:39:18 +00:00
"metadata": {
"ExecuteTime": {
2026-06-30 03:57:35 +00:00
"end_time": "2026-06-30T02:59:03.807801559Z",
"start_time": "2026-06-30T02:59:03.732916546Z"
2026-06-28 14:39:18 +00:00
}
},
"source": [
"from torch import nn\n",
"X_repeat = x_test.repeat_interleave(n_train).reshape((-1, n_train))\n",
"attention_weights = nn.functional.softmax(-(X_repeat - x_train)**2 / 2, dim=1)\n",
"y_hat = torch.matmul(attention_weights,y_train)\n",
"x_train,X_repeat,attention_weights"
],
"outputs": [
{
"data": {
"text/plain": [
2026-06-30 03:57:35 +00:00
"(tensor([0.0153, 0.1180, 0.1875, 0.4215, 0.5745, 0.6163, 0.6557, 0.6917, 0.7939,\n",
" 0.9230, 1.0240, 1.2932, 1.3246, 1.5167, 1.6193, 1.6496, 1.6613, 1.7104,\n",
" 1.7485, 1.7840, 1.7889, 2.0187, 2.0398, 2.0413, 2.2665, 2.4287, 2.4440,\n",
" 2.6026, 2.6988, 2.7922, 2.8129, 2.9356, 3.0271, 3.0442, 3.0646, 3.1996,\n",
" 3.2037, 3.3222, 3.4918, 3.5313, 3.7002, 3.7714, 3.7949, 3.9609, 4.3272,\n",
" 4.4440, 4.6667, 4.7213, 4.8963, 4.9971]),\n",
2026-06-28 14:39:18 +00:00
" tensor([[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
" [0.1000, 0.1000, 0.1000, ..., 0.1000, 0.1000, 0.1000],\n",
" [0.2000, 0.2000, 0.2000, ..., 0.2000, 0.2000, 0.2000],\n",
" ...,\n",
" [4.7000, 4.7000, 4.7000, ..., 4.7000, 4.7000, 4.7000],\n",
" [4.8000, 4.8000, 4.8000, ..., 4.8000, 4.8000, 4.8000],\n",
" [4.9000, 4.9000, 4.9000, ..., 4.9000, 4.9000, 4.9000]]),\n",
2026-06-30 03:57:35 +00:00
" tensor([[7.9004e-02, 7.8465e-02, 7.7637e-02, ..., 1.1413e-06, 4.9195e-07,\n",
" 2.9875e-07],\n",
" [7.2616e-02, 7.2865e-02, 7.2599e-02, ..., 1.6794e-06, 7.3669e-07,\n",
" 4.5190e-07],\n",
" [6.6459e-02, 6.7376e-02, 6.7597e-02, ..., 2.4606e-06, 1.0985e-06,\n",
" 6.8065e-07],\n",
2026-06-28 14:39:18 +00:00
" ...,\n",
2026-06-30 03:57:35 +00:00
" [1.3743e-06, 2.2120e-06, 3.0334e-06, ..., 8.0086e-02, 7.8576e-02,\n",
" 7.6646e-02],\n",
" [9.2161e-07, 1.4986e-06, 2.0694e-06, ..., 8.5977e-02, 8.5846e-02,\n",
" 8.4585e-02],\n",
" [6.1460e-07, 1.0097e-06, 1.4040e-06, ..., 9.1792e-02, 9.3269e-02,\n",
" 9.2831e-02]]))"
2026-06-28 14:39:18 +00:00
]
},
2026-06-30 03:57:35 +00:00
"execution_count": 9,
2026-06-28 14:39:18 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-06-30 03:57:35 +00:00
"execution_count": 9
2026-06-28 14:39:18 +00:00
},
{
2026-06-30 03:57:35 +00:00
"cell_type": "code",
"id": "f76573b2eb74ba7",
2026-06-28 14:39:18 +00:00
"metadata": {
"ExecuteTime": {
2026-06-30 03:57:35 +00:00
"end_time": "2026-06-30T02:59:03.923768674Z",
"start_time": "2026-06-30T02:59:03.813863850Z"
2026-06-28 14:39:18 +00:00
}
},
2026-06-30 03:57:35 +00:00
"source": [
"plot_kernel_reg(y_hat)"
],
2026-06-28 14:39:18 +00:00
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 350x250 with 1 Axes>"
],
2026-06-30 03:57:35 +00:00
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"248.301563pt\" height=\"187.155469pt\" viewBox=\"0 0 248.301563 187.155469\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2026-06-30T10:59:03.882515</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.7.2, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 187.155469 \nL 248.301563 187.155469 \nL 248.301563 0 \nL 0 0 \nz\n\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 42.620312 149.599219 \nL 237.920313 149.599219 \nL 237.920313 10.999219 \nL 42.620312 10.999219 \nz\n\"/>\n </g>\n <g id=\"matplotlib.axis_1\">\n <g id=\"xtick_1\">\n <g id=\"line2d_1\">\n <path d=\"M 42.620312 149.599219 \nL 42.620312 10.999219 \n\" clip-path=\"url(#pdb04ad0de6)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_2\">\n <defs>\n <path id=\"ma117b8ec92\" d=\"M 0 0 \nL 0 3.5 \n\" style=\"stroke: #ffffff; stroke-width: 0.8\"/>\n </defs>\n <g>\n <use xlink:href=\"#ma117b8ec92\" x=\"42.620312\" y=\"149.599219\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_1\">\n <!-- 0 -->\n <g style=\"fill: #ffffff\" transform=\"translate(39.439062 164.197656) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-30\" d=\"M 2034 4250 \nQ 1547 4250 1301 3770 \nQ 1056 3291 1056 2328 \nQ 1056 1369 1301 889 \nQ 1547 409 2034 409 \nQ 2525 409 2770 889 \nQ 3016 1369 3016 2328 \nQ 3016 3291 2770 3770 \nQ 2525 4250 2034 4250 \nz\nM 2034 4750 \nQ 2819 4750 3233 4129 \nQ 3647 3509 3647 2328 \nQ 3647 1150 3233 529 \nQ 2819 -91 2034 -91 \nQ 1250 -91 836 529 \nQ 422 1150 422 2328 \nQ 422 3509 836 4129 \nQ 1250 4750 2034 4750 \nz\n\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-30\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_2\">\n <g id=\"line2d_3\">\n <path d=\"M 81.680312 149.599219 \nL 81.680312 10.999219 \n\" clip-path=\"url(#pdb04ad0de6)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_4\">\n <g>\n <use xlink:href=\"#ma117b8ec92\" x=\"81.680312\" y=\"149.599219\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_2\">\n <!-- 1 -->\n <g style=\"fill: #ffffff\" transform=\"translate(78.499062 164.197656) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-31\" d=\"M 794 531 \nL 1825 531 \nL 1825 4091 \nL 703 3866 \nL 703 4441 \nL 1819 4666 \nL 2450 4666 \nL 2450 531 \nL 3481 531 \nL 3481 0 \nL 794 0 \nL 794 531 \nz\n\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-31\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_3\">\n <g id=\"line2d_5\">\n <path d=\"M 120.740313 149.599219 \nL 120.740313 10.999219 \n\" clip-path=\"url(#pdb04ad0de6)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_6\">\n <g>\n <use xlink:href=\"#ma117b8ec92\" x=\"120.740313\" y=\"149.599219\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_3\">\n <!-- 2 -->\n <g styl
2026-06-28 14:39:18 +00:00
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
}
],
2026-06-30 03:57:35 +00:00
"execution_count": 10
2026-06-28 14:39:18 +00:00
},
{
2026-06-30 03:57:35 +00:00
"cell_type": "code",
"id": "1bbe3af9e8a0f412",
2026-06-28 14:39:18 +00:00
"metadata": {
"ExecuteTime": {
2026-06-30 03:57:35 +00:00
"end_time": "2026-06-30T02:59:04.041592329Z",
"start_time": "2026-06-30T02:59:03.938730755Z"
2026-06-28 14:39:18 +00:00
}
},
"source": [
"d2l.show_heatmaps(attention_weights.unsqueeze(0).unsqueeze(0),\n",
" xlabel='Sorted training inputs',\n",
" ylabel='Sorted testing inputs')"
],
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 250x250 with 2 Axes>"
],
2026-06-30 03:57:35 +00:00
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"206.16425pt\" height=\"159.039469pt\" viewBox=\"0 0 206.16425 159.039469\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2026-06-30T10:59:04.008946</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.7.2, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 159.039469 \nL 206.16425 159.039469 \nL 206.16425 0 \nL 0 0 \nz\n\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 40.603125 121.483219 \nL 152.203125 121.483219 \nL 152.203125 9.883219 \nL 40.603125 9.883219 \nz\n\"/>\n </g>\n <g clip-path=\"url(#p05be1612c0)\">\n <image xlink:href=\"data:image/png;base64,\niVBORw0KGgoAAAANSUhEUgAAAJsAAACbCAYAAAB1YemMAAAQOElEQVR4nO2d3XYbtxVGAc6QFGU5Tp32su//IH2M9KZZqzeJ69qxZYsSyelFamEfcD7oUJJBZa2zr0ZDzAzpIPjm/CJP1x+m9Cdhmp7xq/Je0+Hh8Yd9Od7f6XE7fLbfzV+DZ0831+X8VxzzO+HZ08d3+tkf/lOO//0vPOOmnP/ll/KVPnwuj7grz/j0z1/vj//xc7nnbxiTUko/fy2/aYd/zi2+e075/nihv3kQPC8x2YJujOf+Ag/xKOl8ityqa42sNWTXyO1u9tj8prttOX8HuTtgDO95W8Yffa+br+U0pfP2tpzflmNK54TjA569XpT16M1o/23+vi7TR/2LrIqKxsoW9CMmW9CNFy+jkkdJpUMiFZTOejxvS8mjFUkrVcniLaSPv2+P8ZDKI/gZpDNBUg+3kHPK6A7HePR6UXRwqtamYZ3THMtczq9yWKPBGYjJFnTjzyWjbuk8UWI9921Zo9O85TjRGjUSWa6feJ6Saqxayi7Gp2StVl6/xfFdkfBph2fv5q1RcglrdExWNjeL8mw6bzeQ3otxuD+OlS3oRky2oBsvRkal89ac/04OXs8zRKzy6HpK5w4W4R2OD0IWeaxk9KaSUbKdd+Qapy6lcw9JxW+AAZk2kMH1wf7bLCCXPL56VabVZlOOY2ULuhGTLehGTLagGy/mne1kvG6QU90aavzB985mIgV4BzNBdr6DIRBvIgCPcH1Mwt2Rdoga7Pk7ptnjYSxr0OWlniLrVXmfG8byzvbD61U5f7W+P46VLehGTLagGy9TRpUrYhLH7Zs9fI1LRkWeWv0M5cq4FTJqgufb+TF89raRzyak0xwfxO+D62Ic6MZYliHV0nR5WT7LqzKVlj9d3R8Pry/K9fNPDoLnJyZb0I2zyujkkUU5piGjrmgEMIH1h2V0qmWUUkYL1OSqiSA75ZWSKvLZpqO0cKaYs7IL16g0dkhnxpDFskyLV6/mz6eU0vi2fLigjP71dRl0VSQ1VragGzHZgm68TGtUWqCO9O36Go816pFRFWCvrp/4GSVyS4nE87ZCajnGOHWrZxMho4mOXErngLUml9+9uEDw/FVx0C7WxfpMKaXV34pcZkrsTz+V45DR4BzEZAu60V1GddGxw+p8jDXqucbEOufHyBTvlCpZhFwyHroVuWomlVtUVx2cMroThdDASCdy1fg85qYNl4htbqyM5rd/KX8s8dnbt2XMqyK1sbIF3YjJFnSji4y6Ur7VsSoObsZGPfFUIVNKapWVmarfZyxQYWmqlCGkjk/q32A/XwX14Gf/J0M6c54vMuazxx8vywebjR345k05XhWrNb+BvF6FjAZnICZb0I3vJqO+ail1XlmjThlV49QlB1ia6r4q5lk/z1iXjvShLVtmzXenbMqoeE2gRE44ziOt0fm1htZo+uGHcnx5acblNz+WP9bFak0/Fms0XYZTNzgDMdmCbjxJRp/eUNkR9zyoHhvVs5VceqRXdYUkqrVV/YytcOQKh+3E4mWTnSt+d6vrJb97no+BLi6sY3aOTAdtQ0aNXPKaK1ipIaPBOYjJFnTjZBl9ekPlJzhcPalA9Tgpo6LWU8i5aW1152xbRetSyaUsTBHfuyWjdNKO+E8rrNQknLrpohSpZKQIpVdXdtwVJBZOXcpovkA27/zTguD5ickWdMMlo27pPDV9SGXhqvOtGOap9yKiK6Qc06rdVFarklF17Ikb1ygZJTxfF4J+gzHQ1/OW5dFnA2pINxi3LhZsrGxBN2KyBd2IyRZ04/QIwlOjBtIN4jD1+c7WqrQ6qHcfcY1qQWrGtCIIdH2I9qIqyO55T+O/gXJXpJTSMMwf8xp6+tV7HSuTW+9sGxQjL+H6gLsjrcv7X6xsQTdisgXdeGI+22NcIh6JFN23D8LDXj9O7B+lJFJ2hSRu1wfvxS0g1e8T5xXKXVF/xmPKKKIDRkYZuL+EDFI6Kxk1Lg7cK18gYA+XSKxsQTdisgXdkDKq21k9plWVyMcy5x1RA+adtQLSe9HzQlmaTN9W91WVUn98mXLMQuGdkFFz7ROlU0Grk5YpogPZRBMwhgF2FBkbKzOllDb4e8C9huXs+VjZgm7EZAu6YWT09D4c9TBPMbKSUVHhhPMTHav1s0wFkti12GNpOqqrjBM4pUpGRRNllZ/myS8jyhGbkpVL5pfhfN7AUqQjls+mNQrpzLU1uoJlm1m1BRldhIwGZyAmW9CN8fmKiRvjpHQq5ys6Od45tlGs70uLUMmoykFTUmu26ml0nvTEOhUeq5NSWQMLNNN5S0uTMrpCYTFlEFZmNnHOqrqKVqdxKM/nzMXKFnQjJlvQDWvanCqd9XhXDFQ4chlHVLFG1QQ5VZb0TlijPE9UV0hC6azv44l10trjeZUKpFjqImNTXEy5pIzSopTWKMZQUml9pmStzizisnh2rGxBN2KyBd0YfRVRj+m34bieDltVKGxaVTnkLiUrcyZFSTl1HdYou0LSWVvjsTopWSotSIzPtYzyGlqXyJA1jmDl1KWc0wJd4p5D/WzxCsDzYY0G5yAmW9CNKtD2hGLi1vWezNndfBHvpIp+G/u0u1KMVH8OlWKk4pwtKC3KGUsZZLar7MOxmT9ff0a5NClGkM
2026-06-28 14:39:18 +00:00
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
}
],
2026-06-30 03:57:35 +00:00
"execution_count": 11
2026-06-28 14:39:18 +00:00
},
{
2026-06-30 03:57:35 +00:00
"cell_type": "code",
"id": "aeb79a3bd64c826",
2026-06-28 14:39:18 +00:00
"metadata": {
"ExecuteTime": {
2026-06-30 03:57:35 +00:00
"end_time": "2026-06-30T02:59:04.105924659Z",
"start_time": "2026-06-30T02:59:04.055849961Z"
2026-06-28 14:39:18 +00:00
}
},
"source": [
"class NWKernelRegression(nn.Module):\n",
" def __init__(self,**kwargs):\n",
" super().__init__(**kwargs)\n",
" self.w = nn.Parameter(torch.randn((1,),requires_grad=True))\n",
" def forward(self,queries,keys,values):\n",
" queries = queries.repeat_interleave(keys.shape[1]).reshape(-1,keys.shape[1])\n",
" self.attention_weights = nn.functional.softmax(\n",
" -((queries - keys) * self.w)**2 / 2, dim=1)\n",
" return torch.bmm(self.attention_weights.unsqueeze(1),values.unsqueeze(-1)).reshape(-1)"
],
"outputs": [],
2026-06-30 03:57:35 +00:00
"execution_count": 12
2026-06-28 14:39:18 +00:00
},
{
2026-06-30 03:57:35 +00:00
"cell_type": "code",
"id": "7061ed7f139bc78e",
2026-06-28 14:39:18 +00:00
"metadata": {
"ExecuteTime": {
2026-06-30 03:57:35 +00:00
"end_time": "2026-06-30T02:59:04.158366434Z",
"start_time": "2026-06-30T02:59:04.107607867Z"
2026-06-28 14:39:18 +00:00
}
},
"source": [
"# X_tile的形状:(n_trainn_train),每一行都包含着相同的训练输入\n",
"X_tile = x_train.repeat((n_train, 1))\n",
"# Y_tile的形状:(n_trainn_train),每一行都包含着相同的训练输出\n",
"Y_tile = y_train.repeat((n_train, 1))\n",
"# keys的形状:('n_train''n_train'-1)\n",
"keys = X_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))\n",
"# values的形状:('n_train''n_train'-1)\n",
"values = Y_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))"
],
"outputs": [],
2026-06-30 03:57:35 +00:00
"execution_count": 13
2026-06-28 14:39:18 +00:00
},
{
2026-06-30 03:57:35 +00:00
"cell_type": "code",
"id": "3dca11d4e962d859",
2026-06-28 14:39:18 +00:00
"metadata": {
"ExecuteTime": {
2026-06-30 03:57:35 +00:00
"end_time": "2026-06-30T02:59:04.750466831Z",
"start_time": "2026-06-30T02:59:04.162845978Z"
2026-06-28 14:39:18 +00:00
}
},
"source": [
"net = NWKernelRegression()\n",
"loss = nn.MSELoss(reduction='none')\n",
"trainer = torch.optim.SGD(net.parameters(), lr=0.5)\n",
"animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[1, 5])\n",
"for epoch in range(5):\n",
" trainer.zero_grad()\n",
" l = loss(net(x_train, keys, values), y_train)\n",
" l.sum().backward()\n",
" trainer.step()\n",
" print(f'epoch {epoch + 1}, loss {float(l.sum()):.6f}')\n",
" animator.add(epoch + 1, float(l.sum()))"
],
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 350x250 with 1 Axes>"
],
2026-06-30 03:57:35 +00:00
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"246.284375pt\" height=\"183.35625pt\" viewBox=\"0 0 246.284375 183.35625\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2026-06-30T10:59:04.709978</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.7.2, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 183.35625 \nL 246.284375 183.35625 \nL 246.284375 0 \nL 0 0 \nz\n\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 40.603125 145.8 \nL 235.903125 145.8 \nL 235.903125 7.2 \nL 40.603125 7.2 \nz\n\"/>\n </g>\n <g id=\"matplotlib.axis_1\">\n <g id=\"xtick_1\">\n <g id=\"line2d_1\">\n <path d=\"M 40.603125 145.8 \nL 40.603125 7.2 \n\" clip-path=\"url(#p99eebed7bc)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_2\">\n <defs>\n <path id=\"m1cbf40412a\" d=\"M 0 0 \nL 0 3.5 \n\" style=\"stroke: #ffffff; stroke-width: 0.8\"/>\n </defs>\n <g>\n <use xlink:href=\"#m1cbf40412a\" x=\"40.603125\" y=\"145.8\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_1\">\n <!-- 1 -->\n <g style=\"fill: #ffffff\" transform=\"translate(37.421875 160.398438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-31\" d=\"M 794 531 \nL 1825 531 \nL 1825 4091 \nL 703 3866 \nL 703 4441 \nL 1819 4666 \nL 2450 4666 \nL 2450 531 \nL 3481 531 \nL 3481 0 \nL 794 0 \nL 794 531 \nz\n\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-31\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_2\">\n <g id=\"line2d_3\">\n <path d=\"M 89.428125 145.8 \nL 89.428125 7.2 \n\" clip-path=\"url(#p99eebed7bc)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_4\">\n <g>\n <use xlink:href=\"#m1cbf40412a\" x=\"89.428125\" y=\"145.8\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_2\">\n <!-- 2 -->\n <g style=\"fill: #ffffff\" transform=\"translate(86.246875 160.398438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-32\" d=\"M 1228 531 \nL 3431 531 \nL 3431 0 \nL 469 0 \nL 469 531 \nQ 828 903 1448 1529 \nQ 2069 2156 2228 2338 \nQ 2531 2678 2651 2914 \nQ 2772 3150 2772 3378 \nQ 2772 3750 2511 3984 \nQ 2250 4219 1831 4219 \nQ 1534 4219 1204 4116 \nQ 875 4013 500 3803 \nL 500 4441 \nQ 881 4594 1212 4672 \nQ 1544 4750 1819 4750 \nQ 2544 4750 2975 4387 \nQ 3406 4025 3406 3419 \nQ 3406 3131 3298 2873 \nQ 3191 2616 2906 2266 \nQ 2828 2175 2409 1742 \nQ 1991 1309 1228 531 \nz\n\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-32\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_3\">\n <g id=\"line2d_5\">\n <path d=\"M 138.253125 145.8 \nL 138.253125 7.2 \n\" clip-path=\"url(#p99eebed7bc)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_6\">\n <g>\n <use xlink:href=\"#m1cbf40412a\" x=\"138.253125\" y=\"145.8\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_3\">\n <!-- 3 -->\n <g style=\"fill: #ffffff\" tran
2026-06-28 14:39:18 +00:00
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
}
],
2026-06-30 03:57:35 +00:00
"execution_count": 14
2026-06-28 14:39:18 +00:00
},
{
2026-06-30 03:57:35 +00:00
"cell_type": "code",
"id": "4ff00f0b4983b55e",
2026-06-28 14:39:18 +00:00
"metadata": {
"ExecuteTime": {
2026-06-30 03:57:35 +00:00
"end_time": "2026-06-30T02:59:04.976760199Z",
"start_time": "2026-06-30T02:59:04.799699382Z"
2026-06-28 14:39:18 +00:00
}
},
"source": [
"# keys的形状:(n_testn_train),每一行包含着相同的训练输入(例如,相同的键)\n",
"keys = x_train.repeat((n_test, 1))\n",
"# value的形状:(n_testn_train)\n",
"values = y_train.repeat((n_test, 1))\n",
"y_hat = net(x_test, keys, values).unsqueeze(1).detach()\n",
"plot_kernel_reg(y_hat)"
],
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 350x250 with 1 Axes>"
],
2026-06-30 03:57:35 +00:00
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"248.301563pt\" height=\"187.155469pt\" viewBox=\"0 0 248.301563 187.155469\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2026-06-30T10:59:04.872782</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.7.2, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 187.155469 \nL 248.301563 187.155469 \nL 248.301563 0 \nL 0 0 \nz\n\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 42.620312 149.599219 \nL 237.920313 149.599219 \nL 237.920313 10.999219 \nL 42.620312 10.999219 \nz\n\"/>\n </g>\n <g id=\"matplotlib.axis_1\">\n <g id=\"xtick_1\">\n <g id=\"line2d_1\">\n <path d=\"M 42.620312 149.599219 \nL 42.620312 10.999219 \n\" clip-path=\"url(#p61b5a123fd)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_2\">\n <defs>\n <path id=\"m57c54f2125\" d=\"M 0 0 \nL 0 3.5 \n\" style=\"stroke: #ffffff; stroke-width: 0.8\"/>\n </defs>\n <g>\n <use xlink:href=\"#m57c54f2125\" x=\"42.620312\" y=\"149.599219\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_1\">\n <!-- 0 -->\n <g style=\"fill: #ffffff\" transform=\"translate(39.439062 164.197656) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-30\" d=\"M 2034 4250 \nQ 1547 4250 1301 3770 \nQ 1056 3291 1056 2328 \nQ 1056 1369 1301 889 \nQ 1547 409 2034 409 \nQ 2525 409 2770 889 \nQ 3016 1369 3016 2328 \nQ 3016 3291 2770 3770 \nQ 2525 4250 2034 4250 \nz\nM 2034 4750 \nQ 2819 4750 3233 4129 \nQ 3647 3509 3647 2328 \nQ 3647 1150 3233 529 \nQ 2819 -91 2034 -91 \nQ 1250 -91 836 529 \nQ 422 1150 422 2328 \nQ 422 3509 836 4129 \nQ 1250 4750 2034 4750 \nz\n\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-30\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_2\">\n <g id=\"line2d_3\">\n <path d=\"M 81.680312 149.599219 \nL 81.680312 10.999219 \n\" clip-path=\"url(#p61b5a123fd)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_4\">\n <g>\n <use xlink:href=\"#m57c54f2125\" x=\"81.680312\" y=\"149.599219\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_2\">\n <!-- 1 -->\n <g style=\"fill: #ffffff\" transform=\"translate(78.499062 164.197656) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-31\" d=\"M 794 531 \nL 1825 531 \nL 1825 4091 \nL 703 3866 \nL 703 4441 \nL 1819 4666 \nL 2450 4666 \nL 2450 531 \nL 3481 531 \nL 3481 0 \nL 794 0 \nL 794 531 \nz\n\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-31\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_3\">\n <g id=\"line2d_5\">\n <path d=\"M 120.740313 149.599219 \nL 120.740313 10.999219 \n\" clip-path=\"url(#p61b5a123fd)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_6\">\n <g>\n <use xlink:href=\"#m57c54f2125\" x=\"120.740313\" y=\"149.599219\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_3\">\n <!-- 2 -->\n <g styl
2026-06-28 14:39:18 +00:00
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
}
],
2026-06-30 03:57:35 +00:00
"execution_count": 15
2026-06-28 14:39:18 +00:00
},
{
2026-06-30 03:57:35 +00:00
"cell_type": "code",
"id": "97a690478767c3c3",
2026-06-28 14:39:18 +00:00
"metadata": {
"ExecuteTime": {
2026-06-30 03:57:35 +00:00
"end_time": "2026-06-30T02:59:05.088839245Z",
"start_time": "2026-06-30T02:59:04.978734877Z"
2026-06-28 14:39:18 +00:00
}
},
"source": [
"d2l.show_heatmaps(net.attention_weights.unsqueeze(0).unsqueeze(0),\n",
"xlabel='Sorted training inputs',\n",
"ylabel='Sorted testing inputs')"
],
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 250x250 with 2 Axes>"
],
2026-06-30 03:57:35 +00:00
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"199.80175pt\" height=\"159.039469pt\" viewBox=\"0 0 199.80175 159.039469\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2026-06-30T10:59:05.049223</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.7.2, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 159.039469 \nL 199.80175 159.039469 \nL 199.80175 0 \nL 0 0 \nz\n\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 40.603125 121.483219 \nL 152.203125 121.483219 \nL 152.203125 9.883219 \nL 40.603125 9.883219 \nz\n\"/>\n </g>\n <g clip-path=\"url(#p08007d42f7)\">\n <image xlink:href=\"data:image/png;base64,\niVBORw0KGgoAAAANSUhEUgAAAJsAAACbCAYAAAB1YemMAAAF0UlEQVR4nO3db2hVdRzH8XPuncuYf4Y0S7DUKS0KpWJSRjqkIDMzYg+KoJ4E4hDToKR64oNKpD8PXCmD6ElQTyKD/tE/LJHCJDMLcba2tc1laso2Nptz95yefX/nqBu7272fe++579ejz73n3N0f+uX7vef+OccPh/pCD8iht2tqLW9qO2g5VYjFoDxRbJCpKPQCULrCkf8sn75/teWmd1+w7FfPtUxngwzFBhnGKCYts3Or5eq6Gyyn12+46v50NshQbJDxeVMXExWG8VLZVbPY8pbOny37M+dc9fF0NshQbJDhaBQT1jTjxtjtlqGTWT2ezgYZig0yFBtkeM2GcWU+fcfyc3XXT+lv0dkgQ7FBhk8QMK6mKvd2x57B7tg23/ez+lt0NshQbJDhaBTjCj33KivbsXk5OhtkKDbIMEZxhY1V8y1n+2H7eOhskKHYIEOxQYZigwzFBhmORstY9Fwd4bm/La+bMyMvz0dngwzFBhnGaJkJh4csZ5pfcrm71/K6nta8PDedDTIUG2QYo2Xm+PJ7Ldd90Gy5csGteX9uOhtkKDbIMEYTKjjTZfmvtY2W6z5ssZy+5S7pmuhskKHYIMMYTajh5zdZvqnxbsvq0RlFZ4MMxQYZig0yvGZLkL4HV1me+eQjllONTYVYzhXobJCh2CDDKbNKTDhwLn5HMGrxi6UNlh/K03fSpoLOBhmKDTIcjZaA6DWjMju2jrnf2q5jgtVMHp0NMhQbZBijRSTo/M1y+GckH/rRcnr7nviD/FQkpvO3uBygs0GGYoMMb+oWWHR0Bt/utewvXOLykmWWU4tcLjV0NshQbJDhaFQkvNDv8hl3WZ7D9z1huf6PI+4BafdfM9XrDxQLOhtkKDbIMEbzKLx4wXLmlWfchgr3z17f/rvlpIzLsdDZIEOxQYYxmmPRb9LuX7bScsORfW6nqtkWkz46o+hskKHYIMNnoznWXFNreXP7Icv+rOsKsZyiQmeDDMUGGY5GJyHz/uux2xe/+t7y5raDlhmdcXQ2yFBskKHYIMNrtgkKg8By9DWa53le5bKbLfvVc1VLKjl0NshQbJBhjE7QtuqFll873x7fWOQ/Di4WdDbIUGyQYYxeJnp6qqjhwN3vV0xTLSdR6GyQodggwxj1PC/T+pPlcN8nbsOlSxbfGuxRLimR6GyQodggwxj1PC/8OnKqqlr3OWf0x8SYOjobZCg2yCR+ToSjI5aDAx+7Db1dFv0Vqy2nblsReXT5/IBYgc4GGYoNMokco+GFActBhztBstfd4fKoe8M2vXyNYlllj84GGYoNMiV1ro+xToLseZ4XDpx3N3o7Lfp1t7s8b7HbJ3oZnulVuVskxkRngwzFBpmiH6Nh/1nLmZaXLfv3NMR3rJlvMbXkDrdfRWX+Foes0NkgQ7FBpujH6IuzF1jecfaE2zDtmth+5XQi5FJFZ4MMxQYZig0yRfNBfNB22PLxR5+2vKMzctrQyunSNSG36GyQodggUzRjdOTNVy3XNdZb9ufMK8RykAd0NshQbJApmjH65WfHLK//Ye84e6JU0dkgQ7FBpqBjdGOV+w5ay9DJAq4ECnQ2yFBskJGP0czR7yw/XjNL/fQoIDobZCg2yEi+Fh4O9lnevehOy5tOtbqF8CuoxKOzQYZig0xOj0Zjl+IZHrS4a6EbnVv+cb+Q4rI85YXOBhmKDTJZj9HotdK9gbOxbcG+j1w+sN/ylu4jlhmd5YvOBhmKDTJZv6mbeW+n5XBoMLbNX7XWcqp2qbv/2pmTXR8ShM4GGYoNMhM6Gh1941l34/Rpi+ntu+M7VlVb5BRWuBydDTIUG2QoNshM6K0PfgWFXKCzQYZig0zsrY/o99H6H1hpeXfzBt2KkFh0NshQbJCJHY0GHb/ahs8bHrP8cM8JD5gqOhtkKDbIVASn2u3G0TVPWV73yzeFWA8SjM4GGYoNMv629Gw7Gt35b+QHxFy6BzlGZ4MMxQaZ/wFtKUJZVGIByQAAAABJRU5ErkJggg==\" id=\"image3bd4b5b534\" transform=\"scale(1 -1) translate(0 -111.6)\" x=\"40.603125\" y=\"-9.883219\" width=\"111.6\" height=\"111.6\"/>\n </g>\n <g id=\"matplotlib.axis_1\">\n <g id=\"xtick_1\">\n <g id=\"line2d_1\">\n <defs>\n <path id=\"ma61c8f862e\" d=\"M 0 0 \nL 0 3.5 \n\" style=\"stroke: #ffffff; stroke-width: 0.8\"/>\n </defs>\n <g>\n <use xlink:href=\"#ma61c8f862e\" x=\"41.719125\" y=\"121.483219\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_1\">\n <!-- 0 -->\n <g style=\"fill: #ffff
2026-06-28 14:39:18 +00:00
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
}
],
2026-06-30 03:57:35 +00:00
"execution_count": 16
2026-06-28 14:39:18 +00:00
},
{
2026-06-30 03:57:35 +00:00
"cell_type": "code",
"id": "8d03d0c7f735b755",
2026-06-28 14:39:18 +00:00
"metadata": {
"ExecuteTime": {
2026-06-30 03:57:35 +00:00
"end_time": "2026-06-30T02:59:05.142764987Z",
"start_time": "2026-06-30T02:59:05.093810141Z"
2026-06-28 14:39:18 +00:00
}
},
"source": [
"def masked_softmax(X,valid_lens):\n",
" # X:3D valid_len 1D or 2D\n",
" if valid_lens is None:\n",
" return nn.functional.softmax(X,dim=-1)\n",
" else:\n",
" shape = X.shape\n",
" if valid_lens.dim()==1:\n",
" valid_lens = torch.repeat_interleave(valid_lens,shape[1])\n",
" else:\n",
" valid_lens = valid_lens.reshape(-1)\n",
" #print(valid_lens)\n",
" X = d2l.sequence_mask(X.reshape(-1,shape[-1]),valid_lens,value=-1e6)\n",
" return nn.functional.softmax(X.reshape(shape),dim=-1)"
],
"outputs": [],
2026-06-30 03:57:35 +00:00
"execution_count": 17
2026-06-28 14:39:18 +00:00
},
{
2026-06-30 03:57:35 +00:00
"cell_type": "code",
"id": "27f8c24ec572b267",
2026-06-28 14:39:18 +00:00
"metadata": {
"ExecuteTime": {
2026-06-30 03:57:35 +00:00
"end_time": "2026-06-30T02:59:05.208569156Z",
"start_time": "2026-06-30T02:59:05.144491160Z"
2026-06-28 14:39:18 +00:00
}
},
2026-06-30 03:57:35 +00:00
"source": [
"masked_softmax(torch.rand(2,2,4),torch.tensor([2,3]))"
],
2026-06-28 14:39:18 +00:00
"outputs": [
{
"data": {
"text/plain": [
2026-06-30 03:57:35 +00:00
"tensor([[[0.5584, 0.4416, 0.0000, 0.0000],\n",
" [0.6205, 0.3795, 0.0000, 0.0000]],\n",
2026-06-28 14:39:18 +00:00
"\n",
2026-06-30 03:57:35 +00:00
" [[0.4421, 0.3398, 0.2181, 0.0000],\n",
" [0.3929, 0.3625, 0.2446, 0.0000]]])"
2026-06-28 14:39:18 +00:00
]
},
2026-06-30 03:57:35 +00:00
"execution_count": 18,
2026-06-28 14:39:18 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-06-30 03:57:35 +00:00
"execution_count": 18
2026-06-28 14:39:18 +00:00
},
{
2026-06-30 03:57:35 +00:00
"cell_type": "code",
"id": "c7e5254395b14b1",
2026-06-28 14:39:18 +00:00
"metadata": {
"ExecuteTime": {
2026-06-30 03:57:35 +00:00
"end_time": "2026-06-30T02:59:05.262719443Z",
"start_time": "2026-06-30T02:59:05.210360072Z"
2026-06-28 14:39:18 +00:00
}
},
"source": [
"class AdditiveAttention(nn.Module):\n",
" def __init__(self,key_size,query_size,num_hiddens,dropout,**kwargs):\n",
" super(AdditiveAttention,self).__init__(**kwargs)\n",
" self.W_k = nn.Linear(key_size,num_hiddens,bias=False)\n",
" self.W_q = nn.Linear(query_size,num_hiddens,bias=False)\n",
" self.w_v = nn.Linear(num_hiddens,1,bias=False)\n",
" self.dropout=nn.Dropout(dropout)\n",
" def forward(self,queries,keys,value,valid_lens):\n",
" queries,keys=self.W_q(queries),self.W_k(keys)\n",
" # queries (batch_size,n_q,1,num_hidden)\n",
" # key (batch_size,1,n_k,num_hiddens)\n",
" features = queries.unsqueeze(2) + keys.unsqueeze(1)\n",
" #features (batch_size,n_q,n_k,num_hidden)\n",
" features = torch.tanh(features)\n",
" scores = self.w_v(features).squeeze(-1)\n",
" #print(f\"Inside AdditiveAttention: value.shape = {value.shape}\") # 检查此处形状\n",
" self.attention_weights = masked_softmax(scores, valid_lens)\n",
"\n",
" return torch.bmm(self.dropout(self.attention_weights), value)"
],
"outputs": [],
2026-06-30 03:57:35 +00:00
"execution_count": 19
2026-06-28 14:39:18 +00:00
},
{
2026-06-30 03:57:35 +00:00
"cell_type": "code",
"id": "77e325a032e70d98",
2026-06-28 14:39:18 +00:00
"metadata": {
"ExecuteTime": {
2026-06-30 03:57:35 +00:00
"end_time": "2026-06-30T02:59:05.332247710Z",
"start_time": "2026-06-30T02:59:05.264012533Z"
2026-06-28 14:39:18 +00:00
}
},
"source": [
"queries,keys = torch.normal(0,1,(2,1,20)) , torch.ones((2,10,2))\n",
"values = torch.arange(40,dtype=torch.float32).reshape(1,10,4).repeat(2,1,1)\n",
"valid_lens = torch.tensor([2, 6])\n",
"attention = AdditiveAttention(key_size=2, query_size=20, num_hiddens=8,\n",
" dropout=0.1)\n",
"attention.eval()\n",
"attention(queries,keys,values,valid_lens)"
],
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[ 2.0000, 3.0000, 4.0000, 5.0000]],\n",
"\n",
" [[10.0000, 11.0000, 12.0000, 13.0000]]], grad_fn=<BmmBackward0>)"
]
},
2026-06-30 03:57:35 +00:00
"execution_count": 20,
2026-06-28 14:39:18 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-06-30 03:57:35 +00:00
"execution_count": 20
2026-06-28 14:39:18 +00:00
},
{
2026-06-30 03:57:35 +00:00
"cell_type": "code",
"id": "93ceedfd6e03982f",
2026-06-28 14:39:18 +00:00
"metadata": {
"ExecuteTime": {
2026-06-30 03:57:35 +00:00
"end_time": "2026-06-30T02:59:05.426180301Z",
"start_time": "2026-06-30T02:59:05.338806296Z"
2026-06-28 14:39:18 +00:00
}
},
"source": [
"d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),\n",
" xlabel='Keys', ylabel='Queries')"
],
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 250x250 with 2 Axes>"
],
2026-06-30 03:57:35 +00:00
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"187.07675pt\" height=\"103.438906pt\" viewBox=\"0 0 187.07675 103.438906\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2026-06-30T10:59:05.393045</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.7.2, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M -0 103.438906 \nL 187.07675 103.438906 \nL 187.07675 0 \nL -0 0 \nz\n\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 34.240625 59.94 \nL 145.840625 59.94 \nL 145.840625 37.62 \nL 34.240625 37.62 \nz\n\"/>\n </g>\n <g clip-path=\"url(#pc674f34b08)\">\n <image xlink:href=\"data:image/png;base64,\niVBORw0KGgoAAAANSUhEUgAAAJsAAAAfCAYAAADwQL9CAAAAjUlEQVR4nO3aMRHCAAxA0Zau1YcFxmqoCzbMYYGNO+riZ+A9A8nwL1PW7+v8LaS2+zG9wojb9AL8D7GRERsZsZERGxmxkREbGbGRERsZsZERGxmxkREbGbGRERuZ9bHsY/9sz897ajQDXDYyYiMjNjJiIyM2MmIjIzYyYiMjNjJiIyM2MmIjIzYyYiNzAasBCAMqMzMRAAAAAElFTkSuQmCC\" id=\"image958b197a22\" transform=\"scale(1 -1) translate(0 -22.32)\" x=\"34.240625\" y=\"-37.62\" width=\"111.6\" height=\"22.32\"/>\n </g>\n <g id=\"matplotlib.axis_1\">\n <g id=\"xtick_1\">\n <g id=\"line2d_1\">\n <defs>\n <path id=\"m78a5423376\" d=\"M 0 0 \nL 0 3.5 \n\" style=\"stroke: #ffffff; stroke-width: 0.8\"/>\n </defs>\n <g>\n <use xlink:href=\"#m78a5423376\" x=\"39.820625\" y=\"59.94\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_1\">\n <!-- 0 -->\n <g style=\"fill: #ffffff\" transform=\"translate(36.639375 74.538438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-30\" d=\"M 2034 4250 \nQ 1547 4250 1301 3770 \nQ 1056 3291 1056 2328 \nQ 1056 1369 1301 889 \nQ 1547 409 2034 409 \nQ 2525 409 2770 889 \nQ 3016 1369 3016 2328 \nQ 3016 3291 2770 3770 \nQ 2525 4250 2034 4250 \nz\nM 2034 4750 \nQ 2819 4750 3233 4129 \nQ 3647 3509 3647 2328 \nQ 3647 1150 3233 529 \nQ 2819 -91 2034 -91 \nQ 1250 -91 836 529 \nQ 422 1150 422 2328 \nQ 422 3509 836 4129 \nQ 1250 4750 2034 4750 \nz\n\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-30\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_2\">\n <g id=\"line2d_2\">\n <g>\n <use xlink:href=\"#m78a5423376\" x=\"95.620625\" y=\"59.94\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_2\">\n <!-- 5 -->\n <g style=\"fill: #ffffff\" transform=\"translate(92.439375 74.538438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-35\" d=\"M 691 4666 \nL 3169 4666 \nL 3169 4134 \nL 1269 4134 \nL 1269 2991 \nQ 1406 3038 1543 3061 \nQ 1681 3084 1819 3084 \nQ 2600 3084 3056 2656 \nQ 3513 2228 3513 1497 \nQ 3513 744 3044 326 \nQ 2575 -91 1722 -91 \nQ 1428 -91 1123 -41 \nQ 819 9 494 109 \nL 494 744 \nQ 775 591 1075 516 \nQ 1375 441 1709 441 \nQ 2250 441 2565 725 \nQ 2881 1009 2881 1497 \nQ 2881 1984 2565 2268 \nQ 2250 2553 1709 2553 \nQ 1456 2553 1204 2497 \nQ 953 2441 691 2322 \nL 691 4666 \nz\n\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-35\"/>\n </g>\n </g>\n </g>\n <g id=\"text_3\">\n <!-- Keys -->\n <g style=\"fill: #ffffff\" transform=\"translate(78.371094 88.216563) scale(0.1 -0.1)\">\n <defs>\n <path id=\"
2026-06-28 14:39:18 +00:00
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
}
],
2026-06-30 03:57:35 +00:00
"execution_count": 21
2026-06-28 14:39:18 +00:00
},
{
2026-06-30 03:57:35 +00:00
"cell_type": "code",
"id": "8453c76623a5b435",
2026-06-28 14:39:18 +00:00
"metadata": {
"ExecuteTime": {
2026-06-30 03:57:35 +00:00
"end_time": "2026-06-30T02:59:05.479115341Z",
"start_time": "2026-06-30T02:59:05.429811516Z"
2026-06-28 14:39:18 +00:00
}
},
"source": [
"import math\n",
"class DotProductAttention(nn.Module):\n",
" \"\"\"缩放点积注意力\"\"\"\n",
" def __init__(self, dropout, **kwargs):\n",
" super(DotProductAttention, self).__init__(**kwargs)\n",
" self.dropout = nn.Dropout(dropout)\n",
"# queries的形状(batch_size查询的个数d)\n",
"# keys的形状(batch_size“键值”对的个数d)\n",
"# values的形状(batch_size“键值”对的个数值的维度)\n",
"# valid_lens的形状:(batch_size)或者(batch_size查询的个数)\n",
" def forward(self, queries, keys, values, valid_lens=None):\n",
" d = queries.shape[-1]\n",
" # 设置transpose_b=True为了交换keys的最后两个维度\n",
" scores = torch.bmm(queries, keys.transpose(1,2)) / math.sqrt(d)\n",
" self.attention_weights = masked_softmax(scores, valid_lens)\n",
"\n",
" return torch.bmm(self.dropout(self.attention_weights), values)"
],
"outputs": [],
2026-06-30 03:57:35 +00:00
"execution_count": 22
2026-06-28 14:39:18 +00:00
},
{
2026-06-30 03:57:35 +00:00
"cell_type": "code",
"id": "8af6b28944977a62",
2026-06-28 14:39:18 +00:00
"metadata": {
"ExecuteTime": {
2026-06-30 03:57:35 +00:00
"end_time": "2026-06-30T02:59:05.547775309Z",
"start_time": "2026-06-30T02:59:05.480327680Z"
2026-06-28 14:39:18 +00:00
}
},
"source": [
"queries = torch.normal(0, 1, (2, 1, 2))\n",
"attention = DotProductAttention(dropout=0.5)\n",
"attention.eval()\n",
"attention(queries, keys, values, valid_lens)"
],
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[ 2.0000, 3.0000, 4.0000, 5.0000]],\n",
"\n",
" [[10.0000, 11.0000, 12.0000, 13.0000]]])"
]
},
2026-06-30 03:57:35 +00:00
"execution_count": 23,
2026-06-28 14:39:18 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-06-30 03:57:35 +00:00
"execution_count": 23
2026-06-28 14:39:18 +00:00
},
{
2026-06-30 03:57:35 +00:00
"cell_type": "code",
"id": "28f7e16771076e33",
2026-06-28 14:39:18 +00:00
"metadata": {
"ExecuteTime": {
2026-06-30 03:57:35 +00:00
"end_time": "2026-06-30T02:59:05.663459211Z",
"start_time": "2026-06-30T02:59:05.549703930Z"
2026-06-28 14:39:18 +00:00
}
},
"source": [
"d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),\n",
"xlabel='Keys', ylabel='Queries')"
],
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 250x250 with 2 Axes>"
],
2026-06-30 03:57:35 +00:00
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"187.07675pt\" height=\"103.438906pt\" viewBox=\"0 0 187.07675 103.438906\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2026-06-30T10:59:05.624810</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.7.2, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M -0 103.438906 \nL 187.07675 103.438906 \nL 187.07675 0 \nL -0 0 \nz\n\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 34.240625 59.94 \nL 145.840625 59.94 \nL 145.840625 37.62 \nL 34.240625 37.62 \nz\n\"/>\n </g>\n <g clip-path=\"url(#p46e400e1fd)\">\n <image xlink:href=\"data:image/png;base64,\niVBORw0KGgoAAAANSUhEUgAAAJsAAAAfCAYAAADwQL9CAAAAjUlEQVR4nO3aMRHCAAxA0Zau1YcFxmqoCzbMYYGNO+riZ+A9A8nwL1PW7+v8LaS2+zG9wojb9AL8D7GRERsZsZERGxmxkREbGbGRERsZsZERGxmxkREbGbGRERuZ9bHsY/9sz897ajQDXDYyYiMjNjJiIyM2MmIjIzYyYiMjNjJiIyM2MmIjIzYyYiNzAasBCAMqMzMRAAAAAElFTkSuQmCC\" id=\"image7cf199ddd5\" transform=\"scale(1 -1) translate(0 -22.32)\" x=\"34.240625\" y=\"-37.62\" width=\"111.6\" height=\"22.32\"/>\n </g>\n <g id=\"matplotlib.axis_1\">\n <g id=\"xtick_1\">\n <g id=\"line2d_1\">\n <defs>\n <path id=\"md0e3694cd8\" d=\"M 0 0 \nL 0 3.5 \n\" style=\"stroke: #ffffff; stroke-width: 0.8\"/>\n </defs>\n <g>\n <use xlink:href=\"#md0e3694cd8\" x=\"39.820625\" y=\"59.94\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_1\">\n <!-- 0 -->\n <g style=\"fill: #ffffff\" transform=\"translate(36.639375 74.538438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-30\" d=\"M 2034 4250 \nQ 1547 4250 1301 3770 \nQ 1056 3291 1056 2328 \nQ 1056 1369 1301 889 \nQ 1547 409 2034 409 \nQ 2525 409 2770 889 \nQ 3016 1369 3016 2328 \nQ 3016 3291 2770 3770 \nQ 2525 4250 2034 4250 \nz\nM 2034 4750 \nQ 2819 4750 3233 4129 \nQ 3647 3509 3647 2328 \nQ 3647 1150 3233 529 \nQ 2819 -91 2034 -91 \nQ 1250 -91 836 529 \nQ 422 1150 422 2328 \nQ 422 3509 836 4129 \nQ 1250 4750 2034 4750 \nz\n\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-30\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_2\">\n <g id=\"line2d_2\">\n <g>\n <use xlink:href=\"#md0e3694cd8\" x=\"95.620625\" y=\"59.94\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_2\">\n <!-- 5 -->\n <g style=\"fill: #ffffff\" transform=\"translate(92.439375 74.538438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-35\" d=\"M 691 4666 \nL 3169 4666 \nL 3169 4134 \nL 1269 4134 \nL 1269 2991 \nQ 1406 3038 1543 3061 \nQ 1681 3084 1819 3084 \nQ 2600 3084 3056 2656 \nQ 3513 2228 3513 1497 \nQ 3513 744 3044 326 \nQ 2575 -91 1722 -91 \nQ 1428 -91 1123 -41 \nQ 819 9 494 109 \nL 494 744 \nQ 775 591 1075 516 \nQ 1375 441 1709 441 \nQ 2250 441 2565 725 \nQ 2881 1009 2881 1497 \nQ 2881 1984 2565 2268 \nQ 2250 2553 1709 2553 \nQ 1456 2553 1204 2497 \nQ 953 2441 691 2322 \nL 691 4666 \nz\n\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-35\"/>\n </g>\n </g>\n </g>\n <g id=\"text_3\">\n <!-- Keys -->\n <g style=\"fill: #ffffff\" transform=\"translate(78.371094 88.216563) scale(0.1 -0.1)\">\n <defs>\n <path id=\"
2026-06-28 14:39:18 +00:00
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
}
],
2026-06-30 03:57:35 +00:00
"execution_count": 24
2026-06-28 14:39:18 +00:00
},
{
2026-06-30 03:57:35 +00:00
"cell_type": "code",
"id": "33507bd2e1917a29",
2026-06-28 14:39:18 +00:00
"metadata": {
"ExecuteTime": {
2026-06-30 03:57:35 +00:00
"end_time": "2026-06-30T02:59:05.730333551Z",
"start_time": "2026-06-30T02:59:05.680995441Z"
2026-06-28 14:39:18 +00:00
}
},
"source": [
"class AttentionDecoder(d2l.Decoder):\n",
" def __init__(self,**kwargs):\n",
" super(AttentionDecoder, self).__init__(**kwargs)\n",
" @property\n",
" def attention_weight(self):\n",
" raise NotImplementedError"
],
"outputs": [],
2026-06-30 03:57:35 +00:00
"execution_count": 25
2026-06-28 14:39:18 +00:00
},
{
2026-06-30 03:57:35 +00:00
"cell_type": "code",
"id": "a8a497c9041910a1",
2026-06-28 14:39:18 +00:00
"metadata": {
"ExecuteTime": {
2026-06-30 03:57:35 +00:00
"end_time": "2026-06-30T02:59:05.781440228Z",
"start_time": "2026-06-30T02:59:05.731688196Z"
2026-06-28 14:39:18 +00:00
}
},
"source": [
"class Seq2SeqAttentionDecoder(AttentionDecoder):\n",
" def __init__(self,vocab_size,embed_size,num_hiddens,num_layers,dropout=0,**kwargs):\n",
" super(Seq2SeqAttentionDecoder, self).__init__(**kwargs)\n",
" self.attention = AdditiveAttention(\n",
" num_hiddens, num_hiddens, num_hiddens,dropout)\n",
" self.embedding = nn.Embedding(vocab_size, embed_size)\n",
" self.rnn = nn.GRU(\n",
" embed_size + num_hiddens, num_hiddens, num_layers,\n",
" dropout=dropout)\n",
" self.dense = nn.Linear(num_hiddens, vocab_size)\n",
" def init_state(self, enc_outputs, enc_valid_lens, *args):\n",
"# outputs的形状为(batch_sizenum_stepsnum_hiddens).\n",
"# hidden_state的形状为(num_layersbatch_sizenum_hiddens)\n",
" outputs, hidden_state = enc_outputs\n",
" #print(f\"Encoder outputs shape before permute: {outputs.shape}\") # 应为 (num_steps, batch_size, num_hiddens) 或 (batch_size, num_steps, num_hiddens)\n",
" enc_outputs_permuted = outputs.permute(1, 0, 2)\n",
" #print(f\"After permute: {enc_outputs_permuted.shape}\") # 期望 (batch_size, num_steps, num_hiddens)\n",
" return (enc_outputs_permuted, hidden_state, enc_valid_lens)\n",
" def forward(self, X, state):\n",
" # enc_outputs的形状为(batch_size,num_steps,num_hiddens).\n",
" # hidden_state的形状为(num_layers,batch_size,\n",
" # num_hiddens)\n",
" enc_outputs, hidden_state, enc_valid_lens = state\n",
" # 输出X的形状为(num_steps,batch_size,embed_size)\n",
" X = self.embedding(X).permute(1, 0, 2)\n",
" outputs, self._attention_weights = [], []\n",
" for x in X:\n",
" # query的形状为(batch_size,1,num_hiddens)\n",
" query = torch.unsqueeze(hidden_state[-1], dim=1)\n",
" # context的形状为(batch_size,1,num_hiddens)\n",
" #print(f\"values shape before attention: {enc_outputs.shape}\") # 应为 (4, 7, 16)\n",
" context = self.attention(\n",
" query, enc_outputs, enc_outputs, enc_valid_lens)\n",
" # 在特征维度上连结\n",
" x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)\n",
" # 将x变形为(1,batch_size,embed_size+num_hiddens)\n",
" out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)\n",
" outputs.append(out)\n",
" self._attention_weights.append(self.attention.attention_weights)\n",
" # 全连接层变换后outputs的形状为\n",
" # (num_steps,batch_size,vocab_size)\n",
" outputs = self.dense(torch.cat(outputs, dim=0))\n",
" return outputs.permute(1, 0, 2), [enc_outputs, hidden_state,\n",
" enc_valid_lens]\n",
" @property\n",
" def attention_weights(self):\n",
" return self._attention_weights"
],
"outputs": [],
2026-06-30 03:57:35 +00:00
"execution_count": 26
2026-06-28 14:39:18 +00:00
},
{
2026-06-30 03:57:35 +00:00
"cell_type": "code",
"id": "b2a6bc735743bd0f",
2026-06-28 14:39:18 +00:00
"metadata": {
"ExecuteTime": {
2026-06-30 03:57:35 +00:00
"end_time": "2026-06-30T02:59:05.842620719Z",
"start_time": "2026-06-30T02:59:05.782837732Z"
2026-06-28 14:39:18 +00:00
}
},
"source": [
"encoder = d2l.Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16,\n",
"num_layers=2)\n",
"encoder.eval()\n",
"decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8, num_hiddens=16,\n",
"num_layers=2)\n",
"decoder.eval()\n",
"X = torch.zeros((4, 7), dtype=torch.long) # (batch_size,num_steps)\n",
"state = decoder.init_state(encoder(X), None)\n",
"output, state = decoder(X, state)\n",
"output.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape"
],
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([4, 7, 10]), 3, torch.Size([4, 7, 16]), 2, torch.Size([4, 16]))"
]
},
2026-06-30 03:57:35 +00:00
"execution_count": 27,
2026-06-28 14:39:18 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-06-30 03:57:35 +00:00
"execution_count": 27
2026-06-28 14:39:18 +00:00
},
{
2026-06-30 03:57:35 +00:00
"cell_type": "code",
"id": "4c7bf86095f15c98",
2026-06-28 14:39:18 +00:00
"metadata": {
"ExecuteTime": {
2026-06-30 03:57:35 +00:00
"end_time": "2026-06-30T02:59:05.903307506Z",
"start_time": "2026-06-30T02:59:05.854726928Z"
2026-06-28 14:39:18 +00:00
}
},
2026-06-30 03:57:35 +00:00
"source": [],
2026-06-28 14:39:18 +00:00
"outputs": [],
2026-06-30 03:57:35 +00:00
"execution_count": 27
2026-06-28 14:39:18 +00:00
},
{
2026-06-30 03:57:35 +00:00
"cell_type": "code",
"id": "5f5cedf74b97bd12",
2026-06-28 14:39:18 +00:00
"metadata": {
"ExecuteTime": {
2026-06-30 03:57:35 +00:00
"end_time": "2026-06-30T03:22:39.836556973Z",
"start_time": "2026-06-30T02:59:05.904535279Z"
2026-06-28 14:39:18 +00:00
}
},
"source": [
"embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1\n",
"batch_size, num_steps = 64, 10\n",
"lr, num_epochs, device = 0.005, 250, d2l.try_gpu()\n",
"train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)\n",
"encoder = d2l.Seq2SeqEncoder(\n",
"len(src_vocab), embed_size, num_hiddens, num_layers, dropout)\n",
"decoder = Seq2SeqAttentionDecoder(\n",
"len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)\n",
"net = d2l.EncoderDecoder(encoder, decoder)\n",
"d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)"
],
"outputs": [
{
2026-06-30 03:57:35 +00:00
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001B[31m---------------------------------------------------------------------------\u001B[39m",
"\u001B[31mKeyboardInterrupt\u001B[39m Traceback (most recent call last)",
"\u001B[36mCell\u001B[39m\u001B[36m \u001B[39m\u001B[32mIn[28]\u001B[39m\u001B[32m, line 10\u001B[39m\n\u001B[32m 7\u001B[39m decoder = Seq2SeqAttentionDecoder(\n\u001B[32m 8\u001B[39m \u001B[38;5;28mlen\u001B[39m(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)\n\u001B[32m 9\u001B[39m net = d2l.EncoderDecoder(encoder, decoder)\n\u001B[32m---> \u001B[39m\u001B[32m10\u001B[39m \u001B[43md2l\u001B[49m\u001B[43m.\u001B[49m\u001B[43mtrain_seq2seq\u001B[49m\u001B[43m(\u001B[49m\u001B[43mnet\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mtrain_iter\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mlr\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mnum_epochs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mtgt_vocab\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mdevice\u001B[49m\u001B[43m)\u001B[49m\n",
"\u001B[36mFile \u001B[39m\u001B[32m~/.conda/envs/nn/lib/python3.11/site-packages/d2l/torch.py:3421\u001B[39m, in \u001B[36mtrain_seq2seq\u001B[39m\u001B[34m(net, data_iter, lr, num_epochs, tgt_vocab, device)\u001B[39m\n\u001B[32m 3418\u001B[39m bos = torch.tensor([tgt_vocab[\u001B[33m'\u001B[39m\u001B[33m<bos>\u001B[39m\u001B[33m'\u001B[39m]] * Y.shape[\u001B[32m0\u001B[39m],\n\u001B[32m 3419\u001B[39m device=device).reshape(-\u001B[32m1\u001B[39m, \u001B[32m1\u001B[39m)\n\u001B[32m 3420\u001B[39m dec_input = d2l.concat([bos, Y[:, :-\u001B[32m1\u001B[39m]], \u001B[32m1\u001B[39m) \u001B[38;5;66;03m# Teacher forcing\u001B[39;00m\n\u001B[32m-> \u001B[39m\u001B[32m3421\u001B[39m Y_hat, _ = \u001B[43mnet\u001B[49m\u001B[43m(\u001B[49m\u001B[43mX\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mdec_input\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mX_valid_len\u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m 3422\u001B[39m l = loss(Y_hat, Y, Y_valid_len)\n\u001B[32m 3423\u001B[39m l.sum().backward() \u001B[38;5;66;03m# Make the loss scalar for `backward`\u001B[39;00m\n",
"\u001B[36mFile \u001B[39m\u001B[32m~/.conda/envs/nn/lib/python3.11/site-packages/torch/nn/modules/module.py:1776\u001B[39m, in \u001B[36mModule._wrapped_call_impl\u001B[39m\u001B[34m(self, *args, **kwargs)\u001B[39m\n\u001B[32m 1774\u001B[39m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m._compiled_call_impl(*args, **kwargs) \u001B[38;5;66;03m# type: ignore[misc]\u001B[39;00m\n\u001B[32m 1775\u001B[39m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[32m-> \u001B[39m\u001B[32m1776\u001B[39m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[43m.\u001B[49m\u001B[43m_call_impl\u001B[49m\u001B[43m(\u001B[49m\u001B[43m*\u001B[49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43m*\u001B[49m\u001B[43m*\u001B[49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n",
"\u001B[36mFile \u001B[39m\u001B[32m~/.conda/envs/nn/lib/python3.11/site-packages/torch/nn/modules/module.py:1787\u001B[39m, in \u001B[36mModule._call_impl\u001B[39m\u001B[34m(self, *args, **kwargs)\u001B[39m\n\u001B[32m 1782\u001B[39m \u001B[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001B[39;00m\n\u001B[32m 1783\u001B[39m \u001B[38;5;66;03m# this function, and just call forward.\u001B[39;00m\n\u001B[32m 1784\u001B[39m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m (\u001B[38;5;28mself\u001B[39m._backward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m._backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m._forward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m._forward_pre_hooks\n\u001B[32m 1785\u001B[39m \u001B[38;5;129;01mor\u001B[39;00m _global_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_backward_hooks\n\u001B[32m 1786\u001B[39m \u001B[38;5;129;01mor\u001B[39;00m _global_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_forward_pre_hooks):\n\u001B[32m-> \u001B[39m\u001B[32m1787\u001B[39m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mforward_call\u001B[49m\u001B[43m(\u001B[49m\u001B[43m*\u001B[49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43m*\u001B[49m\u001B[43m*\u001B[49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m 1789\u001B[39m result = \u001B[38;5;28;01mNone\u001B[39;00m\n\u001B[32m 1790\u001B[39m called_always_called_hooks = \u001B[38;5;28mset\u001B[39m()\n",
"\u001B[36mFile \u001B[39m\u001B[32m~/.conda/envs/nn/lib/python3.11/site-packages/d2l/torch.py:964\u001B[39m, in \u001B[36mEncoderDecoder.forward\u001B[39m\u001B[34m(self, enc_X, dec_X, *args)\u001B[39m\n\u001B[32m 962\u001B[39m dec_state = \u001B[38;5;28mself\u001B[39m.decoder.init_state(enc_all_outputs, *args)\n\u001B[32m 963\u001B[39m \u001B[38;5;66;03m# Return decoder output only\u001B[39;00m\n\u001B[32m--> \u001B[39m\u001B[32m964\u001B[39m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[43m.\u001B[49m\u001B[43mdecoder\u001B[49m\u001B[43m(\u001B[49m\u001B[43mdec_X\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mdec_state\u001B[49m\u001B[43m)\u001B[49m\n",
"\u001B[36mFile \u001B[39m\u001B[32m~/.conda/envs/nn/lib/python3.11/site-packages/torch/nn/modules/module.py:1776\u001B[39m, in \u001B[36mModule._wrapped_call_impl\u001B[39m\u001B[34m(self, *args, **kwargs)\u001B[39m\n\u001B[32m 1774\u001B[39m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m._compiled_call_impl(*args, **kwargs) \u001B[38;5;66;03m# type: ignore[misc]\u001B[39;00m\n\u001B[32m 1775\u001B[39m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[32m-> \u001B[39m\u001B[32m1776\u001B[39m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[43m.\u001B[49m\u001B[43m_call_impl\u001B[49m\u001B[43m(\u001B[49m\u001B[43m*\u001B[49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43m*\u001B[49m\u001B[43m*\u001B[49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n",
"\u001B[36mFile \u001B[39m\u001B[32m~/.conda/envs/nn/lib/python3.11/site-packages/torch/nn/modules/module.py:1787\u001B[39m, in \u001B[36mModule._call_impl\u001B[39m\u001B[34m(self, *args, **kwargs)\u001B[39m\n\u001B[32m 1782\u001B[39m \u001B[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001B[39;00m\n\u001B[32m 1783\u001B[39m \u001B[38;5;66;03m# this function, and just call forward.\u001B[39;00m\n\u001B[32m 1784\u001B[39m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m (\u001B[38;5;28mself\u001B[39m._backward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m._backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m._forward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m._forward_pre_hooks\n\u001B[32m 1785\u001B[39m \u001B[38;5;129;01mor\u001B[39;00m _global_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_backward_hooks\n\u001B[32m 1786\u001B[39m \u001B[38;5;129;01mor\u001B[39;00m _global_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_forward_pre_hooks):\n\u001B[32m-> \u001B[39m\u001B[32m1787\u001B[39m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mforward_call\u001B[49m\u001B[43m(\u001B[49m\u001B[43m*\u001B[49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43m*\u001B[49m\u001B[43m*\u001B[49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m 1789\u001B[39m result = \u001B[38;5;28;01mNone\u001B[39;00m\n\u001B[32m 1790\u001B[39m called_always_called_hooks = \u001B[38;5;28mset\u001B[39m()\n",
"\u001B[36mCell\u001B[39m\u001B[36m \u001B[39m\u001B[32mIn[26]\u001B[39m\u001B[32m, line 37\u001B[39m, in \u001B[36mSeq2SeqAttentionDecoder.forward\u001B[39m\u001B[34m(self, X, state)\u001B[39m\n\u001B[32m 35\u001B[39m x = torch.cat((context, torch.unsqueeze(x, dim=\u001B[32m1\u001B[39m)), dim=-\u001B[32m1\u001B[39m)\n\u001B[32m 36\u001B[39m \u001B[38;5;66;03m# 将x变形为(1,batch_size,embed_size+num_hiddens)\u001B[39;00m\n\u001B[32m---> \u001B[39m\u001B[32m37\u001B[39m out, hidden_state = \u001B[38;5;28;43mself\u001B[39;49m\u001B[43m.\u001B[49m\u001B[43mrnn\u001B[49m\u001B[43m(\u001B[49m\u001B[43mx\u001B[49m\u001B[43m.\u001B[49m\u001B[43mpermute\u001B[49m\u001B[43m(\u001B[49m\u001B[32;43m1\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[32;43m0\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[32;43m2\u001B[39;49m\u001B[43m)\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mhidden_state\u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m 38\u001B[39m outputs.append(out)\n\u001B[32m 39\u001B[39m \u001B[38;5;28mself\u001B[39m._attention_weights.append(\u001B[38;5;28mself\u001B[39m.attention.attention_weights)\n",
"\u001B[36mFile \u001B[39m\u001B[32m~/.conda/envs/nn/lib/python3.11/site-packages/torch/nn/modules/module.py:1776\u001B[39m, in \u001B[36mModule._wrapped_call_impl\u001B[39m\u001B[34m(self, *args, **kwargs)\u001B[39m\n\u001B[32m 1774\u001B[39m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m._compiled_call_impl(*args, **kwargs) \u001B[38;5;66;03m# type: ignore[misc]\u001B[39;00m\n\u001B[32m 1775\u001B[39m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[32m-> \u001B[39m\u001B[32m1776\u001B[39m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[43m.\u001B[49m\u001B[43m_call_impl\u001B[49m\u001B[43m(\u001B[49m\u001B[43m*\u001B[49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43m*\u001B[49m\u001B[43m*\u001B[49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n",
"\u001B[36mFile \u001B[39m\u001B[32m~/.conda/envs/nn/lib/python3.11/site-packages/torch/nn/modules/module.py:1787\u001B[39m, in \u001B[36mModule._call_impl\u001B[39m\u001B[34m(self, *args, **kwargs)\u001B[39m\n\u001B[32m 1782\u001B[39m \u001B[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001B[39;00m\n\u001B[32m 1783\u001B[39m \u001B[38;5;66;03m# this function, and just call forward.\u001B[39;00m\n\u001B[32m 1784\u001B[39m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m (\u001B[38;5;28mself\u001B[39m._backward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m._backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m._forward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m._forward_pre_hooks\n\u001B[32m 1785\u001B[39m \u001B[38;5;129;01mor\u001B[39;00m _global_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_backward_hooks\n\u001B[32m 1786\u001B[39m \u001B[38;5;129;01mor\u001B[39;00m _global_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_forward_pre_hooks):\n\u001B[32m-> \u001B[39m\u001B[32m1787\u001B[39m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mforward_call\u001B[49m\u001B[43m(\u001B[49m\u001B[43m*\u001B[49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43m*\u001B[49m\u001B[43m*\u001B[49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m 1789\u001B[39m result = \u001B[38;5;28;01mNone\u001B[39;00m\n\u001B[32m 1790\u001B[39m called_always_called_hooks = \u001B[38;5;28mset\u001B[39m()\n",
"\u001B[36mFile \u001B[39m\u001B[32m~/.conda/envs/nn/lib/python3.11/site-packages/torch/nn/modules/rnn.py:1415\u001B[39m, in \u001B[36mGRU.forward\u001B[39m\u001B[34m(self, input, hx)\u001B[39m\n\u001B[32m 1413\u001B[39m \u001B[38;5;28mself\u001B[39m.check_forward_args(\u001B[38;5;28minput\u001B[39m, hx, batch_sizes)\n\u001B[32m 1414\u001B[39m \u001B[38;5;28;01mif\u001B[39;00m batch_sizes \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[32m-> \u001B[39m\u001B[32m1415\u001B[39m result = \u001B[43m_VF\u001B[49m\u001B[43m.\u001B[49m\u001B[43mgru\u001B[49m\u001B[43m(\u001B[49m\n\u001B[32m 1416\u001B[39m \u001B[43m \u001B[49m\u001B[38;5;28;43minput\u001B[39;49m\u001B[43m,\u001B[49m\n\u001B[32m 1417\u001B[39m \u001B[43m \u001B[49m\u001B[43mhx\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 1418\u001B[39m \u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m.\u001B[49m\u001B[43m_flat_weights\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;66;43;03m# type: ignore[arg-type]\u001B[39;49;00m\n\u001B[32m 1419\u001B[39m \u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m.\u001B[49m\u001B[43mbias\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 1420\u001B[39m \u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m.\u001B[49m\u001B[43mnum_layers\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 1421\u001B[39m \u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m.\u001B[49m\u001B[43mdropout\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 1422\u001B[39m \u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m.\u001B[49m\u001B[43mtraining\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 1423\u001B[39m \u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m.\u001B[49m\u001B[43mbidirectional\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 1424\u001B[39m \u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m.\u001B[49m\u001B[43mbatch_first\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 1425\u001B[39m \u001B[43m \u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m 1426\u001B[39m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[32m 1427\u001B[39m result = _VF.gru(\n\u001B[32m 1428\u001B[39m \u001B[38;5;28minput\u001B[39m,\n\u001B[32m 1429\u001B[39m batch_sizes,\n\u001B[32m (...)\u001B[39m\u001B[32m 1436\u001B[39m \u001B[38;5;28mself\u001B[39m.bidirectional,\n\u001B[32m 1437\u001B[39m )\n",
"\u001B[31mKeyboardInterrupt\u001B[39m: "
2026-06-28 14:39:18 +00:00
]
},
{
"data": {
"text/plain": [
"<Figure size 350x250 with 1 Axes>"
],
2026-06-30 03:57:35 +00:00
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"262.1875pt\" height=\"183.35625pt\" viewBox=\"0 0 262.1875 183.35625\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2026-06-30T11:22:39.515794</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.7.2, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 183.35625 \nL 262.1875 183.35625 \nL 262.1875 0 \nL 0 0 \nz\n\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 50.14375 145.8 \nL 245.44375 145.8 \nL 245.44375 7.2 \nL 50.14375 7.2 \nz\n\"/>\n </g>\n <g id=\"matplotlib.axis_1\">\n <g id=\"xtick_1\">\n <g id=\"line2d_1\">\n <path d=\"M 82.69375 145.8 \nL 82.69375 7.2 \n\" clip-path=\"url(#p5f11d43d4e)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_2\">\n <defs>\n <path id=\"ma58d5e1ace\" d=\"M 0 0 \nL 0 3.5 \n\" style=\"stroke: #ffffff; stroke-width: 0.8\"/>\n </defs>\n <g>\n <use xlink:href=\"#ma58d5e1ace\" x=\"82.69375\" y=\"145.8\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_1\">\n <!-- 50 -->\n <g style=\"fill: #ffffff\" transform=\"translate(76.33125 160.398438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-35\" d=\"M 691 4666 \nL 3169 4666 \nL 3169 4134 \nL 1269 4134 \nL 1269 2991 \nQ 1406 3038 1543 3061 \nQ 1681 3084 1819 3084 \nQ 2600 3084 3056 2656 \nQ 3513 2228 3513 1497 \nQ 3513 744 3044 326 \nQ 2575 -91 1722 -91 \nQ 1428 -91 1123 -41 \nQ 819 9 494 109 \nL 494 744 \nQ 775 591 1075 516 \nQ 1375 441 1709 441 \nQ 2250 441 2565 725 \nQ 2881 1009 2881 1497 \nQ 2881 1984 2565 2268 \nQ 2250 2553 1709 2553 \nQ 1456 2553 1204 2497 \nQ 953 2441 691 2322 \nL 691 4666 \nz\n\" transform=\"scale(0.015625)\"/>\n <path id=\"DejaVuSans-30\" d=\"M 2034 4250 \nQ 1547 4250 1301 3770 \nQ 1056 3291 1056 2328 \nQ 1056 1369 1301 889 \nQ 1547 409 2034 409 \nQ 2525 409 2770 889 \nQ 3016 1369 3016 2328 \nQ 3016 3291 2770 3770 \nQ 2525 4250 2034 4250 \nz\nM 2034 4750 \nQ 2819 4750 3233 4129 \nQ 3647 3509 3647 2328 \nQ 3647 1150 3233 529 \nQ 2819 -91 2034 -91 \nQ 1250 -91 836 529 \nQ 422 1150 422 2328 \nQ 422 3509 836 4129 \nQ 1250 4750 2034 4750 \nz\n\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-35\"/>\n <use xlink:href=\"#DejaVuSans-30\" x=\"63.623047\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_2\">\n <g id=\"line2d_3\">\n <path d=\"M 123.38125 145.8 \nL 123.38125 7.2 \n\" clip-path=\"url(#p5f11d43d4e)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_4\">\n <g>\n <use xlink:href=\"#ma58d5e1ace\" x=\"123.38125\" y=\"145.8\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_2\">\n <!-- 100 -->\n <g style=\"fill: #ffffff\" transform=\"translate(113.8375 160.398438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-31\" d=\"M 794 531 \nL 1825 531 \nL 1825 4091 \nL 703 3866 \nL 703 4441 \nL 1819 4666 \nL 2450 4666 \nL 2450 531 \nL 3481 531 \nL 3481 0 \nL 794 0 \nL 794 531 \nz\n\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-31\"/>\n <
2026-06-28 14:39:18 +00:00
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
}
],
2026-06-30 03:57:35 +00:00
"execution_count": 28
2026-06-28 14:39:18 +00:00
},
{
2026-06-30 03:57:35 +00:00
"cell_type": "code",
"id": "ddaed8bd284ee01b",
2026-06-28 14:39:18 +00:00
"metadata": {
"ExecuteTime": {
2026-06-30 03:57:35 +00:00
"end_time": "2026-06-30T03:22:50.427691735Z",
"start_time": "2026-06-30T03:22:50.314718050Z"
2026-06-28 14:39:18 +00:00
}
},
"source": [
"engs = ['go .', \"i lost .\", 'he\\'s calm .', 'i\\'m home .']\n",
"fras = ['va !', 'j\\'ai perdu .', 'il est calme .', 'je suis chez moi .']\n",
"for eng, fra in zip(engs, fras):\n",
" translation, dec_attention_weight_seq = d2l.predict_seq2seq(\n",
" net, eng, src_vocab, tgt_vocab, num_steps, device, True)\n",
" print(f'{eng} => {translation}, ',\n",
" f'bleu {d2l.bleu(translation, fra, k=2):.3f}')"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"go . => va !, bleu 1.000\n",
"i lost . => j'ai perdu ., bleu 1.000\n",
2026-06-30 03:57:35 +00:00
"he's calm . => il est riche ., bleu 0.658\n",
"i'm home . => je suis calme ., bleu 0.512\n"
2026-06-28 14:39:18 +00:00
]
}
],
2026-06-30 03:57:35 +00:00
"execution_count": 29
2026-06-28 14:39:18 +00:00
},
{
2026-06-30 03:57:35 +00:00
"cell_type": "code",
"id": "13d19f8b5048c74d",
2026-06-28 14:39:18 +00:00
"metadata": {
"ExecuteTime": {
2026-06-30 03:57:35 +00:00
"end_time": "2026-06-30T03:22:57.445168273Z",
"start_time": "2026-06-30T03:22:57.376381899Z"
2026-06-28 14:39:18 +00:00
}
},
"source": [
"class MultiHeadAttention(nn.Module):\n",
" def __init__(self, key_size, query_size, value_size, num_hiddens,\n",
" num_heads, dropout, bias=False, **kwargs):\n",
" super(MultiHeadAttention, self).__init__(**kwargs)\n",
" self.num_heads = num_heads\n",
" self.attention = d2l.DotProductAttention(dropout)\n",
" self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)\n",
" self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)\n",
" self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)\n",
" self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)\n",
" def forward(self, queries, keys, values, valid_lens):\n",
" # 1. 线性投影 + 变换形状以分割多头\n",
" queries = transpose_qkv(self.W_q(queries), self.num_heads)\n",
" keys = transpose_qkv(self.W_k(keys), self.num_heads)\n",
" values = transpose_qkv(self.W_v(values), self.num_heads)\n",
"\n",
" # 2. 处理有效长度掩码valid_lens以适配多头\n",
" if valid_lens is not None:\n",
" valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)\n",
"\n",
" # 3. 计算注意力(每个头独立计算)\n",
" output = self.attention(queries, keys, values, valid_lens)\n",
"\n",
" # 4. 合并多头,并通过输出线性层\n",
" output_concat = transpose_output(output, self.num_heads)\n",
" return self.W_o(output_concat)\n",
"def transpose_qkv(X, num_heads):\n",
" \"\"\"为了多注意力头的并行计算而变换形状\"\"\"\n",
" # 输入X的形状:(batch_size查询或者“键值”对的个数num_hiddens)\n",
" # 输出X的形状:(batch_size查询或者“键值”对的个数num_heads\n",
" # num_hiddens/num_heads)\n",
" X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)\n",
" # 输出X的形状:(batch_sizenum_heads查询或者“键值”对的个数,\n",
" # num_hiddens/num_heads)\n",
" X = X.permute(0, 2, 1, 3)\n",
" # 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数,\n",
" # num_hiddens/num_heads)\n",
" return X.reshape(-1, X.shape[2], X.shape[3])\n",
"\n",
"def transpose_output(X, num_heads):\n",
" \"\"\"逆转transpose_qkv函数的操作\"\"\"\n",
" X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])\n",
" X = X.permute(0, 2, 1, 3)\n",
" return X.reshape(X.shape[0], X.shape[1], -1)\n"
],
"outputs": [],
2026-06-30 03:57:35 +00:00
"execution_count": 30
2026-06-28 14:39:18 +00:00
},
{
2026-06-30 03:57:35 +00:00
"cell_type": "code",
"id": "b25c661a511d7763",
2026-06-28 14:39:18 +00:00
"metadata": {
"ExecuteTime": {
2026-06-30 03:57:35 +00:00
"end_time": "2026-06-30T03:22:58.271805155Z",
"start_time": "2026-06-30T03:22:58.200822898Z"
2026-06-28 14:39:18 +00:00
}
},
"source": [
"num_hiddens, num_heads = 100, 5\n",
"attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,\n",
"num_hiddens, num_heads, 0.5)\n",
"attention.eval()"
],
"outputs": [
{
"data": {
"text/plain": [
"MultiHeadAttention(\n",
" (attention): DotProductAttention(\n",
" (dropout): Dropout(p=0.5, inplace=False)\n",
" )\n",
" (W_q): Linear(in_features=100, out_features=100, bias=False)\n",
" (W_k): Linear(in_features=100, out_features=100, bias=False)\n",
" (W_v): Linear(in_features=100, out_features=100, bias=False)\n",
" (W_o): Linear(in_features=100, out_features=100, bias=False)\n",
")"
]
},
2026-06-30 03:57:35 +00:00
"execution_count": 31,
2026-06-28 14:39:18 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-06-30 03:57:35 +00:00
"execution_count": 31
2026-06-28 14:39:18 +00:00
},
{
2026-06-30 03:57:35 +00:00
"cell_type": "code",
"id": "49877c929741ec94",
2026-06-28 14:39:18 +00:00
"metadata": {
"ExecuteTime": {
2026-06-30 03:57:35 +00:00
"end_time": "2026-06-30T03:22:58.957157613Z",
"start_time": "2026-06-30T03:22:58.850698625Z"
2026-06-28 14:39:18 +00:00
}
},
"source": [
"batch_size, num_queries = 2, 4\n",
"num_kvpairs, valid_lens = 6, torch.tensor([3, 2])\n",
"X = torch.ones((batch_size, num_queries, num_hiddens))\n",
"Y = torch.ones((batch_size, num_kvpairs, num_hiddens))\n",
"attention(X, Y, Y, valid_lens).shape"
],
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([2, 4, 100])"
]
},
2026-06-30 03:57:35 +00:00
"execution_count": 32,
2026-06-28 14:39:18 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2026-06-30 03:57:35 +00:00
"execution_count": 32
},
{
"cell_type": "code",
"id": "b8a10576c7789750",
"metadata": {
"ExecuteTime": {
"end_time": "2026-06-30T03:24:09.856237679Z",
"start_time": "2026-06-30T03:24:09.793761501Z"
}
},
"source": [
"class PositionalEncoding(nn.Module):\n",
" def __init__(self,num_hiddens,dropout,max_len=1000):\n",
" super().__init__()\n",
" self.dropout = nn.Dropout(dropout)\n",
" self.dropout = nn.Dropout(dropout)\n",
" X = torch.arange(max_len, dtype=torch.float32).reshape(\n",
" -1, 1) / torch.pow(10000, torch.arange(\n",
" 0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)\n",
" self.P = torch.zeros((1,max_len,num_hiddens))\n",
" self.P[:,:,0::2] = torch.sin(X)\n",
" self.P[:,:,1::2] = torch.cos(X)\n",
" def forward(self,X):\n",
" X = X + self.P[:,:X.shape[1],:].to(X.device)\n",
" return self.dropout(X)"
],
"outputs": [],
"execution_count": 35
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-06-30T03:24:10.924435124Z",
"start_time": "2026-06-30T03:24:10.759631840Z"
}
},
"cell_type": "code",
"source": [
"encoding_dim, num_steps = 32, 60\n",
"pos_encoding = PositionalEncoding(encoding_dim, 0)\n",
"pos_encoding.eval()\n",
"X = pos_encoding(torch.zeros((1, num_steps, encoding_dim)))\n",
"P = pos_encoding.P[:, :X.shape[1], :]\n",
"d2l.plot(torch.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)',\n",
" figsize=(6, 2.5), legend=[\"Col %d\" % d for d in torch.arange(6, 10)])"
],
"id": "fe6cd8d94205977f",
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 600x250 with 1 Axes>"
],
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"380.482812pt\" height=\"183.35625pt\" viewBox=\"0 0 380.482812 183.35625\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2026-06-30T11:24:10.850286</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.7.2, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M -0 183.35625 \nL 380.482812 183.35625 \nL 380.482812 0 \nL -0 0 \nz\n\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 38.482813 145.8 \nL 373.282813 145.8 \nL 373.282813 7.2 \nL 38.482813 7.2 \nz\n\"/>\n </g>\n <g id=\"matplotlib.axis_1\">\n <g id=\"xtick_1\">\n <g id=\"line2d_1\">\n <path d=\"M 53.700994 145.8 \nL 53.700994 7.2 \n\" clip-path=\"url(#pf6cc2d1b0c)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_2\">\n <defs>\n <path id=\"mece1b97dac\" d=\"M 0 0 \nL 0 3.5 \n\" style=\"stroke: #ffffff; stroke-width: 0.8\"/>\n </defs>\n <g>\n <use xlink:href=\"#mece1b97dac\" x=\"53.700994\" y=\"145.8\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_1\">\n <!-- 0 -->\n <g style=\"fill: #ffffff\" transform=\"translate(50.519744 160.398438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-30\" d=\"M 2034 4250 \nQ 1547 4250 1301 3770 \nQ 1056 3291 1056 2328 \nQ 1056 1369 1301 889 \nQ 1547 409 2034 409 \nQ 2525 409 2770 889 \nQ 3016 1369 3016 2328 \nQ 3016 3291 2770 3770 \nQ 2525 4250 2034 4250 \nz\nM 2034 4750 \nQ 2819 4750 3233 4129 \nQ 3647 3509 3647 2328 \nQ 3647 1150 3233 529 \nQ 2819 -91 2034 -91 \nQ 1250 -91 836 529 \nQ 422 1150 422 2328 \nQ 422 3509 836 4129 \nQ 1250 4750 2034 4750 \nz\n\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-30\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_2\">\n <g id=\"line2d_3\">\n <path d=\"M 105.288051 145.8 \nL 105.288051 7.2 \n\" clip-path=\"url(#pf6cc2d1b0c)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_4\">\n <g>\n <use xlink:href=\"#mece1b97dac\" x=\"105.288051\" y=\"145.8\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_2\">\n <!-- 10 -->\n <g style=\"fill: #ffffff\" transform=\"translate(98.925551 160.398438) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-31\" d=\"M 794 531 \nL 1825 531 \nL 1825 4091 \nL 703 3866 \nL 703 4441 \nL 1819 4666 \nL 2450 4666 \nL 2450 531 \nL 3481 531 \nL 3481 0 \nL 794 0 \nL 794 531 \nz\n\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-31\"/>\n <use xlink:href=\"#DejaVuSans-30\" x=\"63.623047\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_3\">\n <g id=\"line2d_5\">\n <path d=\"M 156.875108 145.8 \nL 156.875108 7.2 \n\" clip-path=\"url(#pf6cc2d1b0c)\" style=\"fill: none; stroke: #ffffff; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"line2d_6\">\n <g>\n <use xlink:href=\"#mece1b97dac\" x=\"156.875108\" y=\"145.8\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_3\">\n <!-- 20 -->\n <g style=\"fi
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
}
],
"execution_count": 36
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-06-30T03:29:41.868213373Z",
"start_time": "2026-06-30T03:29:41.818146905Z"
}
},
"cell_type": "code",
"source": [
"class PositionWiseFFN(nn.Module):\n",
" def __init__(self,ffn_num_input,ffn_num_hiddens,ffn_num_outputs,**kwargs):\n",
" super().__init__(**kwargs)\n",
" self.dense1 = nn.Linear(ffn_num_input,ffn_num_hiddens)\n",
" self.relu = nn.ReLU()\n",
" self.dense2 = nn.Linear(ffn_num_hiddens,ffn_num_outputs)\n",
" def forward(self,X):\n",
" return self.dense2(self.relu(self.dense1(X)))\n"
],
"id": "b764f19d408049b",
"outputs": [],
"execution_count": 38
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-06-30T03:35:03.713494687Z",
"start_time": "2026-06-30T03:35:03.630606833Z"
}
},
"cell_type": "code",
"source": [
"class AddNorm(nn.Module):\n",
" def __init__(self,normalized_shape,dropout,**kwargs):\n",
" super().__init__()\n",
" self.dropout=nn.Dropout()\n",
" self.ln = nn.LayerNorm(normalized_shape)\n",
" def forward(self,X,Y):\n",
" return self.ln(self.dropout(Y)+X)\n"
],
"id": "377f4d9eea428068",
"outputs": [],
"execution_count": 39
2026-06-28 14:39:18 +00:00
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
2026-06-30 03:57:35 +00:00
"source": [
"class EncoderBlock(nn.Moudle):\n",
" def __init(self,key_size,query_size,value_size,num_hiddens,norm_shape,ffn_num_input,ffn_num_hiddens,num_heads,dropout,use_bias=False,**kwargs):\n",
" super(EncoderBlock,self).__init__(**kwargs)\n",
" self.attention = MultiHeadAttention(key_size,query_size,value_size,num_hiddens,num_heads,dropout,bias=use_bias)\n",
" self.addnorm1 = AddNorm(norm_shape,dropout)\n",
" self.ffn = PositionWiseFFN(\n",
" ffn_num_input, ffn_num_hiddens, num_hiddens)\n",
" self.addnorm2 = AddNorm(norm_shape, dropout)\n",
" def forward(self, X, valid_lens):\n",
" Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))\n",
" return self.addnorm2(Y, self.ffn(Y))"
],
"id": "c54dc4e20e9ffb90"
2026-06-28 14:39:18 +00:00
}
],
"metadata": {
"kernelspec": {
2026-06-30 03:57:35 +00:00
"display_name": "Python 3 (ipykernel)",
2026-06-28 14:39:18 +00:00
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}