1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
|
import unittest
from decimal import Decimal
from trade import PriceAdaption, Trade
class TestTrade(unittest.TestCase):
def setUp(self) -> None:
"""
Set up a Trade instance for testing.
"""
self.trade = Trade(amount=Decimal(10.0), total_cost=Decimal(100.0), timestamp="2025-04-14 00:00:00", refid="abcd1234")
def test_initialization(self):
"""
Test that the Trade instance initializes correctly.
"""
self.assertEqual(self.trade.amount, 10.0)
self.assertEqual(self.trade.total_cost, 100.0)
self.assertEqual(self.trade.date, "2025-04-14")
self.assertEqual(self.trade.price_per_coin, 10.0)
def test_remove_coins_ppc_valid(self):
"""
Test removing a valid amount of coins. (PPC const).
"""
coin_price_before = self.trade.price_per_coin
self.trade.remove_coins(4.0, PriceAdaption.KeepPricePerCoin)
coin_price_after = self.trade.price_per_coin
self.assertEqual(coin_price_before, coin_price_after)
self.assertEqual(self.trade.amount, 6.0)
self.assertEqual(self.trade.total_cost, 60.0)
self.assertEqual(self.trade.price_per_coin, 10.0)
def test_remove_coins_tc_valid(self):
"""
Test removing a valid amount of coins. (TC const).
"""
coin_price_before = self.trade.price_per_coin
self.trade.remove_coins(4.0, PriceAdaption.KeepTotalCost)
coin_price_after = self.trade.price_per_coin
self.assertNotEqual(coin_price_before, coin_price_after)
self.assertEqual(self.trade.amount, 6.0)
self.assertEqual(self.trade.total_cost, 100.0)
self.assertEqual(self.trade.price_per_coin, Decimal("100.0") / Decimal("6.0"))
def test_remove_coins_exceeds_amount(self):
"""
Test removing more coins than available.
"""
with self.assertRaises(ValueError):
self.trade.remove_coins(15.0)
def test_price_per_coin_division_by_zero(self):
"""
Test the price_per_coin property when the amount is zero.
"""
self.trade.remove_coins(10.0) # Reduce amount to zero
with self.assertRaises(ZeroDivisionError):
_ = self.trade.price_per_coin
def test_repr(self):
"""
Test the __repr__ method for correct string representation.
"""
expected_repr = "Trade(amount=10.00, price_per_coin=10.00, total_cost=100.00, date=2025-04-14, refid=abcd1234)"
self.assertEqual(repr(self.trade), expected_repr)
if __name__ == "__main__":
unittest.main()
|