From ae1ad4633c2278d41de804aa5fd0fdc21f494f23 Mon Sep 17 00:00:00 2001 From: saji Date: Tue, 29 Oct 2024 10:02:35 -0500 Subject: [PATCH] refactor streams for fetching/addrgen write end to end test --- src/groovylight/fetcher.py | 65 +++++++++++++++++---------- src/groovylight/tests/test_fetcher.py | 44 +++++++++++++++--- 2 files changed, 79 insertions(+), 30 deletions(-) diff --git a/src/groovylight/fetcher.py b/src/groovylight/fetcher.py index 56ba523..353c325 100644 --- a/src/groovylight/fetcher.py +++ b/src/groovylight/fetcher.py @@ -2,13 +2,13 @@ # to know its location. # during operation, it is given a row index, and responds with the data. - -from amaranth import Module, Signal, unsigned, Cat +from amaranth import Module, Signal, unsigned from amaranth.build import Platform from amaranth.lib import wiring, data from amaranth.lib.wiring import In, Out from amaranth.lib import stream import logging +from itertools import pairwise from .common import Rgb888Layout from .geom import DisplayString @@ -16,11 +16,9 @@ from .geom import DisplayString logger = logging.getLogger(__name__) - # FIXME: sizing should be based off of screen size. CoordLayout = data.StructLayout({"x": unsigned(10), "y": unsigned(10)}) - class AddressConverter(wiring.Component): """Translates display (x,y) into full screen (x,y) based on geometry""" @@ -53,7 +51,7 @@ class AddressGenerator(wiring.Component): self.geom = geom super().__init__( { - "coordstream": Out( + "output": Out( stream.Signature(data.ArrayLayout(CoordLayout, geom.dimensions.mux)) ), "start": In(1), @@ -75,41 +73,58 @@ class AddressGenerator(wiring.Component): m.d.comb += translate.input_x.eq(counter) m.d.comb += translate.addr.eq(addr) - m.d.comb += self.coordstream.payload.eq(translate.output) + m.d.comb += self.output.payload.eq(translate.output) with m.FSM(): with m.State("init"): - m.d.comb += [self.done.eq(0), self.coordstream.valid.eq(0)] + m.d.comb += [self.done.eq(0), self.output.valid.eq(0)] m.d.sync += [counter.eq(0), addr.eq(self.addr)] with m.If(self.start): m.next = "run" with m.State("run"): - m.d.comb += self.coordstream.valid.eq(1) + m.d.comb += self.output.valid.eq(1) # stream data out as long as it's valid. with m.If( - self.coordstream.ready - & (counter == self.geom.dimensions.length - 1) + self.output.ready & (counter == self.geom.dimensions.length - 1) ): m.next = "done" - with m.Elif(self.coordstream.ready): + with m.Elif(self.output.ready): m.d.sync += counter.eq(counter + 1) pass with m.State("done"): - m.d.comb += self.coordstream.valid.eq(0) + m.d.comb += self.output.valid.eq(0) m.d.comb += self.done.eq(1) m.next = "init" return m +def example_rgb_transform(x, y): + return { + "red": x + y, + "green": x - y, + "blue": x ^ y, + } + + class BasicFetcher(wiring.Component): - """A generic function-based fetcher. Takes a function of the form f(x,y: int) -> RGB.""" + """A generic function-based fetcher. Takes a function of the form f(x,y: int) -> dict rgb values. + + If no function is provided it uses a basic coordinate-driven rgb transform where red = x+y, + green = x - y, and blue = x ^ y. + + When providing a function, it must return a dictionary with the keys "red", "green", "blue".""" def __init__( - self, geom: DisplayString, dfunc, data_shape=Rgb888Layout, *, src_loc_at=0 + self, + geom: DisplayString, + dfunc=example_rgb_transform, + data_shape=Rgb888Layout, + *, + src_loc_at=0, ): self.geom = geom self.dfunc = dfunc @@ -118,7 +133,7 @@ class BasicFetcher(wiring.Component): "input": In( stream.Signature(data.ArrayLayout(CoordLayout, geom.dimensions.mux)) ), - "pixstream": Out( + "output": Out( stream.Signature(data.ArrayLayout(data_shape, geom.dimensions.mux)) ), }, @@ -128,20 +143,24 @@ class BasicFetcher(wiring.Component): def elaborate(self, platform: Platform) -> Module: m = Module() - # test mode - pass through, r = x + y, g = x - y, b = {y,x} - - colors = self.pixstream.payload + colors = self.output.payload m.d.comb += [ - self.input.valid.eq(self.pixstream.valid), - self.input.ready.eq(self.pixstream.ready), + self.output.valid.eq(self.input.valid), + self.input.ready.eq(self.output.ready), ] for i in range(self.geom.dimensions.mux): inp = self.input.payload[i] + output = self.dfunc(inp.x, inp.y) m.d.comb += [ - colors[i].red.eq(inp.x + inp.y), - colors[i].green.eq(inp.x - inp.y), - colors[i].blue.eq(inp.x ^ inp.y), + colors[i].red.eq(output["red"]), + colors[i].green.eq(output["green"]), + colors[i].blue.eq(output["blue"]), ] return m + + +def chain_streams(m, streams): + for pair in pairwise(streams): + wiring.connect(m, pair[0].output, pair[1].input) diff --git a/src/groovylight/tests/test_fetcher.py b/src/groovylight/tests/test_fetcher.py index 728d2f3..051e95f 100644 --- a/src/groovylight/tests/test_fetcher.py +++ b/src/groovylight/tests/test_fetcher.py @@ -1,10 +1,9 @@ -from amaranth.lib import wiring, data +from amaranth import Module +from amaranth.lib import wiring from amaranth.sim import Simulator -import random -from random import randrange import pytest -from groovylight.fetcher import AddressConverter, AddressGenerator, BasicFetcher +from groovylight.fetcher import AddressConverter, AddressGenerator, BasicFetcher, chain_streams from groovylight.geom import DisplayString, Coord, DisplayDimensions, DisplayRotation ds_testdata = [ @@ -82,7 +81,7 @@ def test_generator(addr, rot): async def stream_checker(ctx): while ctx.get(dut.done) == 0: - payload = await stream_get(ctx, dut.coordstream) + payload = await stream_get(ctx, dut.output) assert expected.pop() == payload sim.add_testbench(runner) @@ -103,12 +102,12 @@ def test_basic_fetcher(inp, expected): ds = DisplayString( Coord(3, 0), DisplayDimensions(128, 64, mux=1), DisplayRotation.R0 ) - dut = BasicFetcher(ds, None) + dut = BasicFetcher(ds) sim = Simulator(dut) async def test(ctx): ctx.set(dut.input.payload[0], inp) - res = ctx.get(dut.pixstream.payload)[0] + res = ctx.get(dut.output.payload)[0] assert res["red"] == expected["red"] assert res["green"] == expected["green"] assert res["blue"] == expected["blue"] @@ -117,3 +116,34 @@ def test_basic_fetcher(inp, expected): with sim.write_vcd("fetcher.vcd"): sim.run() + + + +def test_stream_e2e(): + ds = DisplayString( + Coord(3, 0), DisplayDimensions(128, 64, mux=1), DisplayRotation.R0 + ) + m = Module() + m.submodules.gen = addrgen = AddressGenerator(ds) + m.submodules.fetch = fetch = BasicFetcher(ds) + + chain_streams(m, [addrgen, fetch]) + + sim = Simulator(m) + sim.add_clock(1e-6) + + async def stim(ctx): + await ctx.tick() + ctx.set(addrgen.start, 1) + await ctx.tick() + ctx.set(addrgen.start, 0) + payload = await stream_get(ctx, fetch.output) + assert payload[0] == {"red": 3, "green": 3, "blue": 3} + + + sim.add_testbench(stim) + + with sim.write_vcd("stream_e2e.vcd"): + sim.run() + +