{ "cells": [ { "cell_type": "code", "id": "initial_id", "metadata": { "ExecuteTime": { "end_time": "2026-06-30T02:59:03.101034703Z", "start_time": "2026-06-30T02:58:59.676757575Z" } }, "source": [ "import torch\n", "from d2l import torch as d2l\n" ], "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/yukun/.conda/envs/nn/lib/python3.11/site-packages/torch/cuda/__init__.py:1007: UserWarning: Can't initialize NVML\n", " raw_cnt = _raw_device_count_nvml()\n" ] } ], "execution_count": 1 }, { "cell_type": "code", "id": "a0e3f725b7764f08", "metadata": { "ExecuteTime": { "end_time": "2026-06-30T02:59:03.137220451Z", "start_time": "2026-06-30T02:59:03.118596105Z" } }, "source": [ "def show_heatmaps(matrices,xlabel,ylabel,titles=None,figsize=(2.5,2.5),cmap='Reds'):\n", " d2l.use_svg_display()\n", " num_rows,num_cols = matrices.shape[0],matrices.shape[1]\n", " fig,axes = d2l.plt.subplots(num_rows,num_cols,figsize=figsize,sharex=True,squeeze=False)\n", " for i,(row_axes,row_matrices) in enumerate(zip(axes,matrices)):\n", " for j,(ax,matrix) in enumerate(zip(row_axes,row_matrices)):\n", " pcm = ax.imshow(matrix.detach().numpy(),cmap=cmap)\n", " if i == num_rows - 1:\n", " ax.set_xlabel(xlabel)\n", " if j == 0:\n", " ax.set_ylabel(ylabel)\n", " if titles:\n", " ax.set_title(titles[j])\n", " fig.colorbar(pcm,ax=axes,shrink=0.6)" ], "outputs": [], "execution_count": 2 }, { "cell_type": "code", "id": "bb798de34fc4d0fa", "metadata": { "ExecuteTime": { "end_time": "2026-06-30T02:59:03.336399820Z", "start_time": "2026-06-30T02:59:03.139187972Z" } }, "source": [ "attention_weights = torch.eye(10).reshape((1, 1, 10, 10))\n", "\n", "show_heatmaps(attention_weights, xlabel='Keys', ylabel='Queries')" ], "outputs": [ { "data": { "text/plain": [ "
" ], "image/svg+xml": "\n\n\n \n \n \n \n 2026-06-30T10:59:03.270032\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" }, "metadata": {}, "output_type": "display_data", "jetTransient": { "display_id": null } } ], "execution_count": 3 }, { "cell_type": "code", "id": "a4070c75847fb887", "metadata": { "ExecuteTime": { "end_time": "2026-06-30T02:59:03.393286045Z", "start_time": "2026-06-30T02:59:03.339609950Z" } }, "source": [ "n_train = 50\n", "x_train,_ = torch.sort(torch.rand(n_train)*5)" ], "outputs": [], "execution_count": 4 }, { "cell_type": "code", "id": "2aca6952876d9cf5", "metadata": { "ExecuteTime": { "end_time": "2026-06-30T02:59:03.465395938Z", "start_time": "2026-06-30T02:59:03.397781049Z" } }, "source": [ "def f(x):\n", " return 2*torch.sin(x)+x**0.8" ], "outputs": [], "execution_count": 5 }, { "cell_type": "code", "id": "ea7b94585dcde934", "metadata": { "ExecuteTime": { "end_time": "2026-06-30T02:59:03.554188488Z", "start_time": "2026-06-30T02:59:03.467856816Z" } }, "source": [ "y_train = f(x_train)+torch.normal(0.0,0.5,(n_train,))\n", "x_test = torch.arange(0,5,0.1)\n", "y_truth = f(x_test)\n", "n_test = len(x_test)\n", "n_test" ], "outputs": [ { "data": { "text/plain": [ "50" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 6 }, { "cell_type": "code", "id": "35c2403bbb5509cf", "metadata": { "ExecuteTime": { "end_time": "2026-06-30T02:59:03.578849203Z", "start_time": "2026-06-30T02:59:03.563506811Z" } }, "source": [ "def plot_kernel_reg(y_hat):\n", " d2l.plot(x_test, [y_truth, y_hat], 'x', 'y', legend=['Truth', 'Pred'],\n", " xlim=[0, 5], ylim=[-1, 5])\n", " d2l.plt.plot(x_train, y_train, 'o', alpha=0.5);" ], "outputs": [], "execution_count": 7 }, { "cell_type": "code", "id": "f27ecb9c68aed894", "metadata": { "ExecuteTime": { "end_time": "2026-06-30T02:59:03.728510578Z", "start_time": "2026-06-30T02:59:03.580633476Z" } }, "source": [ "y_hat = torch.repeat_interleave(y_train.mean(),n_test)\n", "plot_kernel_reg(y_hat)" ], "outputs": [ { "data": { "text/plain": [ "
" ], "image/svg+xml": "\n\n\n \n \n \n \n 2026-06-30T10:59:03.678092\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" }, "metadata": {}, "output_type": "display_data", "jetTransient": { "display_id": null } } ], "execution_count": 8 }, { "cell_type": "code", "id": "fa7579d7cd8bfcb7", "metadata": { "ExecuteTime": { "end_time": "2026-06-30T02:59:03.807801559Z", "start_time": "2026-06-30T02:59:03.732916546Z" } }, "source": [ "from torch import nn\n", "X_repeat = x_test.repeat_interleave(n_train).reshape((-1, n_train))\n", "attention_weights = nn.functional.softmax(-(X_repeat - x_train)**2 / 2, dim=1)\n", "y_hat = torch.matmul(attention_weights,y_train)\n", "x_train,X_repeat,attention_weights" ], "outputs": [ { "data": { "text/plain": [ "(tensor([0.0153, 0.1180, 0.1875, 0.4215, 0.5745, 0.6163, 0.6557, 0.6917, 0.7939,\n", " 0.9230, 1.0240, 1.2932, 1.3246, 1.5167, 1.6193, 1.6496, 1.6613, 1.7104,\n", " 1.7485, 1.7840, 1.7889, 2.0187, 2.0398, 2.0413, 2.2665, 2.4287, 2.4440,\n", " 2.6026, 2.6988, 2.7922, 2.8129, 2.9356, 3.0271, 3.0442, 3.0646, 3.1996,\n", " 3.2037, 3.3222, 3.4918, 3.5313, 3.7002, 3.7714, 3.7949, 3.9609, 4.3272,\n", " 4.4440, 4.6667, 4.7213, 4.8963, 4.9971]),\n", " tensor([[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", " [0.1000, 0.1000, 0.1000, ..., 0.1000, 0.1000, 0.1000],\n", " [0.2000, 0.2000, 0.2000, ..., 0.2000, 0.2000, 0.2000],\n", " ...,\n", " [4.7000, 4.7000, 4.7000, ..., 4.7000, 4.7000, 4.7000],\n", " [4.8000, 4.8000, 4.8000, ..., 4.8000, 4.8000, 4.8000],\n", " [4.9000, 4.9000, 4.9000, ..., 4.9000, 4.9000, 4.9000]]),\n", " tensor([[7.9004e-02, 7.8465e-02, 7.7637e-02, ..., 1.1413e-06, 4.9195e-07,\n", " 2.9875e-07],\n", " [7.2616e-02, 7.2865e-02, 7.2599e-02, ..., 1.6794e-06, 7.3669e-07,\n", " 4.5190e-07],\n", " [6.6459e-02, 6.7376e-02, 6.7597e-02, ..., 2.4606e-06, 1.0985e-06,\n", " 6.8065e-07],\n", " ...,\n", " [1.3743e-06, 2.2120e-06, 3.0334e-06, ..., 8.0086e-02, 7.8576e-02,\n", " 7.6646e-02],\n", " [9.2161e-07, 1.4986e-06, 2.0694e-06, ..., 8.5977e-02, 8.5846e-02,\n", " 8.4585e-02],\n", " [6.1460e-07, 1.0097e-06, 1.4040e-06, ..., 9.1792e-02, 9.3269e-02,\n", " 9.2831e-02]]))" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 9 }, { "cell_type": "code", "id": "f76573b2eb74ba7", "metadata": { "ExecuteTime": { "end_time": "2026-06-30T02:59:03.923768674Z", "start_time": "2026-06-30T02:59:03.813863850Z" } }, "source": [ "plot_kernel_reg(y_hat)" ], "outputs": [ { "data": { "text/plain": [ "
" ], "image/svg+xml": "\n\n\n \n \n \n \n 2026-06-30T10:59:03.882515\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" }, "metadata": {}, "output_type": "display_data", "jetTransient": { "display_id": null } } ], "execution_count": 10 }, { "cell_type": "code", "id": "1bbe3af9e8a0f412", "metadata": { "ExecuteTime": { "end_time": "2026-06-30T02:59:04.041592329Z", "start_time": "2026-06-30T02:59:03.938730755Z" } }, "source": [ "d2l.show_heatmaps(attention_weights.unsqueeze(0).unsqueeze(0),\n", " xlabel='Sorted training inputs',\n", " ylabel='Sorted testing inputs')" ], "outputs": [ { "data": { "text/plain": [ "
" ], "image/svg+xml": "\n\n\n \n \n \n \n 2026-06-30T10:59:04.008946\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" }, "metadata": {}, "output_type": "display_data", "jetTransient": { "display_id": null } } ], "execution_count": 11 }, { "cell_type": "code", "id": "aeb79a3bd64c826", "metadata": { "ExecuteTime": { "end_time": "2026-06-30T02:59:04.105924659Z", "start_time": "2026-06-30T02:59:04.055849961Z" } }, "source": [ "class NWKernelRegression(nn.Module):\n", " def __init__(self,**kwargs):\n", " super().__init__(**kwargs)\n", " self.w = nn.Parameter(torch.randn((1,),requires_grad=True))\n", " def forward(self,queries,keys,values):\n", " queries = queries.repeat_interleave(keys.shape[1]).reshape(-1,keys.shape[1])\n", " self.attention_weights = nn.functional.softmax(\n", " -((queries - keys) * self.w)**2 / 2, dim=1)\n", " return torch.bmm(self.attention_weights.unsqueeze(1),values.unsqueeze(-1)).reshape(-1)" ], "outputs": [], "execution_count": 12 }, { "cell_type": "code", "id": "7061ed7f139bc78e", "metadata": { "ExecuteTime": { "end_time": "2026-06-30T02:59:04.158366434Z", "start_time": "2026-06-30T02:59:04.107607867Z" } }, "source": [ "# X_tile的形状:(n_train,n_train),每一行都包含着相同的训练输入\n", "X_tile = x_train.repeat((n_train, 1))\n", "# Y_tile的形状:(n_train,n_train),每一行都包含着相同的训练输出\n", "Y_tile = y_train.repeat((n_train, 1))\n", "# keys的形状:('n_train','n_train'-1)\n", "keys = X_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))\n", "# values的形状:('n_train','n_train'-1)\n", "values = Y_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))" ], "outputs": [], "execution_count": 13 }, { "cell_type": "code", "id": "3dca11d4e962d859", "metadata": { "ExecuteTime": { "end_time": "2026-06-30T02:59:04.750466831Z", "start_time": "2026-06-30T02:59:04.162845978Z" } }, "source": [ "net = NWKernelRegression()\n", "loss = nn.MSELoss(reduction='none')\n", "trainer = torch.optim.SGD(net.parameters(), lr=0.5)\n", "animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[1, 5])\n", "for epoch in range(5):\n", " trainer.zero_grad()\n", " l = loss(net(x_train, keys, values), y_train)\n", " l.sum().backward()\n", " trainer.step()\n", " print(f'epoch {epoch + 1}, loss {float(l.sum()):.6f}')\n", " animator.add(epoch + 1, float(l.sum()))" ], "outputs": [ { "data": { "text/plain": [ "
" ], "image/svg+xml": "\n\n\n \n \n \n \n 2026-06-30T10:59:04.709978\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" }, "metadata": {}, "output_type": "display_data", "jetTransient": { "display_id": null } } ], "execution_count": 14 }, { "cell_type": "code", "id": "4ff00f0b4983b55e", "metadata": { "ExecuteTime": { "end_time": "2026-06-30T02:59:04.976760199Z", "start_time": "2026-06-30T02:59:04.799699382Z" } }, "source": [ "# keys的形状:(n_test,n_train),每一行包含着相同的训练输入(例如,相同的键)\n", "keys = x_train.repeat((n_test, 1))\n", "# value的形状:(n_test,n_train)\n", "values = y_train.repeat((n_test, 1))\n", "y_hat = net(x_test, keys, values).unsqueeze(1).detach()\n", "plot_kernel_reg(y_hat)" ], "outputs": [ { "data": { "text/plain": [ "
" ], "image/svg+xml": "\n\n\n \n \n \n \n 2026-06-30T10:59:04.872782\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" }, "metadata": {}, "output_type": "display_data", "jetTransient": { "display_id": null } } ], "execution_count": 15 }, { "cell_type": "code", "id": "97a690478767c3c3", "metadata": { "ExecuteTime": { "end_time": "2026-06-30T02:59:05.088839245Z", "start_time": "2026-06-30T02:59:04.978734877Z" } }, "source": [ "d2l.show_heatmaps(net.attention_weights.unsqueeze(0).unsqueeze(0),\n", "xlabel='Sorted training inputs',\n", "ylabel='Sorted testing inputs')" ], "outputs": [ { "data": { "text/plain": [ "
" ], "image/svg+xml": "\n\n\n \n \n \n \n 2026-06-30T10:59:05.049223\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" }, "metadata": {}, "output_type": "display_data", "jetTransient": { "display_id": null } } ], "execution_count": 16 }, { "cell_type": "code", "id": "8d03d0c7f735b755", "metadata": { "ExecuteTime": { "end_time": "2026-06-30T02:59:05.142764987Z", "start_time": "2026-06-30T02:59:05.093810141Z" } }, "source": [ "def masked_softmax(X,valid_lens):\n", " # X:3D valid_len 1D or 2D\n", " if valid_lens is None:\n", " return nn.functional.softmax(X,dim=-1)\n", " else:\n", " shape = X.shape\n", " if valid_lens.dim()==1:\n", " valid_lens = torch.repeat_interleave(valid_lens,shape[1])\n", " else:\n", " valid_lens = valid_lens.reshape(-1)\n", " #print(valid_lens)\n", " X = d2l.sequence_mask(X.reshape(-1,shape[-1]),valid_lens,value=-1e6)\n", " return nn.functional.softmax(X.reshape(shape),dim=-1)" ], "outputs": [], "execution_count": 17 }, { "cell_type": "code", "id": "27f8c24ec572b267", "metadata": { "ExecuteTime": { "end_time": "2026-06-30T02:59:05.208569156Z", "start_time": "2026-06-30T02:59:05.144491160Z" } }, "source": [ "masked_softmax(torch.rand(2,2,4),torch.tensor([2,3]))" ], "outputs": [ { "data": { "text/plain": [ "tensor([[[0.5584, 0.4416, 0.0000, 0.0000],\n", " [0.6205, 0.3795, 0.0000, 0.0000]],\n", "\n", " [[0.4421, 0.3398, 0.2181, 0.0000],\n", " [0.3929, 0.3625, 0.2446, 0.0000]]])" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 18 }, { "cell_type": "code", "id": "c7e5254395b14b1", "metadata": { "ExecuteTime": { "end_time": "2026-06-30T02:59:05.262719443Z", "start_time": "2026-06-30T02:59:05.210360072Z" } }, "source": [ "class AdditiveAttention(nn.Module):\n", " def __init__(self,key_size,query_size,num_hiddens,dropout,**kwargs):\n", " super(AdditiveAttention,self).__init__(**kwargs)\n", " self.W_k = nn.Linear(key_size,num_hiddens,bias=False)\n", " self.W_q = nn.Linear(query_size,num_hiddens,bias=False)\n", " self.w_v = nn.Linear(num_hiddens,1,bias=False)\n", " self.dropout=nn.Dropout(dropout)\n", " def forward(self,queries,keys,value,valid_lens):\n", " queries,keys=self.W_q(queries),self.W_k(keys)\n", " # queries (batch_size,n_q,1,num_hidden)\n", " # key (batch_size,1,n_k,num_hiddens)\n", " features = queries.unsqueeze(2) + keys.unsqueeze(1)\n", " #features (batch_size,n_q,n_k,num_hidden)\n", " features = torch.tanh(features)\n", " scores = self.w_v(features).squeeze(-1)\n", " #print(f\"Inside AdditiveAttention: value.shape = {value.shape}\") # 检查此处形状\n", " self.attention_weights = masked_softmax(scores, valid_lens)\n", "\n", " return torch.bmm(self.dropout(self.attention_weights), value)" ], "outputs": [], "execution_count": 19 }, { "cell_type": "code", "id": "77e325a032e70d98", "metadata": { "ExecuteTime": { "end_time": "2026-06-30T02:59:05.332247710Z", "start_time": "2026-06-30T02:59:05.264012533Z" } }, "source": [ "queries,keys = torch.normal(0,1,(2,1,20)) , torch.ones((2,10,2))\n", "values = torch.arange(40,dtype=torch.float32).reshape(1,10,4).repeat(2,1,1)\n", "valid_lens = torch.tensor([2, 6])\n", "attention = AdditiveAttention(key_size=2, query_size=20, num_hiddens=8,\n", " dropout=0.1)\n", "attention.eval()\n", "attention(queries,keys,values,valid_lens)" ], "outputs": [ { "data": { "text/plain": [ "tensor([[[ 2.0000, 3.0000, 4.0000, 5.0000]],\n", "\n", " [[10.0000, 11.0000, 12.0000, 13.0000]]], grad_fn=)" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 20 }, { "cell_type": "code", "id": "93ceedfd6e03982f", "metadata": { "ExecuteTime": { "end_time": "2026-06-30T02:59:05.426180301Z", "start_time": "2026-06-30T02:59:05.338806296Z" } }, "source": [ "d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),\n", " xlabel='Keys', ylabel='Queries')" ], "outputs": [ { "data": { "text/plain": [ "
" ], "image/svg+xml": "\n\n\n \n \n \n \n 2026-06-30T10:59:05.393045\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" }, "metadata": {}, "output_type": "display_data", "jetTransient": { "display_id": null } } ], "execution_count": 21 }, { "cell_type": "code", "id": "8453c76623a5b435", "metadata": { "ExecuteTime": { "end_time": "2026-06-30T02:59:05.479115341Z", "start_time": "2026-06-30T02:59:05.429811516Z" } }, "source": [ "import math\n", "class DotProductAttention(nn.Module):\n", " \"\"\"缩放点积注意力\"\"\"\n", " def __init__(self, dropout, **kwargs):\n", " super(DotProductAttention, self).__init__(**kwargs)\n", " self.dropout = nn.Dropout(dropout)\n", "# queries的形状:(batch_size,查询的个数,d)\n", "# keys的形状:(batch_size,“键-值”对的个数,d)\n", "# values的形状:(batch_size,“键-值”对的个数,值的维度)\n", "# valid_lens的形状:(batch_size,)或者(batch_size,查询的个数)\n", " def forward(self, queries, keys, values, valid_lens=None):\n", " d = queries.shape[-1]\n", " # 设置transpose_b=True为了交换keys的最后两个维度\n", " scores = torch.bmm(queries, keys.transpose(1,2)) / math.sqrt(d)\n", " self.attention_weights = masked_softmax(scores, valid_lens)\n", "\n", " return torch.bmm(self.dropout(self.attention_weights), values)" ], "outputs": [], "execution_count": 22 }, { "cell_type": "code", "id": "8af6b28944977a62", "metadata": { "ExecuteTime": { "end_time": "2026-06-30T02:59:05.547775309Z", "start_time": "2026-06-30T02:59:05.480327680Z" } }, "source": [ "queries = torch.normal(0, 1, (2, 1, 2))\n", "attention = DotProductAttention(dropout=0.5)\n", "attention.eval()\n", "attention(queries, keys, values, valid_lens)" ], "outputs": [ { "data": { "text/plain": [ "tensor([[[ 2.0000, 3.0000, 4.0000, 5.0000]],\n", "\n", " [[10.0000, 11.0000, 12.0000, 13.0000]]])" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 23 }, { "cell_type": "code", "id": "28f7e16771076e33", "metadata": { "ExecuteTime": { "end_time": "2026-06-30T02:59:05.663459211Z", "start_time": "2026-06-30T02:59:05.549703930Z" } }, "source": [ "d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),\n", "xlabel='Keys', ylabel='Queries')" ], "outputs": [ { "data": { "text/plain": [ "
" ], "image/svg+xml": "\n\n\n \n \n \n \n 2026-06-30T10:59:05.624810\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" }, "metadata": {}, "output_type": "display_data", "jetTransient": { "display_id": null } } ], "execution_count": 24 }, { "cell_type": "code", "id": "33507bd2e1917a29", "metadata": { "ExecuteTime": { "end_time": "2026-06-30T02:59:05.730333551Z", "start_time": "2026-06-30T02:59:05.680995441Z" } }, "source": [ "class AttentionDecoder(d2l.Decoder):\n", " def __init__(self,**kwargs):\n", " super(AttentionDecoder, self).__init__(**kwargs)\n", " @property\n", " def attention_weight(self):\n", " raise NotImplementedError" ], "outputs": [], "execution_count": 25 }, { "cell_type": "code", "id": "a8a497c9041910a1", "metadata": { "ExecuteTime": { "end_time": "2026-06-30T02:59:05.781440228Z", "start_time": "2026-06-30T02:59:05.731688196Z" } }, "source": [ "class Seq2SeqAttentionDecoder(AttentionDecoder):\n", " def __init__(self,vocab_size,embed_size,num_hiddens,num_layers,dropout=0,**kwargs):\n", " super(Seq2SeqAttentionDecoder, self).__init__(**kwargs)\n", " self.attention = AdditiveAttention(\n", " num_hiddens, num_hiddens, num_hiddens,dropout)\n", " self.embedding = nn.Embedding(vocab_size, embed_size)\n", " self.rnn = nn.GRU(\n", " embed_size + num_hiddens, num_hiddens, num_layers,\n", " dropout=dropout)\n", " self.dense = nn.Linear(num_hiddens, vocab_size)\n", " def init_state(self, enc_outputs, enc_valid_lens, *args):\n", "# outputs的形状为(batch_size,num_steps,num_hiddens).\n", "# hidden_state的形状为(num_layers,batch_size,num_hiddens)\n", " outputs, hidden_state = enc_outputs\n", " #print(f\"Encoder outputs shape before permute: {outputs.shape}\") # 应为 (num_steps, batch_size, num_hiddens) 或 (batch_size, num_steps, num_hiddens)\n", " enc_outputs_permuted = outputs.permute(1, 0, 2)\n", " #print(f\"After permute: {enc_outputs_permuted.shape}\") # 期望 (batch_size, num_steps, num_hiddens)\n", " return (enc_outputs_permuted, hidden_state, enc_valid_lens)\n", " def forward(self, X, state):\n", " # enc_outputs的形状为(batch_size,num_steps,num_hiddens).\n", " # hidden_state的形状为(num_layers,batch_size,\n", " # num_hiddens)\n", " enc_outputs, hidden_state, enc_valid_lens = state\n", " # 输出X的形状为(num_steps,batch_size,embed_size)\n", " X = self.embedding(X).permute(1, 0, 2)\n", " outputs, self._attention_weights = [], []\n", " for x in X:\n", " # query的形状为(batch_size,1,num_hiddens)\n", " query = torch.unsqueeze(hidden_state[-1], dim=1)\n", " # context的形状为(batch_size,1,num_hiddens)\n", " #print(f\"values shape before attention: {enc_outputs.shape}\") # 应为 (4, 7, 16)\n", " context = self.attention(\n", " query, enc_outputs, enc_outputs, enc_valid_lens)\n", " # 在特征维度上连结\n", " x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)\n", " # 将x变形为(1,batch_size,embed_size+num_hiddens)\n", " out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)\n", " outputs.append(out)\n", " self._attention_weights.append(self.attention.attention_weights)\n", " # 全连接层变换后,outputs的形状为\n", " # (num_steps,batch_size,vocab_size)\n", " outputs = self.dense(torch.cat(outputs, dim=0))\n", " return outputs.permute(1, 0, 2), [enc_outputs, hidden_state,\n", " enc_valid_lens]\n", " @property\n", " def attention_weights(self):\n", " return self._attention_weights" ], "outputs": [], "execution_count": 26 }, { "cell_type": "code", "id": "b2a6bc735743bd0f", "metadata": { "ExecuteTime": { "end_time": "2026-06-30T02:59:05.842620719Z", "start_time": "2026-06-30T02:59:05.782837732Z" } }, "source": [ "encoder = d2l.Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16,\n", "num_layers=2)\n", "encoder.eval()\n", "decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8, num_hiddens=16,\n", "num_layers=2)\n", "decoder.eval()\n", "X = torch.zeros((4, 7), dtype=torch.long) # (batch_size,num_steps)\n", "state = decoder.init_state(encoder(X), None)\n", "output, state = decoder(X, state)\n", "output.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape" ], "outputs": [ { "data": { "text/plain": [ "(torch.Size([4, 7, 10]), 3, torch.Size([4, 7, 16]), 2, torch.Size([4, 16]))" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 27 }, { "cell_type": "code", "id": "4c7bf86095f15c98", "metadata": { "ExecuteTime": { "end_time": "2026-06-30T02:59:05.903307506Z", "start_time": "2026-06-30T02:59:05.854726928Z" } }, "source": [], "outputs": [], "execution_count": 27 }, { "cell_type": "code", "id": "5f5cedf74b97bd12", "metadata": { "ExecuteTime": { "end_time": "2026-06-30T03:22:39.836556973Z", "start_time": "2026-06-30T02:59:05.904535279Z" } }, "source": [ "embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1\n", "batch_size, num_steps = 64, 10\n", "lr, num_epochs, device = 0.005, 250, d2l.try_gpu()\n", "train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)\n", "encoder = d2l.Seq2SeqEncoder(\n", "len(src_vocab), embed_size, num_hiddens, num_layers, dropout)\n", "decoder = Seq2SeqAttentionDecoder(\n", "len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)\n", "net = d2l.EncoderDecoder(encoder, decoder)\n", "d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)" ], "outputs": [ { "ename": "KeyboardInterrupt", "evalue": "", "output_type": "error", "traceback": [ "\u001B[31m---------------------------------------------------------------------------\u001B[39m", "\u001B[31mKeyboardInterrupt\u001B[39m Traceback (most recent call last)", "\u001B[36mCell\u001B[39m\u001B[36m \u001B[39m\u001B[32mIn[28]\u001B[39m\u001B[32m, line 10\u001B[39m\n\u001B[32m 7\u001B[39m decoder = Seq2SeqAttentionDecoder(\n\u001B[32m 8\u001B[39m \u001B[38;5;28mlen\u001B[39m(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)\n\u001B[32m 9\u001B[39m net = d2l.EncoderDecoder(encoder, decoder)\n\u001B[32m---> \u001B[39m\u001B[32m10\u001B[39m \u001B[43md2l\u001B[49m\u001B[43m.\u001B[49m\u001B[43mtrain_seq2seq\u001B[49m\u001B[43m(\u001B[49m\u001B[43mnet\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mtrain_iter\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mlr\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mnum_epochs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mtgt_vocab\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mdevice\u001B[49m\u001B[43m)\u001B[49m\n", "\u001B[36mFile \u001B[39m\u001B[32m~/.conda/envs/nn/lib/python3.11/site-packages/d2l/torch.py:3421\u001B[39m, in \u001B[36mtrain_seq2seq\u001B[39m\u001B[34m(net, data_iter, lr, num_epochs, tgt_vocab, device)\u001B[39m\n\u001B[32m 3418\u001B[39m bos = torch.tensor([tgt_vocab[\u001B[33m'\u001B[39m\u001B[33m\u001B[39m\u001B[33m'\u001B[39m]] * Y.shape[\u001B[32m0\u001B[39m],\n\u001B[32m 3419\u001B[39m device=device).reshape(-\u001B[32m1\u001B[39m, \u001B[32m1\u001B[39m)\n\u001B[32m 3420\u001B[39m dec_input = d2l.concat([bos, Y[:, :-\u001B[32m1\u001B[39m]], \u001B[32m1\u001B[39m) \u001B[38;5;66;03m# Teacher forcing\u001B[39;00m\n\u001B[32m-> \u001B[39m\u001B[32m3421\u001B[39m Y_hat, _ = \u001B[43mnet\u001B[49m\u001B[43m(\u001B[49m\u001B[43mX\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mdec_input\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mX_valid_len\u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m 3422\u001B[39m l = loss(Y_hat, Y, Y_valid_len)\n\u001B[32m 3423\u001B[39m l.sum().backward() \u001B[38;5;66;03m# Make the loss scalar for `backward`\u001B[39;00m\n", "\u001B[36mFile \u001B[39m\u001B[32m~/.conda/envs/nn/lib/python3.11/site-packages/torch/nn/modules/module.py:1776\u001B[39m, in \u001B[36mModule._wrapped_call_impl\u001B[39m\u001B[34m(self, *args, **kwargs)\u001B[39m\n\u001B[32m 1774\u001B[39m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m._compiled_call_impl(*args, **kwargs) \u001B[38;5;66;03m# type: ignore[misc]\u001B[39;00m\n\u001B[32m 1775\u001B[39m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[32m-> \u001B[39m\u001B[32m1776\u001B[39m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[43m.\u001B[49m\u001B[43m_call_impl\u001B[49m\u001B[43m(\u001B[49m\u001B[43m*\u001B[49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43m*\u001B[49m\u001B[43m*\u001B[49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n", "\u001B[36mFile \u001B[39m\u001B[32m~/.conda/envs/nn/lib/python3.11/site-packages/torch/nn/modules/module.py:1787\u001B[39m, in \u001B[36mModule._call_impl\u001B[39m\u001B[34m(self, *args, **kwargs)\u001B[39m\n\u001B[32m 1782\u001B[39m \u001B[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001B[39;00m\n\u001B[32m 1783\u001B[39m \u001B[38;5;66;03m# this function, and just call forward.\u001B[39;00m\n\u001B[32m 1784\u001B[39m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m (\u001B[38;5;28mself\u001B[39m._backward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m._backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m._forward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m._forward_pre_hooks\n\u001B[32m 1785\u001B[39m \u001B[38;5;129;01mor\u001B[39;00m _global_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_backward_hooks\n\u001B[32m 1786\u001B[39m \u001B[38;5;129;01mor\u001B[39;00m _global_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_forward_pre_hooks):\n\u001B[32m-> \u001B[39m\u001B[32m1787\u001B[39m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mforward_call\u001B[49m\u001B[43m(\u001B[49m\u001B[43m*\u001B[49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43m*\u001B[49m\u001B[43m*\u001B[49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m 1789\u001B[39m result = \u001B[38;5;28;01mNone\u001B[39;00m\n\u001B[32m 1790\u001B[39m called_always_called_hooks = \u001B[38;5;28mset\u001B[39m()\n", "\u001B[36mFile \u001B[39m\u001B[32m~/.conda/envs/nn/lib/python3.11/site-packages/d2l/torch.py:964\u001B[39m, in \u001B[36mEncoderDecoder.forward\u001B[39m\u001B[34m(self, enc_X, dec_X, *args)\u001B[39m\n\u001B[32m 962\u001B[39m dec_state = \u001B[38;5;28mself\u001B[39m.decoder.init_state(enc_all_outputs, *args)\n\u001B[32m 963\u001B[39m \u001B[38;5;66;03m# Return decoder output only\u001B[39;00m\n\u001B[32m--> \u001B[39m\u001B[32m964\u001B[39m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[43m.\u001B[49m\u001B[43mdecoder\u001B[49m\u001B[43m(\u001B[49m\u001B[43mdec_X\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mdec_state\u001B[49m\u001B[43m)\u001B[49m\n", "\u001B[36mFile \u001B[39m\u001B[32m~/.conda/envs/nn/lib/python3.11/site-packages/torch/nn/modules/module.py:1776\u001B[39m, in \u001B[36mModule._wrapped_call_impl\u001B[39m\u001B[34m(self, *args, **kwargs)\u001B[39m\n\u001B[32m 1774\u001B[39m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m._compiled_call_impl(*args, **kwargs) \u001B[38;5;66;03m# type: ignore[misc]\u001B[39;00m\n\u001B[32m 1775\u001B[39m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[32m-> \u001B[39m\u001B[32m1776\u001B[39m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[43m.\u001B[49m\u001B[43m_call_impl\u001B[49m\u001B[43m(\u001B[49m\u001B[43m*\u001B[49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43m*\u001B[49m\u001B[43m*\u001B[49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n", "\u001B[36mFile \u001B[39m\u001B[32m~/.conda/envs/nn/lib/python3.11/site-packages/torch/nn/modules/module.py:1787\u001B[39m, in \u001B[36mModule._call_impl\u001B[39m\u001B[34m(self, *args, **kwargs)\u001B[39m\n\u001B[32m 1782\u001B[39m \u001B[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001B[39;00m\n\u001B[32m 1783\u001B[39m \u001B[38;5;66;03m# this function, and just call forward.\u001B[39;00m\n\u001B[32m 1784\u001B[39m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m (\u001B[38;5;28mself\u001B[39m._backward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m._backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m._forward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m._forward_pre_hooks\n\u001B[32m 1785\u001B[39m \u001B[38;5;129;01mor\u001B[39;00m _global_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_backward_hooks\n\u001B[32m 1786\u001B[39m \u001B[38;5;129;01mor\u001B[39;00m _global_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_forward_pre_hooks):\n\u001B[32m-> \u001B[39m\u001B[32m1787\u001B[39m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mforward_call\u001B[49m\u001B[43m(\u001B[49m\u001B[43m*\u001B[49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43m*\u001B[49m\u001B[43m*\u001B[49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m 1789\u001B[39m result = \u001B[38;5;28;01mNone\u001B[39;00m\n\u001B[32m 1790\u001B[39m called_always_called_hooks = \u001B[38;5;28mset\u001B[39m()\n", "\u001B[36mCell\u001B[39m\u001B[36m \u001B[39m\u001B[32mIn[26]\u001B[39m\u001B[32m, line 37\u001B[39m, in \u001B[36mSeq2SeqAttentionDecoder.forward\u001B[39m\u001B[34m(self, X, state)\u001B[39m\n\u001B[32m 35\u001B[39m x = torch.cat((context, torch.unsqueeze(x, dim=\u001B[32m1\u001B[39m)), dim=-\u001B[32m1\u001B[39m)\n\u001B[32m 36\u001B[39m \u001B[38;5;66;03m# 将x变形为(1,batch_size,embed_size+num_hiddens)\u001B[39;00m\n\u001B[32m---> \u001B[39m\u001B[32m37\u001B[39m out, hidden_state = \u001B[38;5;28;43mself\u001B[39;49m\u001B[43m.\u001B[49m\u001B[43mrnn\u001B[49m\u001B[43m(\u001B[49m\u001B[43mx\u001B[49m\u001B[43m.\u001B[49m\u001B[43mpermute\u001B[49m\u001B[43m(\u001B[49m\u001B[32;43m1\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[32;43m0\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[32;43m2\u001B[39;49m\u001B[43m)\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mhidden_state\u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m 38\u001B[39m outputs.append(out)\n\u001B[32m 39\u001B[39m \u001B[38;5;28mself\u001B[39m._attention_weights.append(\u001B[38;5;28mself\u001B[39m.attention.attention_weights)\n", "\u001B[36mFile \u001B[39m\u001B[32m~/.conda/envs/nn/lib/python3.11/site-packages/torch/nn/modules/module.py:1776\u001B[39m, in \u001B[36mModule._wrapped_call_impl\u001B[39m\u001B[34m(self, *args, **kwargs)\u001B[39m\n\u001B[32m 1774\u001B[39m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m._compiled_call_impl(*args, **kwargs) \u001B[38;5;66;03m# type: ignore[misc]\u001B[39;00m\n\u001B[32m 1775\u001B[39m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[32m-> \u001B[39m\u001B[32m1776\u001B[39m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[43m.\u001B[49m\u001B[43m_call_impl\u001B[49m\u001B[43m(\u001B[49m\u001B[43m*\u001B[49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43m*\u001B[49m\u001B[43m*\u001B[49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n", "\u001B[36mFile \u001B[39m\u001B[32m~/.conda/envs/nn/lib/python3.11/site-packages/torch/nn/modules/module.py:1787\u001B[39m, in \u001B[36mModule._call_impl\u001B[39m\u001B[34m(self, *args, **kwargs)\u001B[39m\n\u001B[32m 1782\u001B[39m \u001B[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001B[39;00m\n\u001B[32m 1783\u001B[39m \u001B[38;5;66;03m# this function, and just call forward.\u001B[39;00m\n\u001B[32m 1784\u001B[39m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m (\u001B[38;5;28mself\u001B[39m._backward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m._backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m._forward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m._forward_pre_hooks\n\u001B[32m 1785\u001B[39m \u001B[38;5;129;01mor\u001B[39;00m _global_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_backward_hooks\n\u001B[32m 1786\u001B[39m \u001B[38;5;129;01mor\u001B[39;00m _global_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_forward_pre_hooks):\n\u001B[32m-> \u001B[39m\u001B[32m1787\u001B[39m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mforward_call\u001B[49m\u001B[43m(\u001B[49m\u001B[43m*\u001B[49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43m*\u001B[49m\u001B[43m*\u001B[49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m 1789\u001B[39m result = \u001B[38;5;28;01mNone\u001B[39;00m\n\u001B[32m 1790\u001B[39m called_always_called_hooks = \u001B[38;5;28mset\u001B[39m()\n", "\u001B[36mFile \u001B[39m\u001B[32m~/.conda/envs/nn/lib/python3.11/site-packages/torch/nn/modules/rnn.py:1415\u001B[39m, in \u001B[36mGRU.forward\u001B[39m\u001B[34m(self, input, hx)\u001B[39m\n\u001B[32m 1413\u001B[39m \u001B[38;5;28mself\u001B[39m.check_forward_args(\u001B[38;5;28minput\u001B[39m, hx, batch_sizes)\n\u001B[32m 1414\u001B[39m \u001B[38;5;28;01mif\u001B[39;00m batch_sizes \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[32m-> \u001B[39m\u001B[32m1415\u001B[39m result = \u001B[43m_VF\u001B[49m\u001B[43m.\u001B[49m\u001B[43mgru\u001B[49m\u001B[43m(\u001B[49m\n\u001B[32m 1416\u001B[39m \u001B[43m \u001B[49m\u001B[38;5;28;43minput\u001B[39;49m\u001B[43m,\u001B[49m\n\u001B[32m 1417\u001B[39m \u001B[43m \u001B[49m\u001B[43mhx\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 1418\u001B[39m \u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m.\u001B[49m\u001B[43m_flat_weights\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;66;43;03m# type: ignore[arg-type]\u001B[39;49;00m\n\u001B[32m 1419\u001B[39m \u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m.\u001B[49m\u001B[43mbias\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 1420\u001B[39m \u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m.\u001B[49m\u001B[43mnum_layers\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 1421\u001B[39m \u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m.\u001B[49m\u001B[43mdropout\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 1422\u001B[39m \u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m.\u001B[49m\u001B[43mtraining\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 1423\u001B[39m \u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m.\u001B[49m\u001B[43mbidirectional\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 1424\u001B[39m \u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m.\u001B[49m\u001B[43mbatch_first\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 1425\u001B[39m \u001B[43m \u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m 1426\u001B[39m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[32m 1427\u001B[39m result = _VF.gru(\n\u001B[32m 1428\u001B[39m \u001B[38;5;28minput\u001B[39m,\n\u001B[32m 1429\u001B[39m batch_sizes,\n\u001B[32m (...)\u001B[39m\u001B[32m 1436\u001B[39m \u001B[38;5;28mself\u001B[39m.bidirectional,\n\u001B[32m 1437\u001B[39m )\n", "\u001B[31mKeyboardInterrupt\u001B[39m: " ] }, { "data": { "text/plain": [ "
" ], "image/svg+xml": "\n\n\n \n \n \n \n 2026-06-30T11:22:39.515794\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" }, "metadata": {}, "output_type": "display_data", "jetTransient": { "display_id": null } } ], "execution_count": 28 }, { "cell_type": "code", "id": "ddaed8bd284ee01b", "metadata": { "ExecuteTime": { "end_time": "2026-06-30T03:22:50.427691735Z", "start_time": "2026-06-30T03:22:50.314718050Z" } }, "source": [ "engs = ['go .', \"i lost .\", 'he\\'s calm .', 'i\\'m home .']\n", "fras = ['va !', 'j\\'ai perdu .', 'il est calme .', 'je suis chez moi .']\n", "for eng, fra in zip(engs, fras):\n", " translation, dec_attention_weight_seq = d2l.predict_seq2seq(\n", " net, eng, src_vocab, tgt_vocab, num_steps, device, True)\n", " print(f'{eng} => {translation}, ',\n", " f'bleu {d2l.bleu(translation, fra, k=2):.3f}')" ], "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "go . => va !, bleu 1.000\n", "i lost . => j'ai perdu ., bleu 1.000\n", "he's calm . => il est riche ., bleu 0.658\n", "i'm home . => je suis calme ., bleu 0.512\n" ] } ], "execution_count": 29 }, { "cell_type": "code", "id": "13d19f8b5048c74d", "metadata": { "ExecuteTime": { "end_time": "2026-06-30T03:22:57.445168273Z", "start_time": "2026-06-30T03:22:57.376381899Z" } }, "source": [ "class MultiHeadAttention(nn.Module):\n", " def __init__(self, key_size, query_size, value_size, num_hiddens,\n", " num_heads, dropout, bias=False, **kwargs):\n", " super(MultiHeadAttention, self).__init__(**kwargs)\n", " self.num_heads = num_heads\n", " self.attention = d2l.DotProductAttention(dropout)\n", " self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)\n", " self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)\n", " self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)\n", " self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)\n", " def forward(self, queries, keys, values, valid_lens):\n", " # 1. 线性投影 + 变换形状以分割多头\n", " queries = transpose_qkv(self.W_q(queries), self.num_heads)\n", " keys = transpose_qkv(self.W_k(keys), self.num_heads)\n", " values = transpose_qkv(self.W_v(values), self.num_heads)\n", "\n", " # 2. 处理有效长度掩码(valid_lens)以适配多头\n", " if valid_lens is not None:\n", " valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)\n", "\n", " # 3. 计算注意力(每个头独立计算)\n", " output = self.attention(queries, keys, values, valid_lens)\n", "\n", " # 4. 合并多头,并通过输出线性层\n", " output_concat = transpose_output(output, self.num_heads)\n", " return self.W_o(output_concat)\n", "def transpose_qkv(X, num_heads):\n", " \"\"\"为了多注意力头的并行计算而变换形状\"\"\"\n", " # 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens)\n", " # 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,\n", " # num_hiddens/num_heads)\n", " X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)\n", " # 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数,\n", " # num_hiddens/num_heads)\n", " X = X.permute(0, 2, 1, 3)\n", " # 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数,\n", " # num_hiddens/num_heads)\n", " return X.reshape(-1, X.shape[2], X.shape[3])\n", "\n", "def transpose_output(X, num_heads):\n", " \"\"\"逆转transpose_qkv函数的操作\"\"\"\n", " X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])\n", " X = X.permute(0, 2, 1, 3)\n", " return X.reshape(X.shape[0], X.shape[1], -1)\n" ], "outputs": [], "execution_count": 30 }, { "cell_type": "code", "id": "b25c661a511d7763", "metadata": { "ExecuteTime": { "end_time": "2026-06-30T03:22:58.271805155Z", "start_time": "2026-06-30T03:22:58.200822898Z" } }, "source": [ "num_hiddens, num_heads = 100, 5\n", "attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,\n", "num_hiddens, num_heads, 0.5)\n", "attention.eval()" ], "outputs": [ { "data": { "text/plain": [ "MultiHeadAttention(\n", " (attention): DotProductAttention(\n", " (dropout): Dropout(p=0.5, inplace=False)\n", " )\n", " (W_q): Linear(in_features=100, out_features=100, bias=False)\n", " (W_k): Linear(in_features=100, out_features=100, bias=False)\n", " (W_v): Linear(in_features=100, out_features=100, bias=False)\n", " (W_o): Linear(in_features=100, out_features=100, bias=False)\n", ")" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 31 }, { "cell_type": "code", "id": "49877c929741ec94", "metadata": { "ExecuteTime": { "end_time": "2026-06-30T03:22:58.957157613Z", "start_time": "2026-06-30T03:22:58.850698625Z" } }, "source": [ "batch_size, num_queries = 2, 4\n", "num_kvpairs, valid_lens = 6, torch.tensor([3, 2])\n", "X = torch.ones((batch_size, num_queries, num_hiddens))\n", "Y = torch.ones((batch_size, num_kvpairs, num_hiddens))\n", "attention(X, Y, Y, valid_lens).shape" ], "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 4, 100])" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 32 }, { "cell_type": "code", "id": "b8a10576c7789750", "metadata": { "ExecuteTime": { "end_time": "2026-06-30T03:24:09.856237679Z", "start_time": "2026-06-30T03:24:09.793761501Z" } }, "source": [ "class PositionalEncoding(nn.Module):\n", " def __init__(self,num_hiddens,dropout,max_len=1000):\n", " super().__init__()\n", " self.dropout = nn.Dropout(dropout)\n", " self.dropout = nn.Dropout(dropout)\n", " X = torch.arange(max_len, dtype=torch.float32).reshape(\n", " -1, 1) / torch.pow(10000, torch.arange(\n", " 0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)\n", " self.P = torch.zeros((1,max_len,num_hiddens))\n", " self.P[:,:,0::2] = torch.sin(X)\n", " self.P[:,:,1::2] = torch.cos(X)\n", " def forward(self,X):\n", " X = X + self.P[:,:X.shape[1],:].to(X.device)\n", " return self.dropout(X)" ], "outputs": [], "execution_count": 35 }, { "metadata": { "ExecuteTime": { "end_time": "2026-06-30T03:24:10.924435124Z", "start_time": "2026-06-30T03:24:10.759631840Z" } }, "cell_type": "code", "source": [ "encoding_dim, num_steps = 32, 60\n", "pos_encoding = PositionalEncoding(encoding_dim, 0)\n", "pos_encoding.eval()\n", "X = pos_encoding(torch.zeros((1, num_steps, encoding_dim)))\n", "P = pos_encoding.P[:, :X.shape[1], :]\n", "d2l.plot(torch.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)',\n", " figsize=(6, 2.5), legend=[\"Col %d\" % d for d in torch.arange(6, 10)])" ], "id": "fe6cd8d94205977f", "outputs": [ { "data": { "text/plain": [ "
" ], "image/svg+xml": "\n\n\n \n \n \n \n 2026-06-30T11:24:10.850286\n image/svg+xml\n \n \n Matplotlib v3.7.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" }, "metadata": {}, "output_type": "display_data", "jetTransient": { "display_id": null } } ], "execution_count": 36 }, { "metadata": { "ExecuteTime": { "end_time": "2026-06-30T03:29:41.868213373Z", "start_time": "2026-06-30T03:29:41.818146905Z" } }, "cell_type": "code", "source": [ "class PositionWiseFFN(nn.Module):\n", " def __init__(self,ffn_num_input,ffn_num_hiddens,ffn_num_outputs,**kwargs):\n", " super().__init__(**kwargs)\n", " self.dense1 = nn.Linear(ffn_num_input,ffn_num_hiddens)\n", " self.relu = nn.ReLU()\n", " self.dense2 = nn.Linear(ffn_num_hiddens,ffn_num_outputs)\n", " def forward(self,X):\n", " return self.dense2(self.relu(self.dense1(X)))\n" ], "id": "b764f19d408049b", "outputs": [], "execution_count": 38 }, { "metadata": { "ExecuteTime": { "end_time": "2026-06-30T03:35:03.713494687Z", "start_time": "2026-06-30T03:35:03.630606833Z" } }, "cell_type": "code", "source": [ "class AddNorm(nn.Module):\n", " def __init__(self,normalized_shape,dropout,**kwargs):\n", " super().__init__()\n", " self.dropout=nn.Dropout()\n", " self.ln = nn.LayerNorm(normalized_shape)\n", " def forward(self,X,Y):\n", " return self.ln(self.dropout(Y)+X)\n" ], "id": "377f4d9eea428068", "outputs": [], "execution_count": 39 }, { "metadata": {}, "cell_type": "code", "outputs": [], "execution_count": null, "source": [ "class EncoderBlock(nn.Moudle):\n", " def __init(self,key_size,query_size,value_size,num_hiddens,norm_shape,ffn_num_input,ffn_num_hiddens,num_heads,dropout,use_bias=False,**kwargs):\n", " super(EncoderBlock,self).__init__(**kwargs)\n", " self.attention = MultiHeadAttention(key_size,query_size,value_size,num_hiddens,num_heads,dropout,bias=use_bias)\n", " self.addnorm1 = AddNorm(norm_shape,dropout)\n", " self.ffn = PositionWiseFFN(\n", " ffn_num_input, ffn_num_hiddens, num_hiddens)\n", " self.addnorm2 = AddNorm(norm_shape, dropout)\n", " def forward(self, X, valid_lens):\n", " Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))\n", " return self.addnorm2(Y, self.ffn(Y))" ], "id": "c54dc4e20e9ffb90" } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 5 }