# coding: utf-8
"""Utilities for protocol package.
Copyright (c) 2017, Matthew Edwards. This file is subject to the 3-clause BSD
license, as found in the LICENSE file in the top-level directory of this
distribution and at https://github.com/mje-nz/python_natnet/blob/master/LICENSE.
No part of python_natnet, including this file, may be copied, modified,
propagated, or distributed except according to the terms contained in the
LICENSE file.
"""
import collections
import enum
import functools
import struct
import attr
[docs]class MessageId(enum.IntEnum):
"""Message IDs for each NatNet message (as in NatNetTypes.h).
Attributes:
Connect: Request for server info
ServerInfo: Motive version, NatNet version, clock frequency, data port, and multicast
address
RequestModelDef: Request for model definitions
ModelDef: List of definitions of rigid bodies, markersets, skeletons etc
FrameOfData: Frame of motion capture data
EchoRequest: Request server to immediately respond with its current time (used for clock
sync)
EchoResponse: Current server time (and time contained in EchoRequest message)
"""
Connect = 0
ServerInfo = 1
Request = 2
Response = 3
RequestModelDef = 4
ModelDef = 5
RequestFrameOfData = 6
FrameOfData = 7
MessageString = 8
Disconnect = 9
KeepAlive = 10
DisconnectByTimeout = 11
EchoRequest = 12
EchoResponse = 13
Discovery = 14
UnrecognizedRequest = 0x100 # NatNetTypes.h gives this as decimal 100, but that's incorrect
# Field types
bool_t = struct.Struct('?')
int16_t = struct.Struct('<h')
int32_t = struct.Struct('<i')
uint16_t = struct.Struct('<H')
uint32_t = struct.Struct('<I')
uint64_t = struct.Struct('<Q')
float_t = struct.Struct('<f')
double_t = struct.Struct('<d')
vector3_t = struct.Struct('<fff')
quaternion_t = struct.Struct('<ffff')
[docs]class ParseBuffer(object):
"""Buffer handling logic.
Contains a buffer and an offset, and provides methods for unpacking data types (as struct.Struct
instances) from the buffer."""
def __init__(self, data):
self.data = memoryview(data)
self.offset = 0
def __len__(self):
"""Length of remaining part of buffer."""
return len(self.data) - self.offset
[docs] def skip(self, struct_type, n=1):
"""Skip `n` fields of the given type."""
self.offset += struct_type.size*n
[docs] def unpack(self, struct_type):
"""Unpack a field.
Args:
struct_type (struct.Struct): Type of field to unpack
"""
value = struct_type.unpack(self.data[self.offset:self.offset + struct_type.size])
if len(value) == 1:
value = value[0]
self.offset += struct_type.size
return value
[docs] def unpack_cstr(self, size=None):
"""Unpack a null-terminated string field.
If size is given then always unpack that many bytes, otherwise unpack up to the first null.
"""
field = self.data[self.offset:]
if size:
field = self.data[self.offset:self.offset + size]
# TODO: Would this be better?
# value = data.split('\0', 1)[0]
value, _, _ = field.tobytes().partition(b'\0')
if size:
self.offset += size
else:
self.offset += len(value) + 1
return value.decode('utf-8')
[docs] def unpack_bytes(self, size):
"""Unpack a fixed-length field of bytes."""
value = self.data[self.offset:self.offset + size].tobytes()
self.offset += size
return value
[docs]class Version(collections.namedtuple('Version', ('major', 'minor', 'build', 'revision'))):
"""NatNet version, with correct comparison operator.
Believe it or not, this is performance-critical.
Attributes:
major (int):
minor (int):
build (int):
revision (int):"""
_version_t = struct.Struct('BBBB')
def __new__(cls, major, minor=0, build=0, revision=0):
return super(Version, cls).__new__(cls, major, minor, build, revision)
[docs] @classmethod
def deserialize(cls, data, version=None):
"""Deserialize a Version from a ParseBuffer."""
return cls(*data.unpack(cls._version_t))
[docs] def serialize(self):
"""Serialize a Version to bytes."""
return self._version_t.pack(*self)
[docs]@attr.s
class SerDesRegistry(object):
"""Registry of message implementations, which can serialize messages and deserialize packets.
An instance of this is used to provide the module-level function."""
_implementation_types = attr.ib(default=attr.Factory(dict))
_version = attr.ib(default=Version(3))
[docs] def register_message(self, id_):
"""Decorator to register the class which implements a given message.
Args:
id_ (:class:`MessageId`):
"""
def register_message_impl(cls):
cls.message_id = id_
self._implementation_types[id_] = cls
return cls
return register_message_impl
[docs] @staticmethod
def serialize(message):
"""Serialize a message instance into a binary packet.
Args:
message: A message instance
Returns:
bytes: The message serialized as a packet, ready to be sent
"""
message_id = message.message_id
payload = message.serialize()
return uint16_t.pack(message_id) + uint16_t.pack(len(payload)) + payload
[docs] def deserialize_payload(self, message_id, payload_data, version=None, strict=False):
"""Deserialize the payload of a packet into a message instance.
Args:
message_id (MessageId)
payload_data (ParseBuffer): raw payload
version (Version): Protocol version to use when deserializing
strict (bool): Raise an exception if there is data left in the buffer after parsing
Returns:
Message instance
"""
if version is None:
version = self._version
message_type = self._implementation_types[message_id]
message = message_type.deserialize(payload_data, version)
if strict:
name = message_id.name
assert len(payload_data) == 0, \
"{} bytes remaining after parsing {} message".format(len(payload_data), name)
return message
[docs] def deserialize(self, data, version=None, strict=False):
"""Deserialize a packet into a message instance.
Args:
data (bytes): A NatNet packet
version (Version): Protocol version to use when deserializing
strict (bool): Raise an exception if there is data left in the buffer after parsing.
Returns:
Message instance
"""
if version is None:
version = self._version
message_id, payload_data = self.deserialize_header(data)
return self.deserialize_payload(message_id, payload_data, version, strict)
_registry = SerDesRegistry()
# Wrap these so sphinx documents them as proper functions
[docs]@functools.wraps(_registry.register_message)
def register_message(*args, **kwargs):
return _registry.register_message(*args, **kwargs)
[docs]@functools.wraps(_registry.serialize)
def serialize(*args, **kwargs):
return _registry.serialize(*args, **kwargs)
[docs]@functools.wraps(_registry.deserialize)
def deserialize(*args, **kwargs):
return _registry.deserialize(*args, **kwargs)
[docs]@functools.wraps(_registry.deserialize_payload)
def deserialize_payload(*args, **kwargs):
return _registry.deserialize_payload(*args, **kwargs)