refactor streams for fetching/addrgen

write end to end test
This commit is contained in:
saji 2024-10-29 10:02:35 -05:00
parent 5f54b8acd8
commit ae1ad4633c
2 changed files with 79 additions and 30 deletions

View file

@ -2,13 +2,13 @@
# to know its location.
# during operation, it is given a row index, and responds with the data.
from amaranth import Module, Signal, unsigned, Cat
from amaranth import Module, Signal, unsigned
from amaranth.build import Platform
from amaranth.lib import wiring, data
from amaranth.lib.wiring import In, Out
from amaranth.lib import stream
import logging
from itertools import pairwise
from .common import Rgb888Layout
from .geom import DisplayString
@ -16,11 +16,9 @@ from .geom import DisplayString
logger = logging.getLogger(__name__)
# FIXME: sizing should be based off of screen size.
CoordLayout = data.StructLayout({"x": unsigned(10), "y": unsigned(10)})
class AddressConverter(wiring.Component):
"""Translates display (x,y) into full screen (x,y) based on geometry"""
@ -53,7 +51,7 @@ class AddressGenerator(wiring.Component):
self.geom = geom
super().__init__(
{
"coordstream": Out(
"output": Out(
stream.Signature(data.ArrayLayout(CoordLayout, geom.dimensions.mux))
),
"start": In(1),
@ -75,41 +73,58 @@ class AddressGenerator(wiring.Component):
m.d.comb += translate.input_x.eq(counter)
m.d.comb += translate.addr.eq(addr)
m.d.comb += self.coordstream.payload.eq(translate.output)
m.d.comb += self.output.payload.eq(translate.output)
with m.FSM():
with m.State("init"):
m.d.comb += [self.done.eq(0), self.coordstream.valid.eq(0)]
m.d.comb += [self.done.eq(0), self.output.valid.eq(0)]
m.d.sync += [counter.eq(0), addr.eq(self.addr)]
with m.If(self.start):
m.next = "run"
with m.State("run"):
m.d.comb += self.coordstream.valid.eq(1)
m.d.comb += self.output.valid.eq(1)
# stream data out as long as it's valid.
with m.If(
self.coordstream.ready
& (counter == self.geom.dimensions.length - 1)
self.output.ready & (counter == self.geom.dimensions.length - 1)
):
m.next = "done"
with m.Elif(self.coordstream.ready):
with m.Elif(self.output.ready):
m.d.sync += counter.eq(counter + 1)
pass
with m.State("done"):
m.d.comb += self.coordstream.valid.eq(0)
m.d.comb += self.output.valid.eq(0)
m.d.comb += self.done.eq(1)
m.next = "init"
return m
def example_rgb_transform(x, y):
return {
"red": x + y,
"green": x - y,
"blue": x ^ y,
}
class BasicFetcher(wiring.Component):
"""A generic function-based fetcher. Takes a function of the form f(x,y: int) -> RGB."""
"""A generic function-based fetcher. Takes a function of the form f(x,y: int) -> dict rgb values.
If no function is provided it uses a basic coordinate-driven rgb transform where red = x+y,
green = x - y, and blue = x ^ y.
When providing a function, it must return a dictionary with the keys "red", "green", "blue"."""
def __init__(
self, geom: DisplayString, dfunc, data_shape=Rgb888Layout, *, src_loc_at=0
self,
geom: DisplayString,
dfunc=example_rgb_transform,
data_shape=Rgb888Layout,
*,
src_loc_at=0,
):
self.geom = geom
self.dfunc = dfunc
@ -118,7 +133,7 @@ class BasicFetcher(wiring.Component):
"input": In(
stream.Signature(data.ArrayLayout(CoordLayout, geom.dimensions.mux))
),
"pixstream": Out(
"output": Out(
stream.Signature(data.ArrayLayout(data_shape, geom.dimensions.mux))
),
},
@ -128,20 +143,24 @@ class BasicFetcher(wiring.Component):
def elaborate(self, platform: Platform) -> Module:
m = Module()
# test mode - pass through, r = x + y, g = x - y, b = {y,x}
colors = self.pixstream.payload
colors = self.output.payload
m.d.comb += [
self.input.valid.eq(self.pixstream.valid),
self.input.ready.eq(self.pixstream.ready),
self.output.valid.eq(self.input.valid),
self.input.ready.eq(self.output.ready),
]
for i in range(self.geom.dimensions.mux):
inp = self.input.payload[i]
output = self.dfunc(inp.x, inp.y)
m.d.comb += [
colors[i].red.eq(inp.x + inp.y),
colors[i].green.eq(inp.x - inp.y),
colors[i].blue.eq(inp.x ^ inp.y),
colors[i].red.eq(output["red"]),
colors[i].green.eq(output["green"]),
colors[i].blue.eq(output["blue"]),
]
return m
def chain_streams(m, streams):
for pair in pairwise(streams):
wiring.connect(m, pair[0].output, pair[1].input)

View file

@ -1,10 +1,9 @@
from amaranth.lib import wiring, data
from amaranth import Module
from amaranth.lib import wiring
from amaranth.sim import Simulator
import random
from random import randrange
import pytest
from groovylight.fetcher import AddressConverter, AddressGenerator, BasicFetcher
from groovylight.fetcher import AddressConverter, AddressGenerator, BasicFetcher, chain_streams
from groovylight.geom import DisplayString, Coord, DisplayDimensions, DisplayRotation
ds_testdata = [
@ -82,7 +81,7 @@ def test_generator(addr, rot):
async def stream_checker(ctx):
while ctx.get(dut.done) == 0:
payload = await stream_get(ctx, dut.coordstream)
payload = await stream_get(ctx, dut.output)
assert expected.pop() == payload
sim.add_testbench(runner)
@ -103,12 +102,12 @@ def test_basic_fetcher(inp, expected):
ds = DisplayString(
Coord(3, 0), DisplayDimensions(128, 64, mux=1), DisplayRotation.R0
)
dut = BasicFetcher(ds, None)
dut = BasicFetcher(ds)
sim = Simulator(dut)
async def test(ctx):
ctx.set(dut.input.payload[0], inp)
res = ctx.get(dut.pixstream.payload)[0]
res = ctx.get(dut.output.payload)[0]
assert res["red"] == expected["red"]
assert res["green"] == expected["green"]
assert res["blue"] == expected["blue"]
@ -117,3 +116,34 @@ def test_basic_fetcher(inp, expected):
with sim.write_vcd("fetcher.vcd"):
sim.run()
def test_stream_e2e():
ds = DisplayString(
Coord(3, 0), DisplayDimensions(128, 64, mux=1), DisplayRotation.R0
)
m = Module()
m.submodules.gen = addrgen = AddressGenerator(ds)
m.submodules.fetch = fetch = BasicFetcher(ds)
chain_streams(m, [addrgen, fetch])
sim = Simulator(m)
sim.add_clock(1e-6)
async def stim(ctx):
await ctx.tick()
ctx.set(addrgen.start, 1)
await ctx.tick()
ctx.set(addrgen.start, 0)
payload = await stream_get(ctx, fetch.output)
assert payload[0] == {"red": 3, "green": 3, "blue": 3}
sim.add_testbench(stim)
with sim.write_vcd("stream_e2e.vcd"):
sim.run()