实现

为了让加载器可以加载自定义数据集,
要为自定义数据集实现四个核心功能:

    1. 继承 Dataset
from torch.utils.data import Dataset
class XsDataset(Dataset):
    1. __len__

获取数据集中的要训练数据的个数

    1. __getitem__

打开图片,返回图片tensorlabelstensor

    1. transforms
from torch.utils.data import Dataset
import os
from PIL import Image
import torch

class LetterDataset(Dataset):
    def __init__(self, root: str, transform=None):
        super(LetterDataset, self).__init__()
        self.path = root
        self.transform = transform
        # 可优化
        self.mapping = [i for i in "_0123456789加减乘+-*"]

    def load_picture_path(self):
        picture_list = list(os.walk(self.path))[0][-1]
        # 这里可以增加很多的错误判断
        return picture_list

    def __len__(self):
        return len(self.load_picture_path())

    def __getitem__(self, item):
        load_picture = self.load_picture_path()
        image = Image.open(self.path + '/' +load_picture[item])
        if self.transform:
            image = self.transform(image)
        labels = [self.mapping.index(i) for i in load_picture[item].split('_')[0]]
        for i in range(9-len(labels)):
            labels.insert(0, 0)
        print(labels)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        return image, labels

站点统计

  • 文章总数:316 篇
  • 分类总数:20 个
  • 标签总数:193 个
  • 运行天数:1194 天
  • 访问总数:92395 人次

浙公网安备33011302000604

辽ICP备20003309号