# SPDX-FileCopyrightText: 2017 Scott Shawcroft, written for Adafruit Industries
# SPDX-FileCopyrightText: Copyright (c) 2024 Cooper Dalrymple
#
# SPDX-License-Identifier: MIT
"""
`pio_i2s`
================================================================================

Bidirectional I2S audio communication using PIO.

* Author(s): Cooper Dalrymple

Implementation Notes
--------------------

**Software and Dependencies:**

* Adafruit CircuitPython firmware for the supported boards (requires version 9.2.1+):
  https://circuitpython.org/downloads

* Adafruit's PIOASM library: https://github.com/adafruit/Adafruit_CircuitPython_PIOASM
"""

# imports

__version__ = "1.1.0"
__repo__ = "https://github.com/relic-se/CircuitPython_PIO_I2S.git"

import array

import adafruit_pioasm
import microcontroller
import rp2pio

try:
    import circuitpython_typing
except ImportError:
    pass


def _get_gpio_index(pin: microcontroller.Pin) -> int:
    for name in dir(microcontroller.pin):
        if getattr(microcontroller.pin, name) is pin:
            return int(name.replace("GPIO", ""))
    return None


class I2S:
    """Communicate with external audio devices using I2S protocol.

    :param bit_clock: The bit clock (or serial clock) pin.
    :type bit_clock: :class:`microcontroller.Pin`
    :param word_select: The word select (or left/right clock) pin. Must be the next pin from
        bit_clock sequentially. If not specified, the next pin sequentially from bit_clock will be
        used automatically.
    :type word_select: :class:`microcontroller.Pin`
    :param data_out: The output data pin. If left unspecified, write functionality will be disabled.
    :type data_out: :class:`microcontroller.Pin`, optional
    :param data_in: The input data pin. If left unspecified, read functionality will be disabled.
    :type data_in: :class:`microcontroller.Pin`, optional
    :param channel_count: The number of channels. 1 = mono; 2 = stereo.
    :type channel_count: `int`, optional
    :param sample_rate: The sample rate to be used.
    :type sample_rate: `int`, optional
    :param bits_per_sample: The bits per sample of be used. Must be 8, 16, 24, or 32 bits.
    :type bits_per_sample: `int`, optional
    :param samples_signed: Whether the samples are signed (True) or unsigned (False).
    :type samples_signed: `bool`, optional
    :param buffer_size: The total size in bytes of each of the two playback and record buffers to
        use.
    :type buffer_size: `int`, optional
    :param left_justified: True when data bits are aligned with the word select clock. False when
        they are shifted by one to match classic I2S protocol.
    :type left_justified: `bool`, optional
    :param peripheral: Whether the clock signals are generated by this device (False) or are read
        from the output of an external device (True). data_in must be specified if using peripheral
        mode and come before bit_clock sequentially.
    :type peripheral: `bool`, optional
    """

    def __init__(  # noqa: PLR0912, PLR0913
        self,
        bit_clock: microcontroller.Pin,
        word_select: microcontroller.Pin = None,
        data_out: microcontroller.Pin = None,
        data_in: microcontroller.Pin = None,
        channel_count: int = 2,
        sample_rate: int = 48000,
        bits_per_sample: int = 16,
        samples_signed: bool = True,
        buffer_size: int = 1024,
        left_justified: bool = False,
        peripheral: bool = False,
    ):
        if word_select and not rp2pio.pins_are_sequential([bit_clock, word_select]):
            raise ValueError("Word select pin must be sequential to bit clock pin")

        if peripheral and not data_in:
            raise ValueError("Data input pin must be specified in peripheral mode")

        if peripheral and not rp2pio.pins_are_sequential([data_in, bit_clock]):
            raise ValueError("Data input pin must come before bit clock pin sequentially")

        if channel_count < 1 or channel_count > 2:
            raise ValueError("Invalid channel count")

        if bits_per_sample % 8 != 0 or bits_per_sample < 8 or bits_per_sample > 32:
            raise ValueError("Invalid bits per sample")

        if buffer_size < 1:
            raise ValueError("Buffer size must be greater than 0")

        self._channel_count = channel_count
        self._sample_rate = sample_rate
        self._bits_per_sample = bits_per_sample
        self._samples_signed = samples_signed
        self._buffer_size = buffer_size

        self._writable = bool(data_out)
        self._readable = bool(data_in)

        left_channel_out = "out pins 1" if self._writable else "nop"
        right_channel_out = "out pins 1" if self._writable and channel_count > 1 else "nop"

        left_channel_in = "in pins 1" if self._readable else "nop"
        right_channel_in = "in pins 1" if self._readable and channel_count > 1 else "nop"

        if not peripheral:
            pioasm = f"""
.program i2s_controller
.side_set 2
    nop                         side 0b{1 if left_justified else 0}1
    set x {bits_per_sample-2}   side 0b{1 if left_justified else 0}1
left_bit:
    {left_channel_out}          side 0b00 [1]
    {left_channel_in}           side 0b01
    jmp x-- left_bit            side 0b01
    {left_channel_out}          side 0b{0 if left_justified else 1}0 [1]
    {left_channel_in}           side 0b{0 if left_justified else 1}1
    set x {bits_per_sample-2}   side 0b{0 if left_justified else 1}1
right_bit:
    {right_channel_out}         side 0b10 [1]
    {right_channel_in}          side 0b11
    jmp x-- right_bit           side 0b11
    {right_channel_out}         side 0b{1 if left_justified else 0}0 [1]
    {right_channel_in}          side 0b{1 if left_justified else 0}1
"""
        else:
            bit_clock_gpio = _get_gpio_index(bit_clock)
            word_select_gpio = _get_gpio_index(word_select) if word_select else bit_clock_gpio + 1
            if not left_justified:
                pioasm = f"""
.program i2s_peripheral
.side_set 2
    wait 1 gpio {word_select_gpio}
    wait 1 gpio {bit_clock_gpio}
    wait 0 gpio {word_select_gpio}
    wait 0 gpio {bit_clock_gpio}
    set x {bits_per_sample-2}
    wait 1 gpio {bit_clock_gpio}
left_bit:
    wait 0 gpio {bit_clock_gpio}
    {left_channel_out}
    wait 1 gpio {bit_clock_gpio}
    {left_channel_in}
    jmp x-- left_bit
    wait 1 gpio {word_select_gpio}
    wait 0 gpio {bit_clock_gpio}
    {left_channel_out}
    wait 1 gpio {bit_clock_gpio}
    {left_channel_in}
    set x {bits_per_sample-2}
right_bit:
    wait 0 gpio {bit_clock_gpio}
    {right_channel_out}
    wait 1 gpio {bit_clock_gpio}
    {right_channel_in}
    jmp x-- right_bit
    wait 0 gpio {word_select_gpio}
    wait 0 gpio {bit_clock_gpio}
    {right_channel_out}
    wait 1 gpio {bit_clock_gpio}
    {right_channel_in}
"""
            else:
                pioasm = f"""
.program i2s_peripheral_left_justified
.side_set 2
    wait 1 gpio {word_select_gpio}
    wait 1 gpio {bit_clock_gpio}
    set x {bits_per_sample-1}
    wait 0 gpio {word_select_gpio}
left_bit:
    wait 0 gpio {bit_clock_gpio}
    {left_channel_out}
    wait 1 gpio {bit_clock_gpio}
    {left_channel_in}
    jmp x-- left_bit
    set x {bits_per_sample-1}
    wait 1 gpio {word_select_gpio}
right_bit:
    wait 0 gpio {bit_clock_gpio}
    {right_channel_out}
    wait 1 gpio {bit_clock_gpio}
    {right_channel_in}
    jmp x-- right_bit
"""

        self._pio = rp2pio.StateMachine(
            program=adafruit_pioasm.assemble(pioasm),
            wrap_target=1 if not peripheral else (4 if not left_justified else 2),
            frequency=sample_rate * bits_per_sample * 2 * (4 if not peripheral else 16),
            first_out_pin=data_out,
            out_pin_count=1,
            first_in_pin=data_in,
            in_pin_count=1 if not peripheral else 3,
            first_sideset_pin=bit_clock if not peripheral else None,
            sideset_pin_count=2 if not peripheral else 1,
            auto_pull=True,
            pull_threshold=bits_per_sample,
            out_shift_right=False,
            auto_push=True,
            push_threshold=bits_per_sample,
            in_shift_right=False,
        )

        # Begin double-buffered background read/write operations

        self._buffer_format = (
            "b" if bits_per_sample == 8 else ("h" if bits_per_sample == 16 else "l")
        )
        if not samples_signed:
            self._buffer_format = self._buffer_format.upper()

        self._silence = 0 if samples_signed else 2 ** (bits_per_sample - 1)

        if self._writable:
            self._buffer_out = [
                array.array(
                    self._buffer_format,
                    [self._silence] * buffer_size,
                )
                for i in range(2)
            ]  # double-buffered
            self._pio.background_write(
                loop=self._buffer_out[0],
                loop2=self._buffer_out[1],
            )
            self._write_index = 0
            self._last_write_index = -1

        if self._readable:
            self._buffer_in = [
                array.array(self._buffer_format, [self._silence] * buffer_size) for i in range(2)
            ]  # double-buffered
            self._pio.background_read(
                loop=self._buffer_in[0],
                loop2=self._buffer_in[1],
            )

    def deinit(self) -> None:
        """Stop I2S communication and de-initialize resources used by this object."""
        self._pio.stop()
        self._pio.deinit()
        del self._pio

        if hasattr(self, "_buffer_out"):
            del self._buffer_out

        if hasattr(self, "_buffer_in"):
            del self._buffer_in

    @property
    def channel_count(self) -> int:
        """The number of channels used by the I2S bus. 1 for mono, 2 for stereo. This property is
        read-only.
        """
        return self._channel_count

    @property
    def sample_rate(self) -> int:
        """The rate of the I2S bus in samples per second. This property is read-only."""
        return self._sample_rate

    @property
    def bits_per_sample(self) -> int:
        """The number of bits per sample. This property is read-only."""
        return self._bits_per_sample

    @property
    def samples_signed(self) -> bool:
        """Whether or not the samples are signed (True) or unsigned (False) integers. This property
        is read-only.
        """
        return self._samples_signed

    @property
    def buffer_size(self) -> int:
        """The number of samples per buffer. This property is read-only."""
        return self._buffer_size

    @property
    def buffer_format(self) -> str:
        """The format code of the :class:`array.array` buffers. For more information, refer to the
        original CPython documentation:
        `array <https://docs.python.org/3/library/array.html#module-array>`_. This property is
        read-only.
        """
        return self._buffer_format

    def _get_write_index(self) -> int:
        if not self._writable:
            return None
        last_write = self._pio.last_write
        for i in range(2):
            if last_write is self._buffer_out[i]:
                self._write_index = i
                break
        return self._write_index

    def _set_write_buffer(
        self, data: circuitpython_typing.ReadableBuffer, double: bool = False
    ) -> None:
        if self._writable:
            idx = self._get_write_index()
            for i in range(2 if double else 1):
                for j in range(min(len(data), self._buffer_size)):
                    self._buffer_out[idx][j] = data[j]
                if len(data) < self._buffer_size:
                    for j in range(len(data), self._buffer_size):
                        self._buffer_out[idx][j] = self._silence
                self._last_write_index = idx
                idx = (idx + 1) % 2

    @property
    def write_ready(self) -> bool:
        """Whether or not the I2S bus has a buffer that is ready to be written to. This property is
        read-only.
        """
        return self._writable and self._get_write_index() != self._last_write_index

    def write(
        self, data: circuitpython_typing.ReadableBuffer, loop: bool = False, block: bool = True
    ) -> bool:
        """Write an array-like set of audio samples to the output buffer up to the maximum
        :attr:`buffer_size`.

        :param data: The array of sample data.
        :type data: :class:`circuitpython_typing.ReadableBuffer`
        :param loop: Whether or not to loop the sample data by copying it to both output buffers.
        :type loop: `bool`, optional
        :param block: Whether or not to wait until the I2S bus is ready to be written to.
        :type block: `bool`, optional
        :return: Whether or not the output buffer was successfully written to.
        """
        if not self._writable or not data:
            return False
        if block:
            for i in range(2 if loop else 1):
                while not self.write_ready:
                    pass
                self._set_write_buffer(data)
        elif loop:
            self._set_write_buffer(data, True)
        else:
            if not self.write_ready:
                return False
            self._set_write_buffer(data)
        return True

    def play(self, source: circuitpython_typing.ReadableBuffer, source_length: int = None) -> bool:
        """Plays samples from the source data to the output of the I2S bus bytes of samples to
        destination. This is blocking.

        :param destination: The destination buffer to write the samples from the I2S bus to.
        :type destination: :class:`circuitpython_typing.ReadableBuffer`
        :param destination_length: The number of samples to write to the destination buffer. If not
            provided, the full size of the destination buffer will be written to.
        :type destination_length: `int`
        """
        if not self._writable:
            return False
        if source_length is None:
            source_length = len(source)
        index = 0
        while index < source_length:
            self.write(source[index : index + min(source_length - index, self._buffer_size)])
            index += self._buffer_size
        return True

    def read(self, block: bool = True) -> array.array:
        """Read the input data from the I2S bus as an array of audio samples.

        :param block: Whether or not to wait until data from the I2S bus can be read from.
        :type block: `bool`, optional
        :return: An :class:`array.array` object with :attr:`buffer_size` elements. If the
            :attr:`channel_count` is stereo (2), the left and right channels will alternate between
            even and odd indexes.
        """
        if not self._readable:
            return None
        if block:
            while not (data := self._pio.last_read):
                pass
            return data
        else:
            return self._pio.last_read

    def record(
        self, destination: circuitpython_typing.ReadableBuffer, destination_length: int = None
    ) -> bool:
        """Records samples from the I2S bus to the destination. This is blocking.

        :param destination: The destination buffer to write the samples from the I2S bus to.
        :type destination: :class:`circuitpython_typing.ReadableBuffer`
        :param destination_length: The number of samples to write to the destination buffer. If not
            provided, the full size of the destination buffer will be written to.
        :type destination_length: `int`
        :return: Whether or not the recording operation was successful.
        """
        if not self._readable:
            return False
        if destination_length is None:
            destination_length = len(destination)
        index = 0
        while index < destination_length:
            buffer = self.read()
            if not buffer:
                return False
            for i in range(min(destination_length - index, self._buffer_size)):
                destination[index + i] = buffer[i]
            index += self._buffer_size
        return True
