callbacks: Added pytorch function display and pytorch conway life example. Still need to test coords.
This commit is contained in:
@@ -65,3 +65,78 @@ class function_display_callback(object): # NOSONAR
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.inner_function(self, *args, **kwargs)
|
||||
|
||||
|
||||
class pytorch_function_display_callback(object): # NOSONAR
|
||||
def __init__(self, display_function, finish_function=None):
|
||||
"""Used for running arbitrary functions on pixels.
|
||||
|
||||
>>> import random
|
||||
>>> import torch
|
||||
>>> from cvpubsubs.webcam_pub import VideoHandlerThread
|
||||
>>> img = np.zeros((300, 300, 3))
|
||||
>>> def fun(array, coords, finished):
|
||||
... rgb = torch.empty(array.shape).uniform_(0,1).type(torch.DoubleTensor).to(array.device)/200.0
|
||||
... array[coords] = (array[coords] + rgb[coords])%1.0
|
||||
>>> VideoHandlerThread(video_source=img, callbacks=pytorch_function_display_callback(fun)).display()
|
||||
|
||||
thanks: https://medium.com/@awildtaber/building-a-rendering-engine-in-tensorflow-262438b2e062
|
||||
|
||||
:param display_function:
|
||||
:param finish_function:
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch.autograd import Variable
|
||||
|
||||
self.looping = True
|
||||
self.first_call = True
|
||||
|
||||
def _run_finisher(self, frame, finished, *args, **kwargs):
|
||||
if not callable(finish_function):
|
||||
WinCtrl.quit()
|
||||
else:
|
||||
finished = finish_function(frame, Ellipsis, finished, *args, **kwargs)
|
||||
if finished:
|
||||
WinCtrl.quit()
|
||||
|
||||
def _setup(self, frame, cam_id, *args, **kwargs):
|
||||
|
||||
if "device" in kwargs:
|
||||
self.device = torch.device(kwargs["device"])
|
||||
else:
|
||||
if torch.cuda.is_available():
|
||||
self.device = torch.device("cuda")
|
||||
else:
|
||||
self.device = torch.device("cpu")
|
||||
|
||||
self.min_bounds = [0 for _ in frame.shape]
|
||||
self.max_bounds = list(frame.shape)
|
||||
grid_slices = [slice(self.min_bounds[d], self.max_bounds[d]) for d in range(len(frame.shape))]
|
||||
self.space_grid = np.mgrid[grid_slices]
|
||||
x_tens = torch.LongTensor(self.space_grid[0, ...]).to(self.device)
|
||||
y_tens = torch.LongTensor(self.space_grid[1, ...]).to(self.device)
|
||||
c_tens = torch.LongTensor(self.space_grid[2, ...]).to(self.device)
|
||||
self.x = Variable(x_tens, requires_grad=False)
|
||||
self.y = Variable(y_tens, requires_grad=False)
|
||||
self.c = Variable(c_tens, requires_grad=False)
|
||||
|
||||
def _display_internal(self, frame, cam_id, *args, **kwargs):
|
||||
finished = True
|
||||
if self.first_call:
|
||||
# return to display initial frame
|
||||
_setup(self, frame, finished, *args, **kwargs)
|
||||
self.first_call = False
|
||||
return
|
||||
if self.looping:
|
||||
tor_frame = torch.from_numpy(frame).to(self.device)
|
||||
finished = display_function(tor_frame, (self.x, self.y, self.c), finished, *args, **kwargs)
|
||||
frame[...] = tor_frame.cpu().numpy()[...]
|
||||
if finished:
|
||||
self.looping = False
|
||||
_run_finisher(self, frame, finished, *args, **kwargs)
|
||||
|
||||
self.inner_function = _display_internal
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.inner_function(self, *args, **kwargs)
|
||||
|
||||
@@ -40,9 +40,9 @@ class SubscriberWindows(object):
|
||||
@staticmethod
|
||||
def set_global_frame_dict(name, *args):
|
||||
if len(str(name)) <= 1000:
|
||||
SubscriberWindows.frame_dict[str(name) + "frame"] = [*args]
|
||||
SubscriberWindows.frame_dict[str(name) + "frame"] = list(args)
|
||||
elif isinstance(name, np.ndarray):
|
||||
SubscriberWindows.frame_dict[str(hash(str(name))) + "frame"] = [*args]
|
||||
SubscriberWindows.frame_dict[str(hash(str(name))) + "frame"] = list(args)
|
||||
else:
|
||||
raise ValueError("Input window name too long.")
|
||||
|
||||
|
||||
@@ -128,3 +128,39 @@ class TestSubWin(ut.TestCase):
|
||||
array[coords] = 1.0
|
||||
|
||||
VideoHandlerThread(video_source=img, callbacks=function_display_callback(conway)).display()
|
||||
|
||||
def test_conway_life_pytorch(self):
|
||||
import torch
|
||||
from torch import functional as F
|
||||
from cvpubsubs.webcam_pub import VideoHandlerThread
|
||||
from cvpubsubs.webcam_pub.callbacks import pytorch_function_display_callback
|
||||
|
||||
img = np.ones((600, 800, 1))
|
||||
img[10:590, 10:790, :] = 0
|
||||
|
||||
def fun(frame, coords, finished):
|
||||
array = frame
|
||||
neighbor_weights = torch.ones(torch.Size([3, 3]))
|
||||
neighbor_weights[1, 1, ...] = 0
|
||||
neighbor_weights = torch.Tensor(neighbor_weights).type_as(array).to(array.device)
|
||||
neighbor_weights = neighbor_weights.squeeze()[None, None, :, :]
|
||||
array = array.permute(2, 1, 0)[None, ...]
|
||||
neighbors = torch.nn.functional.conv2d(array, neighbor_weights, stride=1, padding=1)
|
||||
live_array = torch.where((neighbors < 2) | (neighbors > 3),
|
||||
torch.zeros_like(array),
|
||||
torch.where((2 <= neighbors) & (neighbors <= 3),
|
||||
torch.ones_like(array),
|
||||
array
|
||||
)
|
||||
)
|
||||
dead_array = torch.where(neighbors == 3,
|
||||
torch.ones_like(array),
|
||||
array)
|
||||
array = torch.where(array == 1.0,
|
||||
live_array,
|
||||
dead_array
|
||||
)
|
||||
array = array.squeeze().permute(1, 0)[...,None]
|
||||
frame[...] = array[...]
|
||||
|
||||
VideoHandlerThread(video_source=img, callbacks=pytorch_function_display_callback(fun)).display()
|
||||
|
||||
Reference in New Issue
Block a user