Merge branch 'dev'

This commit is contained in:
yukun-hh 2026-04-10 16:12:16 +08:00
commit 5d51c20a1d

43
model.py Normal file
View file

@ -0,0 +1,43 @@
"""
这个文件是模型的定义文件请不要擅自修改如有疑问微信群里反馈
author : yukun-hh
date : 2026-4-10
"""
#神经网络模型库
import torch
from modelscope.msdatasets.dataset_cls.custom_datasets.audio.kws_nearfield_processor import padding
from torch import nn
from torch.nn import functional as F
#残差块
class Resblock(nn.Module):
def __init__(self, input_channels,output_channels,use_1x1conv=False,strides=1):
"""
:param input_channels: 进入残差块时的原通道
:param output_channels: 输出时的通道数
:param use_1x1conv: 如果输入和输出通道不相等时需要用一个1x1的卷积层对原来的输入进行一个通道提升
:param strides: 默认1如果大于1起到缩小张量的作用
"""
super().__init__()
self.conv1 = nn.Conv2d(input_channels,output_channels,kernel_size=3,padding=1,stride=strides)
self.conv2 = nn.Conv2d(output_channels,output_channels,kernel_size=3,padding=1,stride=strides)
if use_1x1conv:
self.conv3 = nn.Conv2d(input_channels, output_channels,kernel_size=1, stride=strides)
else:
self.conv3 = None
self.bn1 = nn.BatchNorm2d(output_channels)
self.bn2 = nn.BatchNorm2d(output_channels)
def forward(self,X):
Y = F.relu(self.bn1(self.conv1(X)))
Y = self.bn2(self.conv2(Y))
if self.conv3 is not None:
X = self.conv3(X)
Y += X
return F.relu(Y)
class Net():
def