summaryrefslogtreecommitdiff
path: root/test_trade_queue.py
blob: 9ce04101cc61a0635085f753dc71a603893f104c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
from decimal import Decimal
import unittest
from datetime import datetime

from exceptions import TradeNotFound
from trade import Trade
from trade_queue import FIFOQueue


class TestFIFOQueue(unittest.TestCase):
    def setUp(self):
        """
        Set up a FIFOQueue instance and some test trades.
        """
        self.queue = FIFOQueue()
        self.queue.add(Decimal(10.0), Decimal(100.0), "2025-04-14")
        self.queue.add(Decimal(20.0), Decimal(200.0), "2025-04-15")
        self.queue.add(Decimal(30.0), Decimal(300.0), "2025-04-16")

    def test_add(self):
        """
        Test adding trades to the queue.
        """
        tq = self.queue.get_copy()
        self.assertEqual(len(tq), 3)  # There should be 3 trades in the queue
        self.assertEqual(tq[0].amount, 10.0)  # Check the first trade's amount
        self.assertEqual(tq[1].date, "2025-04-15")  # Check the second trade's date

    def test_remove_exact_amount(self):
        """
        Test removing an exact amount from the queue.
        """
        trades = self.queue.remove_coins(10.0)
        self.assertEqual(len(trades), 1)  # One trade should be returned
        self.assertEqual(trades[0].amount, 10.0)  # Amount should match the request
        tq = self.queue.get_copy()
        self.assertEqual(len(tq), 2)  # Two trades should remain in the queue

    def test_remove_partial_trade(self):
        """
        Test removing an amount that partially consumes a trade.
        """
        trades = self.queue.remove_coins(5.0)
        self.assertEqual(len(trades), 1)  # One partial trade should be returned
        self.assertEqual(trades[0].amount, 5.0)  # Amount should match the request
        tq = self.queue.get_copy()
        self.assertEqual(tq[0].amount, 5.0)  # Remaining trade amount should update

    def test_remove_multiple_trades(self):
        """
        Test removing an amount that spans multiple trades.
        """
        trades = self.queue.remove_coins(25.0)
        self.assertEqual(len(trades), 2)  # Two trades should be returned
        # The first trade should be fully consumed
        self.assertEqual(trades[0].amount, 10.0)
        # The second trade should be partially consumed
        self.assertEqual(trades[1].amount, 15.0)
        tq = self.queue.get_copy()
        self.assertEqual(tq[0].amount, 5.0)  # Remaining trade in queue should update

    def test_remove_insufficient_amount(self):
        """
        Test trying to remove more than is available in the queue.
        """
        with self.assertRaises(ValueError):
            self.queue.remove_coins(100.0)  # This should raise an exception

    def test_remove_negative_amount(self):
        """
        Test trying to remove a negative amount.
        """
        with self.assertRaises(ValueError):
            self.queue.remove_coins(-5.0)  # This should raise an exception

    def test_get_remaining_amount_initial(self):
        """
        Test the remaining amount in the queue after adding trades.
        """
        # Total of all amounts: 10 + 20 + 30
        self.assertEqual(self.queue.get_remaining_amount(), 60.0)

    def test_get_remaining_amount_after_removal(self):
        """
        Test the remaining amount after removing some assets.
        """
        self.queue.remove_coins(15.0)  # Remove 15 assets
        # Remaining: 60 - 15
        self.assertEqual(self.queue.get_remaining_amount(), 45.0)

    def test_get_remaining_amount_empty_queue(self):
        """
        Test the remaining amount in an empty queue.
        """
        empty_queue = FIFOQueue()  # New empty queue
        self.assertEqual(empty_queue.get_remaining_amount(), 0.0)  # No trades in queue

    def test_get_remaining_amount_partial_removal(self):
        """
        Test the remaining amount after partially consuming a trade.
        """
        self.queue.remove_coins(5.0)  # Remove 5 assets, leaving 5 in the first trade
        self.assertEqual(self.queue.get_remaining_amount(), 55.0)  # Remaining: 60 - 5

    def test_get_remaining_amount_full_removal(self):
        """
        Test the remaining amount after removing all trades.
        """
        self.queue.remove_coins(60.0)  # Remove all assets
        self.assertEqual(self.queue.get_remaining_amount(), 0.0)  # Remaining: 0

    def test_remove_partial_trade_correct_cost(self):
        """
        Test removing a partial trade and ensure the correct cost is calculated.
        """
        trades = self.queue.remove_coins(4.0)  # Remove 4 COIN from the first trade
        self.assertEqual(len(trades), 1)  # Only one trade should be returned

        # Coin-cost needs to stay constant
        self.assertEqual(trades[0].price_per_coin, 10)
        self.assertEqual(trades[0].amount, 4.0)  # Check the removed amount
        # Total cost should be proportional: (100 * 5 / 10)
        self.assertEqual(trades[0].total_cost, 40.0)

        tq = self.queue.get_copy()
        # Original total cost remains unchanged
        self.assertEqual(tq[0].price_per_coin, 10)
        # Remaining amount in the first trade should be updated
        self.assertEqual(tq[0].amount, 6.0)
        # Original total cost remains unchanged
        self.assertEqual(tq[0].total_cost, 60.0)


class TestFIFOQueueRemove(unittest.TestCase):
    def setUp(self):
        """
        Set up a FIFOQueue instance with sample trades for testing.
        """
        self.fifo_queue = FIFOQueue()
        self.fifo_queue.add(Decimal("10"), Decimal("100"), "2025-03-17 12:00:00")
        self.fifo_queue.add(Decimal("10"), Decimal("100"), "2025-04-17 12:15:00")
        self.fifo_queue.add(Decimal("20"), Decimal("200"), "2025-04-18 13:20:00")
        self.fifo_queue.add(Decimal("10"), Decimal("100"), "2025-04-18 14:50:00")
        # Duplicate date for testing

    def test_add_in_order(self):
        """
        Test whether trades are ordered by date.
        """
        self.fifo_queue.add(Decimal("10"), Decimal("100"), "2025-04-01 11:11:00")
        lt = self.fifo_queue.get_copy()
        self.assertTrue(
            all(lt[i].timestamp <= lt[i + 1].timestamp for i in range(len(lt) - 1))
        )

    def test_remove_successful(self):
        """
        Test removing a trade successfully using a matching predicate.
        """
        removed_trade = self.fifo_queue.remove(lambda t: t.date == "2025-04-17")
        self.assertEqual(removed_trade.date, "2025-04-17")
        self.assertEqual(removed_trade.amount, Decimal("10"))
        self.assertEqual(removed_trade.total_cost, Decimal("100"))
        # Ensure one trade is removed
        self.assertEqual(len(self.fifo_queue), 3)

    def test_remove_no_match(self):
        """
        Test trying to remove a trade when no match is found.
        """
        # No such date
        with self.assertRaises(TradeNotFound) as context:
            self.fifo_queue.remove(lambda t: t.date == "2024-04-17")
        self.assertIn("No trade matches the given predicate.", str(context.exception))
        # Ensure no trade is removed
        self.assertEqual(len(self.fifo_queue), 4)

    def test_remove_multiple_matches(self):
        """
        Test trying to remove a trade when multiple matches are found.
        """
        # Two trades match this date
        with self.assertRaises(ValueError) as context:
            self.fifo_queue.remove(lambda t: t.date == "2025-04-18")
        self.assertIn(
            "Multiple trades match the given predicate.", str(context.exception)
        )
        # Ensure no trade is removed
        self.assertEqual(len(self.fifo_queue), 4)


class TestFIFOQueueMatchTrades(unittest.TestCase):

    def setUp(self):
        """Set up a new FIFOQueue for each test case."""
        self.fifo_queue = FIFOQueue()

    def test_full_match_single_buy_trade(self):
        """Test if a single buy trade fully matches a sell trade."""
        self.fifo_queue.add_trade(Trade(Decimal(5), Decimal(50), "2025-04-19 10:00:00"))
        self.fifo_queue.add_trade(
            Trade(Decimal(-5), Decimal(-75), "2025-04-19 12:00:00")
        )

        matches = self.fifo_queue.match_trades()

        self.assertEqual(len(matches), 1)
        self.assertEqual(matches[0][0].amount, Decimal(5))  # Buy trade amount
        self.assertEqual(matches[0][1].amount, Decimal(5))  # Sell trade amount

    def test_error_single_buy_trade_invalid_order(self):
        """Test if a trade match fails if sell date is before buy date."""
        self.fifo_queue.add_trade(Trade(Decimal(5), Decimal(50), "2025-04-19 10:00:00"))
        self.fifo_queue.add_trade(
            Trade(Decimal(-5), Decimal(-75), "2025-04-18 12:00:00")
        )

        with self.assertRaises(ValueError):
            self.fifo_queue.match_trades()

    def test_partial_match_multiple_buy_trades(self):
        """Test if a sell trade partially matches multiple buy trades."""
        self.fifo_queue.add_trade(Trade(Decimal(3), Decimal(30), "2025-04-19 10:00:00"))
        self.fifo_queue.add_trade(Trade(Decimal(4), Decimal(48), "2025-04-19 11:00:00"))
        self.fifo_queue.add_trade(
            Trade(Decimal(-5), Decimal(-65), "2025-04-19 12:00:00")
        )

        matches = self.fifo_queue.match_trades()

        self.assertEqual(len(matches), 2)
        # First buy trade (fully matched)
        self.assertEqual(matches[0][0].amount, Decimal(3))
        # Portion of sell trade
        self.assertEqual(matches[0][1].amount, Decimal(3))
        # Portion of second buy trade
        self.assertEqual(matches[1][0].amount, Decimal(2))
        # Remaining sell trade
        self.assertEqual(matches[1][1].amount, Decimal(2))
        # check remaining amount in queue
        self.assertEqual(self.fifo_queue.get_remaining_amount(), Decimal(2))

    def test_sell_trade_exceeds_buy_trades(self):
        """Test if an error is raised when a sell trade exceeds available buy trades."""
        self.fifo_queue.add_trade(Trade(Decimal(3), Decimal(30), "2025-04-19 10:00:00"))
        self.fifo_queue.add_trade(
            Trade(Decimal(-5), Decimal(-75), "2025-04-19 12:00:00")
        )

        with self.assertRaises(ValueError):
            self.fifo_queue.match_trades()

    def test_no_sell_trade(self):
        """Test behavior when there are no sell trades."""
        self.fifo_queue.add_trade(Trade(Decimal(5), Decimal(50), "2025-04-19 10:00:00"))

        matches = self.fifo_queue.match_trades()
        # No matches since no sell trade exists
        self.assertEqual(len(matches), 0)

    def test_no_buy_trades(self):
        """Test behavior when there are no buy trades."""
        self.fifo_queue.add_trade(
            Trade(Decimal(-5), Decimal(-75), "2025-04-19 12:00:00")
        )

        with self.assertRaises(ValueError):
            self.fifo_queue.match_trades()




if __name__ == "__main__":
    unittest.main()