diff --git a/displayarray/__main__.py b/displayarray/__main__.py new file mode 100644 index 0000000..0529801 --- /dev/null +++ b/displayarray/__main__.py @@ -0,0 +1,76 @@ +""" +DisplayArray. +Display NumPy arrays. + +Usage: + displayarray (-w | -v | -t [,dtype])... [-m ] + displayarray -h + displayarray --version + + +Options: + -h, --help Show this help text. + --version Show version number. + -w , --webcam= Display video from a webcam. + -v , --video= Display frames from a video file. + -t , --topic= Display frames from a topic using the chosen message broker. + -m , --message-backend Choose message broker backend. [Default: ROS] + Currently supported: ROS, ZeroMQ + --ros Use ROS as the backend message broker. + --zeromq Use ZeroMQ as the backend message broker. +""" + +from docopt import docopt + + +def main(argv=None): + arguments = docopt(__doc__, argv=argv) + if arguments["--version"]: + from displayarray import __version__ + + print(f"DisplayArray V{__version__}") + return + from displayarray import display + + vids = [int(w) for w in arguments["--webcam"]] + arguments["--video"] + v_disps = None + if vids: + v_disps = display(*vids, blocking=False) + from displayarray.frame.frame_updater import read_updates_ros, read_updates_zero_mq + + topics = arguments["--topic"] + topics_split = [t.split(",") for t in topics] + d = display() + + async def msg_recv(): + nonlocal d + while True: + if arguments["--message-backend"] == "ROS": + async for v_name, frame in read_updates_ros( + [t for t, d in topics_split], [d for t, d in topics_split] + ): + d.update(arr=frame, id=v_name) + if arguments["--message-backend"] == "ZeroMQ": + async for v_name, frame in read_updates_zero_mq( + *[bytes(t, encoding="ascii") for t in topics] + ): + d.update(arr=frame, id=v_name) + + async def update_vids(): + while True: + if v_disps: + v_disps.update() + await asyncio.sleep(0) + + async def runner(): + await asyncio.wait([msg_recv(), update_vids()]) + + loop = asyncio.get_event_loop() + loop.run_until_complete(runner()) + loop.close() + + +if __name__ == "__main__": + import asyncio + + main() diff --git a/displayarray/frame/frame_publishing.py b/displayarray/frame/frame_publishing.py index 1ea3899..83f4565 100644 --- a/displayarray/frame/frame_publishing.py +++ b/displayarray/frame/frame_publishing.py @@ -1,18 +1,19 @@ import sys import threading import time +import asyncio import cv2 import numpy as np -from displayarray import read_updates from displayarray.frame import subscriber_dictionary -from displayarray.frame.frame_updater import FrameCallable from .np_to_opencv import NpCam from displayarray.uid import uid_for_source from typing import Union, Tuple, Optional, Dict, Any, List, Callable +FrameCallable = Callable[[np.ndarray], Optional[np.ndarray]] + def pub_cam_loop( cam_id: Union[int, str, np.ndarray], @@ -97,7 +98,7 @@ def pub_cam_thread( return t -def publish_updates_zero_mq( +async def publish_updates_zero_mq( *vids, callbacks: Optional[ Union[Dict[Any, FrameCallable], List[FrameCallable], FrameCallable] @@ -105,30 +106,47 @@ def publish_updates_zero_mq( fps_limit=float("inf"), size=(-1, -1), end_callback: Callable[[], bool] = lambda: False, - blocking=True, + blocking=False, publishing_address="tcp://127.0.0.1:5600", - prepend_topic="" + prepend_topic="", + flags=0, + copy=True, + track=False ): import zmq + from displayarray import read_updates ctx = zmq.Context() s = ctx.socket(zmq.PUB) s.bind(publishing_address) + if not blocking: + flags |= zmq.NOBLOCK + try: for v in read_updates(vids, callbacks, fps_limit, size, end_callback, blocking): - for vid_name, frame in v.items(): - s.send_multipart([prepend_topic + vid_name, frame]) + if v: + for vid_name, frame in v.items(): + md = dict( + dtype=str(frame.dtype), + shape=frame.shape, + name=prepend_topic + vid_name, + ) + s.send_json(md, flags | zmq.SNDMORE) + s.send(frame, flags, copy=copy, track=track) + if fps_limit: + await asyncio.sleep(1.0 / fps_limit) + else: + await asyncio.sleep(0) except KeyboardInterrupt: pass finally: vid_names = [uid_for_source(name) for name in vids] for v in vid_names: subscriber_dictionary.stop_cam(v) - sys.exit() -def publish_updates_ros( +async def publish_updates_ros( *vids, callbacks: Optional[ Union[Dict[Any, FrameCallable], List[FrameCallable], FrameCallable] @@ -136,54 +154,64 @@ def publish_updates_ros( fps_limit=float("inf"), size=(-1, -1), end_callback: Callable[[], bool] = lambda: False, - blocking=True, - node_name="displayarray" + blocking=False, + node_name="displayarray", + publisher_name="npy", + rate_hz=None, + dtype=None ): - # mostly copied from: - # https://answers.ros.org/question/289557/custom-message-including-numpy-arrays/?answer=321122#post-id-321122 - import rospy - from std_msgs.msg import Float32MultiArray, MultiArrayDimension, UInt8MultiArray + from rospy.numpy_msg import numpy_msg + import std_msgs.msg + from displayarray import read_updates - vid_names = [uid_for_source(name) for name in vids] + def get_msg_type(dtype): + if dtype is None: + msg_type = { + np.float32: std_msgs.msg.Float32(), + np.float64: std_msgs.msg.Float64(), + np.bool: std_msgs.msg.Bool(), + np.char: std_msgs.msg.Char(), + np.int16: std_msgs.msg.Int16(), + np.int32: std_msgs.msg.Int32(), + np.int64: std_msgs.msg.Int64(), + np.str: std_msgs.msg.String(), + np.uint16: std_msgs.msg.UInt16(), + np.uint32: std_msgs.msg.UInt32(), + np.uint64: std_msgs.msg.UInt64(), + np.uint8: std_msgs.msg.UInt8(), + }[dtype] + else: + msg_type = ( + dtype + ) # allow users to use their own custom messages in numpy arrays + return msg_type + + publishers = {} rospy.init_node(node_name, anonymous=True) - pubs = { - vid_name: rospy.Publisher(vid_name, Float32MultiArray, queue_size=1) - for vid_name in vid_names - } try: for v in read_updates(vids, callbacks, fps_limit, size, end_callback, blocking): - if rospy.is_shutdown(): - print("ROS is shutdown.") - break - for vid_name, frame in v.items(): - if frame.dtype == np.uint8: - frame_msg = UInt8MultiArray() - elif frame.dtype == np.float32: - frame_msg = Float32MultiArray() - else: - raise NotImplementedError( - "Only uint8 and float32 types supported currently." - ) - frame_msg.layout.dim = [] - dims = np.array(frame.shape) - frame_size = dims.prod() / float( - frame.nbytes - ) # this is my attempt to normalize the strides size depending on .nbytes. not sure this is correct - - for i in range(0, len(dims)): # should be rather fast. - # gets the num. of dims of nparray to construct the message - frame_msg.layout.dim.append(MultiArrayDimension()) - frame_msg.layout.dim[i].size = dims[i] - frame_msg.layout.dim[i].stride = dims[i:].prod() / frame_size - frame_msg.layout.dim[i].label = "dim_%d" % i - - frame_msg.data = np.frombuffer(frame.tobytes()) - pubs[vid_name].publish(frame_msg) + if v: + if rospy.is_shutdown(): + break + for vid_name, frame in v.items(): + if vid_name not in publishers: + dty = frame.dtype if dtype is None else dtype + publishers[vid_name] = rospy.Publisher( + publisher_name + vid_name, + numpy_msg(get_msg_type(dty)), + queue_size=10, + ) + publishers[vid_name].publish(frame) + if rate_hz: + await asyncio.sleep(1.0 / rate_hz) + else: + await asyncio.sleep(0) except KeyboardInterrupt: pass finally: vid_names = [uid_for_source(name) for name in vids] for v in vid_names: subscriber_dictionary.stop_cam(v) - sys.exit() + if rospy.core.is_shutdown(): + raise rospy.exceptions.ROSInterruptException("rospy shutdown") diff --git a/displayarray/frame/frame_updater.py b/displayarray/frame/frame_updater.py index fa602dd..f6746b5 100644 --- a/displayarray/frame/frame_updater.py +++ b/displayarray/frame/frame_updater.py @@ -1,4 +1,5 @@ import threading +import asyncio from typing import Union, Tuple, Any, Callable, List, Optional, Dict import numpy as np @@ -111,7 +112,7 @@ class FrameUpdater(threading.Thread): raise self.exception_raised -def read_updates( +async def read_updates( *vids, callbacks: Optional[ Union[Dict[Any, FrameCallable], List[FrameCallable], FrameCallable] @@ -158,7 +159,136 @@ def read_updates( dict_was_updated = True if dict_was_updated or not blocking: yield vid_update_dict + await asyncio.sleep(0) for v in vid_names: subscriber_dictionary.stop_cam(v) for v in vid_threads: v.join() + + +async def read_updates_zero_mq( + *topic_names, + address="tcp://127.0.0.1:5600", + flags=0, + copy=True, + track=False, + blocking=False, + end_callback: Callable[[Any], bool] = lambda: False, +): + import zmq + + ctx = zmq.Context() + s = ctx.socket(zmq.SUB) + s.connect(address) + if not blocking: + flags |= zmq.NOBLOCK + + for topic in topic_names: + s.setsockopt(zmq.SUBSCRIBE, topic) + cb_val = False + while not cb_val: + try: + md = s.recv_json(flags=flags) + msg = s.recv(flags=flags, copy=copy, track=track) + buf = memoryview(msg) + arr = np.frombuffer(buf, dtype=md["dtype"]) + arr.reshape(md["shape"]) + name = md["name"] + cb_val = end_callback(md) + yield name, arr + except zmq.ZMQError as e: + if isinstance(e, zmq.Again): + pass # no messages to receive + else: + raise e + finally: + await asyncio.sleep(0) + + +async def read_updates_ros( + *topic_names, + dtypes=None, + listener_node_name=None, + poll_rate_hz=None, + end_callback: Callable[[Any], bool] = lambda: False, +): + import rospy + from rospy.numpy_msg import numpy_msg + from rospy.client import _WFM + import std_msgs.msg + import random + import string + + if dtypes is None: + raise ValueError( + "ROS cannot automatically determine the types of incoming numpy arrays. Please specify.\n" + "Options are: \n" + "\tfloat32, float64, bool, char, int16, " + "\tint32, int64, str, uint16, uint32, uint64, uint8" + ) + + if listener_node_name is None: + # https://stackoverflow.com/a/2257449 + listener_node_name = "".join( + random.choices(string.ascii_uppercase + string.digits, k=8) + ) + + rospy.init_node(listener_node_name) + + msg_types = [ + { + np.float32: std_msgs.msg.Float32(), + np.float64: std_msgs.msg.Float64(), + np.bool: std_msgs.msg.Bool(), + np.char: std_msgs.msg.Char(), + np.int16: std_msgs.msg.Int16(), + np.int32: std_msgs.msg.Int32(), + np.int64: std_msgs.msg.Int64(), + np.str: std_msgs.msg.String(), + np.uint16: std_msgs.msg.UInt16(), + np.uint32: std_msgs.msg.UInt32(), + np.uint64: std_msgs.msg.UInt64(), + np.uint8: std_msgs.msg.UInt8(), + "float32": std_msgs.msg.Float32(), + "float64": std_msgs.msg.Float64(), + "bool": std_msgs.msg.Bool(), + "char": std_msgs.msg.Char(), + "int16": std_msgs.msg.Int16(), + "int32": std_msgs.msg.Int32(), + "int64": std_msgs.msg.Int64(), + "str": std_msgs.msg.String(), + "uint16": std_msgs.msg.UInt16(), + "uint32": std_msgs.msg.UInt32(), + "uint64": std_msgs.msg.UInt64(), + "uint8": std_msgs.msg.UInt8(), + }.get( + dtype, dtype + ) # allow users to use their own custom messages in numpy arrays + for dtype in dtypes + ] + s = None + cb_val = False + try: + wfms = {t: _WFM() for t in topic_names} + s = { + t: rospy.Subscriber(t, numpy_msg(msg_types[i]), wfms[t].cb) + for i, t in enumerate(topic_names) + } + while not cb_val: + while not rospy.core.is_shutdown(): + if poll_rate_hz: + await asyncio.sleep(1.0 / poll_rate_hz) + else: + await asyncio.sleep(0) + for t, w in wfms.items(): + if w.msg is not None: + yield t, w.msg + cb_val = end_callback(w.msg) + w.msg = None + except KeyboardInterrupt: + pass + finally: + if s is not None: + s.unregister() + if rospy.core.is_shutdown(): + raise rospy.exceptions.ROSInterruptException("rospy shutdown") diff --git a/pyproject.toml b/pyproject.toml index 57c4831..6d793e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,9 +1,9 @@ [metadata] name = 'displayarray' -version = '0.0.2' -description = 'Simple tool for working with multiple streams from OpenCV.' +version = '0.0.3' +description = 'Tool for displaying numpy arrays.' author = 'SimLeek' -author_email = 'josh.miklos@gmail.com' +author_email = 'simulator.leek@gmail.com' license = 'MIT/Apache-2.0' url = 'https://github.com/simleek/displayarray' @@ -15,3 +15,6 @@ requires = ['setuptools', 'wheel'] [tool.hatch.commands] prerelease = 'hatch build' + +[scripts] +displayarray = "displayarray.__main__:main"