Source code for src.common.model.backbone

import torch
import torchvision
import torch.nn.modules


[docs]class resnet(torch.nn.Module): def __init__(self,layers,pretrained = False): super(resnet,self).__init__() if layers == '18': model = torchvision.models.resnet18(pretrained=pretrained) elif layers == '34': model = torchvision.models.resnet34(pretrained=pretrained) elif layers == '50': model = torchvision.models.resnet50(pretrained=pretrained) elif layers == '101': model = torchvision.models.resnet101(pretrained=pretrained) elif layers == '152': model = torchvision.models.resnet152(pretrained=pretrained) elif layers == '50next': model = torchvision.models.resnext50_32x4d(pretrained=pretrained) elif layers == '101next': model = torchvision.models.resnext101_32x8d(pretrained=pretrained) elif layers == '50wide': model = torchvision.models.wide_resnet50_2(pretrained=pretrained) elif layers == '101wide': model = torchvision.models.wide_resnet101_2(pretrained=pretrained) else: raise NotImplementedError self.conv1 = model.conv1 self.bn1 = model.bn1 self.relu = model.relu self.maxpool = model.maxpool self.layer1 = model.layer1 self.layer2 = model.layer2 self.layer3 = model.layer3 self.layer4 = model.layer4
[docs] def forward(self,x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x2 = self.layer2(x) x3 = self.layer3(x2) x4 = self.layer4(x3) return x2,x3,x4