import logging
import bisect
from collections import deque
from copy import deepcopy
from decimal import Decimal
from typing import Callable, Deque, List, Tuple

from exceptions import TradeNotFound
from trade import Trade

# Set up a dedicated logger for FIFOQueue
logger = logging.getLogger(__name__)


class FIFOQueue:
    """
    Crypto trading FIFO queue.

    Will track trades.
    """

    def __init__(self) -> None:
        """
        Create new FIFO queue holding cryptocurrency trades.
        """
        self.__queue: Deque[Trade] = deque()
        self._cached_total: Decimal = Decimal(0)
        self._cache_valid: bool = True
        logger.info("FIFOQueue initialized with empty queue.")

    def __len__(self) -> int:
        """
        Get amount of trades in the queue.
        """
        return len(self.__queue)

    def __repr__(self) -> str:
        """
        Get string representation of queue (for debugging).
        """
        return f"FIFOQueue(len={len(self)})"

    def get_copy(self) -> List[Trade]:
        """
        Helper for unit tests, to check internal state.
        """
        return list(deepcopy(self.__queue))

    def add(self, amount: Decimal, total_cost: Decimal, timestamp: str) -> None:
        """
        Add a trade to the queue by specifying properties.
        """
        trade = Trade(amount, total_cost, timestamp)
        bisect.insort(self.__queue, trade, key=lambda t: t.timestamp)
        self._cache_valid = False
        logger.info(f"Added trade: {trade}.")

    def add_trade(self, trade: Trade) -> None:
        """
        Add a trade to the queue.
        """
        bisect.insort(self.__queue, trade, key=lambda t: t.timestamp)
        self._cache_valid = False
        logger.info(f"Added trade: {trade}.")

    def match_trades(self) -> List[Tuple[Trade, Trade]]:
        """
        Match sell trades with buy trades in a FIFO manner, ensuring matched amounts are equal.
        If multiple buy trades are needed for one sell trade, split the sell trade accordingly.

        Returns:
            List[Tuple[Trade, Trade]]: List of tuples with (buy_trade, sell_trade) of equal amounts.
        """
        matched_pairs: List[Tuple[Trade, Trade]] = []

        # Locate the next sell trade
        sell_trade = next((t for t in self.__queue if t.amount < 0), None)

        while sell_trade:
            matched_pairs.extend(self.__match_trades_impl(sell_trade))
            self.__queue.remove(sell_trade)
            self._cache_valid = False
            sell_trade = next((t for t in self.__queue if t.amount < 0), None)

        if not sell_trade:
            logger.info("No more sell trades available for matching.")

        return matched_pairs

    def __match_trades_impl(self, sell_trade: Trade) -> List[Tuple[Trade, Trade]]:
        matched_pairs: List[Tuple[Trade, Trade]] = []

        # Convert to positive for easier calculations
        remaining_sell_amount = -sell_trade.amount

        # Use `self.remove_coins` to fetch corresponding buy trades
        try:
            buy_trades = self.remove_coins(remaining_sell_amount)
        except ValueError as e:
            logger.error(f"Failed to match trade: {e}")
            raise

        # Process the buy trades and split the sell trade if necessary
        for buy_trade in buy_trades:
            #  remove_coins should take care this doesn't happen
            assert buy_trade.amount <= remaining_sell_amount

            matched_pairs.append(
                (
                    buy_trade,
                    Trade(
                        buy_trade.amount,
                        buy_trade.amount * sell_trade.price_per_coin,
                        sell_trade.timestamp,
                    ),
                )
            )
            remaining_sell_amount -= buy_trade.amount

        logger.info(f"Matched sell trade {sell_trade} with buy trades: {matched_pairs}")
        return matched_pairs

    # The remove_coins method needs updating, it currently operates on the first trade in the queue, disregarding whether it's buy or sell
    def remove_coins(
        self, amount: float | Decimal, before_ts: str | None = None
    ) -> List[Trade]:
        """
        Remove a specified amount of coins from the queue, returning the
        trades used to buy. This can be used to calculate profit/loss.
        """
        if amount <= 0:
            logger.error("Attempted to remove non-positive amount.")
            raise ValueError("The amount to remove must be positive.")

        amount = Decimal(amount)

        available = sum(
            (
                trade.amount
                for trade in self.__queue
                if trade.amount > 0 and trade.timestamp < (before_ts or "9999-99-99")
            ),
            Decimal(0),
        )
        if amount > available:
            logger.error(f"Insufficient assets to process sale of {amount}.")
            raise ValueError(
                f"Insufficient assets in queue to process sale of {amount}."
            )

        logger.debug(f"Removing {amount:.2f} coins from the queue.")
        logger.info("Cache invalidated before removing coins.")
        self._cache_valid = False

        remaining: Decimal = amount
        entries: List[Trade] = []

        while remaining > 0:
            trade = self.__queue[0]
            assert trade.amount > 0
            logger.debug(f"Processing trade: {trade}")

            if trade.amount > remaining:
                ppc = trade.price_per_coin
                trade.remove_coins(remaining)
                entries.append(Trade(remaining, remaining * ppc, trade.date))
                logger.info(f"Partial removal from trade: {remaining:.2f} coins.")
                break
            else:
                remaining -= trade.amount
                entries.append(trade)
                self.__queue.popleft()
                logger.info(
                    f"Removed full trade: {trade}. Remaining coins to remove: {remaining}"
                )

        return entries

    def get_remaining_amount(self) -> Decimal:
        """
        Calculate the total remaining amount in the queue.
        """
        if not self._cache_valid:
            logger.debug("Cache invalid, recalculating remaining amount.")
            self._cached_total = sum(
                (trade.amount for trade in self.__queue), Decimal(0)
            )
            self._cache_valid = True
            logger.info(f"Cache recalculated: {self._cached_total:.2f}")

        logger.debug(f"Returning cached remaining amount: {self._cached_total:.2f}")
        return self._cached_total

    def remove(self, predicate: Callable[[Trade], bool]) -> Trade:
        """
        Remove a trade from the queue based on a given predicate.

        Args:
            predicate (Callable[[Trade], bool]): A function that returns True for the trade to remove.

        Returns:
            Trade: The removed trade.

        Raises:
            TradeNotFound: If no trade matches the predicate or multiple trades are found.
            ValueError: If multiple trades match the predicate.
        """
        # Use filter to find matching trades
        matching_trades = list(filter(predicate, self.__queue))

        if len(matching_trades) == 0:
            logger.error("No matching trade found for removal.")
            raise TradeNotFound("No trade matches the given predicate.")
        elif len(matching_trades) > 1:
            logger.error("Multiple matching trades found for removal.")
            raise ValueError(
                "Multiple trades match the given predicate. Please refine your criteria."
            )

        # Locate the exact match in the original queue
        trade_to_remove = matching_trades[0]
        self.__queue.remove(trade_to_remove)
        self._cache_valid = False
        logger.info(f"Removed trade: {trade_to_remove}.")
        return trade_to_remove