Added rospy coroutines. Added command line interface.
This commit is contained in:
@@ -0,0 +1,76 @@
|
||||
"""
|
||||
DisplayArray.
|
||||
Display NumPy arrays.
|
||||
|
||||
Usage:
|
||||
displayarray (-w <webcam-number> | -v <video-filename> | -t <topic-name>[,dtype])... [-m <msg-backend>]
|
||||
displayarray -h
|
||||
displayarray --version
|
||||
|
||||
|
||||
Options:
|
||||
-h, --help Show this help text.
|
||||
--version Show version number.
|
||||
-w <webcam-number>, --webcam=<webcam-number> Display video from a webcam.
|
||||
-v <video-filename>, --video=<video-filename> Display frames from a video file.
|
||||
-t <topic-name>, --topic=<topic-name> Display frames from a topic using the chosen message broker.
|
||||
-m <msg-backend>, --message-backend <msg-backend> Choose message broker backend. [Default: ROS]
|
||||
Currently supported: ROS, ZeroMQ
|
||||
--ros Use ROS as the backend message broker.
|
||||
--zeromq Use ZeroMQ as the backend message broker.
|
||||
"""
|
||||
|
||||
from docopt import docopt
|
||||
|
||||
|
||||
def main(argv=None):
|
||||
arguments = docopt(__doc__, argv=argv)
|
||||
if arguments["--version"]:
|
||||
from displayarray import __version__
|
||||
|
||||
print(f"DisplayArray V{__version__}")
|
||||
return
|
||||
from displayarray import display
|
||||
|
||||
vids = [int(w) for w in arguments["--webcam"]] + arguments["--video"]
|
||||
v_disps = None
|
||||
if vids:
|
||||
v_disps = display(*vids, blocking=False)
|
||||
from displayarray.frame.frame_updater import read_updates_ros, read_updates_zero_mq
|
||||
|
||||
topics = arguments["--topic"]
|
||||
topics_split = [t.split(",") for t in topics]
|
||||
d = display()
|
||||
|
||||
async def msg_recv():
|
||||
nonlocal d
|
||||
while True:
|
||||
if arguments["--message-backend"] == "ROS":
|
||||
async for v_name, frame in read_updates_ros(
|
||||
[t for t, d in topics_split], [d for t, d in topics_split]
|
||||
):
|
||||
d.update(arr=frame, id=v_name)
|
||||
if arguments["--message-backend"] == "ZeroMQ":
|
||||
async for v_name, frame in read_updates_zero_mq(
|
||||
*[bytes(t, encoding="ascii") for t in topics]
|
||||
):
|
||||
d.update(arr=frame, id=v_name)
|
||||
|
||||
async def update_vids():
|
||||
while True:
|
||||
if v_disps:
|
||||
v_disps.update()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
async def runner():
|
||||
await asyncio.wait([msg_recv(), update_vids()])
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(runner())
|
||||
loop.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
main()
|
||||
@@ -1,18 +1,19 @@
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import asyncio
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from displayarray import read_updates
|
||||
from displayarray.frame import subscriber_dictionary
|
||||
from displayarray.frame.frame_updater import FrameCallable
|
||||
from .np_to_opencv import NpCam
|
||||
from displayarray.uid import uid_for_source
|
||||
|
||||
from typing import Union, Tuple, Optional, Dict, Any, List, Callable
|
||||
|
||||
FrameCallable = Callable[[np.ndarray], Optional[np.ndarray]]
|
||||
|
||||
|
||||
def pub_cam_loop(
|
||||
cam_id: Union[int, str, np.ndarray],
|
||||
@@ -97,7 +98,7 @@ def pub_cam_thread(
|
||||
return t
|
||||
|
||||
|
||||
def publish_updates_zero_mq(
|
||||
async def publish_updates_zero_mq(
|
||||
*vids,
|
||||
callbacks: Optional[
|
||||
Union[Dict[Any, FrameCallable], List[FrameCallable], FrameCallable]
|
||||
@@ -105,30 +106,47 @@ def publish_updates_zero_mq(
|
||||
fps_limit=float("inf"),
|
||||
size=(-1, -1),
|
||||
end_callback: Callable[[], bool] = lambda: False,
|
||||
blocking=True,
|
||||
blocking=False,
|
||||
publishing_address="tcp://127.0.0.1:5600",
|
||||
prepend_topic=""
|
||||
prepend_topic="",
|
||||
flags=0,
|
||||
copy=True,
|
||||
track=False
|
||||
):
|
||||
import zmq
|
||||
from displayarray import read_updates
|
||||
|
||||
ctx = zmq.Context()
|
||||
s = ctx.socket(zmq.PUB)
|
||||
s.bind(publishing_address)
|
||||
|
||||
if not blocking:
|
||||
flags |= zmq.NOBLOCK
|
||||
|
||||
try:
|
||||
for v in read_updates(vids, callbacks, fps_limit, size, end_callback, blocking):
|
||||
for vid_name, frame in v.items():
|
||||
s.send_multipart([prepend_topic + vid_name, frame])
|
||||
if v:
|
||||
for vid_name, frame in v.items():
|
||||
md = dict(
|
||||
dtype=str(frame.dtype),
|
||||
shape=frame.shape,
|
||||
name=prepend_topic + vid_name,
|
||||
)
|
||||
s.send_json(md, flags | zmq.SNDMORE)
|
||||
s.send(frame, flags, copy=copy, track=track)
|
||||
if fps_limit:
|
||||
await asyncio.sleep(1.0 / fps_limit)
|
||||
else:
|
||||
await asyncio.sleep(0)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
vid_names = [uid_for_source(name) for name in vids]
|
||||
for v in vid_names:
|
||||
subscriber_dictionary.stop_cam(v)
|
||||
sys.exit()
|
||||
|
||||
|
||||
def publish_updates_ros(
|
||||
async def publish_updates_ros(
|
||||
*vids,
|
||||
callbacks: Optional[
|
||||
Union[Dict[Any, FrameCallable], List[FrameCallable], FrameCallable]
|
||||
@@ -136,54 +154,64 @@ def publish_updates_ros(
|
||||
fps_limit=float("inf"),
|
||||
size=(-1, -1),
|
||||
end_callback: Callable[[], bool] = lambda: False,
|
||||
blocking=True,
|
||||
node_name="displayarray"
|
||||
blocking=False,
|
||||
node_name="displayarray",
|
||||
publisher_name="npy",
|
||||
rate_hz=None,
|
||||
dtype=None
|
||||
):
|
||||
# mostly copied from:
|
||||
# https://answers.ros.org/question/289557/custom-message-including-numpy-arrays/?answer=321122#post-id-321122
|
||||
|
||||
import rospy
|
||||
from std_msgs.msg import Float32MultiArray, MultiArrayDimension, UInt8MultiArray
|
||||
from rospy.numpy_msg import numpy_msg
|
||||
import std_msgs.msg
|
||||
from displayarray import read_updates
|
||||
|
||||
vid_names = [uid_for_source(name) for name in vids]
|
||||
def get_msg_type(dtype):
|
||||
if dtype is None:
|
||||
msg_type = {
|
||||
np.float32: std_msgs.msg.Float32(),
|
||||
np.float64: std_msgs.msg.Float64(),
|
||||
np.bool: std_msgs.msg.Bool(),
|
||||
np.char: std_msgs.msg.Char(),
|
||||
np.int16: std_msgs.msg.Int16(),
|
||||
np.int32: std_msgs.msg.Int32(),
|
||||
np.int64: std_msgs.msg.Int64(),
|
||||
np.str: std_msgs.msg.String(),
|
||||
np.uint16: std_msgs.msg.UInt16(),
|
||||
np.uint32: std_msgs.msg.UInt32(),
|
||||
np.uint64: std_msgs.msg.UInt64(),
|
||||
np.uint8: std_msgs.msg.UInt8(),
|
||||
}[dtype]
|
||||
else:
|
||||
msg_type = (
|
||||
dtype
|
||||
) # allow users to use their own custom messages in numpy arrays
|
||||
return msg_type
|
||||
|
||||
publishers = {}
|
||||
rospy.init_node(node_name, anonymous=True)
|
||||
pubs = {
|
||||
vid_name: rospy.Publisher(vid_name, Float32MultiArray, queue_size=1)
|
||||
for vid_name in vid_names
|
||||
}
|
||||
try:
|
||||
for v in read_updates(vids, callbacks, fps_limit, size, end_callback, blocking):
|
||||
if rospy.is_shutdown():
|
||||
print("ROS is shutdown.")
|
||||
break
|
||||
for vid_name, frame in v.items():
|
||||
if frame.dtype == np.uint8:
|
||||
frame_msg = UInt8MultiArray()
|
||||
elif frame.dtype == np.float32:
|
||||
frame_msg = Float32MultiArray()
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Only uint8 and float32 types supported currently."
|
||||
)
|
||||
frame_msg.layout.dim = []
|
||||
dims = np.array(frame.shape)
|
||||
frame_size = dims.prod() / float(
|
||||
frame.nbytes
|
||||
) # this is my attempt to normalize the strides size depending on .nbytes. not sure this is correct
|
||||
|
||||
for i in range(0, len(dims)): # should be rather fast.
|
||||
# gets the num. of dims of nparray to construct the message
|
||||
frame_msg.layout.dim.append(MultiArrayDimension())
|
||||
frame_msg.layout.dim[i].size = dims[i]
|
||||
frame_msg.layout.dim[i].stride = dims[i:].prod() / frame_size
|
||||
frame_msg.layout.dim[i].label = "dim_%d" % i
|
||||
|
||||
frame_msg.data = np.frombuffer(frame.tobytes())
|
||||
pubs[vid_name].publish(frame_msg)
|
||||
if v:
|
||||
if rospy.is_shutdown():
|
||||
break
|
||||
for vid_name, frame in v.items():
|
||||
if vid_name not in publishers:
|
||||
dty = frame.dtype if dtype is None else dtype
|
||||
publishers[vid_name] = rospy.Publisher(
|
||||
publisher_name + vid_name,
|
||||
numpy_msg(get_msg_type(dty)),
|
||||
queue_size=10,
|
||||
)
|
||||
publishers[vid_name].publish(frame)
|
||||
if rate_hz:
|
||||
await asyncio.sleep(1.0 / rate_hz)
|
||||
else:
|
||||
await asyncio.sleep(0)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
vid_names = [uid_for_source(name) for name in vids]
|
||||
for v in vid_names:
|
||||
subscriber_dictionary.stop_cam(v)
|
||||
sys.exit()
|
||||
if rospy.core.is_shutdown():
|
||||
raise rospy.exceptions.ROSInterruptException("rospy shutdown")
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import threading
|
||||
import asyncio
|
||||
from typing import Union, Tuple, Any, Callable, List, Optional, Dict
|
||||
|
||||
import numpy as np
|
||||
@@ -111,7 +112,7 @@ class FrameUpdater(threading.Thread):
|
||||
raise self.exception_raised
|
||||
|
||||
|
||||
def read_updates(
|
||||
async def read_updates(
|
||||
*vids,
|
||||
callbacks: Optional[
|
||||
Union[Dict[Any, FrameCallable], List[FrameCallable], FrameCallable]
|
||||
@@ -158,7 +159,136 @@ def read_updates(
|
||||
dict_was_updated = True
|
||||
if dict_was_updated or not blocking:
|
||||
yield vid_update_dict
|
||||
await asyncio.sleep(0)
|
||||
for v in vid_names:
|
||||
subscriber_dictionary.stop_cam(v)
|
||||
for v in vid_threads:
|
||||
v.join()
|
||||
|
||||
|
||||
async def read_updates_zero_mq(
|
||||
*topic_names,
|
||||
address="tcp://127.0.0.1:5600",
|
||||
flags=0,
|
||||
copy=True,
|
||||
track=False,
|
||||
blocking=False,
|
||||
end_callback: Callable[[Any], bool] = lambda: False,
|
||||
):
|
||||
import zmq
|
||||
|
||||
ctx = zmq.Context()
|
||||
s = ctx.socket(zmq.SUB)
|
||||
s.connect(address)
|
||||
if not blocking:
|
||||
flags |= zmq.NOBLOCK
|
||||
|
||||
for topic in topic_names:
|
||||
s.setsockopt(zmq.SUBSCRIBE, topic)
|
||||
cb_val = False
|
||||
while not cb_val:
|
||||
try:
|
||||
md = s.recv_json(flags=flags)
|
||||
msg = s.recv(flags=flags, copy=copy, track=track)
|
||||
buf = memoryview(msg)
|
||||
arr = np.frombuffer(buf, dtype=md["dtype"])
|
||||
arr.reshape(md["shape"])
|
||||
name = md["name"]
|
||||
cb_val = end_callback(md)
|
||||
yield name, arr
|
||||
except zmq.ZMQError as e:
|
||||
if isinstance(e, zmq.Again):
|
||||
pass # no messages to receive
|
||||
else:
|
||||
raise e
|
||||
finally:
|
||||
await asyncio.sleep(0)
|
||||
|
||||
|
||||
async def read_updates_ros(
|
||||
*topic_names,
|
||||
dtypes=None,
|
||||
listener_node_name=None,
|
||||
poll_rate_hz=None,
|
||||
end_callback: Callable[[Any], bool] = lambda: False,
|
||||
):
|
||||
import rospy
|
||||
from rospy.numpy_msg import numpy_msg
|
||||
from rospy.client import _WFM
|
||||
import std_msgs.msg
|
||||
import random
|
||||
import string
|
||||
|
||||
if dtypes is None:
|
||||
raise ValueError(
|
||||
"ROS cannot automatically determine the types of incoming numpy arrays. Please specify.\n"
|
||||
"Options are: \n"
|
||||
"\tfloat32, float64, bool, char, int16, "
|
||||
"\tint32, int64, str, uint16, uint32, uint64, uint8"
|
||||
)
|
||||
|
||||
if listener_node_name is None:
|
||||
# https://stackoverflow.com/a/2257449
|
||||
listener_node_name = "".join(
|
||||
random.choices(string.ascii_uppercase + string.digits, k=8)
|
||||
)
|
||||
|
||||
rospy.init_node(listener_node_name)
|
||||
|
||||
msg_types = [
|
||||
{
|
||||
np.float32: std_msgs.msg.Float32(),
|
||||
np.float64: std_msgs.msg.Float64(),
|
||||
np.bool: std_msgs.msg.Bool(),
|
||||
np.char: std_msgs.msg.Char(),
|
||||
np.int16: std_msgs.msg.Int16(),
|
||||
np.int32: std_msgs.msg.Int32(),
|
||||
np.int64: std_msgs.msg.Int64(),
|
||||
np.str: std_msgs.msg.String(),
|
||||
np.uint16: std_msgs.msg.UInt16(),
|
||||
np.uint32: std_msgs.msg.UInt32(),
|
||||
np.uint64: std_msgs.msg.UInt64(),
|
||||
np.uint8: std_msgs.msg.UInt8(),
|
||||
"float32": std_msgs.msg.Float32(),
|
||||
"float64": std_msgs.msg.Float64(),
|
||||
"bool": std_msgs.msg.Bool(),
|
||||
"char": std_msgs.msg.Char(),
|
||||
"int16": std_msgs.msg.Int16(),
|
||||
"int32": std_msgs.msg.Int32(),
|
||||
"int64": std_msgs.msg.Int64(),
|
||||
"str": std_msgs.msg.String(),
|
||||
"uint16": std_msgs.msg.UInt16(),
|
||||
"uint32": std_msgs.msg.UInt32(),
|
||||
"uint64": std_msgs.msg.UInt64(),
|
||||
"uint8": std_msgs.msg.UInt8(),
|
||||
}.get(
|
||||
dtype, dtype
|
||||
) # allow users to use their own custom messages in numpy arrays
|
||||
for dtype in dtypes
|
||||
]
|
||||
s = None
|
||||
cb_val = False
|
||||
try:
|
||||
wfms = {t: _WFM() for t in topic_names}
|
||||
s = {
|
||||
t: rospy.Subscriber(t, numpy_msg(msg_types[i]), wfms[t].cb)
|
||||
for i, t in enumerate(topic_names)
|
||||
}
|
||||
while not cb_val:
|
||||
while not rospy.core.is_shutdown():
|
||||
if poll_rate_hz:
|
||||
await asyncio.sleep(1.0 / poll_rate_hz)
|
||||
else:
|
||||
await asyncio.sleep(0)
|
||||
for t, w in wfms.items():
|
||||
if w.msg is not None:
|
||||
yield t, w.msg
|
||||
cb_val = end_callback(w.msg)
|
||||
w.msg = None
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
if s is not None:
|
||||
s.unregister()
|
||||
if rospy.core.is_shutdown():
|
||||
raise rospy.exceptions.ROSInterruptException("rospy shutdown")
|
||||
|
||||
+6
-3
@@ -1,9 +1,9 @@
|
||||
[metadata]
|
||||
name = 'displayarray'
|
||||
version = '0.0.2'
|
||||
description = 'Simple tool for working with multiple streams from OpenCV.'
|
||||
version = '0.0.3'
|
||||
description = 'Tool for displaying numpy arrays.'
|
||||
author = 'SimLeek'
|
||||
author_email = 'josh.miklos@gmail.com'
|
||||
author_email = 'simulator.leek@gmail.com'
|
||||
license = 'MIT/Apache-2.0'
|
||||
url = 'https://github.com/simleek/displayarray'
|
||||
|
||||
@@ -15,3 +15,6 @@ requires = ['setuptools', 'wheel']
|
||||
|
||||
[tool.hatch.commands]
|
||||
prerelease = 'hatch build'
|
||||
|
||||
[scripts]
|
||||
displayarray = "displayarray.__main__:main"
|
||||
|
||||
Reference in New Issue
Block a user