diff --git a/simclr/modules/resnet.py b/simclr/modules/resnet.py index 7c9590a..2e09b8e 100644 --- a/simclr/modules/resnet.py +++ b/simclr/modules/resnet.py @@ -3,8 +3,10 @@ def get_resnet(name, pretrained=False): resnets = { - "resnet18": torchvision.models.resnet18(pretrained=pretrained), - "resnet50": torchvision.models.resnet50(pretrained=pretrained), + "resnet18": torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT + if pretrained else None), + "resnet50": torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT + if pretrained else None), } if name not in resnets.keys(): raise KeyError(f"{name} is not a valid ResNet version")