summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--test_trade.py8
-rw-r--r--trade.py10
-rw-r--r--trade_queue.py15
3 files changed, 19 insertions, 14 deletions
diff --git a/test_trade.py b/test_trade.py
index 691ca1a..dab8a78 100644
--- a/test_trade.py
+++ b/test_trade.py
@@ -17,7 +17,7 @@ class TestTrade(unittest.TestCase):
self.assertEqual(self.trade.amount, 10.0)
self.assertEqual(self.trade.total_cost, 100.0)
self.assertEqual(self.trade.date, "2025-04-14")
- self.assertAlmostEqual(self.trade.price_per_coin, 10.0)
+ self.assertEqual(self.trade.price_per_coin, 10.0)
def test_remove_coins_valid(self):
"""
@@ -25,8 +25,8 @@ class TestTrade(unittest.TestCase):
"""
self.trade.remove_coins(5.0)
self.assertEqual(self.trade.amount, 5.0)
- self.assertAlmostEqual(self.trade.total_cost, 50.0)
- self.assertAlmostEqual(self.trade.price_per_coin, 10.0)
+ self.assertEqual(self.trade.total_cost, 50.0)
+ self.assertEqual(self.trade.price_per_coin, 10.0)
def test_remove_coins_exceeds_amount(self):
"""
@@ -47,7 +47,7 @@ class TestTrade(unittest.TestCase):
"""
Test the __repr__ method for correct string representation.
"""
- expected_repr = "Trade(amount=10.0, price_per_coin=10.00, total_cost=100.00, date=2025-04-14)"
+ expected_repr = "Trade(amount=10, price_per_coin=10.00, total_cost=100.00, date=2025-04-14)"
self.assertEqual(repr(self.trade), expected_repr)
diff --git a/trade.py b/trade.py
index eedc9c4..e46d2b6 100644
--- a/trade.py
+++ b/trade.py
@@ -5,7 +5,7 @@ class Trade:
Represents a cryptocurrency trade, including the amount traded, total cost, and the date of trade.
Provides methods to modify the trade and access various attributes.
"""
- def __init__(self, amount: Decimal, total_cost: Decimal, date: str) -> None:
+ def __init__(self, amount: float|Decimal, total_cost: float|Decimal, date: str) -> None:
"""
Initialize a new Trade instance.
@@ -18,11 +18,11 @@ class Trade:
if amount <= 0 or total_cost <= 0:
raise ValueError("Amount and total cost must be > 0")
- self.__amount: Decimal = amount
- self.__total_cost: Decimal = total_cost
+ self.__amount: Decimal = Decimal(amount)
+ self.__total_cost: Decimal = Decimal(total_cost)
self.__date: str = date
- def remove_coins(self, amount: Decimal) -> None:
+ def remove_coins(self, amount: float|Decimal) -> None:
"""
Reduce the amount of cryptocurrency in the trade by a specified amount.
@@ -35,6 +35,7 @@ class Trade:
if amount > self.__amount:
raise ValueError(f"Can't remove more than {self.__amount}")
+ amount = Decimal(amount)
self.__total_cost -= amount * self.price_per_coin
self.__amount -= amount
@@ -81,6 +82,7 @@ class Trade:
"""
if self.amount == 0:
raise ZeroDivisionError("Price per coin cannot be calculated when the amount is zero")
+
return self.total_cost / self.amount
def __repr__(self) -> str:
diff --git a/trade_queue.py b/trade_queue.py
index 1048b8a..ed9902e 100644
--- a/trade_queue.py
+++ b/trade_queue.py
@@ -1,4 +1,5 @@
from collections import deque
+from decimal import Decimal
from typing import Deque, List
from trade import Trade
@@ -13,14 +14,14 @@ class FIFOQueue:
def __init__(self) -> None:
self.queue: Deque[Trade] = deque()
- def add(self, amount: float, total_cost: float, date: str) -> None:
+ def add(self, amount: float|Decimal, total_cost: float|Decimal, date: str) -> None:
"""
Add a trade to the queue.
"""
trade = Trade(amount, total_cost, date)
self.queue.append(trade)
- def remove(self, amount: float) -> List[Trade]:
+ def remove(self, amount: float|Decimal) -> List[Trade]:
"""
Remove a specified amount from the queue, returning the
trades used to buy.
@@ -28,7 +29,9 @@ class FIFOQueue:
if amount <= 0:
raise ValueError("The amount to remove must be positive.")
- remaining: float = amount
+ amount = Decimal(amount)
+
+ remaining: Decimal = amount
entries: List[Trade] = []
while remaining > 0:
@@ -39,7 +42,7 @@ class FIFOQueue:
if trade.amount > remaining:
trade.remove_coins(remaining)
entries.append(Trade(remaining, trade.total_cost, trade.date))
- remaining = 0
+ remaining = Decimal(0)
else:
remaining -= trade.amount
entries.append(trade)
@@ -47,8 +50,8 @@ class FIFOQueue:
return entries
- def get_remaining_amount(self) -> float:
+ def get_remaining_amount(self) -> Decimal:
"""
Calculate the total remaining amount in the queue.
"""
- return sum(trade.amount for trade in self.queue)
+ return sum((trade.amount for trade in self.queue), Decimal(0))