gotelem/py/pytelem/skylab.py

468 lines
14 KiB
Python
Raw Normal View History

2023-06-20 02:21:47 +00:00
# this file describes a skylab yaml and it's associated fields. It also
# provides functions to parse a skylab packet folder and a few AST operators.
from abc import ABC, abstractmethod
import re
from pathlib import Path
from typing import Callable, Iterable, NewType, TypedDict, List, Protocol, Union, Set
2023-09-19 19:17:22 +00:00
from pydantic import field_validator, BaseModel, validator, model_validator
2023-06-20 02:21:47 +00:00
from enum import Enum
import yaml
import jinja2
# This part of the file is dedicated to parsing the skylab yaml files. We define
# classes that represent objects in the yaml files, and perform basic validation on
# the input data. We also define a load_yamls function that loads a directory of skylab files.
class FieldType(str, Enum):
"""FieldType indicates the type of the field - the enum represents the C type,
but you can use a map to convert the type to another language."""
# used to ensure types are valid, and act as representations for other languages/mappings.
U8 = "uint8_t"
U16 = "uint16_t"
U32 = "uint32_t"
U64 = "uint64_t"
I8 = "int8_t"
I16 = "int16_t"
I32 = "int32_t"
I64 = "int64_t"
F32 = "float"
Bitfield = "bitfield"
def size(self) -> int:
"""Returns the size, in bytes, of the type."""
match self:
case FieldType.U8:
return 1
case FieldType.U16:
return 2
case FieldType.U32:
return 4
case FieldType.U64:
return 8
case FieldType.I8:
return 1
case FieldType.I16:
return 2
case FieldType.I32:
return 4
case FieldType.I64:
return 8
case FieldType.F32:
return 4
case FieldType.Bitfield:
return 1
return -1
# A FieldTypeMapper is any function that takes a field type and returns
# either a mapped string (based on the type) or none if there was not a match.
FieldTypeMapper = Callable[[FieldType], str | None]
class _Bits(TypedDict):
"""Internal class: a bits object just has a name."""
name: str
class SkylabField(BaseModel):
"""Represents a field (data element) inside a Skylab Packet."""
name: str
"the name of the field. must be alphanumeric and underscores"
type: FieldType
"the type of the field"
2023-09-19 19:17:22 +00:00
units: str | None = None
2023-06-20 02:21:47 +00:00
"optional descriptor of the unit representation"
2023-09-19 19:17:22 +00:00
conversion: float | None = None
2023-06-20 02:21:47 +00:00
"optional conversion factor to be applied when parsing"
2023-09-19 19:17:22 +00:00
bits: List[_Bits] | None = None
2023-06-20 02:21:47 +00:00
"if the type if a bitfield, "
2023-09-19 19:17:22 +00:00
@model_validator(mode='after')
def bits_must_exist_if_bitfield(self) -> 'SkylabField':
if self.bits is None and self.type == FieldType.Bitfield:
2023-06-20 02:21:47 +00:00
raise ValueError("bits are not present on bitfield type")
2023-09-19 19:17:22 +00:00
if self.bits is not None and self.type != FieldType.Bitfield:
2023-06-20 02:21:47 +00:00
raise ValueError("bits are present on non-bitfield type")
2023-09-19 19:17:22 +00:00
return self
2023-06-20 02:21:47 +00:00
2023-09-19 19:17:22 +00:00
@field_validator("name")
@classmethod
def name_valid_string(cls, v: str) -> str:
2023-06-20 02:21:47 +00:00
if not re.match(r"^[A-Za-z0-9_]+$", v):
2023-09-19 19:17:22 +00:00
raise ValueError("invalid name")
2023-06-20 02:21:47 +00:00
return v
2023-09-19 19:17:22 +00:00
@field_validator("name")
@classmethod
def name_nonzero_length(cls, v: str) -> str:
2023-06-20 02:21:47 +00:00
if len(v) == 0:
2023-09-19 19:17:22 +00:00
raise ValueError("name cannot be empty string")
2023-06-20 02:21:47 +00:00
return v
class Endian(str, Enum):
"""Symbol representing the endianness of the packet"""
Big = "big"
Little = "little"
class SkylabPacket(BaseModel):
"""Represents a CAN packet. Contains SkylabFields with information on the structure of the data."""
name: str
2023-09-19 19:17:22 +00:00
description: str | None = None
2023-06-20 02:21:47 +00:00
id: int
endian: Endian
2023-09-19 19:17:22 +00:00
repeat: int | None = None
offset: int | None = None
2023-06-20 02:21:47 +00:00
data: List[SkylabField]
# @validator("data")
# def packet_size_limit(cls, v: List[SkylabField]):
# tot = sum([f.type.size() for f in v])
# if tot > 8:
# return ValueError("Total packet size cannot exceed 8 bytes")
# return v
2023-09-19 19:17:22 +00:00
@field_validator("id")
@classmethod
2023-06-20 02:21:47 +00:00
def id_non_negative(cls, v: int) -> int:
if v < 0:
raise ValueError("id must be above zero")
return v
2023-09-19 19:17:22 +00:00
@field_validator("name")
@classmethod
2023-06-20 02:21:47 +00:00
def name_valid_string(cls, v: str) -> str:
if not re.match(r"^[A-Za-z0-9_]+$", v):
raise ValueError("invalid name", v)
return v
2023-09-19 19:17:22 +00:00
@field_validator("name")
@classmethod
2023-06-20 02:21:47 +00:00
def name_nonzero_length(cls, v: str) -> str:
if len(v) == 0:
raise ValueError("name cannot be empty string")
return v
2023-09-19 19:17:22 +00:00
# TODO[pydantic]: We couldn't refactor the `validator`, please replace it by `field_validator` manually.
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information.
@model_validator(mode='after')
def offset_must_have_repeat(self) -> "SkylabPacket":
if self.offset is not None and self.repeat is not None:
2023-06-20 02:21:47 +00:00
raise ValueError("field with offset must have repeat defined")
return v
2023-09-19 19:17:22 +00:00
@field_validator("repeat")
@classmethod
2023-06-20 02:21:47 +00:00
def repeat_gt_one(cls, v: int | None):
if v is not None and v <= 1:
raise ValueError("repeat must be strictly greater than one")
return v
class SkylabBoard(BaseModel):
"""Represents a single board. Each board has packets that it sends and receives
Validations:
- There can only be one sender of a packet, but multiple receivers
- every name in the transmit/receive list must have a corresponding packet.
"""
name: str
"The name of the board"
transmit: List[str]
"The packets sent by this board"
receive: List[str]
"The packets received by this board."
2023-09-19 19:17:22 +00:00
@field_validator("name")
@classmethod
2023-06-20 02:21:47 +00:00
def name_valid_string(cls, v: str):
if not re.match(r"^[A-Za-z0-9_]+$", v):
return ValueError("invalid name", v)
return v
2023-09-19 19:17:22 +00:00
@field_validator("name")
@classmethod
2023-06-20 02:21:47 +00:00
def name_nonzero_length(cls, v: str):
if len(v) == 0:
return ValueError("name cannot be empty string")
return v
class SkylabBus(BaseModel):
name: str
"The name of the bus"
baud_rate: int
"Baud rate setting for the bus"
extended_id: bool
"If the bus uses extended ids"
2023-09-19 19:17:22 +00:00
@field_validator("name")
@classmethod
2023-06-20 02:21:47 +00:00
def name_valid_string(cls, v: str):
if not re.match(r"^[A-Za-z0-9_]+$", v):
return ValueError("invalid name", v)
return v
2023-09-19 19:17:22 +00:00
@field_validator("baud_rate")
@classmethod
2023-06-20 02:21:47 +00:00
def baud_rate_supported(cls, v: int):
if v not in [125000, 250000, 500000, 750000, 1000000]:
raise ValueError("unsupported baud rate", v)
return v
class SkylabFile(BaseModel):
"""Represents an entire skylab yaml file. Performs additional cross-validation between
boards and packets."""
packets: List[SkylabPacket] = []
boards: List[SkylabBoard] = []
busses: List[SkylabBus] = []
# TODO: add extra validators here?
def load_skylab_dir(path: Path) -> SkylabFile:
"""Loads all the .yaml files in a directory and merges them into one large SkylabFile, which is then returned."""
files = [f for f in path.iterdir() if re.search(r".*\.ya?ml$", str(f))]
sky_files: List[SkylabFile] = []
for file in files:
with open(file, "r") as f:
obj = yaml.load(f, Loader=yaml.Loader)
sky_files.append(SkylabFile.parse_obj(obj))
# merge the files
# this is not very fast or elegant but who cares.
all_pkts = []
all_boards = []
for sky_f in sky_files:
d = sky_f.dict()
all_pkts.append(d["packets"])
all_boards.append(d["boards"])
collected_skyfile = SkylabFile.parse_obj(
{"packets": all_pkts, "boards": all_boards}
)
return collected_skyfile
# hey. The next bit of code is entirely optional for you to use! It's totally acceptable to just skip it and manually
# iterate over the SkylabFile yourself. The reason we use the walk tree/visitor pattern here is to abstract away
# traversing the tree from the functions that process it.
# While this is generally a pretty common use case (hence the abstraction), it
# can be difficult to wrap certain processes around it.
SkylabObject = Union[SkylabFile, SkylabPacket, SkylabBoard, SkylabField, SkylabBus]
"SkylabObject is any object that will be walked when making a parsing pass on the AST"
class SkylabWalker(Protocol):
"""A SkylabWalker is any class that implements walk(self, SkylabObject)."""
def walk(self, obj: SkylabObject):
...
"walk is called for each SkylabObject in the SkylabFile tree"
class SkylabVisitor(ABC):
"""SkylabVisitor is an abstract class that makes children walkable. Children must
implement the visit_* functions which contain explicit signatures for each discrete unit in a Skylabfile.
"""
@abstractmethod
def visit_file(self, file: SkylabFile):
...
@abstractmethod
def visit_board(self, board: SkylabBoard):
...
@abstractmethod
def visit_packet(self, packet: SkylabPacket):
...
@abstractmethod
def visit_field(self, field: SkylabField, parent: SkylabPacket):
...
@abstractmethod
def visit_bus(self, bus: SkylabBus):
...
_last_parent: SkylabPacket | None = None
"internal variable storing the parent packet for visit_field"
def walk(self, obj: SkylabObject):
match obj:
case SkylabFile():
self.visit_file(obj)
case SkylabBoard():
self.visit_board(obj)
case SkylabPacket():
self._last_parent = obj
self.visit_packet(obj)
case SkylabBus():
self.visit_bus(obj)
case SkylabField():
if self._last_parent is None:
raise Error("Unexpected field without parent")
self.visit_field(obj, self._last_parent)
def walk_tree(tree: SkylabFile, walker: SkylabWalker):
"""Walks the tree using the given walker"""
walker.walk(tree)
for bus in tree.busses:
walker.walk(bus)
for board in tree.boards:
walker.walk(board)
for packet in tree.packets:
walker.walk(packet)
for field in packet.data:
walker.walk(field)
class Error(Exception):
"""An exception when processing the tree"""
class RelationValidator:
"""This class is a processor that validates the relation between boards and
packets.
- Each packet MAY have AT MOST one transmitter.
- Each packet MUST have AT LEAST one board reference.
- Boards MUST ONLY reference packets that exist in 'packets'
- Boards MUST have a UNIQUE name
- Packets MUST have a UNIQUE name"""
seen_packets: Set[str]
"A set of all the packets in the 'packets' field"
sent_packets: Set[str]
"A set of all the packet names that are sent by boards"
recv_packets: Set[str]
"A set of all packets recv'd by boards"
board_names: Set[str]
def __init__(self):
self.board_names = set()
self.seen_packets = set()
self.sent_packets = set()
self.board_names = set()
# The first test: make sure that no two boards send the same packet.
# check sent_packets for existing element before adding.
# the second test: each packet should have a unique name.
# the third test -> Union of sent_packets and recv_packets should
# be exactly equal to seen_packets
def walk(self, obj: SkylabObject):
match obj:
case SkylabPacket(name=n):
if n in self.seen_packets:
raise Error(f"packet {n} declared twice")
self.seen_packets.add(n)
case SkylabBoard(transmit=tx, receive=rx, name=n):
if n in self.board_names:
raise Error(f"board {n} declared twice")
self.board_names.add(n)
for r_pkt in rx:
self.recv_packets.add(r_pkt)
for t_pkt in tx:
if t_pkt in self.sent_packets:
raise Error(f"packet {t_pkt} is sent from two sources")
self.sent_packets.add(t_pkt)
case _:
... # skip others.
def validate(self):
"""runs final checks"""
# perform third check: packets must all be used.
board_ref_packets = self.recv_packets.union(self.sent_packets)
# xor = symmetric_difference, which is packets in one or the other but not both.
unref_packets = board_ref_packets ^ self.seen_packets
if len(unref_packets) > 0:
raise Error(f"packets missing a link: {unref_packets}")
class CollisionDetector:
"""This class detects ID collisions of packets. It expands the repeated packet to ensure that there is
no overlap"""
seen_ids: Set[int] = set()
def add_or_fail(self, idx: int):
if idx in self.seen_ids:
raise ValueError(f"Collision on packet {idx}")
self.seen_ids.add(idx)
def walk(self, obj: SkylabObject):
match obj:
case SkylabPacket(id=idx, offset=None, repeat=None):
# matches single packets - just add it directly.
self.add_or_fail(idx)
# we need the guard clause cause pyright/mypy isn't smart enough to read the entire match block.
case SkylabPacket(
id=base_idx, offset=off, repeat=rpt
) if off is not None and rpt is not None:
for i in range(0, rpt):
self.add_or_fail(base_idx + i * off)
case _:
... # do nothing for packets or fields.
def validate(self):
print(f"{len(self.seen_ids)} packet IDs discovered with no collisions")
class ExampleGenerator:
"""Demonstrates how to use Jinja templates with a custom environment to generate output documents from
the skylab objects."""
env: jinja2.Environment
def __init__(self):
self.env = jinja2.Environment(
loader=jinja2.loaders.PackageLoader(".templates.c")
)
def render(self, skylab: SkylabFile):
...
class GraphvizGenerator:
"""This class converts the Skylab files into a GraphViz document
detailing the flow of data as well as information about the data."""
class CGenerator:
"""This class generates C files for our microcontrollers."""
class PyGenerator:
"""This class generates a python module that can serialize/deserialize packets."""