Merge branch 'dev'
This commit is contained in:
commit
5d51c20a1d
1 changed files with 43 additions and 0 deletions
43
model.py
Normal file
43
model.py
Normal file
|
|
@ -0,0 +1,43 @@
|
||||||
|
"""
|
||||||
|
这个文件是模型的定义文件,请不要擅自修改,如有疑问微信群里反馈
|
||||||
|
author : yukun-hh
|
||||||
|
date : 2026-4-10
|
||||||
|
|
||||||
|
"""
|
||||||
|
#神经网络模型库
|
||||||
|
import torch
|
||||||
|
from modelscope.msdatasets.dataset_cls.custom_datasets.audio.kws_nearfield_processor import padding
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
#残差块
|
||||||
|
class Resblock(nn.Module):
|
||||||
|
def __init__(self, input_channels,output_channels,use_1x1conv=False,strides=1):
|
||||||
|
"""
|
||||||
|
|
||||||
|
:param input_channels: 进入残差块时的原通道
|
||||||
|
:param output_channels: 输出时的通道数
|
||||||
|
:param use_1x1conv: 如果输入和输出通道不相等时,需要用一个1x1的卷积层对原来的输入进行一个通道提升
|
||||||
|
:param strides: 默认1,如果大于1起到缩小张量的作用
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.conv1 = nn.Conv2d(input_channels,output_channels,kernel_size=3,padding=1,stride=strides)
|
||||||
|
self.conv2 = nn.Conv2d(output_channels,output_channels,kernel_size=3,padding=1,stride=strides)
|
||||||
|
if use_1x1conv:
|
||||||
|
self.conv3 = nn.Conv2d(input_channels, output_channels,kernel_size=1, stride=strides)
|
||||||
|
else:
|
||||||
|
self.conv3 = None
|
||||||
|
self.bn1 = nn.BatchNorm2d(output_channels)
|
||||||
|
self.bn2 = nn.BatchNorm2d(output_channels)
|
||||||
|
def forward(self,X):
|
||||||
|
Y = F.relu(self.bn1(self.conv1(X)))
|
||||||
|
Y = self.bn2(self.conv2(Y))
|
||||||
|
if self.conv3 is not None:
|
||||||
|
X = self.conv3(X)
|
||||||
|
Y += X
|
||||||
|
return F.relu(Y)
|
||||||
|
|
||||||
|
class Net():
|
||||||
|
def
|
||||||
|
|
||||||
|
|
||||||
Loading…
Reference in a new issue