#!/usr/bin/env python
# coding=utf-8

__author__ = "TrackMe Limited"
__copyright__ = "Copyright 2023-2025, TrackMe Limited, U.K."
__credits__ = "TrackMe Limited, U.K."
__license__ = "TrackMe Limited, all rights reserved"
__version__ = "0.1.0"
__maintainer__ = "TrackMe Limited, U.K."
__email__ = "support@trackme-solutions.com"
__status__ = "PRODUCTION"

# Standard library
import os
import sys
import time
import json
import hashlib

# External libraries
import urllib3

# Disable urllib3 warnings
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

# Configure logging
import logging
from logging.handlers import RotatingFileHandler

splunkhome = os.environ["SPLUNK_HOME"]

# set logging
filehandler = RotatingFileHandler(
    os.path.join(splunkhome, "var", "log", "splunk", "trackme_persistentfields.log"),
    mode="a",
    maxBytes=10_000_000,
    backupCount=1,
)
formatter = logging.Formatter(
    "%(asctime)s %(levelname)s %(filename)s %(funcName)s %(lineno)d %(message)s"
)
logging.Formatter.converter = time.gmtime
filehandler.setFormatter(formatter)
log = logging.getLogger()  # root logger - Good to get it only once.
for hdlr in log.handlers[:]:  # remove the existing file handlers
    if isinstance(hdlr, logging.FileHandler):
        log.removeHandler(hdlr)
log.addHandler(filehandler)  # set the new handler
# set the log level to INFO, DEBUG as the default is ERROR
log.setLevel(logging.INFO)

# append current directory
sys.path.append(os.path.dirname(os.path.abspath(__file__)))

# import libs
import import_declare_test

# import Splunk libs
from splunklib.searchcommands import (
    dispatch,
    StreamingCommand,
    Configuration,
    Option,
    validators,
)

# Import trackme libs
from trackme_libs import trackme_reqinfo
from trackme_libs_utils import decode_unicode

# import trackme libs persistent fields definition
from collections_data import (
    persistent_fields_dsm,
    persistent_fields_dhm,
    persistent_fields_mhm,
    persistent_fields_flx,
    persistent_fields_fqm,
    persistent_fields_wlk,
    persistent_fields_cim,
)


@Configuration(distributed=False)
class TrackMePersistentHandler(StreamingCommand):
    collection = Option(
        doc="""
        **Syntax:** **collection=****
        **Description:** Specify the collection.""",
        require=True,
        default="None",
        validate=validators.Match("collection", r"^.*$"),
    )

    key = Option(
        doc="""
        **Syntax:** **key=****
        **Description:** Specify the key.""",
        require=True,
        default="None",
        validate=validators.Match("key", r"^.*$"),
    )

    update_collection = Option(
        doc="""
        **Syntax:** **update_collection=****
        **Description:** Enables or disables updating and inserting innto the collection, this replaces the need from calling outputlookup.""",
        require=False,
        default=False,
        validate=validators.Match("key", r"^(True|False)$"),
    )

    def get_component(self, collection_name):
        """
        Determine the component name based on the collection name.

        Args:
            collection_name (str): The name of the collection.

        Returns:
            str: The component name derived from the collection name.
        """
        # Define the prefix and corresponding component name
        if collection_name.startswith("trackme_dsm_"):
            component = "dsm"
        elif collection_name.startswith("trackme_dhm_"):
            component = "dhm"
        elif collection_name.startswith("trackme_mhm_"):
            component = "mhm"
        elif collection_name.startswith("trackme_flx_"):
            component = "flx"
        elif collection_name.startswith("trackme_fqm_"):
            component = "fqm"
        elif collection_name.startswith("trackme_cim_"):
            component = "cim"
        elif collection_name.startswith("trackme_wlk_"):
            component = "wlk"
        else:
            component = None  # or a default value if there's an expected default

        return component

    def stream(self, records):
        # performance counter
        start_time = time.time()

        # Get request info and set logging level
        reqinfo = trackme_reqinfo(
            self._metadata.searchinfo.session_key, self._metadata.searchinfo.splunkd_uri
        )
        log.setLevel(reqinfo["logging_level"])

        # connect to the KVstore
        target_collection = f"kv_{self.collection}"
        collection = self.service.kvstore[target_collection]

        # set the component
        persistent_fields = []
        component = self.get_component(self.collection)

        if component == "dsm":
            persistent_fields = []
            for field in persistent_fields_dsm:
                persistent_fields.append(field)
        elif component == "dhm":
            for field in persistent_fields_dhm:
                persistent_fields.append(field)
        elif component == "mhm":
            for field in persistent_fields_mhm:
                persistent_fields.append(field)
        elif component == "flx":
            for field in persistent_fields_flx:
                persistent_fields.append(field)
        elif component == "fqm":
            for field in persistent_fields_fqm:
                persistent_fields.append(field)
        elif component == "cim":
            for field in persistent_fields_cim:
                persistent_fields.append(field)
        elif component == "wlk":
            for field in persistent_fields_wlk:
                persistent_fields.append(field)

        # get all records
        get_collection_start = time.time()
        collection_records = []
        collection_records_keys = set()

        end = False
        skip_tracker = 0
        while end == False:
            process_collection_records = collection.data.query(skip=skip_tracker)
            if len(process_collection_records) != 0:
                for item in process_collection_records:
                    if item.get("_key") not in collection_records_keys:
                        collection_records.append(item)
                        collection_records_keys.add(item.get("_key"))
                skip_tracker += 5000
            else:
                end = True

        logging.info(
            f'context="perf", get collection records, no_records="{len(collection_records)}", run_time="{round((time.time() - get_collection_start), 3)}", collection="{self.collection}"'
        )

        # create a list of all keys for fast verification
        collection_keys = []

        # create a dictionary of all keys with their persistent fields
        collection_dict = {}

        for collection_record in collection_records:
            collection_keys.append(collection_record.get("_key"))
            collection_subrecord = {}

            # add time of update of this record
            try:
                collection_subrecord["mtime"] = float(collection_record.get("mtime"))
            except Exception as e:
                collection_subrecord["mtime"] = float(time.time())

            # loop and add persistent fields
            for field in persistent_fields:
                try:
                    collection_subrecord[field] = collection_record.get(field)
                except Exception as e:
                    logging.debug(f'field="{field}" was not found in the record')

            # for dsm/dhm, add data_last_time_seen
            if component in ["dsm", "dhm"]:
                try:
                    collection_subrecord["data_last_time_seen"] = float(
                        collection_record.get("data_last_time_seen")
                    )
                except:
                    pass

            # add to our dict
            collection_dict[collection_record.get("_key")] = collection_subrecord

        #
        # Define Meta
        #

        final_records = []

        # Loop in the results
        for record in records:
            record_is_new = record.get(self.key) not in collection_keys

            # ctime: if record is new, add a ctime field, otherwise ensure we have a ctime field set to the current time
            if record_is_new:
                record["ctime"] = time.time()
            else:
                ctime = record.get("ctime", None)
                if not ctime:
                    record["ctime"] = time.time()

            # get time, if any
            time_event = None
            try:
                time_event = record["_time"]
            except Exception as e:
                time_event = time.time()

            logging.debug(f"inspecting record={json.dumps(record, indent=2)}")

            # always set an object_256 and add _key in the record
            record_object_value = decode_unicode(record["object"])
            record_alias = decode_unicode(record["alias"])
            logging.debug(
                f'object="{record["object"]}", decoded_object="{record_object_value}", alias="{record["alias"]}", decoded_alias="{record_alias}"'
            )

            # add the _key in the record if there is none
            if not record.get("_key"):
                object_256 = hashlib.sha256(
                    record_object_value.encode("utf-8")
                ).hexdigest()
                record["_key"] = object_256
                logging.debug(
                    f'adding _key="{object_256}" to record for object="{record_object_value}", alias="{record_alias}"'
                )

            # handle unicode for object and alias
            record["object"] = record_object_value
            record["alias"] = record_alias

            # get tracker_runtime, if any
            tracker_runtime = record.get("tracker_runtime", None)
            if tracker_runtime:
                try:
                    tracker_runtime = float(tracker_runtime)
                except Exception as e:
                    tracker_runtime = time.time()
            else:
                tracker_runtime = time.time()

            # rejected record: only applies to component dsm/dhm, if the value of data_last_time_seen in record is is lower than the current value in the KVstore,
            # then the record should be rejected as it is outdated and might indicate a plateform level temporary issue
            # fields are epochtime and should be loaded as float, rejected_record is a boolean
            # the current value in the Kvstore can be retrieved from: float(collection_dict[record.get(self.key)].get("data_last_time_seen"))
            # this should be made through a try/except block to avoid any potential issue, if an exception is raised log the exception and set rejected_record to False
            rejected_record = False
            if component in ["dsm", "dhm"]:

                try:

                    rejected_record_key = record.get(self.key)
                    logging.debug(f'record key="{rejected_record_key}"')

                    rejected_record_dict = collection_dict.get(rejected_record_key)
                    logging.debug(f'rejected_record_dict="{rejected_record_dict}"')

                    if rejected_record_dict is not None:
                        kvstore_data_last_time_seen = rejected_record_dict.get(
                            "data_last_time_seen", None
                        )

                        # get kvstore_data_last_time_seen
                        kvstore_data_last_time_seen = rejected_record_dict.get(
                            "data_last_time_seen", None
                        )
                        logging.info(
                            f'kvstore_data_last_time_seen="{kvstore_data_last_time_seen}"'
                        )

                        if kvstore_data_last_time_seen:
                            kvstore_data_last_time_seen = float(
                                kvstore_data_last_time_seen
                            )

                        # get current_data_last_time_seen
                        current_data_last_time_seen = record.get(
                            "data_last_time_seen", None
                        )
                        if current_data_last_time_seen:
                            current_data_last_time_seen = float(
                                current_data_last_time_seen
                            )

                        # process if we have values
                        if kvstore_data_last_time_seen and current_data_last_time_seen:
                            if (
                                current_data_last_time_seen
                                < kvstore_data_last_time_seen
                            ):
                                rejected_record = True
                                logging.warning(
                                    f'collection="{target_collection}", component="{component}", record key="{record.get(self.key)}", rejected record detected, epoch value in kVstore {kvstore_data_last_time_seen} is bigger than record submitted value {current_data_last_time_seen}, record="{json.dumps(record, indent=2)}"'
                                )
                            else:
                                rejected_record = False
                                logging.debug(
                                    f'collection="{target_collection}", component="{component}", record key="{record.get(self.key)}", rejected record not detected, epoch value in kVstore {kvstore_data_last_time_seen} and record submitted value {current_data_last_time_seen} are both None, record="{json.dumps(record, indent=2)}"'
                                )
                        else:
                            rejected_record = False
                            logging.debug(
                                f'collection="{target_collection}", component="{component}", record key="{record.get(self.key)}", object="{record.get("object")}", rejected record not detected, epoch value in kVstore {kvstore_data_last_time_seen} and record submitted value {current_data_last_time_seen} are both None, record="{json.dumps(record, indent=2)}"'
                            )

                    else:
                        rejected_record = False
                        logging.warning(
                            f'collection="{target_collection}", component="{component}", record key="{record.get(self.key)}", object="{record.get("object")}", this object could not be found in the dictionnary, most likely because this sourcetype is corrupted and index non printable characters as its name!'
                        )

                except Exception as e:
                    logging.error(
                        f'collection="{target_collection}", component="{component}", failed to extract and convert data_last_time_seen, record key="{record.get(self.key)}", object="{record.get("object")}", exception message="{str(e)}"'
                    )
                    rejected_record = False

            # detect conflict update
            conflict_update = False
            if not record_is_new:
                # attempt to retrieve and convert mtime value, if fails for any reason, set conflict_update to False
                try:
                    mtime = float(collection_dict[record.get(self.key)].get("mtime"))
                    if mtime > float(tracker_runtime):
                        conflict_update = True
                        logging.info(
                            f'record key="{record.get(self.key)}", conflict update detected, preserving persistent fields="{persistent_fields}", record="{json.dumps(record, indent=2)}"'
                        )
                    else:
                        conflict_update = False

                except Exception as e:
                    logging.error(
                        f'failed to extract and convert mtime="{mtime}", tracker_runtime="{tracker_runtime}", exception message="{str(e)}"'
                    )
                    conflict_update = False

            # create a summary record
            summary_record = {}

            # Add _time first
            summary_record["_time"] = float(time_event)

            # if not rejected
            if not rejected_record:

                # loop through the dict
                for k in record:
                    logging.debug(f'field="{k}", value="{record[k]}"')

                    # Exclude the event time, add existing fields
                    if k != "_time":
                        #
                        # handle persistent field
                        #

                        if not record_is_new:
                            # if field is in persistent list of fields
                            if k in persistent_fields:
                                # preserve persistent fields if conflict update is detected
                                if conflict_update:
                                    summary_record[k] = collection_dict[
                                        record.get(self.key)
                                    ].get(k)
                                else:
                                    summary_record[k] = record[k]

                            # normal field
                            else:
                                summary_record[k] = record[k]

                        else:
                            # record is new, no need to consider it
                            summary_record[k] = record[k]

                # insert and update the collection if requested
                final_records.append(record)

        # log
        logging.debug(f"final_records={json.dumps(final_records, indent=2)}")

        # batch update/insert
        batch_update_collection_start = time.time()

        # process by chunk
        chunks = [final_records[i : i + 500] for i in range(0, len(final_records), 500)]
        for chunk in chunks:
            try:
                collection.data.batch_save(*chunk)
            except Exception as e:
                logging.error(f'KVstore batch failed with exception="{str(e)}"')

            # yield
            for record in chunk:
                yield record

        # perf counter for the batch operation
        logging.info(
            f'context="perf", batch KVstore update terminated, no_records="{len(final_records)}", run_time="{round((time.time() - batch_update_collection_start), 3)}", collection="{self.collection}"'
        )

        # perf counter for the entire call
        logging.info(
            f'trackmepersistentfields has terminated, collection="{self.collection}", key="{self.key}", update_collection="{self.update_collection}", run_time="{round((time.time() - start_time), 3)}"'
        )


dispatch(TrackMePersistentHandler, sys.argv, sys.stdin, sys.stdout, __name__)
