{ "cells": [ { "cell_type": "code", "id": "initial_id", "metadata": { "collapsed": true, "ExecuteTime": { "end_time": "2026-05-29T15:19:42.541639835Z", "start_time": "2026-05-29T15:19:40.265547407Z" } }, "source": [ "import torch\n", "from d2l import torch as d2l\n" ], "outputs": [], "execution_count": 2 }, { "metadata": { "ExecuteTime": { "end_time": "2026-05-29T15:19:42.593785744Z", "start_time": "2026-05-29T15:19:42.545273168Z" } }, "cell_type": "code", "source": [ "def show_heatmaps(matrices,xlabel,ylabel,titles=None,figsize=(2.5,2.5),cmap='Reds'):\n", " d2l.use_svg_display()\n", " num_rows,num_cols = matrices.shape[0],matrices.shape[1]\n", " fig,axes = d2l.plt.subplots(num_rows,num_cols,figsize=figsize,sharex=True,squeeze=False)\n", " for i,(row_axes,row_matrices) in enumerate(zip(axes,matrices)):\n", " for j,(ax,matrix) in enumerate(zip(row_axes,row_matrices)):\n", " pcm = ax.imshow(matrix.detach().numpy(),cmap=cmap)\n", " if i == num_rows - 1:\n", " ax.set_xlabel(xlabel)\n", " if j == 0:\n", " ax.set_ylabel(ylabel)\n", " if titles:\n", " ax.set_title(titles[j])\n", " fig.colorbar(pcm,ax=axes,shrink=0.6)" ], "id": "a0e3f725b7764f08", "outputs": [], "execution_count": 3 }, { "metadata": { "ExecuteTime": { "end_time": "2026-05-29T15:19:42.832298248Z", "start_time": "2026-05-29T15:19:42.594979091Z" } }, "cell_type": "code", "source": [ "attention_weights = torch.eye(10).reshape((1, 1, 10, 10))\n", "\n", "show_heatmaps(attention_weights, xlabel='Keys', ylabel='Queries')" ], "id": "bb798de34fc4d0fa", "outputs": [ { "data": { "text/plain": [ "
" ], "image/svg+xml": "\n\n\n \n \n \n \n 2026-05-29T23:19:42.792345\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" }, "metadata": {}, "output_type": "display_data", "jetTransient": { "display_id": null } } ], "execution_count": 4 }, { "metadata": { "ExecuteTime": { "end_time": "2026-05-29T15:19:42.925242334Z", "start_time": "2026-05-29T15:19:42.844425757Z" } }, "cell_type": "code", "source": [ "n_train = 50\n", "x_train,_ = torch.sort(torch.rand(n_train)*5)" ], "id": "a4070c75847fb887", "outputs": [], "execution_count": 5 }, { "metadata": { "ExecuteTime": { "end_time": "2026-05-29T15:19:42.978395734Z", "start_time": "2026-05-29T15:19:42.926738710Z" } }, "cell_type": "code", "source": [ "def f(x):\n", " return 2*torch.sin(x)+x**0.8" ], "id": "2aca6952876d9cf5", "outputs": [], "execution_count": 6 }, { "metadata": { "ExecuteTime": { "end_time": "2026-05-29T15:19:43.056742866Z", "start_time": "2026-05-29T15:19:42.982949395Z" } }, "cell_type": "code", "source": [ "y_train = f(x_train)+torch.normal(0.0,0.5,(n_train,))\n", "x_test = torch.arange(0,5,0.1)\n", "y_truth = f(x_test)\n", "n_test = len(x_test)\n", "n_test" ], "id": "ea7b94585dcde934", "outputs": [ { "data": { "text/plain": [ "50" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 7 }, { "metadata": { "ExecuteTime": { "end_time": "2026-05-29T15:19:43.117629637Z", "start_time": "2026-05-29T15:19:43.058324467Z" } }, "cell_type": "code", "source": [ "def plot_kernel_reg(y_hat):\n", " d2l.plot(x_test, [y_truth, y_hat], 'x', 'y', legend=['Truth', 'Pred'],\n", " xlim=[0, 5], ylim=[-1, 5])\n", " d2l.plt.plot(x_train, y_train, 'o', alpha=0.5);" ], "id": "35c2403bbb5509cf", "outputs": [], "execution_count": 8 }, { "metadata": { "ExecuteTime": { "end_time": "2026-05-29T15:19:43.224202029Z", "start_time": "2026-05-29T15:19:43.119223163Z" } }, "cell_type": "code", "source": [ "y_hat = torch.repeat_interleave(y_train.mean(),n_test)\n", "plot_kernel_reg(y_hat)" ], "id": "f27ecb9c68aed894", "outputs": [ { "data": { "text/plain": [ "
" ], "image/svg+xml": "\n\n\n \n \n \n \n 2026-05-29T23:19:43.188981\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" }, "metadata": {}, "output_type": "display_data", "jetTransient": { "display_id": null } } ], "execution_count": 9 }, { "metadata": { "ExecuteTime": { "end_time": "2026-05-29T15:19:43.317103476Z", "start_time": "2026-05-29T15:19:43.225907041Z" } }, "cell_type": "code", "source": [ "from torch import nn\n", "X_repeat = x_test.repeat_interleave(n_train).reshape((-1, n_train))\n", "attention_weights = nn.functional.softmax(-(X_repeat - x_train)**2 / 2, dim=1)\n", "y_hat = torch.matmul(attention_weights,y_train)\n", "x_train,X_repeat,attention_weights" ], "id": "fa7579d7cd8bfcb7", "outputs": [ { "data": { "text/plain": [ "(tensor([9.4801e-04, 5.2621e-02, 2.3643e-01, 2.9541e-01, 3.2529e-01, 3.5294e-01,\n", " 3.6399e-01, 4.1190e-01, 5.8751e-01, 6.8494e-01, 7.2698e-01, 8.0101e-01,\n", " 9.2743e-01, 9.7225e-01, 1.1656e+00, 1.4397e+00, 1.4788e+00, 1.6088e+00,\n", " 1.7147e+00, 1.8255e+00, 1.8951e+00, 1.9322e+00, 2.3849e+00, 2.4183e+00,\n", " 2.4902e+00, 2.5441e+00, 2.6483e+00, 2.8371e+00, 2.9705e+00, 3.1383e+00,\n", " 3.2709e+00, 3.3751e+00, 3.3849e+00, 3.4902e+00, 3.5679e+00, 3.5782e+00,\n", " 3.5821e+00, 3.8648e+00, 4.1552e+00, 4.3914e+00, 4.4348e+00, 4.4414e+00,\n", " 4.6038e+00, 4.6099e+00, 4.7481e+00, 4.7528e+00, 4.8005e+00, 4.8134e+00,\n", " 4.8934e+00, 4.9044e+00]),\n", " tensor([[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", " [0.1000, 0.1000, 0.1000, ..., 0.1000, 0.1000, 0.1000],\n", " [0.2000, 0.2000, 0.2000, ..., 0.2000, 0.2000, 0.2000],\n", " ...,\n", " [4.7000, 4.7000, 4.7000, ..., 4.7000, 4.7000, 4.7000],\n", " [4.8000, 4.8000, 4.8000, ..., 4.8000, 4.8000, 4.8000],\n", " [4.9000, 4.9000, 4.9000, ..., 4.9000, 4.9000, 4.9000]]),\n", " tensor([[6.8662e-02, 6.8567e-02, 6.6769e-02, ..., 6.3916e-07, 4.3363e-07,\n", " 4.1091e-07],\n", " [6.4257e-02, 6.4501e-02, 6.3975e-02, ..., 9.6787e-07, 6.6191e-07,\n", " 6.2792e-07],\n", " [5.9935e-02, 6.0473e-02, 6.1093e-02, ..., 1.4608e-06, 1.0070e-06,\n", " 9.5634e-07],\n", " ...,\n", " [9.4878e-07, 1.2079e-06, 2.7906e-06, ..., 5.8779e-02, 5.8062e-02,\n", " 5.7935e-02],\n", " [6.1962e-07, 7.9294e-07, 1.8658e-06, ..., 6.2113e-02, 6.1848e-02,\n", " 6.1781e-02],\n", " [4.0301e-07, 5.1842e-07, 1.2425e-06, ..., 6.5370e-02, 6.5614e-02,\n", " 6.5615e-02]]))" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 10 }, { "metadata": { "ExecuteTime": { "end_time": "2026-05-29T15:19:43.435846223Z", "start_time": "2026-05-29T15:19:43.319732830Z" } }, "cell_type": "code", "source": "plot_kernel_reg(y_hat)", "id": "f76573b2eb74ba7", "outputs": [ { "data": { "text/plain": [ "
" ], "image/svg+xml": "\n\n\n \n \n \n \n 2026-05-29T23:19:43.394543\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" }, "metadata": {}, "output_type": "display_data", "jetTransient": { "display_id": null } } ], "execution_count": 11 }, { "metadata": { "ExecuteTime": { "end_time": "2026-05-29T15:19:43.558539946Z", "start_time": "2026-05-29T15:19:43.438360863Z" } }, "cell_type": "code", "source": [ "d2l.show_heatmaps(attention_weights.unsqueeze(0).unsqueeze(0),\n", " xlabel='Sorted training inputs',\n", " ylabel='Sorted testing inputs')" ], "id": "1bbe3af9e8a0f412", "outputs": [ { "data": { "text/plain": [ "
" ], "image/svg+xml": "\n\n\n \n \n \n \n 2026-05-29T23:19:43.512317\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" }, "metadata": {}, "output_type": "display_data", "jetTransient": { "display_id": null } } ], "execution_count": 12 }, { "metadata": { "ExecuteTime": { "end_time": "2026-05-29T15:19:43.610782245Z", "start_time": "2026-05-29T15:19:43.560607937Z" } }, "cell_type": "code", "source": [ "class NWKernelRegression(nn.Module):\n", " def __init__(self,**kwargs):\n", " super().__init__(**kwargs)\n", " self.w = nn.Parameter(torch.randn((1,),requires_grad=True))\n", " def forward(self,queries,keys,values):\n", " queries = queries.repeat_interleave(keys.shape[1]).reshape(-1,keys.shape[1])\n", " self.attention_weights = nn.functional.softmax(\n", " -((queries - keys) * self.w)**2 / 2, dim=1)\n", " return torch.bmm(self.attention_weights.unsqueeze(1),values.unsqueeze(-1)).reshape(-1)" ], "id": "aeb79a3bd64c826", "outputs": [], "execution_count": 13 }, { "metadata": { "ExecuteTime": { "end_time": "2026-05-29T15:19:43.661276280Z", "start_time": "2026-05-29T15:19:43.611752619Z" } }, "cell_type": "code", "source": [ "# X_tile的形状:(n_train,n_train),每一行都包含着相同的训练输入\n", "X_tile = x_train.repeat((n_train, 1))\n", "# Y_tile的形状:(n_train,n_train),每一行都包含着相同的训练输出\n", "Y_tile = y_train.repeat((n_train, 1))\n", "# keys的形状:('n_train','n_train'-1)\n", "keys = X_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))\n", "# values的形状:('n_train','n_train'-1)\n", "values = Y_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))" ], "id": "7061ed7f139bc78e", "outputs": [], "execution_count": 14 }, { "metadata": { "ExecuteTime": { "end_time": "2026-05-29T15:19:44.080669795Z", "start_time": "2026-05-29T15:19:43.662206275Z" } }, "cell_type": "code", "source": [ "net = NWKernelRegression()\n", "loss = nn.MSELoss(reduction='none')\n", "trainer = torch.optim.SGD(net.parameters(), lr=0.5)\n", "animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[1, 5])\n", "for epoch in range(5):\n", " trainer.zero_grad()\n", " l = loss(net(x_train, keys, values), y_train)\n", " l.sum().backward()\n", " trainer.step()\n", " print(f'epoch {epoch + 1}, loss {float(l.sum()):.6f}')\n", " animator.add(epoch + 1, float(l.sum()))" ], "id": "3dca11d4e962d859", "outputs": [ { "data": { "text/plain": [ "
" ], "image/svg+xml": "\n\n\n \n \n \n \n 2026-05-29T23:19:44.050719\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" }, "metadata": {}, "output_type": "display_data", "jetTransient": { "display_id": null } } ], "execution_count": 15 }, { "metadata": { "ExecuteTime": { "end_time": "2026-05-29T15:19:44.196985129Z", "start_time": "2026-05-29T15:19:44.082458369Z" } }, "cell_type": "code", "source": [ "# keys的形状:(n_test,n_train),每一行包含着相同的训练输入(例如,相同的键)\n", "keys = x_train.repeat((n_test, 1))\n", "# value的形状:(n_test,n_train)\n", "values = y_train.repeat((n_test, 1))\n", "y_hat = net(x_test, keys, values).unsqueeze(1).detach()\n", "plot_kernel_reg(y_hat)" ], "id": "4ff00f0b4983b55e", "outputs": [ { "data": { "text/plain": [ "
" ], "image/svg+xml": "\n\n\n \n \n \n \n 2026-05-29T23:19:44.158025\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" }, "metadata": {}, "output_type": "display_data", "jetTransient": { "display_id": null } } ], "execution_count": 16 }, { "metadata": { "ExecuteTime": { "end_time": "2026-05-29T15:19:44.303651198Z", "start_time": "2026-05-29T15:19:44.199584310Z" } }, "cell_type": "code", "source": [ "d2l.show_heatmaps(net.attention_weights.unsqueeze(0).unsqueeze(0),\n", "xlabel='Sorted training inputs',\n", "ylabel='Sorted testing inputs')" ], "id": "97a690478767c3c3", "outputs": [ { "data": { "text/plain": [ "
" ], "image/svg+xml": "\n\n\n \n \n \n \n 2026-05-29T23:19:44.256961\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" }, "metadata": {}, "output_type": "display_data", "jetTransient": { "display_id": null } } ], "execution_count": 17 }, { "metadata": { "ExecuteTime": { "end_time": "2026-05-29T15:19:44.358514818Z", "start_time": "2026-05-29T15:19:44.306386112Z" } }, "cell_type": "code", "source": [ "def masked_softmax(X,valid_lens):\n", " # X:3D valid_len 1D or 2D\n", " if valid_lens is None:\n", " return nn.functional.softmax(X,dim=-1)\n", " else:\n", " shape = X.shape\n", " if valid_lens.dim()==1:\n", " valid_lens = torch.repeat_interleave(valid_lens,shape[1])\n", " else:\n", " valid_lens = valid_lens.reshape(-1)\n", " #print(valid_lens)\n", " X = d2l.sequence_mask(X.reshape(-1,shape[-1]),valid_lens,value=-1e6)\n", " return nn.functional.softmax(X.reshape(shape),dim=-1)" ], "id": "8d03d0c7f735b755", "outputs": [], "execution_count": 18 }, { "metadata": { "ExecuteTime": { "end_time": "2026-05-29T15:19:44.420710191Z", "start_time": "2026-05-29T15:19:44.359351187Z" } }, "cell_type": "code", "source": "masked_softmax(torch.rand(2,2,4),torch.tensor([2,3]))", "id": "27f8c24ec572b267", "outputs": [ { "data": { "text/plain": [ "tensor([[[0.6937, 0.3063, 0.0000, 0.0000],\n", " [0.6165, 0.3835, 0.0000, 0.0000]],\n", "\n", " [[0.3526, 0.3291, 0.3183, 0.0000],\n", " [0.2735, 0.2359, 0.4906, 0.0000]]])" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 19 }, { "metadata": { "ExecuteTime": { "end_time": "2026-05-29T15:19:44.470391356Z", "start_time": "2026-05-29T15:19:44.422576814Z" } }, "cell_type": "code", "source": [ "class AdditiveAttention(nn.Module):\n", " def __init__(self,key_size,query_size,num_hiddens,dropout,**kwargs):\n", " super(AdditiveAttention,self).__init__(**kwargs)\n", " self.W_k = nn.Linear(key_size,num_hiddens,bias=False)\n", " self.W_q = nn.Linear(query_size,num_hiddens,bias=False)\n", " self.w_v = nn.Linear(num_hiddens,1,bias=False)\n", " self.dropout=nn.Dropout(dropout)\n", " def forward(self,queries,keys,value,valid_lens):\n", " queries,keys=self.W_q(queries),self.W_k(keys)\n", " # queries (batch_size,n_q,1,num_hidden)\n", " # key (batch_size,1,n_k,num_hiddens)\n", " features = queries.unsqueeze(2) + keys.unsqueeze(1)\n", " #features (batch_size,n_q,n_k,num_hidden)\n", " features = torch.tanh(features)\n", " scores = self.w_v(features).squeeze(-1)\n", " #print(f\"Inside AdditiveAttention: value.shape = {value.shape}\") # 检查此处形状\n", " self.attention_weights = masked_softmax(scores, valid_lens)\n", "\n", " return torch.bmm(self.dropout(self.attention_weights), value)" ], "id": "c7e5254395b14b1", "outputs": [], "execution_count": 20 }, { "metadata": { "ExecuteTime": { "end_time": "2026-05-29T15:19:44.536311739Z", "start_time": "2026-05-29T15:19:44.471725853Z" } }, "cell_type": "code", "source": [ "queries,keys = torch.normal(0,1,(2,1,20)) , torch.ones((2,10,2))\n", "values = torch.arange(40,dtype=torch.float32).reshape(1,10,4).repeat(2,1,1)\n", "valid_lens = torch.tensor([2, 6])\n", "attention = AdditiveAttention(key_size=2, query_size=20, num_hiddens=8,\n", " dropout=0.1)\n", "attention.eval()\n", "attention(queries,keys,values,valid_lens)" ], "id": "77e325a032e70d98", "outputs": [ { "data": { "text/plain": [ "tensor([[[ 2.0000, 3.0000, 4.0000, 5.0000]],\n", "\n", " [[10.0000, 11.0000, 12.0000, 13.0000]]], grad_fn=)" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 21 }, { "metadata": { "ExecuteTime": { "end_time": "2026-05-29T15:19:44.618227172Z", "start_time": "2026-05-29T15:19:44.538570102Z" } }, "cell_type": "code", "source": [ "d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),\n", " xlabel='Keys', ylabel='Queries')" ], "id": "93ceedfd6e03982f", "outputs": [ { "data": { "text/plain": [ "
" ], "image/svg+xml": "\n\n\n \n \n \n \n 2026-05-29T23:19:44.586919\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" }, "metadata": {}, "output_type": "display_data", "jetTransient": { "display_id": null } } ], "execution_count": 22 }, { "metadata": { "ExecuteTime": { "end_time": "2026-05-29T15:19:44.668038302Z", "start_time": "2026-05-29T15:19:44.620075434Z" } }, "cell_type": "code", "source": [ "import math\n", "class DotProductAttention(nn.Module):\n", " \"\"\"缩放点积注意力\"\"\"\n", " def __init__(self, dropout, **kwargs):\n", " super(DotProductAttention, self).__init__(**kwargs)\n", " self.dropout = nn.Dropout(dropout)\n", "# queries的形状:(batch_size,查询的个数,d)\n", "# keys的形状:(batch_size,“键-值”对的个数,d)\n", "# values的形状:(batch_size,“键-值”对的个数,值的维度)\n", "# valid_lens的形状:(batch_size,)或者(batch_size,查询的个数)\n", " def forward(self, queries, keys, values, valid_lens=None):\n", " d = queries.shape[-1]\n", " # 设置transpose_b=True为了交换keys的最后两个维度\n", " scores = torch.bmm(queries, keys.transpose(1,2)) / math.sqrt(d)\n", " self.attention_weights = masked_softmax(scores, valid_lens)\n", "\n", " return torch.bmm(self.dropout(self.attention_weights), values)" ], "id": "8453c76623a5b435", "outputs": [], "execution_count": 23 }, { "metadata": { "ExecuteTime": { "end_time": "2026-05-29T15:19:44.754268677Z", "start_time": "2026-05-29T15:19:44.668929892Z" } }, "cell_type": "code", "source": [ "queries = torch.normal(0, 1, (2, 1, 2))\n", "attention = DotProductAttention(dropout=0.5)\n", "attention.eval()\n", "attention(queries, keys, values, valid_lens)" ], "id": "8af6b28944977a62", "outputs": [ { "data": { "text/plain": [ "tensor([[[ 2.0000, 3.0000, 4.0000, 5.0000]],\n", "\n", " [[10.0000, 11.0000, 12.0000, 13.0000]]])" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 24 }, { "metadata": { "ExecuteTime": { "end_time": "2026-05-29T15:19:44.822905172Z", "start_time": "2026-05-29T15:19:44.756082157Z" } }, "cell_type": "code", "source": [ "d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),\n", "xlabel='Keys', ylabel='Queries')" ], "id": "28f7e16771076e33", "outputs": [ { "data": { "text/plain": [ "
" ], "image/svg+xml": "\n\n\n \n \n \n \n 2026-05-29T23:19:44.802392\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" }, "metadata": {}, "output_type": "display_data", "jetTransient": { "display_id": null } } ], "execution_count": 25 }, { "metadata": { "ExecuteTime": { "end_time": "2026-05-29T15:19:44.874393721Z", "start_time": "2026-05-29T15:19:44.824430337Z" } }, "cell_type": "code", "source": [ "class AttentionDecoder(d2l.Decoder):\n", " def __init__(self,**kwargs):\n", " super(AttentionDecoder, self).__init__(**kwargs)\n", " @property\n", " def attention_weight(self):\n", " raise NotImplementedError" ], "id": "33507bd2e1917a29", "outputs": [], "execution_count": 26 }, { "metadata": { "ExecuteTime": { "end_time": "2026-05-29T15:19:44.926627628Z", "start_time": "2026-05-29T15:19:44.875567483Z" } }, "cell_type": "code", "source": [ "class Seq2SeqAttentionDecoder(AttentionDecoder):\n", " def __init__(self,vocab_size,embed_size,num_hiddens,num_layers,dropout=0,**kwargs):\n", " super(Seq2SeqAttentionDecoder, self).__init__(**kwargs)\n", " self.attention = AdditiveAttention(\n", " num_hiddens, num_hiddens, num_hiddens,dropout)\n", " self.embedding = nn.Embedding(vocab_size, embed_size)\n", " self.rnn = nn.GRU(\n", " embed_size + num_hiddens, num_hiddens, num_layers,\n", " dropout=dropout)\n", " self.dense = nn.Linear(num_hiddens, vocab_size)\n", " def init_state(self, enc_outputs, enc_valid_lens, *args):\n", "# outputs的形状为(batch_size,num_steps,num_hiddens).\n", "# hidden_state的形状为(num_layers,batch_size,num_hiddens)\n", " outputs, hidden_state = enc_outputs\n", " #print(f\"Encoder outputs shape before permute: {outputs.shape}\") # 应为 (num_steps, batch_size, num_hiddens) 或 (batch_size, num_steps, num_hiddens)\n", " enc_outputs_permuted = outputs.permute(1, 0, 2)\n", " #print(f\"After permute: {enc_outputs_permuted.shape}\") # 期望 (batch_size, num_steps, num_hiddens)\n", " return (enc_outputs_permuted, hidden_state, enc_valid_lens)\n", " def forward(self, X, state):\n", " # enc_outputs的形状为(batch_size,num_steps,num_hiddens).\n", " # hidden_state的形状为(num_layers,batch_size,\n", " # num_hiddens)\n", " enc_outputs, hidden_state, enc_valid_lens = state\n", " # 输出X的形状为(num_steps,batch_size,embed_size)\n", " X = self.embedding(X).permute(1, 0, 2)\n", " outputs, self._attention_weights = [], []\n", " for x in X:\n", " # query的形状为(batch_size,1,num_hiddens)\n", " query = torch.unsqueeze(hidden_state[-1], dim=1)\n", " # context的形状为(batch_size,1,num_hiddens)\n", " #print(f\"values shape before attention: {enc_outputs.shape}\") # 应为 (4, 7, 16)\n", " context = self.attention(\n", " query, enc_outputs, enc_outputs, enc_valid_lens)\n", " # 在特征维度上连结\n", " x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)\n", " # 将x变形为(1,batch_size,embed_size+num_hiddens)\n", " out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)\n", " outputs.append(out)\n", " self._attention_weights.append(self.attention.attention_weights)\n", " # 全连接层变换后,outputs的形状为\n", " # (num_steps,batch_size,vocab_size)\n", " outputs = self.dense(torch.cat(outputs, dim=0))\n", " return outputs.permute(1, 0, 2), [enc_outputs, hidden_state,\n", " enc_valid_lens]\n", " @property\n", " def attention_weights(self):\n", " return self._attention_weights" ], "id": "a8a497c9041910a1", "outputs": [], "execution_count": 27 }, { "metadata": { "ExecuteTime": { "end_time": "2026-05-29T15:19:44.996664037Z", "start_time": "2026-05-29T15:19:44.927689328Z" } }, "cell_type": "code", "source": [ "encoder = d2l.Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16,\n", "num_layers=2)\n", "encoder.eval()\n", "decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8, num_hiddens=16,\n", "num_layers=2)\n", "decoder.eval()\n", "X = torch.zeros((4, 7), dtype=torch.long) # (batch_size,num_steps)\n", "state = decoder.init_state(encoder(X), None)\n", "output, state = decoder(X, state)\n", "output.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape" ], "id": "b2a6bc735743bd0f", "outputs": [ { "data": { "text/plain": [ "(torch.Size([4, 7, 10]), 3, torch.Size([4, 7, 16]), 2, torch.Size([4, 16]))" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 28 }, { "metadata": { "ExecuteTime": { "end_time": "2026-05-29T15:19:45.047118759Z", "start_time": "2026-05-29T15:19:44.998053092Z" } }, "cell_type": "code", "source": "", "id": "4c7bf86095f15c98", "outputs": [], "execution_count": 28 }, { "metadata": { "ExecuteTime": { "end_time": "2026-05-29T15:20:14.678359202Z", "start_time": "2026-05-29T15:19:45.048442204Z" } }, "cell_type": "code", "source": [ "embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1\n", "batch_size, num_steps = 64, 10\n", "lr, num_epochs, device = 0.005, 250, d2l.try_gpu()\n", "train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)\n", "encoder = d2l.Seq2SeqEncoder(\n", "len(src_vocab), embed_size, num_hiddens, num_layers, dropout)\n", "decoder = Seq2SeqAttentionDecoder(\n", "len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)\n", "net = d2l.EncoderDecoder(encoder, decoder)\n", "d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)" ], "id": "5f5cedf74b97bd12", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "loss 0.019, 17862.0 tokens/sec on cuda:0\n" ] }, { "data": { "text/plain": [ "
" ], "image/svg+xml": "\n\n\n \n \n \n \n 2026-05-29T23:20:14.649299\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" }, "metadata": {}, "output_type": "display_data", "jetTransient": { "display_id": null } } ], "execution_count": 29 }, { "metadata": { "ExecuteTime": { "end_time": "2026-05-29T15:20:14.794039857Z", "start_time": "2026-05-29T15:20:14.727727180Z" } }, "cell_type": "code", "source": [ "engs = ['go .', \"i lost .\", 'he\\'s calm .', 'i\\'m home .']\n", "fras = ['va !', 'j\\'ai perdu .', 'il est calme .', 'je suis chez moi .']\n", "for eng, fra in zip(engs, fras):\n", " translation, dec_attention_weight_seq = d2l.predict_seq2seq(\n", " net, eng, src_vocab, tgt_vocab, num_steps, device, True)\n", " print(f'{eng} => {translation}, ',\n", " f'bleu {d2l.bleu(translation, fra, k=2):.3f}')" ], "id": "ddaed8bd284ee01b", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "go . => va !, bleu 1.000\n", "i lost . => j'ai perdu ., bleu 1.000\n", "he's calm . => il est ., bleu 0.658\n", "i'm home . => je suis chez moi ., bleu 1.000\n" ] } ], "execution_count": 30 }, { "metadata": { "ExecuteTime": { "end_time": "2026-05-29T15:20:38.342057516Z", "start_time": "2026-05-29T15:20:38.288567783Z" } }, "cell_type": "code", "source": [ "class MultiHeadAttention(nn.Module):\n", " def __init__(self, key_size, query_size, value_size, num_hiddens,\n", " num_heads, dropout, bias=False, **kwargs):\n", " super(MultiHeadAttention, self).__init__(**kwargs)\n", " self.num_heads = num_heads\n", " self.attention = d2l.DotProductAttention(dropout)\n", " self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)\n", " self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)\n", " self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)\n", " self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)\n", " def forward(self, queries, keys, values, valid_lens):\n", " # 1. 线性投影 + 变换形状以分割多头\n", " queries = transpose_qkv(self.W_q(queries), self.num_heads)\n", " keys = transpose_qkv(self.W_k(keys), self.num_heads)\n", " values = transpose_qkv(self.W_v(values), self.num_heads)\n", "\n", " # 2. 处理有效长度掩码(valid_lens)以适配多头\n", " if valid_lens is not None:\n", " valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)\n", "\n", " # 3. 计算注意力(每个头独立计算)\n", " output = self.attention(queries, keys, values, valid_lens)\n", "\n", " # 4. 合并多头,并通过输出线性层\n", " output_concat = transpose_output(output, self.num_heads)\n", " return self.W_o(output_concat)\n", "def transpose_qkv(X, num_heads):\n", " \"\"\"为了多注意力头的并行计算而变换形状\"\"\"\n", " # 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens)\n", " # 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,\n", " # num_hiddens/num_heads)\n", " X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)\n", " # 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数,\n", " # num_hiddens/num_heads)\n", " X = X.permute(0, 2, 1, 3)\n", " # 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数,\n", " # num_hiddens/num_heads)\n", " return X.reshape(-1, X.shape[2], X.shape[3])\n", "\n", "def transpose_output(X, num_heads):\n", " \"\"\"逆转transpose_qkv函数的操作\"\"\"\n", " X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])\n", " X = X.permute(0, 2, 1, 3)\n", " return X.reshape(X.shape[0], X.shape[1], -1)\n" ], "id": "13d19f8b5048c74d", "outputs": [], "execution_count": 32 }, { "metadata": { "ExecuteTime": { "end_time": "2026-05-29T15:20:39.251763154Z", "start_time": "2026-05-29T15:20:39.201715213Z" } }, "cell_type": "code", "source": [ "num_hiddens, num_heads = 100, 5\n", "attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,\n", "num_hiddens, num_heads, 0.5)\n", "attention.eval()" ], "id": "b25c661a511d7763", "outputs": [ { "data": { "text/plain": [ "MultiHeadAttention(\n", " (attention): DotProductAttention(\n", " (dropout): Dropout(p=0.5, inplace=False)\n", " )\n", " (W_q): Linear(in_features=100, out_features=100, bias=False)\n", " (W_k): Linear(in_features=100, out_features=100, bias=False)\n", " (W_v): Linear(in_features=100, out_features=100, bias=False)\n", " (W_o): Linear(in_features=100, out_features=100, bias=False)\n", ")" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 33 }, { "metadata": { "ExecuteTime": { "end_time": "2026-05-29T15:20:40.455647106Z", "start_time": "2026-05-29T15:20:40.307682855Z" } }, "cell_type": "code", "source": [ "batch_size, num_queries = 2, 4\n", "num_kvpairs, valid_lens = 6, torch.tensor([3, 2])\n", "X = torch.ones((batch_size, num_queries, num_hiddens))\n", "Y = torch.ones((batch_size, num_kvpairs, num_hiddens))\n", "attention(X, Y, Y, valid_lens).shape" ], "id": "49877c929741ec94", "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 4, 100])" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 34 }, { "metadata": {}, "cell_type": "code", "outputs": [], "execution_count": null, "source": "", "id": "b8a10576c7789750" } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 5 }