from decimal import Decimal
import unittest
from datetime import datetime

from trade_queue import FIFOQueue


class TestFIFOQueue(unittest.TestCase):
    def setUp(self):
        """
        Set up a FIFOQueue instance and some test trades.
        """
        self.queue = FIFOQueue()
        self.queue.add(Decimal(10.0), Decimal(100.0), "2025-04-14")
        self.queue.add(Decimal(20.0), Decimal(200.0), "2025-04-15")
        self.queue.add(Decimal(30.0), Decimal(300.0), "2025-04-16")

    def test_add(self):
        """
        Test adding trades to the queue.
        """
        tq = self.queue.get_copy()
        self.assertEqual(len(tq), 3)  # There should be 3 trades in the queue
        self.assertEqual(tq[0].amount, 10.0)  # Check the first trade's amount
        self.assertEqual(tq[1].date, "2025-04-15")  # Check the second trade's date

    def test_remove_exact_amount(self):
        """
        Test removing an exact amount from the queue.
        """
        trades = self.queue.remove_coins(10.0)
        self.assertEqual(len(trades), 1)  # One trade should be returned
        self.assertEqual(trades[0].amount, 10.0)  # Amount should match the request
        tq = self.queue.get_copy()
        self.assertEqual(len(tq), 2)  # Two trades should remain in the queue

    def test_remove_partial_trade(self):
        """
        Test removing an amount that partially consumes a trade.
        """
        trades = self.queue.remove_coins(5.0)
        self.assertEqual(len(trades), 1)  # One partial trade should be returned
        self.assertEqual(trades[0].amount, 5.0)  # Amount should match the request
        tq = self.queue.get_copy()
        self.assertEqual(tq[0].amount, 5.0)  # Remaining trade amount should update

    def test_remove_multiple_trades(self):
        """
        Test removing an amount that spans multiple trades.
        """
        trades = self.queue.remove_coins(25.0)
        self.assertEqual(len(trades), 2)  # Two trades should be returned
        # The first trade should be fully consumed
        self.assertEqual(trades[0].amount, 10.0)
        # The second trade should be partially consumed
        self.assertEqual(trades[1].amount, 15.0)
        tq = self.queue.get_copy()
        self.assertEqual(tq[0].amount, 5.0)  # Remaining trade in queue should update

    def test_remove_insufficient_amount(self):
        """
        Test trying to remove more than is available in the queue.
        """
        with self.assertRaises(ValueError):
            self.queue.remove_coins(100.0)  # This should raise an exception

    def test_remove_negative_amount(self):
        """
        Test trying to remove a negative amount.
        """
        with self.assertRaises(ValueError):
            self.queue.remove_coins(-5.0)  # This should raise an exception

    def test_get_remaining_amount_initial(self):
        """
        Test the remaining amount in the queue after adding trades.
        """
        # Total of all amounts: 10 + 20 + 30
        self.assertEqual(self.queue.get_remaining_amount(), 60.0)

    def test_get_remaining_amount_after_removal(self):
        """
        Test the remaining amount after removing some assets.
        """
        self.queue.remove_coins(15.0)  # Remove 15 assets
        # Remaining: 60 - 15
        self.assertEqual(self.queue.get_remaining_amount(), 45.0)

    def test_get_remaining_amount_empty_queue(self):
        """
        Test the remaining amount in an empty queue.
        """
        empty_queue = FIFOQueue()  # New empty queue
        self.assertEqual(empty_queue.get_remaining_amount(), 0.0)  # No trades in queue

    def test_get_remaining_amount_partial_removal(self):
        """
        Test the remaining amount after partially consuming a trade.
        """
        self.queue.remove_coins(5.0)  # Remove 5 assets, leaving 5 in the first trade
        self.assertEqual(self.queue.get_remaining_amount(), 55.0)  # Remaining: 60 - 5

    def test_get_remaining_amount_full_removal(self):
        """
        Test the remaining amount after removing all trades.
        """
        self.queue.remove_coins(60.0)  # Remove all assets
        self.assertEqual(self.queue.get_remaining_amount(), 0.0)  # Remaining: 0

    def test_remove_partial_trade_correct_cost(self):
        """
        Test removing a partial trade and ensure the correct cost is calculated.
        """
        trades = self.queue.remove_coins(4.0)  # Remove 4 COIN from the first trade
        self.assertEqual(len(trades), 1)  # Only one trade should be returned

        # Coin-cost needs to stay constant
        self.assertEqual(trades[0].price_per_coin, 10)
        self.assertEqual(trades[0].amount, 4.0)  # Check the removed amount
        # Total cost should be proportional: (100 * 5 / 10)
        self.assertEqual(trades[0].total_cost, 40.0)

        tq = self.queue.get_copy()
        # Original total cost remains unchanged
        self.assertEqual(tq[0].price_per_coin, 10)
        # Remaining amount in the first trade should be updated
        self.assertEqual(tq[0].amount, 6.0)
        # Original total cost remains unchanged
        self.assertEqual(tq[0].total_cost, 60.0)


class TestFIFOQueueRemove(unittest.TestCase):
    def setUp(self):
        """
        Set up a FIFOQueue instance with sample trades for testing.
        """
        self.fifo_queue = FIFOQueue()
        self.fifo_queue.add(Decimal("10"), Decimal("100"), "2025-03-17")
        self.fifo_queue.add(Decimal("10"), Decimal("100"), "2025-04-17")
        self.fifo_queue.add(Decimal("20"), Decimal("200"), "2025-04-18")
        self.fifo_queue.add(Decimal("10"), Decimal("100"), "2025-04-18")
        # Duplicate date for testing

    def test_remove_successful(self):
        """
        Test removing a trade successfully using a matching predicate.
        """
        removed_trade = self.fifo_queue.remove(lambda t: t.date == "2025-04-17")
        self.assertEqual(removed_trade.date, "2025-04-17")
        self.assertEqual(removed_trade.amount, Decimal("10"))
        self.assertEqual(removed_trade.total_cost, Decimal("100"))
        # Ensure one trade is removed
        self.assertEqual(len(self.fifo_queue), 3)

    def test_remove_no_match(self):
        """
        Test trying to remove a trade when no match is found.
        """
        # No such date
        with self.assertRaises(ValueError) as context:
            self.fifo_queue.remove(lambda t: t.date == "2024-04-17")
        self.assertIn("No trade matches the given predicate.", str(context.exception))
        # Ensure no trade is removed
        self.assertEqual(len(self.fifo_queue), 4)

    def test_remove_multiple_matches(self):
        """
        Test trying to remove a trade when multiple matches are found.
        """
        # Two trades match this date
        with self.assertRaises(ValueError) as context:
            self.fifo_queue.remove(lambda t: t.date == "2025-04-18")
        self.assertIn(
            "Multiple trades match the given predicate.", str(context.exception)
        )
        # Ensure no trade is removed
        self.assertEqual(len(self.fifo_queue), 4)


if __name__ == "__main__":
    unittest.main()