From b8d13f49968562db05f80b46d423fd02de7839cd Mon Sep 17 00:00:00 2001 From: yukun-hh Date: Sun, 28 Jun 2026 22:39:18 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E6=88=90=E7=AC=AC=E5=8D=81=E7=AB=A0?= =?UTF-8?q?=E5=AD=A6=E4=B9=A0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- chapter10.ipynb | 1059 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 1059 insertions(+) create mode 100644 chapter10.ipynb diff --git a/chapter10.ipynb b/chapter10.ipynb new file mode 100644 index 0000000..cd1ba10 --- /dev/null +++ b/chapter10.ipynb @@ -0,0 +1,1059 @@ +{ + "cells": [ + { + "cell_type": "code", + "id": "initial_id", + "metadata": { + "collapsed": true, + "ExecuteTime": { + "end_time": "2026-05-29T15:19:42.541639835Z", + "start_time": "2026-05-29T15:19:40.265547407Z" + } + }, + "source": [ + "import torch\n", + "from d2l import torch as d2l\n" + ], + "outputs": [], + "execution_count": 2 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-29T15:19:42.593785744Z", + "start_time": "2026-05-29T15:19:42.545273168Z" + } + }, + "cell_type": "code", + "source": [ + "def show_heatmaps(matrices,xlabel,ylabel,titles=None,figsize=(2.5,2.5),cmap='Reds'):\n", + " d2l.use_svg_display()\n", + " num_rows,num_cols = matrices.shape[0],matrices.shape[1]\n", + " fig,axes = d2l.plt.subplots(num_rows,num_cols,figsize=figsize,sharex=True,squeeze=False)\n", + " for i,(row_axes,row_matrices) in enumerate(zip(axes,matrices)):\n", + " for j,(ax,matrix) in enumerate(zip(row_axes,row_matrices)):\n", + " pcm = ax.imshow(matrix.detach().numpy(),cmap=cmap)\n", + " if i == num_rows - 1:\n", + " ax.set_xlabel(xlabel)\n", + " if j == 0:\n", + " ax.set_ylabel(ylabel)\n", + " if titles:\n", + " ax.set_title(titles[j])\n", + " fig.colorbar(pcm,ax=axes,shrink=0.6)" + ], + "id": "a0e3f725b7764f08", + "outputs": [], + "execution_count": 3 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-29T15:19:42.832298248Z", + "start_time": "2026-05-29T15:19:42.594979091Z" + } + }, + "cell_type": "code", + "source": [ + "attention_weights = torch.eye(10).reshape((1, 1, 10, 10))\n", + "\n", + "show_heatmaps(attention_weights, xlabel='Keys', ylabel='Queries')" + ], + "id": "bb798de34fc4d0fa", + "outputs": [ + { + "data": { + "text/plain": [ + "
" + ], + "image/svg+xml": "\n\n\n \n \n \n \n 2026-05-29T23:19:42.792345\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" + }, + "metadata": {}, + "output_type": "display_data", + "jetTransient": { + "display_id": null + } + } + ], + "execution_count": 4 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-29T15:19:42.925242334Z", + "start_time": "2026-05-29T15:19:42.844425757Z" + } + }, + "cell_type": "code", + "source": [ + "n_train = 50\n", + "x_train,_ = torch.sort(torch.rand(n_train)*5)" + ], + "id": "a4070c75847fb887", + "outputs": [], + "execution_count": 5 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-29T15:19:42.978395734Z", + "start_time": "2026-05-29T15:19:42.926738710Z" + } + }, + "cell_type": "code", + "source": [ + "def f(x):\n", + " return 2*torch.sin(x)+x**0.8" + ], + "id": "2aca6952876d9cf5", + "outputs": [], + "execution_count": 6 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-29T15:19:43.056742866Z", + "start_time": "2026-05-29T15:19:42.982949395Z" + } + }, + "cell_type": "code", + "source": [ + "y_train = f(x_train)+torch.normal(0.0,0.5,(n_train,))\n", + "x_test = torch.arange(0,5,0.1)\n", + "y_truth = f(x_test)\n", + "n_test = len(x_test)\n", + "n_test" + ], + "id": "ea7b94585dcde934", + "outputs": [ + { + "data": { + "text/plain": [ + "50" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 7 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-29T15:19:43.117629637Z", + "start_time": "2026-05-29T15:19:43.058324467Z" + } + }, + "cell_type": "code", + "source": [ + "def plot_kernel_reg(y_hat):\n", + " d2l.plot(x_test, [y_truth, y_hat], 'x', 'y', legend=['Truth', 'Pred'],\n", + " xlim=[0, 5], ylim=[-1, 5])\n", + " d2l.plt.plot(x_train, y_train, 'o', alpha=0.5);" + ], + "id": "35c2403bbb5509cf", + "outputs": [], + "execution_count": 8 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-29T15:19:43.224202029Z", + "start_time": "2026-05-29T15:19:43.119223163Z" + } + }, + "cell_type": "code", + "source": [ + "y_hat = torch.repeat_interleave(y_train.mean(),n_test)\n", + "plot_kernel_reg(y_hat)" + ], + "id": "f27ecb9c68aed894", + "outputs": [ + { + "data": { + "text/plain": [ + "
" + ], + "image/svg+xml": "\n\n\n \n \n \n \n 2026-05-29T23:19:43.188981\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" + }, + "metadata": {}, + "output_type": "display_data", + "jetTransient": { + "display_id": null + } + } + ], + "execution_count": 9 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-29T15:19:43.317103476Z", + "start_time": "2026-05-29T15:19:43.225907041Z" + } + }, + "cell_type": "code", + "source": [ + "from torch import nn\n", + "X_repeat = x_test.repeat_interleave(n_train).reshape((-1, n_train))\n", + "attention_weights = nn.functional.softmax(-(X_repeat - x_train)**2 / 2, dim=1)\n", + "y_hat = torch.matmul(attention_weights,y_train)\n", + "x_train,X_repeat,attention_weights" + ], + "id": "fa7579d7cd8bfcb7", + "outputs": [ + { + "data": { + "text/plain": [ + "(tensor([9.4801e-04, 5.2621e-02, 2.3643e-01, 2.9541e-01, 3.2529e-01, 3.5294e-01,\n", + " 3.6399e-01, 4.1190e-01, 5.8751e-01, 6.8494e-01, 7.2698e-01, 8.0101e-01,\n", + " 9.2743e-01, 9.7225e-01, 1.1656e+00, 1.4397e+00, 1.4788e+00, 1.6088e+00,\n", + " 1.7147e+00, 1.8255e+00, 1.8951e+00, 1.9322e+00, 2.3849e+00, 2.4183e+00,\n", + " 2.4902e+00, 2.5441e+00, 2.6483e+00, 2.8371e+00, 2.9705e+00, 3.1383e+00,\n", + " 3.2709e+00, 3.3751e+00, 3.3849e+00, 3.4902e+00, 3.5679e+00, 3.5782e+00,\n", + " 3.5821e+00, 3.8648e+00, 4.1552e+00, 4.3914e+00, 4.4348e+00, 4.4414e+00,\n", + " 4.6038e+00, 4.6099e+00, 4.7481e+00, 4.7528e+00, 4.8005e+00, 4.8134e+00,\n", + " 4.8934e+00, 4.9044e+00]),\n", + " tensor([[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", + " [0.1000, 0.1000, 0.1000, ..., 0.1000, 0.1000, 0.1000],\n", + " [0.2000, 0.2000, 0.2000, ..., 0.2000, 0.2000, 0.2000],\n", + " ...,\n", + " [4.7000, 4.7000, 4.7000, ..., 4.7000, 4.7000, 4.7000],\n", + " [4.8000, 4.8000, 4.8000, ..., 4.8000, 4.8000, 4.8000],\n", + " [4.9000, 4.9000, 4.9000, ..., 4.9000, 4.9000, 4.9000]]),\n", + " tensor([[6.8662e-02, 6.8567e-02, 6.6769e-02, ..., 6.3916e-07, 4.3363e-07,\n", + " 4.1091e-07],\n", + " [6.4257e-02, 6.4501e-02, 6.3975e-02, ..., 9.6787e-07, 6.6191e-07,\n", + " 6.2792e-07],\n", + " [5.9935e-02, 6.0473e-02, 6.1093e-02, ..., 1.4608e-06, 1.0070e-06,\n", + " 9.5634e-07],\n", + " ...,\n", + " [9.4878e-07, 1.2079e-06, 2.7906e-06, ..., 5.8779e-02, 5.8062e-02,\n", + " 5.7935e-02],\n", + " [6.1962e-07, 7.9294e-07, 1.8658e-06, ..., 6.2113e-02, 6.1848e-02,\n", + " 6.1781e-02],\n", + " [4.0301e-07, 5.1842e-07, 1.2425e-06, ..., 6.5370e-02, 6.5614e-02,\n", + " 6.5615e-02]]))" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 10 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-29T15:19:43.435846223Z", + "start_time": "2026-05-29T15:19:43.319732830Z" + } + }, + "cell_type": "code", + "source": "plot_kernel_reg(y_hat)", + "id": "f76573b2eb74ba7", + "outputs": [ + { + "data": { + "text/plain": [ + "
" + ], + "image/svg+xml": "\n\n\n \n \n \n \n 2026-05-29T23:19:43.394543\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" + }, + "metadata": {}, + "output_type": "display_data", + "jetTransient": { + "display_id": null + } + } + ], + "execution_count": 11 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-29T15:19:43.558539946Z", + "start_time": "2026-05-29T15:19:43.438360863Z" + } + }, + "cell_type": "code", + "source": [ + "d2l.show_heatmaps(attention_weights.unsqueeze(0).unsqueeze(0),\n", + " xlabel='Sorted training inputs',\n", + " ylabel='Sorted testing inputs')" + ], + "id": "1bbe3af9e8a0f412", + "outputs": [ + { + "data": { + "text/plain": [ + "
" + ], + "image/svg+xml": "\n\n\n \n \n \n \n 2026-05-29T23:19:43.512317\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" + }, + "metadata": {}, + "output_type": "display_data", + "jetTransient": { + "display_id": null + } + } + ], + "execution_count": 12 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-29T15:19:43.610782245Z", + "start_time": "2026-05-29T15:19:43.560607937Z" + } + }, + "cell_type": "code", + "source": [ + "class NWKernelRegression(nn.Module):\n", + " def __init__(self,**kwargs):\n", + " super().__init__(**kwargs)\n", + " self.w = nn.Parameter(torch.randn((1,),requires_grad=True))\n", + " def forward(self,queries,keys,values):\n", + " queries = queries.repeat_interleave(keys.shape[1]).reshape(-1,keys.shape[1])\n", + " self.attention_weights = nn.functional.softmax(\n", + " -((queries - keys) * self.w)**2 / 2, dim=1)\n", + " return torch.bmm(self.attention_weights.unsqueeze(1),values.unsqueeze(-1)).reshape(-1)" + ], + "id": "aeb79a3bd64c826", + "outputs": [], + "execution_count": 13 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-29T15:19:43.661276280Z", + "start_time": "2026-05-29T15:19:43.611752619Z" + } + }, + "cell_type": "code", + "source": [ + "# X_tile的形状:(n_train,n_train),每一行都包含着相同的训练输入\n", + "X_tile = x_train.repeat((n_train, 1))\n", + "# Y_tile的形状:(n_train,n_train),每一行都包含着相同的训练输出\n", + "Y_tile = y_train.repeat((n_train, 1))\n", + "# keys的形状:('n_train','n_train'-1)\n", + "keys = X_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))\n", + "# values的形状:('n_train','n_train'-1)\n", + "values = Y_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))" + ], + "id": "7061ed7f139bc78e", + "outputs": [], + "execution_count": 14 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-29T15:19:44.080669795Z", + "start_time": "2026-05-29T15:19:43.662206275Z" + } + }, + "cell_type": "code", + "source": [ + "net = NWKernelRegression()\n", + "loss = nn.MSELoss(reduction='none')\n", + "trainer = torch.optim.SGD(net.parameters(), lr=0.5)\n", + "animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[1, 5])\n", + "for epoch in range(5):\n", + " trainer.zero_grad()\n", + " l = loss(net(x_train, keys, values), y_train)\n", + " l.sum().backward()\n", + " trainer.step()\n", + " print(f'epoch {epoch + 1}, loss {float(l.sum()):.6f}')\n", + " animator.add(epoch + 1, float(l.sum()))" + ], + "id": "3dca11d4e962d859", + "outputs": [ + { + "data": { + "text/plain": [ + "
" + ], + "image/svg+xml": "\n\n\n \n \n \n \n 2026-05-29T23:19:44.050719\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" + }, + "metadata": {}, + "output_type": "display_data", + "jetTransient": { + "display_id": null + } + } + ], + "execution_count": 15 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-29T15:19:44.196985129Z", + "start_time": "2026-05-29T15:19:44.082458369Z" + } + }, + "cell_type": "code", + "source": [ + "# keys的形状:(n_test,n_train),每一行包含着相同的训练输入(例如,相同的键)\n", + "keys = x_train.repeat((n_test, 1))\n", + "# value的形状:(n_test,n_train)\n", + "values = y_train.repeat((n_test, 1))\n", + "y_hat = net(x_test, keys, values).unsqueeze(1).detach()\n", + "plot_kernel_reg(y_hat)" + ], + "id": "4ff00f0b4983b55e", + "outputs": [ + { + "data": { + "text/plain": [ + "
" + ], + "image/svg+xml": "\n\n\n \n \n \n \n 2026-05-29T23:19:44.158025\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" + }, + "metadata": {}, + "output_type": "display_data", + "jetTransient": { + "display_id": null + } + } + ], + "execution_count": 16 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-29T15:19:44.303651198Z", + "start_time": "2026-05-29T15:19:44.199584310Z" + } + }, + "cell_type": "code", + "source": [ + "d2l.show_heatmaps(net.attention_weights.unsqueeze(0).unsqueeze(0),\n", + "xlabel='Sorted training inputs',\n", + "ylabel='Sorted testing inputs')" + ], + "id": "97a690478767c3c3", + "outputs": [ + { + "data": { + "text/plain": [ + "
" + ], + "image/svg+xml": "\n\n\n \n \n \n \n 2026-05-29T23:19:44.256961\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" + }, + "metadata": {}, + "output_type": "display_data", + "jetTransient": { + "display_id": null + } + } + ], + "execution_count": 17 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-29T15:19:44.358514818Z", + "start_time": "2026-05-29T15:19:44.306386112Z" + } + }, + "cell_type": "code", + "source": [ + "def masked_softmax(X,valid_lens):\n", + " # X:3D valid_len 1D or 2D\n", + " if valid_lens is None:\n", + " return nn.functional.softmax(X,dim=-1)\n", + " else:\n", + " shape = X.shape\n", + " if valid_lens.dim()==1:\n", + " valid_lens = torch.repeat_interleave(valid_lens,shape[1])\n", + " else:\n", + " valid_lens = valid_lens.reshape(-1)\n", + " #print(valid_lens)\n", + " X = d2l.sequence_mask(X.reshape(-1,shape[-1]),valid_lens,value=-1e6)\n", + " return nn.functional.softmax(X.reshape(shape),dim=-1)" + ], + "id": "8d03d0c7f735b755", + "outputs": [], + "execution_count": 18 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-29T15:19:44.420710191Z", + "start_time": "2026-05-29T15:19:44.359351187Z" + } + }, + "cell_type": "code", + "source": "masked_softmax(torch.rand(2,2,4),torch.tensor([2,3]))", + "id": "27f8c24ec572b267", + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[0.6937, 0.3063, 0.0000, 0.0000],\n", + " [0.6165, 0.3835, 0.0000, 0.0000]],\n", + "\n", + " [[0.3526, 0.3291, 0.3183, 0.0000],\n", + " [0.2735, 0.2359, 0.4906, 0.0000]]])" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 19 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-29T15:19:44.470391356Z", + "start_time": "2026-05-29T15:19:44.422576814Z" + } + }, + "cell_type": "code", + "source": [ + "class AdditiveAttention(nn.Module):\n", + " def __init__(self,key_size,query_size,num_hiddens,dropout,**kwargs):\n", + " super(AdditiveAttention,self).__init__(**kwargs)\n", + " self.W_k = nn.Linear(key_size,num_hiddens,bias=False)\n", + " self.W_q = nn.Linear(query_size,num_hiddens,bias=False)\n", + " self.w_v = nn.Linear(num_hiddens,1,bias=False)\n", + " self.dropout=nn.Dropout(dropout)\n", + " def forward(self,queries,keys,value,valid_lens):\n", + " queries,keys=self.W_q(queries),self.W_k(keys)\n", + " # queries (batch_size,n_q,1,num_hidden)\n", + " # key (batch_size,1,n_k,num_hiddens)\n", + " features = queries.unsqueeze(2) + keys.unsqueeze(1)\n", + " #features (batch_size,n_q,n_k,num_hidden)\n", + " features = torch.tanh(features)\n", + " scores = self.w_v(features).squeeze(-1)\n", + " #print(f\"Inside AdditiveAttention: value.shape = {value.shape}\") # 检查此处形状\n", + " self.attention_weights = masked_softmax(scores, valid_lens)\n", + "\n", + " return torch.bmm(self.dropout(self.attention_weights), value)" + ], + "id": "c7e5254395b14b1", + "outputs": [], + "execution_count": 20 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-29T15:19:44.536311739Z", + "start_time": "2026-05-29T15:19:44.471725853Z" + } + }, + "cell_type": "code", + "source": [ + "queries,keys = torch.normal(0,1,(2,1,20)) , torch.ones((2,10,2))\n", + "values = torch.arange(40,dtype=torch.float32).reshape(1,10,4).repeat(2,1,1)\n", + "valid_lens = torch.tensor([2, 6])\n", + "attention = AdditiveAttention(key_size=2, query_size=20, num_hiddens=8,\n", + " dropout=0.1)\n", + "attention.eval()\n", + "attention(queries,keys,values,valid_lens)" + ], + "id": "77e325a032e70d98", + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[ 2.0000, 3.0000, 4.0000, 5.0000]],\n", + "\n", + " [[10.0000, 11.0000, 12.0000, 13.0000]]], grad_fn=)" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 21 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-29T15:19:44.618227172Z", + "start_time": "2026-05-29T15:19:44.538570102Z" + } + }, + "cell_type": "code", + "source": [ + "d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),\n", + " xlabel='Keys', ylabel='Queries')" + ], + "id": "93ceedfd6e03982f", + "outputs": [ + { + "data": { + "text/plain": [ + "
" + ], + "image/svg+xml": "\n\n\n \n \n \n \n 2026-05-29T23:19:44.586919\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" + }, + "metadata": {}, + "output_type": "display_data", + "jetTransient": { + "display_id": null + } + } + ], + "execution_count": 22 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-29T15:19:44.668038302Z", + "start_time": "2026-05-29T15:19:44.620075434Z" + } + }, + "cell_type": "code", + "source": [ + "import math\n", + "class DotProductAttention(nn.Module):\n", + " \"\"\"缩放点积注意力\"\"\"\n", + " def __init__(self, dropout, **kwargs):\n", + " super(DotProductAttention, self).__init__(**kwargs)\n", + " self.dropout = nn.Dropout(dropout)\n", + "# queries的形状:(batch_size,查询的个数,d)\n", + "# keys的形状:(batch_size,“键-值”对的个数,d)\n", + "# values的形状:(batch_size,“键-值”对的个数,值的维度)\n", + "# valid_lens的形状:(batch_size,)或者(batch_size,查询的个数)\n", + " def forward(self, queries, keys, values, valid_lens=None):\n", + " d = queries.shape[-1]\n", + " # 设置transpose_b=True为了交换keys的最后两个维度\n", + " scores = torch.bmm(queries, keys.transpose(1,2)) / math.sqrt(d)\n", + " self.attention_weights = masked_softmax(scores, valid_lens)\n", + "\n", + " return torch.bmm(self.dropout(self.attention_weights), values)" + ], + "id": "8453c76623a5b435", + "outputs": [], + "execution_count": 23 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-29T15:19:44.754268677Z", + "start_time": "2026-05-29T15:19:44.668929892Z" + } + }, + "cell_type": "code", + "source": [ + "queries = torch.normal(0, 1, (2, 1, 2))\n", + "attention = DotProductAttention(dropout=0.5)\n", + "attention.eval()\n", + "attention(queries, keys, values, valid_lens)" + ], + "id": "8af6b28944977a62", + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[ 2.0000, 3.0000, 4.0000, 5.0000]],\n", + "\n", + " [[10.0000, 11.0000, 12.0000, 13.0000]]])" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 24 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-29T15:19:44.822905172Z", + "start_time": "2026-05-29T15:19:44.756082157Z" + } + }, + "cell_type": "code", + "source": [ + "d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),\n", + "xlabel='Keys', ylabel='Queries')" + ], + "id": "28f7e16771076e33", + "outputs": [ + { + "data": { + "text/plain": [ + "
" + ], + "image/svg+xml": "\n\n\n \n \n \n \n 2026-05-29T23:19:44.802392\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" + }, + "metadata": {}, + "output_type": "display_data", + "jetTransient": { + "display_id": null + } + } + ], + "execution_count": 25 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-29T15:19:44.874393721Z", + "start_time": "2026-05-29T15:19:44.824430337Z" + } + }, + "cell_type": "code", + "source": [ + "class AttentionDecoder(d2l.Decoder):\n", + " def __init__(self,**kwargs):\n", + " super(AttentionDecoder, self).__init__(**kwargs)\n", + " @property\n", + " def attention_weight(self):\n", + " raise NotImplementedError" + ], + "id": "33507bd2e1917a29", + "outputs": [], + "execution_count": 26 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-29T15:19:44.926627628Z", + "start_time": "2026-05-29T15:19:44.875567483Z" + } + }, + "cell_type": "code", + "source": [ + "class Seq2SeqAttentionDecoder(AttentionDecoder):\n", + " def __init__(self,vocab_size,embed_size,num_hiddens,num_layers,dropout=0,**kwargs):\n", + " super(Seq2SeqAttentionDecoder, self).__init__(**kwargs)\n", + " self.attention = AdditiveAttention(\n", + " num_hiddens, num_hiddens, num_hiddens,dropout)\n", + " self.embedding = nn.Embedding(vocab_size, embed_size)\n", + " self.rnn = nn.GRU(\n", + " embed_size + num_hiddens, num_hiddens, num_layers,\n", + " dropout=dropout)\n", + " self.dense = nn.Linear(num_hiddens, vocab_size)\n", + " def init_state(self, enc_outputs, enc_valid_lens, *args):\n", + "# outputs的形状为(batch_size,num_steps,num_hiddens).\n", + "# hidden_state的形状为(num_layers,batch_size,num_hiddens)\n", + " outputs, hidden_state = enc_outputs\n", + " #print(f\"Encoder outputs shape before permute: {outputs.shape}\") # 应为 (num_steps, batch_size, num_hiddens) 或 (batch_size, num_steps, num_hiddens)\n", + " enc_outputs_permuted = outputs.permute(1, 0, 2)\n", + " #print(f\"After permute: {enc_outputs_permuted.shape}\") # 期望 (batch_size, num_steps, num_hiddens)\n", + " return (enc_outputs_permuted, hidden_state, enc_valid_lens)\n", + " def forward(self, X, state):\n", + " # enc_outputs的形状为(batch_size,num_steps,num_hiddens).\n", + " # hidden_state的形状为(num_layers,batch_size,\n", + " # num_hiddens)\n", + " enc_outputs, hidden_state, enc_valid_lens = state\n", + " # 输出X的形状为(num_steps,batch_size,embed_size)\n", + " X = self.embedding(X).permute(1, 0, 2)\n", + " outputs, self._attention_weights = [], []\n", + " for x in X:\n", + " # query的形状为(batch_size,1,num_hiddens)\n", + " query = torch.unsqueeze(hidden_state[-1], dim=1)\n", + " # context的形状为(batch_size,1,num_hiddens)\n", + " #print(f\"values shape before attention: {enc_outputs.shape}\") # 应为 (4, 7, 16)\n", + " context = self.attention(\n", + " query, enc_outputs, enc_outputs, enc_valid_lens)\n", + " # 在特征维度上连结\n", + " x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)\n", + " # 将x变形为(1,batch_size,embed_size+num_hiddens)\n", + " out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)\n", + " outputs.append(out)\n", + " self._attention_weights.append(self.attention.attention_weights)\n", + " # 全连接层变换后,outputs的形状为\n", + " # (num_steps,batch_size,vocab_size)\n", + " outputs = self.dense(torch.cat(outputs, dim=0))\n", + " return outputs.permute(1, 0, 2), [enc_outputs, hidden_state,\n", + " enc_valid_lens]\n", + " @property\n", + " def attention_weights(self):\n", + " return self._attention_weights" + ], + "id": "a8a497c9041910a1", + "outputs": [], + "execution_count": 27 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-29T15:19:44.996664037Z", + "start_time": "2026-05-29T15:19:44.927689328Z" + } + }, + "cell_type": "code", + "source": [ + "encoder = d2l.Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16,\n", + "num_layers=2)\n", + "encoder.eval()\n", + "decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8, num_hiddens=16,\n", + "num_layers=2)\n", + "decoder.eval()\n", + "X = torch.zeros((4, 7), dtype=torch.long) # (batch_size,num_steps)\n", + "state = decoder.init_state(encoder(X), None)\n", + "output, state = decoder(X, state)\n", + "output.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape" + ], + "id": "b2a6bc735743bd0f", + "outputs": [ + { + "data": { + "text/plain": [ + "(torch.Size([4, 7, 10]), 3, torch.Size([4, 7, 16]), 2, torch.Size([4, 16]))" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 28 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-29T15:19:45.047118759Z", + "start_time": "2026-05-29T15:19:44.998053092Z" + } + }, + "cell_type": "code", + "source": "", + "id": "4c7bf86095f15c98", + "outputs": [], + "execution_count": 28 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-29T15:20:14.678359202Z", + "start_time": "2026-05-29T15:19:45.048442204Z" + } + }, + "cell_type": "code", + "source": [ + "embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1\n", + "batch_size, num_steps = 64, 10\n", + "lr, num_epochs, device = 0.005, 250, d2l.try_gpu()\n", + "train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)\n", + "encoder = d2l.Seq2SeqEncoder(\n", + "len(src_vocab), embed_size, num_hiddens, num_layers, dropout)\n", + "decoder = Seq2SeqAttentionDecoder(\n", + "len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)\n", + "net = d2l.EncoderDecoder(encoder, decoder)\n", + "d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)" + ], + "id": "5f5cedf74b97bd12", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loss 0.019, 17862.0 tokens/sec on cuda:0\n" + ] + }, + { + "data": { + "text/plain": [ + "
" + ], + "image/svg+xml": "\n\n\n \n \n \n \n 2026-05-29T23:20:14.649299\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" + }, + "metadata": {}, + "output_type": "display_data", + "jetTransient": { + "display_id": null + } + } + ], + "execution_count": 29 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-29T15:20:14.794039857Z", + "start_time": "2026-05-29T15:20:14.727727180Z" + } + }, + "cell_type": "code", + "source": [ + "engs = ['go .', \"i lost .\", 'he\\'s calm .', 'i\\'m home .']\n", + "fras = ['va !', 'j\\'ai perdu .', 'il est calme .', 'je suis chez moi .']\n", + "for eng, fra in zip(engs, fras):\n", + " translation, dec_attention_weight_seq = d2l.predict_seq2seq(\n", + " net, eng, src_vocab, tgt_vocab, num_steps, device, True)\n", + " print(f'{eng} => {translation}, ',\n", + " f'bleu {d2l.bleu(translation, fra, k=2):.3f}')" + ], + "id": "ddaed8bd284ee01b", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "go . => va !, bleu 1.000\n", + "i lost . => j'ai perdu ., bleu 1.000\n", + "he's calm . => il est ., bleu 0.658\n", + "i'm home . => je suis chez moi ., bleu 1.000\n" + ] + } + ], + "execution_count": 30 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-29T15:20:38.342057516Z", + "start_time": "2026-05-29T15:20:38.288567783Z" + } + }, + "cell_type": "code", + "source": [ + "class MultiHeadAttention(nn.Module):\n", + " def __init__(self, key_size, query_size, value_size, num_hiddens,\n", + " num_heads, dropout, bias=False, **kwargs):\n", + " super(MultiHeadAttention, self).__init__(**kwargs)\n", + " self.num_heads = num_heads\n", + " self.attention = d2l.DotProductAttention(dropout)\n", + " self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)\n", + " self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)\n", + " self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)\n", + " self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)\n", + " def forward(self, queries, keys, values, valid_lens):\n", + " # 1. 线性投影 + 变换形状以分割多头\n", + " queries = transpose_qkv(self.W_q(queries), self.num_heads)\n", + " keys = transpose_qkv(self.W_k(keys), self.num_heads)\n", + " values = transpose_qkv(self.W_v(values), self.num_heads)\n", + "\n", + " # 2. 处理有效长度掩码(valid_lens)以适配多头\n", + " if valid_lens is not None:\n", + " valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)\n", + "\n", + " # 3. 计算注意力(每个头独立计算)\n", + " output = self.attention(queries, keys, values, valid_lens)\n", + "\n", + " # 4. 合并多头,并通过输出线性层\n", + " output_concat = transpose_output(output, self.num_heads)\n", + " return self.W_o(output_concat)\n", + "def transpose_qkv(X, num_heads):\n", + " \"\"\"为了多注意力头的并行计算而变换形状\"\"\"\n", + " # 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens)\n", + " # 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,\n", + " # num_hiddens/num_heads)\n", + " X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)\n", + " # 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数,\n", + " # num_hiddens/num_heads)\n", + " X = X.permute(0, 2, 1, 3)\n", + " # 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数,\n", + " # num_hiddens/num_heads)\n", + " return X.reshape(-1, X.shape[2], X.shape[3])\n", + "\n", + "def transpose_output(X, num_heads):\n", + " \"\"\"逆转transpose_qkv函数的操作\"\"\"\n", + " X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])\n", + " X = X.permute(0, 2, 1, 3)\n", + " return X.reshape(X.shape[0], X.shape[1], -1)\n" + ], + "id": "13d19f8b5048c74d", + "outputs": [], + "execution_count": 32 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-29T15:20:39.251763154Z", + "start_time": "2026-05-29T15:20:39.201715213Z" + } + }, + "cell_type": "code", + "source": [ + "num_hiddens, num_heads = 100, 5\n", + "attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,\n", + "num_hiddens, num_heads, 0.5)\n", + "attention.eval()" + ], + "id": "b25c661a511d7763", + "outputs": [ + { + "data": { + "text/plain": [ + "MultiHeadAttention(\n", + " (attention): DotProductAttention(\n", + " (dropout): Dropout(p=0.5, inplace=False)\n", + " )\n", + " (W_q): Linear(in_features=100, out_features=100, bias=False)\n", + " (W_k): Linear(in_features=100, out_features=100, bias=False)\n", + " (W_v): Linear(in_features=100, out_features=100, bias=False)\n", + " (W_o): Linear(in_features=100, out_features=100, bias=False)\n", + ")" + ] + }, + "execution_count": 33, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 33 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-29T15:20:40.455647106Z", + "start_time": "2026-05-29T15:20:40.307682855Z" + } + }, + "cell_type": "code", + "source": [ + "batch_size, num_queries = 2, 4\n", + "num_kvpairs, valid_lens = 6, torch.tensor([3, 2])\n", + "X = torch.ones((batch_size, num_queries, num_hiddens))\n", + "Y = torch.ones((batch_size, num_kvpairs, num_hiddens))\n", + "attention(X, Y, Y, valid_lens).shape" + ], + "id": "49877c929741ec94", + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([2, 4, 100])" + ] + }, + "execution_count": 34, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 34 + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": "", + "id": "b8a10576c7789750" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}