# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: MIT

from typing import Union
from mpp.core.types import RawDataFrame, RawDataFrameColumns as rdc

class UnitOutOfBoundsException(Exception):
    """
    Exception raised when a unit value is out of bounds for the given data.

    This exception is raised when the unit filter map contains unit values
    that are not present in the events DataFrame being filtered.
    """

    def __init__(self, message: str):
        """
        Initialize the exception with a message and optional out of bounds values.

        Args:
            message (str): The error message
        """
        super().__init__(message)


class UnitFilter:
    """
    A class to filter units based on a given mapping from unit names to integer ranges.

    The unit filter allows filtering of events DataFrame based on specific unit types
    (e.g., 'core', 'socket') and their corresponding integer values. The filter is
    designed to work with the events_df DataFrame that contains performance monitoring
    events.

    Attributes:
        unit_filter_map (dict): A mapping from unit name (str) to a range of integers
                               (e.g., range or list of ints) specifying which units to include.
                               For example: {'core': [0, 1, 2]} would filter to include
                               only cores 0, 1, and 2.
    """

    def __init__(self, unit_filter_map: Union[dict, None] = None, valid_core_types: list = None):
        """
        Initialize the UnitFilter with a mapping from unit names to integer ranges.

        Args:
            unit_filter_map (dict): A mapping from unit name (str) to a range of integers
                                   (e.g., range or list of ints) specifying which units to include.
                                   The dataframe being filtered is the events_df that contains
                                   the events we want to filter on.
            valid_core_types (list): List of valid core type column names to filter on.
                                     Defaults to ['CORE'] if not provided.
        """
        self.unit_filter_map = unit_filter_map
        self.__valid_core_types = valid_core_types if valid_core_types is not None and len(valid_core_types) > 0 else [
            rdc.CORE]

    def filter_units(self, events_df: RawDataFrame) -> RawDataFrame:
        """
        Filter the given events DataFrame based on the unit_filter_map.

        This method filters the events_df DataFrame based on the configured unit filter map.
        The dataframe is actually the events_df that contains the events we want to filter on.

        Args:
            events_df (pandas.DataFrame): The DataFrame containing events to filter.
                                        It should have a column for unit name and a column
                                        for the integer value to filter on.

        Returns:
            pandas.DataFrame: A filtered DataFrame containing only the events that satisfy
                            the mapping conditions.

        Raises:
            UnitOutOfBoundsException: If any unit value in the filter map is not present
                                    in the events_df.
        """
        if not self.unit_filter_map:
            return events_df

        events_df = self._filter_core_units(events_df)
        return events_df

    def _filter_core_units(self, events_df: RawDataFrame) -> RawDataFrame:
        """
        Filter by core units if specified in the unit filter map.

        Args:
            events_df: The events DataFrame to filter

        Returns:
            Filtered DataFrame

        Raises:
            UnitOutOfBoundsException: If core filter values are out of bounds
        """
        for core_type in self.__valid_core_types:
            if core_type in self.unit_filter_map and self.unit_filter_map[core_type]:
                # Filter by core - only filter rows where CORE is not null, keep all NaN CORE rows
                core_filter = self.unit_filter_map[core_type]

                # Check bounds only on non-null CORE rows
                core_rows = events_df[events_df[rdc.CORE].notna()]
                self._check_for_out_of_bounds_units(core_filter, core_rows)

                # Keep all rows where CORE is null OR (CORE is not null AND UNIT matches filter)
                events_df = events_df[(events_df[rdc.CORE].isna()) |
                                     ((events_df[rdc.CORE].notna()) & (events_df[rdc.UNIT].isin(core_filter)))]

        return events_df

    def _check_for_out_of_bounds_units(self, core_filter: list, core_df: RawDataFrame):
        """
        Check if any requested units are out of bounds in the data.

        Args:
            core_filter: List of core values to filter for
            core_df: The events DataFrame to check against

        Raises:
            UnitOutOfBoundsException: If any filter values are not present in the data
        """
        available_units = set(core_df[rdc.UNIT].unique())
        requested_units = set(core_filter)
        out_of_bounds_units = sorted(list(requested_units - available_units))

        self.__handle_out_of_bounds_units(out_of_bounds_units)

    def __handle_out_of_bounds_units(self, out_of_bounds_units):
        if self.__too_many_out_of_bounds_units_to_display(out_of_bounds_units):
            out_of_bounds_units = out_of_bounds_units[:5] + ['...']
        if out_of_bounds_units:
            out_of_bounds_str = ', '.join(map(str, out_of_bounds_units))
            raise UnitOutOfBoundsException(
                f"Filter is out of bounds. Filtered units {out_of_bounds_str} are not present in the data."
            )

    @staticmethod
    def __too_many_out_of_bounds_units_to_display(out_of_bounds_units):
        return out_of_bounds_units and len(out_of_bounds_units) > 5
