一位不错的小伙给的代码 (前同事)。
这里实现主要是使用:nvidia.dali
代码如下:
代码语言:javascript复制from __future__ import division
import torch
import types
import joblib
import collections
import numpy as np
import pandas as pd
from random import shuffle
from nvidia.dali.pipeline import Pipeline
import nvidia.dali.ops as ops
import nvidia.dali.types as types
import nvidia.dali.plugin.pytorch as dalitorch
from nvidia.dali.plugin.pytorch import DALIGenericIterator as PyTorchIterator
def grid2x2(img):
h, w, c = img.shape
left_top = img[0:h//2, 0:w//2, :]
left_bottom = img[h//2:h, 0:w//2, :]
right_top = img[0:h//2, w//2:w, :]
right_bottom = img[h//2:h, w//2:w, :]
return left_top, right_top, left_bottom, left_bottom
class ExternalInputIterator(object):
def __init__(self, images_dir, txt_path, batch_size, device_id, num_gpus):
self.images_dir = images_dir
self.batch_size = batch_size
with open(txt_path, 'r') as f:
self.files = [line.rstrip() for line in f if line is not '']
# whole data set size
self.data_set_len = len(self.files)
# based on the device_id and total number of GPUs - world size
# get proper shard
self.files = self.files[self.data_set_len * device_id // num_gpus:
self.data_set_len * (device_id 1) // num_gpus]
self.n = len(self.files)
def __iter__(self):
self.i = 0
shuffle(self.files)
return self
def __next__(self):
batch = []
labels = []
if self.i >= self.n:
raise StopIteration
for _ in range(self.batch_size):
jpeg_filename, label = self.files[self.i].split(',')
f = open(self.images_dir jpeg_filename, 'rb')
# jpeg_filename, label = self.files[self.i], 1
# f = open(jpeg_filename, 'rb')
batch.append(np.frombuffer(f.read(), dtype = np.uint8))
labels.append(np.array([int(label)], dtype = np.uint8))
self.i = (self.i 1) % self.n
return (batch, labels)
@property
def size(self,):
return self.data_set_len
next = __next__
class ExternalSourcePipeline(Pipeline):
def __init__(self, resize, batch_size, num_threads, device_id, external_data):
super(ExternalSourcePipeline, self).__init__(batch_size,
num_threads,
device_id,
seed=12,
exec_async=False,
exec_pipelined=False,
)
self.input = ops.ExternalSource()
self.input_label = ops.ExternalSource()
self.decode = ops.ImageDecoder(device = "cpu", output_type = types.RGB)
# PythonFunction: exec_async and exec_pipelined need to be False, and input must cpu
self.grid = ops.PythonFunction(function=grid2x2, num_outputs=4)
# self.grid = dalitorch.TorchPythonFunction(function=grid2x2, num_outputs=5)
self.resize = ops.Resize(device="gpu",
resize_x=resize,
resize_y=resize,
interp_type=types.INTERP_LINEAR)
# self.cast = ops.Cast(device = "gpu",
# dtype = types.UINT8)
self.external_data = external_data
self.iterator = iter(self.external_data)
def define_graph(self):
self.jpegs = self.input()
self.labels = self.input_label()
images = self.decode(self.jpegs)
images1, images2, images3, images4 = self.grid(images)
images = self.resize(images.gpu())
images1 = self.resize(images1.gpu())
images2 = self.resize(images2.gpu())
images3 = self.resize(images3.gpu())
images4 = self.resize(images4.gpu())
return (images, images1, images2, images3, images4, self.labels)
def iter_setup(self):
try:
images, labels = self.iterator.next()
self.feed_input(self.jpegs, images)
self.feed_input(self.labels, labels)
except StopIteration:
self.iterator = iter(self.external_data)
raise StopIteration
def create_dataloder(img_dir,
txt_path,
resize,
batch_size,
device_id=0,
num_gpus=1,
num_threads=6):
eii = ExternalInputIterator(img_dir,
txt_path,
batch_size=batch_size,
device_id=device_id,
num_gpus=num_gpus)
pipe = ExternalSourcePipeline(resize=resize,
batch_size=batch_size,
num_threads=num_threads,
device_id = 0,
external_data = eii)
pii = PyTorchIterator(pipe,
output_map=["data0", "data1", "data2", "data3", "data4", "label"],
size=eii.size,
last_batch_padded=True,
fill_last_batch=False)
return pii
if __name__ == '__main__':
batch_size = 32
num_gpus = 1
num_threads = 8
epochs = 1
pii = create_dataloder('/home/hanbing/hanbing_data/datasets/deepfake/train_videos/',
resize=224,
batch_size=batch_size,
txt_path='./txt/train_5.txt',
)
for e in range(epochs):
print('tttt', len(pii))
for i, data in enumerate(pii):
imgs = data[0]["data4"]
labels = data[0]["label"]
print("epoch: {}, iter {}".format(e, i), imgs.shape, labels.shape)
pii.reset()