Added more api to display function

This commit is contained in:
simleek
2019-09-29 22:12:00 -07:00
parent ff34f51ddd
commit 941751b007
5 changed files with 196 additions and 52 deletions

View File

@ -1,3 +1,3 @@
__version__ = '0.6.4'
__version__ = '0.6.5'
from .window_sub.cv_window_sub import display

12
cvpubsubs/serialize.py Normal file
View File

@ -0,0 +1,12 @@
import numpy as np
from collections import Hashable
def uid_for_source(video_source):
if len(str(video_source)) <= 1000:
uid = str(video_source)
elif isinstance(video_source, Hashable):
uid = str(hash(video_source))
else:
uid = str(hash(str(video_source)))
return uid

View File

@ -5,7 +5,7 @@ import numpy as np
from cvpubsubs.webcam_pub.pub_cam import pub_cam_thread
from cvpubsubs.webcam_pub.camctrl import CamCtrl
from cvpubsubs.window_sub.winctrl import WinCtrl
from cvpubsubs.serialize import uid_for_source
if False:
from typing import Union, Tuple, Any, Callable, List, Optional
@ -21,7 +21,7 @@ class VideoHandlerThread(threading.Thread):
"Thread for publishing frames from a video source."
def __init__(self, video_source=0, # type: Union[int, str, np.ndarray]
callbacks=(global_cv_display_callback,), # type: Union[List[FrameCallable], FrameCallable]
callbacks=(), # type: Union[List[FrameCallable], FrameCallable]
request_size=(-1, -1), # type: Tuple[int, int]
high_speed=True, # type: bool
fps_limit=240 # type: float
@ -40,13 +40,7 @@ class VideoHandlerThread(threading.Thread):
:type fps_limit: float
"""
super(VideoHandlerThread, self).__init__(target=self.loop, args=())
if isinstance(video_source, (int, str)):
self.cam_id = str(video_source)
elif isinstance(video_source, np.ndarray):
self.cam_id = str(hash(str(video_source)))
else:
raise TypeError(
"Only strings or ints representing cameras, or numpy arrays representing pictures supported.")
self.cam_id = uid_for_source(video_source)
self.video_source = video_source
if callable(callbacks):
self.callbacks = [callbacks]
@ -68,17 +62,22 @@ class VideoHandlerThread(threading.Thread):
while msg_owner != 'quit':
frame = sub_cam.get(blocking=True, timeout=1.0) # type: np.ndarray
if frame is not None:
frame_c = None
for c in self.callbacks:
try:
frame_c = c(frame, self.cam_id)
frame_c = c(frame)
except TypeError as te:
raise TypeError("Callback functions for cvpubsub need to accept two arguments: array and uid")
except Exception as e:
import traceback
self.exception_raised = e
frame = frame_c = self.exception_raised
CamCtrl.stop_cam(self.cam_id)
WinCtrl.quit()
self.exception_raised = e
frame_c = self.exception_raised
raise e
if frame_c is not None:
frame = frame_c
global_cv_display_callback(frame_c, self.cam_id)
else:
global_cv_display_callback(frame, self.cam_id)
msg_owner = sub_owner.get()
sub_owner.release()
sub_cam.release()

View File

@ -8,10 +8,11 @@ from cvpubsubs.webcam_pub.camctrl import CamCtrl
from cvpubsubs.webcam_pub.frame_handler import VideoHandlerThread
from localpubsub import NoData
from cvpubsubs.window_sub.mouse_event import MouseEvent
from cvpubsubs.serialize import uid_for_source
if False:
from typing import List, Union, Callable, Any
import numpy as np
from typing import List, Union, Callable, Any, Dict
import numpy as np
from cvpubsubs.callbacks import global_cv_display_callback
class SubscriberWindows(object):
@ -24,24 +25,33 @@ class SubscriberWindows(object):
video_sources=(0,), # type: List[Union[str,int]]
callbacks=(None,), # type: List[Callable[[List[np.ndarray]], Any]]
):
self.window_names = window_names
self.source_names = []
for name in video_sources:
if isinstance(name, np.ndarray):
self.source_names.append(str(hash(str(name))))
self.input_vid_global_names = [str(hash(str(name))) + "frame" for name in video_sources]
elif len(str(name)) <= 1000:
self.source_names.append(str(name))
self.input_vid_global_names = [str(name) + "frame" for name in video_sources]
else:
raise ValueError("Input window name too long.")
self.close_threads = None
self.frames = []
self.input_vid_global_names = []
self.window_names = []
self.input_cams = []
for name in video_sources:
self.add_source(name)
self.callbacks = callbacks
self.input_cams = video_sources
for name in self.window_names:
for name in window_names:
self.add_window(name)
def add_source(self, name):
uid = uid_for_source(name)
self.source_names.append(uid)
self.input_vid_global_names.append(uid + "frame")
self.input_cams.append(name)
def add_window(self, name):
self.window_names.append(name)
cv2.namedWindow(name + " (press ESC to quit)")
cv2.setMouseCallback(name + " (press ESC to quit)", self.handle_mouse)
def add_callback(self, callback):
self.callbacks.append(callback)
@staticmethod
def set_global_frame_dict(name, *args):
if len(str(name)) <= 1000:
@ -76,13 +86,13 @@ class SubscriberWindows(object):
mousey = MouseEvent(event, x, y, flags, param)
WinCtrl.mouse_pub.publish(mousey)
def _display_frames(self, frames, win_num):
def _display_frames(self, frames, win_num, ids=None):
if isinstance(frames, Exception):
raise frames
for f in range(len(frames)):
# detect nested:
if isinstance(frames[f], (list, tuple)) or frames[f].dtype.num == 17 or len(frames[f].shape) > 3:
win_num = self._display_frames(frames[f], win_num)
win_num = self._display_frames(frames[f], win_num, ids)
else:
cv2.imshow(self.window_names[win_num % len(self.window_names)] + " (press ESC to quit)", frames[f])
win_num += 1
@ -94,12 +104,39 @@ class SubscriberWindows(object):
if self.input_vid_global_names[i] in self.frame_dict and \
not isinstance(self.frame_dict[self.input_vid_global_names[i]], NoData):
if len(self.callbacks) > 0 and self.callbacks[i % len(self.callbacks)] is not None:
frames = self.callbacks[i % len(self.callbacks)](self.frame_dict[self.input_vid_global_names[i]])
self.frames = self.callbacks[i % len(self.callbacks)](
self.frame_dict[self.input_vid_global_names[i]])
else:
frames = self.frame_dict[self.input_vid_global_names[i]]
if isinstance(frames, np.ndarray) and len(frames.shape) <= 3:
frames = [frames]
win_num = self._display_frames(frames, win_num)
self.frames = self.frame_dict[self.input_vid_global_names[i]]
if isinstance(self.frames, np.ndarray) and len(self.frames.shape) <= 3:
self.frames = [self.frames]
win_num = self._display_frames(self.frames, win_num)
def update(self, arr=None, id=None):
if arr is not None and id is not None:
global_cv_display_callback(arr, id)
if id not in self.input_cams:
self.add_source(id)
self.add_window(id)
sub_cmd = WinCtrl.win_cmd_sub()
self.update_window_frames()
msg_cmd = sub_cmd.get()
key = self.handle_keys(cv2.waitKey(1))
return msg_cmd, key
def wait_for_init(self):
msg_cmd=""
key = ""
while msg_cmd != 'quit' and key != 'quit' and len(self.frames)==0:
msg_cmd, key = self.update()
def end(self):
if self.close_threads is not None:
for t in self.close_threads:
t.join()
def __exit__(self, exc_type, exc_val, exc_tb):
self.end()
# todo: figure out how to get the red x button to work. Try: https://stackoverflow.com/a/37881722/782170
def loop(self):
@ -107,22 +144,57 @@ class SubscriberWindows(object):
msg_cmd = ''
key = ''
while msg_cmd != 'quit' and key != 'quit':
self.update_window_frames()
msg_cmd = sub_cmd.get()
key = self.handle_keys(cv2.waitKey(1))
msg_cmd, key = self.update()
sub_cmd.release()
WinCtrl.quit(force_all_read=False)
self.__stop_all_cams()
def display(*vids, names=[]):
vid_threads = [VideoHandlerThread(v) for v in vids]
from cvpubsubs.callbacks import global_cv_display_callback
from threading import Thread
def display(*vids,
callbacks: Union[Dict[Any, Callable], List[Callable], Callable, None] = None,
window_names=[],
blocking=False):
vid_threads = []
if isinstance(callbacks, Dict):
for v in vids:
v_name = uid_for_source(v)
v_callbacks = []
if v_name in callbacks:
v_callbacks.extend(callbacks[v_name])
if v in callbacks:
v_callbacks.extend(callbacks[v])
vid_threads.append(VideoHandlerThread(v, callbacks=v_callbacks))
elif isinstance(callbacks, List):
for v in vids:
vid_threads.append(VideoHandlerThread(v, callbacks=callbacks))
elif isinstance(callbacks, Callable):
for v in vids:
vid_threads.append(VideoHandlerThread(v, callbacks=[callbacks]))
else:
for v in vids:
vid_threads.append(VideoHandlerThread(v))
for v in vid_threads:
v.start()
if len(names) == 0:
names = ["window {}".format(i) for i in range(len(vids))]
SubscriberWindows(window_names=names,
if len(window_names) == 0:
window_names = ["window {}".format(i) for i in range(len(vids))]
if blocking:
SubscriberWindows(window_names=window_names,
video_sources=vids
).loop()
for v in vid_threads:
v.join()
else:
s = SubscriberWindows(window_names=window_names,
video_sources=vids
)
s.close_threads = vid_threads
v_names = []
for v in vids:
v_name = uid_for_source(v)
v_names.append(v_name)
return s, v_names

61
tests/test_simple_api.py Normal file
View File

@ -0,0 +1,61 @@
import unittest as ut
class TestSubWin(ut.TestCase):
def test_display_numpy(self):
from cvpubsubs import display
import numpy as np
display(np.random.normal(0.5, .1, (500,500,3)))
def test_display_numpy_callback(self):
from cvpubsubs import display
import numpy as np
arr = np.random.normal(0.5, .1, (500, 500, 3))
def fix_arr_cv(arr_in):
arr_in[:] += np.random.normal(0.01, .005, (500, 500, 3))
arr_in%=1.0
display(arr, callbacks= fix_arr_cv, blocking=True)
def test_display_numpy_loop(self):
from cvpubsubs import display
import numpy as np
arr = np.random.normal(0.5, .1, (500, 500, 3))
displayer, ids = display(arr, blocking = False)
while True:
arr[:] += np.random.normal(0.01, .005, (500, 500, 3))
arr %= 1.0
displayer.update(arr, ids[0])
displayer.end()
def test_display_tensorflow(self):
from cvpubsubs import display
import numpy as np
from tensorflow.keras import layers, models
import tensorflow as tf
for gpu in tf.config.experimental.list_physical_devices("GPU"):
tf.compat.v2.config.experimental.set_memory_growth(gpu, True)
#tf.keras.backend.set_floatx("float16")
displayer, ids = display(0, blocking = False)
displayer.wait_for_init()
autoencoder = models.Sequential()
autoencoder.add(
layers.Conv2D(20, (3, 3), activation="sigmoid", input_shape=displayer.frames[0].shape)
)
autoencoder.add(layers.Conv2DTranspose(3, (3, 3), activation="sigmoid"))
autoencoder.compile(loss="mse", optimizer="adam")
while True:
grab = tf.convert_to_tensor(displayer.frame_dict['0frame'][np.newaxis, ...].astype(np.float32)/255.0)
autoencoder.fit(grab, grab, steps_per_epoch=1, epochs=1)
output_image = autoencoder.predict(grab, steps=1)
displayer.update((output_image[0]*255.0).astype(np.uint8), "uid for autoencoder output")