This commit is contained in:
yukun-hh 2026-03-05 15:31:25 +08:00
parent d45d466fd1
commit 369db12fc6

21
main.py
View file

@ -2,15 +2,28 @@
# 按 Shift+F10 执行或将其替换为您的代码。 # 按 Shift+F10 执行或将其替换为您的代码。
# 按 双击 Shift 在所有地方搜索类、文件、工具窗口、操作和设置。 # 按 双击 Shift 在所有地方搜索类、文件、工具窗口、操作和设置。
import torch
from torch import nn
# 按 Ctrl+F8 切换断点。
class MyDictDense(nn.Module):
def __init__(self):
super(MyDictDense, self).__init__()
self.params = nn.ParameterDict({
'linear1': nn.Parameter(torch.randn(4, 4)),
'linear2': nn.Parameter(torch.randn(4, 1))
})
self.params.update({'linear3': nn.Parameter(torch.randn(4, 2))}) # 新增
def forward(self, x, choice='linear1'):
return torch.mm(x, self.params[choice])
def print_hi(name):
# 在下面的代码行中使用断点来调试脚本。
print(f'Hi, {name}') # 按 Ctrl+F8 切换断点。
# 按装订区域中的绿色按钮以运行脚本。 # 按装订区域中的绿色按钮以运行脚本。
if __name__ == '__main__': if __name__ == '__main__':
print_hi('PyCharm')
net = MyDictDense()
print(net)
# 访问 https://www.jetbrains.com/help/pycharm/ 获取 PyCharm 帮助 # 访问 https://www.jetbrains.com/help/pycharm/ 获取 PyCharm 帮助