summaryrefslogtreecommitdiff
path: root/trade_queue.py
blob: e14ea1c50702b3fdd055df75c8a31616ed61567d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import logging
import bisect
from collections import deque
from copy import deepcopy
from decimal import Decimal
from typing import Callable, Deque, List

from exceptions import TradeNotFound
from trade import Trade

# Set up a dedicated logger for FIFOQueue
logger = logging.getLogger(__name__)


class FIFOQueue:
    """
    Crypto trading FIFO queue.

    Will track trades.
    """

    def __init__(self) -> None:
        """
        Create new FIFO queue holding cryptocurrency trades.
        """
        self.__queue: Deque[Trade] = deque()
        self._cached_total: Decimal = Decimal(0)
        self._cache_valid: bool = True
        logger.info("FIFOQueue initialized with empty queue.")

    def __len__(self) -> int:
        """
        Get amount of trades in the queue.
        """
        return len(self.__queue)

    def __repr__(self) -> str:
        """
        Get string representation of queue (for debugging).
        """
        return f"FIFOQueue(len={len(self)})"

    def get_copy(self) -> List[Trade]:
        """
        Helper for unit tests, to check internal state.
        """
        return list(deepcopy(self.__queue))

    def add(self, amount: Decimal, total_cost: Decimal, timestamp: str) -> None:
        """
        Add a trade to the queue by specifying properties.
        """
        trade = Trade(amount, total_cost, timestamp)
        bisect.insort(self.__queue, trade, key=lambda t: t.timestamp)
        self._cache_valid = False
        logger.info(f"Added trade: {trade}.")

    def add_trade(self, trade: Trade) -> None:
        """
        Add a trade to the queue.
        """
        bisect.insort(self.__queue, trade, key=lambda t: t.timestamp)
        self._cache_valid = False
        logger.info(f"Added trade: {trade}.")

    def remove_coins(self, amount: float | Decimal) -> List[Trade]:
        """
        Remove a specified amount of coins from the queue, returning the
        trades used to buy. This can be used to calculate profit/loss.
        """
        if amount <= 0:
            logger.error("Attempted to remove non-positive amount.")
            raise ValueError("The amount to remove must be positive.")

        amount = Decimal(amount)

        if amount > self.get_remaining_amount():
            logger.error(f"Insufficient assets to process sale of {amount}.")
            raise ValueError(
                f"Insufficient assets in queue to process sale of {amount}."
            )

        logger.debug(f"Removing {amount:.2f} coins from the queue.")
        logger.info("Cache invalidated before removing coins.")
        self._cache_valid = False

        remaining: Decimal = amount
        entries: List[Trade] = []

        while remaining > 0:
            trade = self.__queue[0]
            logger.debug(f"Processing trade: {trade}")

            if trade.amount > remaining:
                ppc = trade.price_per_coin
                trade.remove_coins(remaining)
                entries.append(Trade(remaining, remaining * ppc, trade.date))
                logger.info(f"Partial removal from trade: {remaining:.2f} coins.")
                break
            else:
                remaining -= trade.amount
                entries.append(trade)
                self.__queue.popleft()
                logger.info(
                    f"Removed full trade: {trade}. Remaining coins to remove: {remaining}"
                )

        return entries

    def get_remaining_amount(self) -> Decimal:
        """
        Calculate the total remaining amount in the queue.
        """
        if not self._cache_valid:
            logger.debug("Cache invalid, recalculating remaining amount.")
            self._cached_total = sum(
                (trade.amount for trade in self.__queue), Decimal(0)
            )
            self._cache_valid = True
            logger.info(f"Cache recalculated: {self._cached_total:.2f}")

        logger.debug(f"Returning cached remaining amount: {self._cached_total:.2f}")
        return self._cached_total

    def remove(self, predicate: Callable[[Trade], bool]) -> Trade:
        """
        Remove a trade from the queue based on a given predicate.

        Args:
            predicate (Callable[[Trade], bool]): A function that returns True for the trade to remove.

        Returns:
            Trade: The removed trade.

        Raises:
            TradeNotFound: If no trade matches the predicate or multiple trades are found.
            ValueError: If multiple trades match the predicate.
        """
        # Use filter to find matching trades
        matching_trades = list(filter(predicate, self.__queue))

        if len(matching_trades) == 0:
            logger.error("No matching trade found for removal.")
            raise TradeNotFound("No trade matches the given predicate.")
        elif len(matching_trades) > 1:
            logger.error("Multiple matching trades found for removal.")
            raise ValueError(
                "Multiple trades match the given predicate. Please refine your criteria."
            )

        # Locate the exact match in the original queue
        trade_to_remove = matching_trades[0]
        self.__queue.remove(trade_to_remove)
        self._cache_valid = False
        logger.info(f"Removed trade: {trade_to_remove}.")
        return trade_to_remove