import numpy as np
from scipy import interpolate
import matplotlib.pyplot as plt
from typing import List, Tuple
import datetime
import pickle

class LinearPiecewiseFunction:
    def __init__(self, points: List[Tuple[float, float]]):
        self.points = np.array(sorted(points))
        self.x_values, self.y_values = self.points.T
        self.interpolator = interpolate.interp1d(self.x_values, self.y_values, bounds_error=False,
                                                 kind='linear', fill_value=np.inf)

    def evaluate(self, x: float) -> float:
        return float(self.interpolator(x))

    def plot(self, start: float = None, end: float = None):
        plt.figure(figsize=(10, 6))
        plt.plot(self.x_values, self.y_values)
        plt.scatter(self.x_values, self.y_values, color='red', zorder=5)

        # Label each point with its coordinates
        for x, y in zip(self.x_values, self.y_values):
            label = f'({x:.2f}, {y:.2f})'
            plt.annotate(label, (x, y), textcoords="offset points", xytext=(0,10), ha='center')

        plt.title("Linear Piecewise Function")
        plt.xlabel("x")
        plt.ylabel("f(x)")
        plt.grid(True)
        plt.xlim(self.x_values[0], self.x_values[-1])
        plt.show()

def combine_non_overlapping_functions(func1: LinearPiecewiseFunction, func2: LinearPiecewiseFunction) -> LinearPiecewiseFunction:
    # Check if the functions overlap (excluding the single point they might share)
    if (func1.x_values[-1] > func2.x_values[0] and func1.x_values[0] < func2.x_values[-1]) or \
       (func2.x_values[-1] > func1.x_values[0] and func2.x_values[0] < func1.x_values[-1]):
        raise ValueError("The functions must not overlap except possibly at a single point")

    # Determine the order of the functions
    if func1.x_values[0] <= func2.x_values[0]:
        first, second = func1, func2
    else:
        first, second = func2, func1

    # Combine the points
    combined_points = list(zip(first.x_values, first.y_values))
    
    # Check if the last point of the first function is the same as the first point of the second function
    if first.x_values[-1] == second.x_values[0]:
        # If they're the same, only add the points from the second function after the first point
        combined_points.extend(list(zip(second.x_values[1:], second.y_values[1:])))
    else:
        # If they're different, add all points from the second function
        combined_points.extend(list(zip(second.x_values, second.y_values)))

    # Create and return a new LinearPiecewiseFunction
    return LinearPiecewiseFunction(combined_points)

    
class Book:
    def __init__(self,  token_id: str, bids: List[Tuple[float, float]], asks: List[Tuple[float, float]]):
        self.token_id = token_id
        self.time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        self.bids: List[Tuple[float, float]] = sorted(bids, key=lambda x: x[0], reverse=True)
        self.asks: List[Tuple[float, float]] = sorted(asks, key=lambda x: x[0])

    def print_book(self, n: int = 5):
        print("Book", self.token_id, "@", self.time)
        print("-" * 80)
        print(f"{'Bids':^39} | {'Asks':^39}")
        print(f"{'Price':^11} {'Size':^12} {'Value ($)':^14} | {'Price':^11} {'Size':^12} {'Value ($)':^14}")
        print("-" * 80)
        
        for i in range(max(len(self.bids), len(self.asks), n)):
            bid = self.bids[i] if i < len(self.bids) else ("", "")
            ask = self.asks[i] if i < len(self.asks) else ("", "")
            
            bid_price, bid_size = bid
            ask_price, ask_size = ask
            
            bid_value = bid_price * bid_size if bid_price and bid_size else 0
            ask_value = ask_price * ask_size if ask_price and ask_size else 0
            
            bid_str = f"{bid_price:<11.2f} {bid_size:<12.2f} {bid_value:<14.2f}" if bid_price or bid_size else " " * 39
            ask_str = f"{ask_price:<11.2f} {ask_size:<12.2f} {ask_value:<14.2f}" if ask_price or ask_size else " " * 39
            
            print(f"{bid_str} | {ask_str}")
            
            if i == n - 1:
                break
        
        print("-" * 80)

    def save_as_pickle(self):
        with open(f"books/{self.token_id}_{self.time}.pkl", "wb") as f:
            pickle.dump(self, f)
    
    def get_inverse_book(self):
        bids = [(1 - price, amount) for price, amount in self.asks]
        asks = [(1 - price, amount) for price, amount in self.bids]
        return Book(bids=bids, asks=asks)

    def position_vs_cost(self, initial_position: float = 0.0) -> LinearPiecewiseFunction:
        points = [(initial_position, 0.0)]

        if initial_position >= 0.0:
            curr_pos = initial_position
            curr_cost = 0.0
            for (price, amount) in self.asks:
                curr_pos += amount
                curr_cost += amount * price
                points.append((curr_pos, curr_cost))

            total_cost = 0.0
            new_levels = []
            
            for price, volume in self.bids:
                if initial_position <= 0:
                    new_levels.append((price, volume))
                else:
                    transacted_volume = min(initial_position, volume)
                    total_cost -= price * transacted_volume
                    initial_position -= transacted_volume
                    points.append((initial_position, total_cost))
                    if volume > transacted_volume:
                        new_levels.append((price, volume - transacted_volume))

            curr_pos = 0.0
            for (price, amount) in new_levels:
                curr_pos -= amount
                total_cost += amount * (1 - price)
                points.append((curr_pos, total_cost))

            
            return LinearPiecewiseFunction(points)
        
        else:

            curr_pos = initial_position
            curr_cost = 0.0
            for (price, amount) in self.bids:
                curr_pos -= amount
                curr_cost += amount * (1 - price)
                points.append((curr_pos, curr_cost))

            total_cost = 0.0
            new_levels = []
            
            for price, volume in self.asks:
                if initial_position >= 0:
                    new_levels.append((price, volume))
                else:
                    transacted_volume = min(abs(initial_position), volume)
                    total_cost -= (1-price) * transacted_volume
                    initial_position += transacted_volume
                    points.append((initial_position, total_cost))
                    if volume > transacted_volume:
                        new_levels.append((price, volume - transacted_volume))

            curr_pos = 0.0
            for (price, amount) in new_levels:
                curr_pos += amount
                total_cost += amount * (price)
                points.append((curr_pos, total_cost))


            return LinearPiecewiseFunction(points)
            

    # def _simulate_trade(self, amount: float, levels: List[Tuple[float, float]])  -> Tuple[List[Tuple[float, float]], float]:
    #     if amount < 0.0:
    #         raise Exception(f"Negative {amount=} in _simulate_trade")
        
    #     total_cost = 0.0
    #     new_levels = []
        
    #     for price, volume in levels:
    #         if amount <= 0:
    #             new_levels.append((price, volume))
    #         else:
    #             transacted_volume = min(amount, volume)
    #             total_cost += price * transacted_volume
    #             amount -= transacted_volume
    #             if volume > transacted_volume:
    #                 new_levels.append((price, volume - transacted_volume))
        
    #     if amount > 0:
    #         raise ValueError(f"Not enough liquidity fulfill {amount=}")

    #     return new_levels, total_cost

    
    # def simulate_buy(self, amount: float = 0.0) -> Tuple[Book, float]:
    #     new_asks, total_cost = self._simulate_trade(amount, self.asks)
    #     return Book(bids=self.bids, asks=new_asks), total_cost

    # def simulate_sell(self, amount: float = 0.0) -> Tuple[Book, float]:
    #     new_bids, total_cost = self._simulate_trade(amount, self.bids)
    #     return Book(bids=new_bids, asks=self.asks), total_cost


def calculate_arbitrage(book1: Book, book2: Book) -> Tuple[float, float, float]:
    """
    Calculate arbitrage opportunity between two books, considering all price levels.
    Returns the final combined limit orders for each book using the deepest price.
    
    :param book1: First Book instance
    :param book2: Second Book instance
    :return: Tuple of (position_in_book1, position_in_book2, profit)
    """
    position_book1 = 0
    position_book2 = 0
    profit = 0
    orders_book1 = {'buy': [], 'sell': []}
    orders_book2 = {'buy': [], 'sell': []}

    def process_arbitrage(bids: List[Tuple[float, float]], asks: List[Tuple[float, float]], buy_from_book2: bool):
        nonlocal position_book1, position_book2, profit
        bid_index, ask_index = 0, 0

        while bid_index < len(bids) and ask_index < len(asks):
            bid_price, bid_volume = bids[bid_index]
            ask_price, ask_volume = asks[ask_index]

            if bid_price <= ask_price:
                break  # No more arbitrage opportunity

            trade_volume = min(bid_volume, ask_volume)
            trade_profit = trade_volume * (bid_price - ask_price)

            if buy_from_book2:
                position_book1 += trade_volume
                position_book2 -= trade_volume
                orders_book2['buy'].append((ask_price, trade_volume))
                orders_book1['sell'].append((bid_price, trade_volume))
            else:
                position_book1 -= trade_volume
                position_book2 += trade_volume
                orders_book1['buy'].append((ask_price, trade_volume))
                orders_book2['sell'].append((bid_price, trade_volume))

            profit += trade_profit

            # Update volumes
            bids[bid_index] = (bid_price, bid_volume - trade_volume)
            asks[ask_index] = (ask_price, ask_volume - trade_volume)

            # Move to next level if volume is exhausted
            if bids[bid_index][1] == 0:
                bid_index += 1
            if asks[ask_index][1] == 0:
                ask_index += 1

    # Copy the books to avoid modifying the originals
    bids1, asks1 = book1.bids.copy(), book1.asks.copy()
    bids2, asks2 = book2.bids.copy(), book2.asks.copy()

    process_arbitrage(bids1, asks2, True)   # Buy from book2, sell to book1
    process_arbitrage(bids2, asks1, False)  # Buy from book1, sell to book2

    def combine_orders(orders):
        if not orders:
            return None
        total_qty = sum(qty for _, qty in orders)
        best_price = min(price for price, _ in orders) if orders[0][1] > 0 else max(price for price, _ in orders)
        return total_qty, best_price

    # Print final combined limit orders
    print(f"{datetime.datetime.now()},{combine_orders(orders_book1['buy'])},{combine_orders(orders_book1['sell'])},{combine_orders(orders_book2['buy'])},{combine_orders(orders_book2['sell'])}")
    

# class OrderBook:
#     def __init__(self, yes_asks: Asks, no_asks: Asks):
#         self.yes_asks: Asks = yes_asks
#         self.no_asks: Asks = no_asks

#     def position_vs_cost(self) -> LinearPiecewiseFunction:
#         yes_piecewise =
