trash-division/model.py
yukun-hh e101321ff1 Create model.py
模型初步搭建
2026-04-10 16:11:49 +08:00

43 lines
1.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
这个文件是模型的定义文件,请不要擅自修改,如有疑问微信群里反馈
author : yukun-hh
date : 2026-4-10
"""
#神经网络模型库
import torch
from modelscope.msdatasets.dataset_cls.custom_datasets.audio.kws_nearfield_processor import padding
from torch import nn
from torch.nn import functional as F
#残差块
class Resblock(nn.Module):
def __init__(self, input_channels,output_channels,use_1x1conv=False,strides=1):
"""
:param input_channels: 进入残差块时的原通道
:param output_channels: 输出时的通道数
:param use_1x1conv: 如果输入和输出通道不相等时需要用一个1x1的卷积层对原来的输入进行一个通道提升
:param strides: 默认1如果大于1起到缩小张量的作用
"""
super().__init__()
self.conv1 = nn.Conv2d(input_channels,output_channels,kernel_size=3,padding=1,stride=strides)
self.conv2 = nn.Conv2d(output_channels,output_channels,kernel_size=3,padding=1,stride=strides)
if use_1x1conv:
self.conv3 = nn.Conv2d(input_channels, output_channels,kernel_size=1, stride=strides)
else:
self.conv3 = None
self.bn1 = nn.BatchNorm2d(output_channels)
self.bn2 = nn.BatchNorm2d(output_channels)
def forward(self,X):
Y = F.relu(self.bn1(self.conv1(X)))
Y = self.bn2(self.conv2(Y))
if self.conv3 is not None:
X = self.conv3(X)
Y += X
return F.relu(Y)
class Net():
def