Added rospy coroutines. Added command line interface.

This commit is contained in:
simleek
2019-11-06 20:44:25 -07:00
parent 268ecddbe0
commit 475786a7d2
4 changed files with 289 additions and 52 deletions
+76
View File
@@ -0,0 +1,76 @@
"""
DisplayArray.
Display NumPy arrays.
Usage:
displayarray (-w <webcam-number> | -v <video-filename> | -t <topic-name>[,dtype])... [-m <msg-backend>]
displayarray -h
displayarray --version
Options:
-h, --help Show this help text.
--version Show version number.
-w <webcam-number>, --webcam=<webcam-number> Display video from a webcam.
-v <video-filename>, --video=<video-filename> Display frames from a video file.
-t <topic-name>, --topic=<topic-name> Display frames from a topic using the chosen message broker.
-m <msg-backend>, --message-backend <msg-backend> Choose message broker backend. [Default: ROS]
Currently supported: ROS, ZeroMQ
--ros Use ROS as the backend message broker.
--zeromq Use ZeroMQ as the backend message broker.
"""
from docopt import docopt
def main(argv=None):
arguments = docopt(__doc__, argv=argv)
if arguments["--version"]:
from displayarray import __version__
print(f"DisplayArray V{__version__}")
return
from displayarray import display
vids = [int(w) for w in arguments["--webcam"]] + arguments["--video"]
v_disps = None
if vids:
v_disps = display(*vids, blocking=False)
from displayarray.frame.frame_updater import read_updates_ros, read_updates_zero_mq
topics = arguments["--topic"]
topics_split = [t.split(",") for t in topics]
d = display()
async def msg_recv():
nonlocal d
while True:
if arguments["--message-backend"] == "ROS":
async for v_name, frame in read_updates_ros(
[t for t, d in topics_split], [d for t, d in topics_split]
):
d.update(arr=frame, id=v_name)
if arguments["--message-backend"] == "ZeroMQ":
async for v_name, frame in read_updates_zero_mq(
*[bytes(t, encoding="ascii") for t in topics]
):
d.update(arr=frame, id=v_name)
async def update_vids():
while True:
if v_disps:
v_disps.update()
await asyncio.sleep(0)
async def runner():
await asyncio.wait([msg_recv(), update_vids()])
loop = asyncio.get_event_loop()
loop.run_until_complete(runner())
loop.close()
if __name__ == "__main__":
import asyncio
main()
+76 -48
View File
@@ -1,18 +1,19 @@
import sys
import threading
import time
import asyncio
import cv2
import numpy as np
from displayarray import read_updates
from displayarray.frame import subscriber_dictionary
from displayarray.frame.frame_updater import FrameCallable
from .np_to_opencv import NpCam
from displayarray.uid import uid_for_source
from typing import Union, Tuple, Optional, Dict, Any, List, Callable
FrameCallable = Callable[[np.ndarray], Optional[np.ndarray]]
def pub_cam_loop(
cam_id: Union[int, str, np.ndarray],
@@ -97,7 +98,7 @@ def pub_cam_thread(
return t
def publish_updates_zero_mq(
async def publish_updates_zero_mq(
*vids,
callbacks: Optional[
Union[Dict[Any, FrameCallable], List[FrameCallable], FrameCallable]
@@ -105,30 +106,47 @@ def publish_updates_zero_mq(
fps_limit=float("inf"),
size=(-1, -1),
end_callback: Callable[[], bool] = lambda: False,
blocking=True,
blocking=False,
publishing_address="tcp://127.0.0.1:5600",
prepend_topic=""
prepend_topic="",
flags=0,
copy=True,
track=False
):
import zmq
from displayarray import read_updates
ctx = zmq.Context()
s = ctx.socket(zmq.PUB)
s.bind(publishing_address)
if not blocking:
flags |= zmq.NOBLOCK
try:
for v in read_updates(vids, callbacks, fps_limit, size, end_callback, blocking):
for vid_name, frame in v.items():
s.send_multipart([prepend_topic + vid_name, frame])
if v:
for vid_name, frame in v.items():
md = dict(
dtype=str(frame.dtype),
shape=frame.shape,
name=prepend_topic + vid_name,
)
s.send_json(md, flags | zmq.SNDMORE)
s.send(frame, flags, copy=copy, track=track)
if fps_limit:
await asyncio.sleep(1.0 / fps_limit)
else:
await asyncio.sleep(0)
except KeyboardInterrupt:
pass
finally:
vid_names = [uid_for_source(name) for name in vids]
for v in vid_names:
subscriber_dictionary.stop_cam(v)
sys.exit()
def publish_updates_ros(
async def publish_updates_ros(
*vids,
callbacks: Optional[
Union[Dict[Any, FrameCallable], List[FrameCallable], FrameCallable]
@@ -136,54 +154,64 @@ def publish_updates_ros(
fps_limit=float("inf"),
size=(-1, -1),
end_callback: Callable[[], bool] = lambda: False,
blocking=True,
node_name="displayarray"
blocking=False,
node_name="displayarray",
publisher_name="npy",
rate_hz=None,
dtype=None
):
# mostly copied from:
# https://answers.ros.org/question/289557/custom-message-including-numpy-arrays/?answer=321122#post-id-321122
import rospy
from std_msgs.msg import Float32MultiArray, MultiArrayDimension, UInt8MultiArray
from rospy.numpy_msg import numpy_msg
import std_msgs.msg
from displayarray import read_updates
vid_names = [uid_for_source(name) for name in vids]
def get_msg_type(dtype):
if dtype is None:
msg_type = {
np.float32: std_msgs.msg.Float32(),
np.float64: std_msgs.msg.Float64(),
np.bool: std_msgs.msg.Bool(),
np.char: std_msgs.msg.Char(),
np.int16: std_msgs.msg.Int16(),
np.int32: std_msgs.msg.Int32(),
np.int64: std_msgs.msg.Int64(),
np.str: std_msgs.msg.String(),
np.uint16: std_msgs.msg.UInt16(),
np.uint32: std_msgs.msg.UInt32(),
np.uint64: std_msgs.msg.UInt64(),
np.uint8: std_msgs.msg.UInt8(),
}[dtype]
else:
msg_type = (
dtype
) # allow users to use their own custom messages in numpy arrays
return msg_type
publishers = {}
rospy.init_node(node_name, anonymous=True)
pubs = {
vid_name: rospy.Publisher(vid_name, Float32MultiArray, queue_size=1)
for vid_name in vid_names
}
try:
for v in read_updates(vids, callbacks, fps_limit, size, end_callback, blocking):
if rospy.is_shutdown():
print("ROS is shutdown.")
break
for vid_name, frame in v.items():
if frame.dtype == np.uint8:
frame_msg = UInt8MultiArray()
elif frame.dtype == np.float32:
frame_msg = Float32MultiArray()
else:
raise NotImplementedError(
"Only uint8 and float32 types supported currently."
)
frame_msg.layout.dim = []
dims = np.array(frame.shape)
frame_size = dims.prod() / float(
frame.nbytes
) # this is my attempt to normalize the strides size depending on .nbytes. not sure this is correct
for i in range(0, len(dims)): # should be rather fast.
# gets the num. of dims of nparray to construct the message
frame_msg.layout.dim.append(MultiArrayDimension())
frame_msg.layout.dim[i].size = dims[i]
frame_msg.layout.dim[i].stride = dims[i:].prod() / frame_size
frame_msg.layout.dim[i].label = "dim_%d" % i
frame_msg.data = np.frombuffer(frame.tobytes())
pubs[vid_name].publish(frame_msg)
if v:
if rospy.is_shutdown():
break
for vid_name, frame in v.items():
if vid_name not in publishers:
dty = frame.dtype if dtype is None else dtype
publishers[vid_name] = rospy.Publisher(
publisher_name + vid_name,
numpy_msg(get_msg_type(dty)),
queue_size=10,
)
publishers[vid_name].publish(frame)
if rate_hz:
await asyncio.sleep(1.0 / rate_hz)
else:
await asyncio.sleep(0)
except KeyboardInterrupt:
pass
finally:
vid_names = [uid_for_source(name) for name in vids]
for v in vid_names:
subscriber_dictionary.stop_cam(v)
sys.exit()
if rospy.core.is_shutdown():
raise rospy.exceptions.ROSInterruptException("rospy shutdown")
+131 -1
View File
@@ -1,4 +1,5 @@
import threading
import asyncio
from typing import Union, Tuple, Any, Callable, List, Optional, Dict
import numpy as np
@@ -111,7 +112,7 @@ class FrameUpdater(threading.Thread):
raise self.exception_raised
def read_updates(
async def read_updates(
*vids,
callbacks: Optional[
Union[Dict[Any, FrameCallable], List[FrameCallable], FrameCallable]
@@ -158,7 +159,136 @@ def read_updates(
dict_was_updated = True
if dict_was_updated or not blocking:
yield vid_update_dict
await asyncio.sleep(0)
for v in vid_names:
subscriber_dictionary.stop_cam(v)
for v in vid_threads:
v.join()
async def read_updates_zero_mq(
*topic_names,
address="tcp://127.0.0.1:5600",
flags=0,
copy=True,
track=False,
blocking=False,
end_callback: Callable[[Any], bool] = lambda: False,
):
import zmq
ctx = zmq.Context()
s = ctx.socket(zmq.SUB)
s.connect(address)
if not blocking:
flags |= zmq.NOBLOCK
for topic in topic_names:
s.setsockopt(zmq.SUBSCRIBE, topic)
cb_val = False
while not cb_val:
try:
md = s.recv_json(flags=flags)
msg = s.recv(flags=flags, copy=copy, track=track)
buf = memoryview(msg)
arr = np.frombuffer(buf, dtype=md["dtype"])
arr.reshape(md["shape"])
name = md["name"]
cb_val = end_callback(md)
yield name, arr
except zmq.ZMQError as e:
if isinstance(e, zmq.Again):
pass # no messages to receive
else:
raise e
finally:
await asyncio.sleep(0)
async def read_updates_ros(
*topic_names,
dtypes=None,
listener_node_name=None,
poll_rate_hz=None,
end_callback: Callable[[Any], bool] = lambda: False,
):
import rospy
from rospy.numpy_msg import numpy_msg
from rospy.client import _WFM
import std_msgs.msg
import random
import string
if dtypes is None:
raise ValueError(
"ROS cannot automatically determine the types of incoming numpy arrays. Please specify.\n"
"Options are: \n"
"\tfloat32, float64, bool, char, int16, "
"\tint32, int64, str, uint16, uint32, uint64, uint8"
)
if listener_node_name is None:
# https://stackoverflow.com/a/2257449
listener_node_name = "".join(
random.choices(string.ascii_uppercase + string.digits, k=8)
)
rospy.init_node(listener_node_name)
msg_types = [
{
np.float32: std_msgs.msg.Float32(),
np.float64: std_msgs.msg.Float64(),
np.bool: std_msgs.msg.Bool(),
np.char: std_msgs.msg.Char(),
np.int16: std_msgs.msg.Int16(),
np.int32: std_msgs.msg.Int32(),
np.int64: std_msgs.msg.Int64(),
np.str: std_msgs.msg.String(),
np.uint16: std_msgs.msg.UInt16(),
np.uint32: std_msgs.msg.UInt32(),
np.uint64: std_msgs.msg.UInt64(),
np.uint8: std_msgs.msg.UInt8(),
"float32": std_msgs.msg.Float32(),
"float64": std_msgs.msg.Float64(),
"bool": std_msgs.msg.Bool(),
"char": std_msgs.msg.Char(),
"int16": std_msgs.msg.Int16(),
"int32": std_msgs.msg.Int32(),
"int64": std_msgs.msg.Int64(),
"str": std_msgs.msg.String(),
"uint16": std_msgs.msg.UInt16(),
"uint32": std_msgs.msg.UInt32(),
"uint64": std_msgs.msg.UInt64(),
"uint8": std_msgs.msg.UInt8(),
}.get(
dtype, dtype
) # allow users to use their own custom messages in numpy arrays
for dtype in dtypes
]
s = None
cb_val = False
try:
wfms = {t: _WFM() for t in topic_names}
s = {
t: rospy.Subscriber(t, numpy_msg(msg_types[i]), wfms[t].cb)
for i, t in enumerate(topic_names)
}
while not cb_val:
while not rospy.core.is_shutdown():
if poll_rate_hz:
await asyncio.sleep(1.0 / poll_rate_hz)
else:
await asyncio.sleep(0)
for t, w in wfms.items():
if w.msg is not None:
yield t, w.msg
cb_val = end_callback(w.msg)
w.msg = None
except KeyboardInterrupt:
pass
finally:
if s is not None:
s.unregister()
if rospy.core.is_shutdown():
raise rospy.exceptions.ROSInterruptException("rospy shutdown")
+6 -3
View File
@@ -1,9 +1,9 @@
[metadata]
name = 'displayarray'
version = '0.0.2'
description = 'Simple tool for working with multiple streams from OpenCV.'
version = '0.0.3'
description = 'Tool for displaying numpy arrays.'
author = 'SimLeek'
author_email = 'josh.miklos@gmail.com'
author_email = 'simulator.leek@gmail.com'
license = 'MIT/Apache-2.0'
url = 'https://github.com/simleek/displayarray'
@@ -15,3 +15,6 @@ requires = ['setuptools', 'wheel']
[tool.hatch.commands]
prerelease = 'hatch build'
[scripts]
displayarray = "displayarray.__main__:main"