实现
为了让加载器可以加载自定义数据集,
要为自定义数据集实现四个核心功能:
-
- 继承
Dataset
类
- 继承
from torch.utils.data import Dataset
class XsDataset(Dataset):
-
__len__
获取数据集中的要训练数据的个数
-
__getitem__
打开图片,返回图片tensor
和labels
的tensor
-
transforms
from torch.utils.data import Dataset
import os
from PIL import Image
import torch
class LetterDataset(Dataset):
def __init__(self, root: str, transform=None):
super(LetterDataset, self).__init__()
self.path = root
self.transform = transform
# 可优化
self.mapping = [i for i in "_0123456789加减乘+-*"]
def load_picture_path(self):
picture_list = list(os.walk(self.path))[0][-1]
# 这里可以增加很多的错误判断
return picture_list
def __len__(self):
return len(self.load_picture_path())
def __getitem__(self, item):
load_picture = self.load_picture_path()
image = Image.open(self.path + '/' +load_picture[item])
if self.transform:
image = self.transform(image)
labels = [self.mapping.index(i) for i in load_picture[item].split('_')[0]]
for i in range(9-len(labels)):
labels.insert(0, 0)
print(labels)
labels = torch.as_tensor(labels, dtype=torch.int64)
return image, labels