From 941751b00784b65aa7d5bd7eae783fe596fa93ca Mon Sep 17 00:00:00 2001 From: simleek Date: Sun, 29 Sep 2019 22:12:00 -0700 Subject: [PATCH] Added more api to display function --- cvpubsubs/__init__.py | 2 +- cvpubsubs/serialize.py | 12 +++ cvpubsubs/webcam_pub/frame_handler.py | 29 +++--- cvpubsubs/window_sub/cv_window_sub.py | 144 +++++++++++++++++++------- tests/test_simple_api.py | 61 +++++++++++ 5 files changed, 196 insertions(+), 52 deletions(-) create mode 100644 cvpubsubs/serialize.py create mode 100644 tests/test_simple_api.py diff --git a/cvpubsubs/__init__.py b/cvpubsubs/__init__.py index 2abf7d7..b2741c7 100644 --- a/cvpubsubs/__init__.py +++ b/cvpubsubs/__init__.py @@ -1,3 +1,3 @@ -__version__ = '0.6.4' +__version__ = '0.6.5' from .window_sub.cv_window_sub import display diff --git a/cvpubsubs/serialize.py b/cvpubsubs/serialize.py new file mode 100644 index 0000000..0db2aab --- /dev/null +++ b/cvpubsubs/serialize.py @@ -0,0 +1,12 @@ +import numpy as np +from collections import Hashable + + +def uid_for_source(video_source): + if len(str(video_source)) <= 1000: + uid = str(video_source) + elif isinstance(video_source, Hashable): + uid = str(hash(video_source)) + else: + uid = str(hash(str(video_source))) + return uid diff --git a/cvpubsubs/webcam_pub/frame_handler.py b/cvpubsubs/webcam_pub/frame_handler.py index c218a82..84f02c7 100644 --- a/cvpubsubs/webcam_pub/frame_handler.py +++ b/cvpubsubs/webcam_pub/frame_handler.py @@ -5,7 +5,7 @@ import numpy as np from cvpubsubs.webcam_pub.pub_cam import pub_cam_thread from cvpubsubs.webcam_pub.camctrl import CamCtrl from cvpubsubs.window_sub.winctrl import WinCtrl - +from cvpubsubs.serialize import uid_for_source if False: from typing import Union, Tuple, Any, Callable, List, Optional @@ -21,7 +21,7 @@ class VideoHandlerThread(threading.Thread): "Thread for publishing frames from a video source." def __init__(self, video_source=0, # type: Union[int, str, np.ndarray] - callbacks=(global_cv_display_callback,), # type: Union[List[FrameCallable], FrameCallable] + callbacks=(), # type: Union[List[FrameCallable], FrameCallable] request_size=(-1, -1), # type: Tuple[int, int] high_speed=True, # type: bool fps_limit=240 # type: float @@ -40,13 +40,7 @@ class VideoHandlerThread(threading.Thread): :type fps_limit: float """ super(VideoHandlerThread, self).__init__(target=self.loop, args=()) - if isinstance(video_source, (int, str)): - self.cam_id = str(video_source) - elif isinstance(video_source, np.ndarray): - self.cam_id = str(hash(str(video_source))) - else: - raise TypeError( - "Only strings or ints representing cameras, or numpy arrays representing pictures supported.") + self.cam_id = uid_for_source(video_source) self.video_source = video_source if callable(callbacks): self.callbacks = [callbacks] @@ -68,17 +62,22 @@ class VideoHandlerThread(threading.Thread): while msg_owner != 'quit': frame = sub_cam.get(blocking=True, timeout=1.0) # type: np.ndarray if frame is not None: + frame_c = None for c in self.callbacks: try: - frame_c = c(frame, self.cam_id) + frame_c = c(frame) + except TypeError as te: + raise TypeError("Callback functions for cvpubsub need to accept two arguments: array and uid") except Exception as e: - import traceback + self.exception_raised = e + frame = frame_c = self.exception_raised CamCtrl.stop_cam(self.cam_id) WinCtrl.quit() - self.exception_raised = e - frame_c = self.exception_raised - if frame_c is not None: - frame = frame_c + raise e + if frame_c is not None: + global_cv_display_callback(frame_c, self.cam_id) + else: + global_cv_display_callback(frame, self.cam_id) msg_owner = sub_owner.get() sub_owner.release() sub_cam.release() diff --git a/cvpubsubs/window_sub/cv_window_sub.py b/cvpubsubs/window_sub/cv_window_sub.py index 4b30e43..7c32e4b 100644 --- a/cvpubsubs/window_sub/cv_window_sub.py +++ b/cvpubsubs/window_sub/cv_window_sub.py @@ -8,10 +8,11 @@ from cvpubsubs.webcam_pub.camctrl import CamCtrl from cvpubsubs.webcam_pub.frame_handler import VideoHandlerThread from localpubsub import NoData from cvpubsubs.window_sub.mouse_event import MouseEvent +from cvpubsubs.serialize import uid_for_source -if False: - from typing import List, Union, Callable, Any - import numpy as np +from typing import List, Union, Callable, Any, Dict +import numpy as np +from cvpubsubs.callbacks import global_cv_display_callback class SubscriberWindows(object): @@ -24,23 +25,32 @@ class SubscriberWindows(object): video_sources=(0,), # type: List[Union[str,int]] callbacks=(None,), # type: List[Callable[[List[np.ndarray]], Any]] ): - self.window_names = window_names self.source_names = [] - for name in video_sources: - if isinstance(name, np.ndarray): - self.source_names.append(str(hash(str(name)))) - self.input_vid_global_names = [str(hash(str(name))) + "frame" for name in video_sources] - elif len(str(name)) <= 1000: - self.source_names.append(str(name)) - self.input_vid_global_names = [str(name) + "frame" for name in video_sources] - else: - raise ValueError("Input window name too long.") + self.close_threads = None + self.frames = [] + self.input_vid_global_names = [] + self.window_names = [] + self.input_cams = [] + for name in video_sources: + self.add_source(name) self.callbacks = callbacks - self.input_cams = video_sources - for name in self.window_names: - cv2.namedWindow(name + " (press ESC to quit)") - cv2.setMouseCallback(name + " (press ESC to quit)", self.handle_mouse) + for name in window_names: + self.add_window(name) + + def add_source(self, name): + uid = uid_for_source(name) + self.source_names.append(uid) + self.input_vid_global_names.append(uid + "frame") + self.input_cams.append(name) + + def add_window(self, name): + self.window_names.append(name) + cv2.namedWindow(name + " (press ESC to quit)") + cv2.setMouseCallback(name + " (press ESC to quit)", self.handle_mouse) + + def add_callback(self, callback): + self.callbacks.append(callback) @staticmethod def set_global_frame_dict(name, *args): @@ -76,13 +86,13 @@ class SubscriberWindows(object): mousey = MouseEvent(event, x, y, flags, param) WinCtrl.mouse_pub.publish(mousey) - def _display_frames(self, frames, win_num): + def _display_frames(self, frames, win_num, ids=None): if isinstance(frames, Exception): raise frames for f in range(len(frames)): # detect nested: if isinstance(frames[f], (list, tuple)) or frames[f].dtype.num == 17 or len(frames[f].shape) > 3: - win_num = self._display_frames(frames[f], win_num) + win_num = self._display_frames(frames[f], win_num, ids) else: cv2.imshow(self.window_names[win_num % len(self.window_names)] + " (press ESC to quit)", frames[f]) win_num += 1 @@ -94,12 +104,39 @@ class SubscriberWindows(object): if self.input_vid_global_names[i] in self.frame_dict and \ not isinstance(self.frame_dict[self.input_vid_global_names[i]], NoData): if len(self.callbacks) > 0 and self.callbacks[i % len(self.callbacks)] is not None: - frames = self.callbacks[i % len(self.callbacks)](self.frame_dict[self.input_vid_global_names[i]]) + self.frames = self.callbacks[i % len(self.callbacks)]( + self.frame_dict[self.input_vid_global_names[i]]) else: - frames = self.frame_dict[self.input_vid_global_names[i]] - if isinstance(frames, np.ndarray) and len(frames.shape) <= 3: - frames = [frames] - win_num = self._display_frames(frames, win_num) + self.frames = self.frame_dict[self.input_vid_global_names[i]] + if isinstance(self.frames, np.ndarray) and len(self.frames.shape) <= 3: + self.frames = [self.frames] + win_num = self._display_frames(self.frames, win_num) + + def update(self, arr=None, id=None): + if arr is not None and id is not None: + global_cv_display_callback(arr, id) + if id not in self.input_cams: + self.add_source(id) + self.add_window(id) + sub_cmd = WinCtrl.win_cmd_sub() + self.update_window_frames() + msg_cmd = sub_cmd.get() + key = self.handle_keys(cv2.waitKey(1)) + return msg_cmd, key + + def wait_for_init(self): + msg_cmd="" + key = "" + while msg_cmd != 'quit' and key != 'quit' and len(self.frames)==0: + msg_cmd, key = self.update() + + def end(self): + if self.close_threads is not None: + for t in self.close_threads: + t.join() + + def __exit__(self, exc_type, exc_val, exc_tb): + self.end() # todo: figure out how to get the red x button to work. Try: https://stackoverflow.com/a/37881722/782170 def loop(self): @@ -107,22 +144,57 @@ class SubscriberWindows(object): msg_cmd = '' key = '' while msg_cmd != 'quit' and key != 'quit': - self.update_window_frames() - msg_cmd = sub_cmd.get() - key = self.handle_keys(cv2.waitKey(1)) + msg_cmd, key = self.update() sub_cmd.release() WinCtrl.quit(force_all_read=False) self.__stop_all_cams() -def display(*vids, names=[]): - vid_threads = [VideoHandlerThread(v) for v in vids] +from cvpubsubs.callbacks import global_cv_display_callback + +from threading import Thread + + +def display(*vids, + callbacks: Union[Dict[Any, Callable], List[Callable], Callable, None] = None, + window_names=[], + blocking=False): + vid_threads = [] + if isinstance(callbacks, Dict): + for v in vids: + v_name = uid_for_source(v) + v_callbacks = [] + if v_name in callbacks: + v_callbacks.extend(callbacks[v_name]) + if v in callbacks: + v_callbacks.extend(callbacks[v]) + vid_threads.append(VideoHandlerThread(v, callbacks=v_callbacks)) + elif isinstance(callbacks, List): + for v in vids: + vid_threads.append(VideoHandlerThread(v, callbacks=callbacks)) + elif isinstance(callbacks, Callable): + for v in vids: + vid_threads.append(VideoHandlerThread(v, callbacks=[callbacks])) + else: + for v in vids: + vid_threads.append(VideoHandlerThread(v)) for v in vid_threads: v.start() - if len(names) == 0: - names = ["window {}".format(i) for i in range(len(vids))] - SubscriberWindows(window_names=names, - video_sources=vids - ).loop() - for v in vid_threads: - v.join() + if len(window_names) == 0: + window_names = ["window {}".format(i) for i in range(len(vids))] + if blocking: + SubscriberWindows(window_names=window_names, + video_sources=vids + ).loop() + for v in vid_threads: + v.join() + else: + s = SubscriberWindows(window_names=window_names, + video_sources=vids + ) + s.close_threads = vid_threads + v_names = [] + for v in vids: + v_name = uid_for_source(v) + v_names.append(v_name) + return s, v_names diff --git a/tests/test_simple_api.py b/tests/test_simple_api.py new file mode 100644 index 0000000..b8893d3 --- /dev/null +++ b/tests/test_simple_api.py @@ -0,0 +1,61 @@ +import unittest as ut + +class TestSubWin(ut.TestCase): + + def test_display_numpy(self): + from cvpubsubs import display + import numpy as np + + display(np.random.normal(0.5, .1, (500,500,3))) + + def test_display_numpy_callback(self): + from cvpubsubs import display + import numpy as np + + arr = np.random.normal(0.5, .1, (500, 500, 3)) + + def fix_arr_cv(arr_in): + arr_in[:] += np.random.normal(0.01, .005, (500, 500, 3)) + arr_in%=1.0 + + display(arr, callbacks= fix_arr_cv, blocking=True) + + def test_display_numpy_loop(self): + from cvpubsubs import display + import numpy as np + + arr = np.random.normal(0.5, .1, (500, 500, 3)) + + displayer, ids = display(arr, blocking = False) + + while True: + arr[:] += np.random.normal(0.01, .005, (500, 500, 3)) + arr %= 1.0 + displayer.update(arr, ids[0]) + displayer.end() + + def test_display_tensorflow(self): + from cvpubsubs import display + import numpy as np + from tensorflow.keras import layers, models + import tensorflow as tf + + for gpu in tf.config.experimental.list_physical_devices("GPU"): + tf.compat.v2.config.experimental.set_memory_growth(gpu, True) + #tf.keras.backend.set_floatx("float16") + + displayer, ids = display(0, blocking = False) + displayer.wait_for_init() + autoencoder = models.Sequential() + autoencoder.add( + layers.Conv2D(20, (3, 3), activation="sigmoid", input_shape=displayer.frames[0].shape) + ) + autoencoder.add(layers.Conv2DTranspose(3, (3, 3), activation="sigmoid")) + + autoencoder.compile(loss="mse", optimizer="adam") + + while True: + grab = tf.convert_to_tensor(displayer.frame_dict['0frame'][np.newaxis, ...].astype(np.float32)/255.0) + autoencoder.fit(grab, grab, steps_per_epoch=1, epochs=1) + output_image = autoencoder.predict(grab, steps=1) + displayer.update((output_image[0]*255.0).astype(np.uint8), "uid for autoencoder output")