From 5ed67c88b608a3ee10381635fcff799eebbfc201 Mon Sep 17 00:00:00 2001
From: uvok
Date: Thu, 17 Apr 2025 11:00:30 +0200
Subject: Add removal methods

---
 test_trade.py | 26 ++++++++++++++++++++------
 trade.py      | 22 +++++++++++++++++-----
 2 files changed, 37 insertions(+), 11 deletions(-)

diff --git a/test_trade.py b/test_trade.py
index e6bb445..c3d7764 100644
--- a/test_trade.py
+++ b/test_trade.py
@@ -1,7 +1,7 @@
 import unittest
 from decimal import Decimal
 
-from trade import Trade
+from trade import PriceAdaption, Trade
 
 class TestTrade(unittest.TestCase):
     def setUp(self) -> None:
@@ -19,20 +19,34 @@ class TestTrade(unittest.TestCase):
         self.assertEqual(self.trade.date, "2025-04-14")
         self.assertEqual(self.trade.price_per_coin, 10.0)
 
-    def test_remove_coins_valid(self):
+    def test_remove_coins_ppc_valid(self):
         """
-        Test removing a valid amount of coins.
+        Test removing a valid amount of coins. (PPC const).
         """
         coin_price_before = self.trade.price_per_coin
-        self.trade.remove_coins(5.0)
+        self.trade.remove_coins(4.0, PriceAdaption.KeepPricePerCoin)
         coin_price_after = self.trade.price_per_coin
 
         self.assertEqual(coin_price_before, coin_price_after)
 
-        self.assertEqual(self.trade.amount, 5.0)
-        self.assertEqual(self.trade.total_cost, 50.0)
+        self.assertEqual(self.trade.amount, 6.0)
+        self.assertEqual(self.trade.total_cost, 60.0)
         self.assertEqual(self.trade.price_per_coin, 10.0)
 
+    def test_remove_coins_tc_valid(self):
+        """
+        Test removing a valid amount of coins. (TC const).
+        """
+        coin_price_before = self.trade.price_per_coin
+        self.trade.remove_coins(4.0, PriceAdaption.KeepTotalCost)
+        coin_price_after = self.trade.price_per_coin
+
+        self.assertNotEqual(coin_price_before, coin_price_after)
+
+        self.assertEqual(self.trade.amount, 6.0)
+        self.assertEqual(self.trade.total_cost, 100.0)
+        self.assertEqual(self.trade.price_per_coin, Decimal("100.0")/Decimal("6.0"))
+
     def test_remove_coins_exceeds_amount(self):
         """
         Test removing more coins than available.
diff --git a/trade.py b/trade.py
index 1a39b2d..ed646d1 100644
--- a/trade.py
+++ b/trade.py
@@ -1,4 +1,10 @@
 from decimal import Decimal
+from enum import Enum
+
+class PriceAdaption(Enum):
+    KeepTotalCost = 1
+    KeepPricePerCoin = 2
+
 
 class Trade:
     """
@@ -22,11 +28,11 @@ class Trade:
         self.__total_cost: Decimal = Decimal(total_cost)
         self.__date: str = date
 
-    def remove_coins(self, amount: float|Decimal) -> None:
+    def remove_coins(self, amount: float|Decimal, adapt: PriceAdaption = PriceAdaption.KeepPricePerCoin) -> None:
         """
         Reduce the amount of cryptocurrency in the trade by a specified amount.
 
-        This effectively "loses" coins. The price-per-coin remains the same.
+        This effectively "loses" coins.
 
         Args:
             amount (Decimal): The amount of cryptocurrency to remove.
@@ -37,9 +43,15 @@ class Trade:
         if amount > self.__amount:
             raise ValueError(f"Can't remove more than {self.__amount}")
         
-        amount = Decimal(amount)
-        self.__total_cost -= amount * self.price_per_coin
-        self.__amount -= amount
+        if adapt == PriceAdaption.KeepPricePerCoin:
+            amount = Decimal(amount)
+            self.__total_cost -= amount * self.price_per_coin
+            self.__amount -= amount
+        elif adapt == PriceAdaption.KeepTotalCost:
+            amount = Decimal(amount)
+            self.__amount -= amount
+        else:
+            raise ValueError("Unknown adaptation strategy")
 
     @property
     def amount(self) -> Decimal:
-- 
cgit v1.2.3