File size: 3,893 Bytes
3a81605 513e1fb 3a81605 513e1fb 3a81605 513e1fb 3a81605 513e1fb 3a81605 513e1fb 3a81605 513e1fb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import cv2
import torch as th
import os
import numpy as np
from decord import VideoReader, cpu
class Normalize(object):
def __init__(self, mean, std):
self.mean = th.FloatTensor(mean).view(1, 3, 1, 1)
self.std = th.FloatTensor(std).view(1, 3, 1, 1)
def __call__(self, tensor):
tensor = (tensor - self.mean) / (self.std + 1e-8)
return tensor
class Preprocessing(object):
def __init__(self):
self.norm = Normalize(
mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711],
)
def __call__(self, tensor):
tensor = tensor / 255.0
tensor = self.norm(tensor)
return tensor
class VideoLoader:
"""Pytorch video loader."""
def __init__(
self,
framerate=1,
size=224,
centercrop=True,
):
self.centercrop = centercrop
self.size = size
self.framerate = framerate
self.preprocess = Preprocessing()
self.max_feats = 10
self.features_dim = 768
def _get_video_dim(self, video_path):
vr = VideoReader(video_path, ctx=cpu(0))
height, width, _ = vr[0].shape
frame_rate = vr.get_avg_fps()
return height, width, frame_rate
def _get_output_dim(self, h, w):
if isinstance(self.size, tuple) and len(self.size) == 2:
return self.size
elif h >= w:
return int(h * self.size / w), self.size
else:
return self.size, int(w * self.size / h)
def _getvideo(self, video_path):
if os.path.isfile(video_path):
print("Decoding video: {}".format(video_path))
try:
h, w, fr = self._get_video_dim(video_path)
except:
print("ffprobe failed at: {}".format(video_path))
return {
"video": th.zeros(1),
"input": video_path
}
if fr < 1:
print("Corrupted Frame Rate: {}".format(video_path))
return {
"video": th.zeros(1),
"input": video_path
}
height, width = self._get_output_dim(h, w)
# resize ##
vr = VideoReader(video_path, ctx=cpu(0))
video = vr.get_batch(range(0, len(vr), int(fr))).asnumpy()
video = np.array([cv2.resize(frame, (width, height)) for frame in video])
try:
if self.centercrop:
x = int((width - self.size) / 2.0)
y = int((height - self.size) / 2.0)
video = video[:, y:y+self.size, x:x+self.size, :]
except:
print("ffmpeg error at: {}".format(video_path))
return {
"video": th.zeros(1),
"input": video_path,
}
if self.centercrop and isinstance(self.size, int):
height, width = self.size, self.size
video = th.from_numpy(video.astype("float32"))
video = video.permute(0, 3, 1, 2) # t,c,h,w
else:
video = th.zeros(1)
return {"video": video, "input": video_path}
def __call__(self, video_path):
video = self._getvideo(video_path)['video']
if len(video) > self.max_feats:
sampled = []
for j in range(self.max_feats):
sampled.append(video[(j * len(video)) // self.max_feats])
video = th.stack(sampled)
video_len = self.max_feats
elif len(video) < self.max_feats:
video_len = len(video)
video = th.cat(
[video, th.zeros(self.max_feats - video_len, self.features_dim)], 0
)
video = self.preprocess(video)
return video, video_len
|