refactor streams for fetching/addrgen

write end to end test
This commit is contained in:
saji 2024-10-29 10:02:35 -05:00
parent 5f54b8acd8
commit ae1ad4633c
2 changed files with 79 additions and 30 deletions

View file

@ -2,13 +2,13 @@
# to know its location. # to know its location.
# during operation, it is given a row index, and responds with the data. # during operation, it is given a row index, and responds with the data.
from amaranth import Module, Signal, unsigned
from amaranth import Module, Signal, unsigned, Cat
from amaranth.build import Platform from amaranth.build import Platform
from amaranth.lib import wiring, data from amaranth.lib import wiring, data
from amaranth.lib.wiring import In, Out from amaranth.lib.wiring import In, Out
from amaranth.lib import stream from amaranth.lib import stream
import logging import logging
from itertools import pairwise
from .common import Rgb888Layout from .common import Rgb888Layout
from .geom import DisplayString from .geom import DisplayString
@ -16,11 +16,9 @@ from .geom import DisplayString
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# FIXME: sizing should be based off of screen size. # FIXME: sizing should be based off of screen size.
CoordLayout = data.StructLayout({"x": unsigned(10), "y": unsigned(10)}) CoordLayout = data.StructLayout({"x": unsigned(10), "y": unsigned(10)})
class AddressConverter(wiring.Component): class AddressConverter(wiring.Component):
"""Translates display (x,y) into full screen (x,y) based on geometry""" """Translates display (x,y) into full screen (x,y) based on geometry"""
@ -53,7 +51,7 @@ class AddressGenerator(wiring.Component):
self.geom = geom self.geom = geom
super().__init__( super().__init__(
{ {
"coordstream": Out( "output": Out(
stream.Signature(data.ArrayLayout(CoordLayout, geom.dimensions.mux)) stream.Signature(data.ArrayLayout(CoordLayout, geom.dimensions.mux))
), ),
"start": In(1), "start": In(1),
@ -75,41 +73,58 @@ class AddressGenerator(wiring.Component):
m.d.comb += translate.input_x.eq(counter) m.d.comb += translate.input_x.eq(counter)
m.d.comb += translate.addr.eq(addr) m.d.comb += translate.addr.eq(addr)
m.d.comb += self.coordstream.payload.eq(translate.output) m.d.comb += self.output.payload.eq(translate.output)
with m.FSM(): with m.FSM():
with m.State("init"): with m.State("init"):
m.d.comb += [self.done.eq(0), self.coordstream.valid.eq(0)] m.d.comb += [self.done.eq(0), self.output.valid.eq(0)]
m.d.sync += [counter.eq(0), addr.eq(self.addr)] m.d.sync += [counter.eq(0), addr.eq(self.addr)]
with m.If(self.start): with m.If(self.start):
m.next = "run" m.next = "run"
with m.State("run"): with m.State("run"):
m.d.comb += self.coordstream.valid.eq(1) m.d.comb += self.output.valid.eq(1)
# stream data out as long as it's valid. # stream data out as long as it's valid.
with m.If( with m.If(
self.coordstream.ready self.output.ready & (counter == self.geom.dimensions.length - 1)
& (counter == self.geom.dimensions.length - 1)
): ):
m.next = "done" m.next = "done"
with m.Elif(self.coordstream.ready): with m.Elif(self.output.ready):
m.d.sync += counter.eq(counter + 1) m.d.sync += counter.eq(counter + 1)
pass pass
with m.State("done"): with m.State("done"):
m.d.comb += self.coordstream.valid.eq(0) m.d.comb += self.output.valid.eq(0)
m.d.comb += self.done.eq(1) m.d.comb += self.done.eq(1)
m.next = "init" m.next = "init"
return m return m
def example_rgb_transform(x, y):
return {
"red": x + y,
"green": x - y,
"blue": x ^ y,
}
class BasicFetcher(wiring.Component): class BasicFetcher(wiring.Component):
"""A generic function-based fetcher. Takes a function of the form f(x,y: int) -> RGB.""" """A generic function-based fetcher. Takes a function of the form f(x,y: int) -> dict rgb values.
If no function is provided it uses a basic coordinate-driven rgb transform where red = x+y,
green = x - y, and blue = x ^ y.
When providing a function, it must return a dictionary with the keys "red", "green", "blue"."""
def __init__( def __init__(
self, geom: DisplayString, dfunc, data_shape=Rgb888Layout, *, src_loc_at=0 self,
geom: DisplayString,
dfunc=example_rgb_transform,
data_shape=Rgb888Layout,
*,
src_loc_at=0,
): ):
self.geom = geom self.geom = geom
self.dfunc = dfunc self.dfunc = dfunc
@ -118,7 +133,7 @@ class BasicFetcher(wiring.Component):
"input": In( "input": In(
stream.Signature(data.ArrayLayout(CoordLayout, geom.dimensions.mux)) stream.Signature(data.ArrayLayout(CoordLayout, geom.dimensions.mux))
), ),
"pixstream": Out( "output": Out(
stream.Signature(data.ArrayLayout(data_shape, geom.dimensions.mux)) stream.Signature(data.ArrayLayout(data_shape, geom.dimensions.mux))
), ),
}, },
@ -128,20 +143,24 @@ class BasicFetcher(wiring.Component):
def elaborate(self, platform: Platform) -> Module: def elaborate(self, platform: Platform) -> Module:
m = Module() m = Module()
# test mode - pass through, r = x + y, g = x - y, b = {y,x} colors = self.output.payload
colors = self.pixstream.payload
m.d.comb += [ m.d.comb += [
self.input.valid.eq(self.pixstream.valid), self.output.valid.eq(self.input.valid),
self.input.ready.eq(self.pixstream.ready), self.input.ready.eq(self.output.ready),
] ]
for i in range(self.geom.dimensions.mux): for i in range(self.geom.dimensions.mux):
inp = self.input.payload[i] inp = self.input.payload[i]
output = self.dfunc(inp.x, inp.y)
m.d.comb += [ m.d.comb += [
colors[i].red.eq(inp.x + inp.y), colors[i].red.eq(output["red"]),
colors[i].green.eq(inp.x - inp.y), colors[i].green.eq(output["green"]),
colors[i].blue.eq(inp.x ^ inp.y), colors[i].blue.eq(output["blue"]),
] ]
return m return m
def chain_streams(m, streams):
for pair in pairwise(streams):
wiring.connect(m, pair[0].output, pair[1].input)

View file

@ -1,10 +1,9 @@
from amaranth.lib import wiring, data from amaranth import Module
from amaranth.lib import wiring
from amaranth.sim import Simulator from amaranth.sim import Simulator
import random
from random import randrange
import pytest import pytest
from groovylight.fetcher import AddressConverter, AddressGenerator, BasicFetcher from groovylight.fetcher import AddressConverter, AddressGenerator, BasicFetcher, chain_streams
from groovylight.geom import DisplayString, Coord, DisplayDimensions, DisplayRotation from groovylight.geom import DisplayString, Coord, DisplayDimensions, DisplayRotation
ds_testdata = [ ds_testdata = [
@ -82,7 +81,7 @@ def test_generator(addr, rot):
async def stream_checker(ctx): async def stream_checker(ctx):
while ctx.get(dut.done) == 0: while ctx.get(dut.done) == 0:
payload = await stream_get(ctx, dut.coordstream) payload = await stream_get(ctx, dut.output)
assert expected.pop() == payload assert expected.pop() == payload
sim.add_testbench(runner) sim.add_testbench(runner)
@ -103,12 +102,12 @@ def test_basic_fetcher(inp, expected):
ds = DisplayString( ds = DisplayString(
Coord(3, 0), DisplayDimensions(128, 64, mux=1), DisplayRotation.R0 Coord(3, 0), DisplayDimensions(128, 64, mux=1), DisplayRotation.R0
) )
dut = BasicFetcher(ds, None) dut = BasicFetcher(ds)
sim = Simulator(dut) sim = Simulator(dut)
async def test(ctx): async def test(ctx):
ctx.set(dut.input.payload[0], inp) ctx.set(dut.input.payload[0], inp)
res = ctx.get(dut.pixstream.payload)[0] res = ctx.get(dut.output.payload)[0]
assert res["red"] == expected["red"] assert res["red"] == expected["red"]
assert res["green"] == expected["green"] assert res["green"] == expected["green"]
assert res["blue"] == expected["blue"] assert res["blue"] == expected["blue"]
@ -117,3 +116,34 @@ def test_basic_fetcher(inp, expected):
with sim.write_vcd("fetcher.vcd"): with sim.write_vcd("fetcher.vcd"):
sim.run() sim.run()
def test_stream_e2e():
ds = DisplayString(
Coord(3, 0), DisplayDimensions(128, 64, mux=1), DisplayRotation.R0
)
m = Module()
m.submodules.gen = addrgen = AddressGenerator(ds)
m.submodules.fetch = fetch = BasicFetcher(ds)
chain_streams(m, [addrgen, fetch])
sim = Simulator(m)
sim.add_clock(1e-6)
async def stim(ctx):
await ctx.tick()
ctx.set(addrgen.start, 1)
await ctx.tick()
ctx.set(addrgen.start, 0)
payload = await stream_get(ctx, fetch.output)
assert payload[0] == {"red": 3, "green": 3, "blue": 3}
sim.add_testbench(stim)
with sim.write_vcd("stream_e2e.vcd"):
sim.run()