Files
Stomach_Cancer_Pytorch/testing.py

64 lines
2.2 KiB
Python

# import paramiko
# from scp import SCPClient
# import os
# import pexpect
# # def createSSHClient(server, port, user, password):
# # client = paramiko.SSHClient()
# # client.load_system_host_keys()
# # client.set_missing_host_key_policy(paramiko.AutoAddPolicy)
# # client.connect(server, port, user, password)
# # return client
# # ssh = createSSHClient("10.1.29.28", 31931, "root", "whitekirin")
# # # os.mkdir("Original_ResNet101V2_with_NPC_Augmentation_Image")
# # # with open("Original_ResNet101V2_with_NPC_Augmentation_Image_train3.txt", "w") as file:
# # # pass
# # with SCPClient(ssh.get_transport()) as scp:
# # scp.get("/mnt/c/張晉嘉/stomach_cancer/Original_ResNet101V2_with_NPC_Augmentation_Image_train3.txt", "/raid/whitekirin/stomach_cancer/Model_result/save_the_train_result(2024-10-05)/Original_ResNet101V2_with_NPC_Augmentation_Image_train3.txt")
# def upload(port, filename, user, ip, dst_path):
# cmdline = "scp %s -r %s %s@%s:%s" % (port, filename, user, ip, dst_path)
# try:
# child = pexpect.spawn(cmdline)
# child.expect("whitekirin109316118")
# child.sendline()
# child.expect(pexpect.EOF)
# print("file upload Finish")
# except Exception as e:
# print("upload faild: ", e)
# upload(2222, "/raid/whitekirin/stomach_cancer/Model_result/save_the_train_result(2024-10-05)", "whitekirin", "203.64.84.39", "/mnt/c/張晉嘉/stomach_cancer")
from torch.utils.data import Dataset
from torch.utils.data import Subset, DataLoader
class ListDataset(Dataset):
def __init__(self, data_list, labels_list):
self.data = data_list
self.labels = labels_list
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sample = self.data[idx]
label = self.labels[idx]
return sample, label
# 示例數據
data_list = ["image1.jpg", "image2.jpg", "image3.jpg"]
labels_list = [0, 1, 0]
# 創建 Dataset
dataset = ListDataset(data_list, labels_list)
# 測試
# print(type(dataset[0])) # ('image1.jpg', 0)
dataloader = DataLoader(dataset = dataset, batch_size = 1, shuffle=True, num_workers = 0, pin_memory=True)
print(dataloader)