From 4586e9743f6bcd007a5dc157b63d8a2152dbec09 Mon Sep 17 00:00:00 2001 From: SimLeek Date: Sun, 24 Feb 2019 21:04:35 -0700 Subject: [PATCH] callbacks: Added pytorch function display and pytorch conway life example. Still need to test coords. --- cvpubsubs/webcam_pub/callbacks.py | 75 +++++++++++++++++++++++++++ cvpubsubs/window_sub/cv_window_sub.py | 4 +- tests/test_sub_win.py | 36 +++++++++++++ 3 files changed, 113 insertions(+), 2 deletions(-) diff --git a/cvpubsubs/webcam_pub/callbacks.py b/cvpubsubs/webcam_pub/callbacks.py index 767aa0f..ece4a73 100644 --- a/cvpubsubs/webcam_pub/callbacks.py +++ b/cvpubsubs/webcam_pub/callbacks.py @@ -65,3 +65,78 @@ class function_display_callback(object): # NOSONAR def __call__(self, *args, **kwargs): return self.inner_function(self, *args, **kwargs) + + +class pytorch_function_display_callback(object): # NOSONAR + def __init__(self, display_function, finish_function=None): + """Used for running arbitrary functions on pixels. + + >>> import random + >>> import torch + >>> from cvpubsubs.webcam_pub import VideoHandlerThread + >>> img = np.zeros((300, 300, 3)) + >>> def fun(array, coords, finished): + ... rgb = torch.empty(array.shape).uniform_(0,1).type(torch.DoubleTensor).to(array.device)/200.0 + ... array[coords] = (array[coords] + rgb[coords])%1.0 + >>> VideoHandlerThread(video_source=img, callbacks=pytorch_function_display_callback(fun)).display() + + thanks: https://medium.com/@awildtaber/building-a-rendering-engine-in-tensorflow-262438b2e062 + + :param display_function: + :param finish_function: + """ + + import torch + from torch.autograd import Variable + + self.looping = True + self.first_call = True + + def _run_finisher(self, frame, finished, *args, **kwargs): + if not callable(finish_function): + WinCtrl.quit() + else: + finished = finish_function(frame, Ellipsis, finished, *args, **kwargs) + if finished: + WinCtrl.quit() + + def _setup(self, frame, cam_id, *args, **kwargs): + + if "device" in kwargs: + self.device = torch.device(kwargs["device"]) + else: + if torch.cuda.is_available(): + self.device = torch.device("cuda") + else: + self.device = torch.device("cpu") + + self.min_bounds = [0 for _ in frame.shape] + self.max_bounds = list(frame.shape) + grid_slices = [slice(self.min_bounds[d], self.max_bounds[d]) for d in range(len(frame.shape))] + self.space_grid = np.mgrid[grid_slices] + x_tens = torch.LongTensor(self.space_grid[0, ...]).to(self.device) + y_tens = torch.LongTensor(self.space_grid[1, ...]).to(self.device) + c_tens = torch.LongTensor(self.space_grid[2, ...]).to(self.device) + self.x = Variable(x_tens, requires_grad=False) + self.y = Variable(y_tens, requires_grad=False) + self.c = Variable(c_tens, requires_grad=False) + + def _display_internal(self, frame, cam_id, *args, **kwargs): + finished = True + if self.first_call: + # return to display initial frame + _setup(self, frame, finished, *args, **kwargs) + self.first_call = False + return + if self.looping: + tor_frame = torch.from_numpy(frame).to(self.device) + finished = display_function(tor_frame, (self.x, self.y, self.c), finished, *args, **kwargs) + frame[...] = tor_frame.cpu().numpy()[...] + if finished: + self.looping = False + _run_finisher(self, frame, finished, *args, **kwargs) + + self.inner_function = _display_internal + + def __call__(self, *args, **kwargs): + return self.inner_function(self, *args, **kwargs) diff --git a/cvpubsubs/window_sub/cv_window_sub.py b/cvpubsubs/window_sub/cv_window_sub.py index 08b1562..6e0c4cf 100644 --- a/cvpubsubs/window_sub/cv_window_sub.py +++ b/cvpubsubs/window_sub/cv_window_sub.py @@ -40,9 +40,9 @@ class SubscriberWindows(object): @staticmethod def set_global_frame_dict(name, *args): if len(str(name)) <= 1000: - SubscriberWindows.frame_dict[str(name) + "frame"] = [*args] + SubscriberWindows.frame_dict[str(name) + "frame"] = list(args) elif isinstance(name, np.ndarray): - SubscriberWindows.frame_dict[str(hash(str(name))) + "frame"] = [*args] + SubscriberWindows.frame_dict[str(hash(str(name))) + "frame"] = list(args) else: raise ValueError("Input window name too long.") diff --git a/tests/test_sub_win.py b/tests/test_sub_win.py index 9fac5e6..addfedc 100644 --- a/tests/test_sub_win.py +++ b/tests/test_sub_win.py @@ -128,3 +128,39 @@ class TestSubWin(ut.TestCase): array[coords] = 1.0 VideoHandlerThread(video_source=img, callbacks=function_display_callback(conway)).display() + + def test_conway_life_pytorch(self): + import torch + from torch import functional as F + from cvpubsubs.webcam_pub import VideoHandlerThread + from cvpubsubs.webcam_pub.callbacks import pytorch_function_display_callback + + img = np.ones((600, 800, 1)) + img[10:590, 10:790, :] = 0 + + def fun(frame, coords, finished): + array = frame + neighbor_weights = torch.ones(torch.Size([3, 3])) + neighbor_weights[1, 1, ...] = 0 + neighbor_weights = torch.Tensor(neighbor_weights).type_as(array).to(array.device) + neighbor_weights = neighbor_weights.squeeze()[None, None, :, :] + array = array.permute(2, 1, 0)[None, ...] + neighbors = torch.nn.functional.conv2d(array, neighbor_weights, stride=1, padding=1) + live_array = torch.where((neighbors < 2) | (neighbors > 3), + torch.zeros_like(array), + torch.where((2 <= neighbors) & (neighbors <= 3), + torch.ones_like(array), + array + ) + ) + dead_array = torch.where(neighbors == 3, + torch.ones_like(array), + array) + array = torch.where(array == 1.0, + live_array, + dead_array + ) + array = array.squeeze().permute(1, 0)[...,None] + frame[...] = array[...] + + VideoHandlerThread(video_source=img, callbacks=pytorch_function_display_callback(fun)).display()