21 lines
724 B
Python
21 lines
724 B
Python
import torchvision
|
|
import torch.nn as nn
|
|
|
|
def check_vit():
|
|
model = torchvision.models.vit_b_16(pretrained=False)
|
|
print("Searching for MultiheadAttention modules:")
|
|
found = False
|
|
for name, module in model.named_modules():
|
|
if isinstance(module, nn.MultiheadAttention):
|
|
print(f"Found MultiheadAttention at: {name}")
|
|
found = True
|
|
|
|
if not found:
|
|
print("No MultiheadAttention found directly. Checking modules with 'attn' or 'attention' in name:")
|
|
for name, module in model.named_modules():
|
|
if 'attn' in name or 'attention' in name:
|
|
print(f"Candidate: {name} ({type(module).__name__})")
|
|
|
|
if __name__ == "__main__":
|
|
check_vit()
|