#!/usr/bin/env python
# coding=utf-8

__author__ = "TrackMe Limited"
__copyright__ = "Copyright 2023-2025, TrackMe Limited, U.K."
__credits__ = "TrackMe Limited, U.K."
__license__ = "TrackMe Limited, all rights reserved"
__version__ = "0.1.0"
__maintainer__ = "TrackMe Limited, U.K."
__email__ = "support@trackme-solutions.com"
__status__ = "PRODUCTION"

# Standard library imports
import os
import sys
import time
import logging
import json
import itertools

# Networking and URL handling imports
import requests
from urllib.parse import urlencode
import urllib3

# multithreading
from concurrent.futures import ThreadPoolExecutor, as_completed

# Disable insecure request warnings for urllib3
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

# splunk home
splunkhome = os.environ["SPLUNK_HOME"]

# append lib
sys.path.append(os.path.join(splunkhome, "etc", "apps", "trackme", "lib"))

# import trackme libs
from trackme_libs import (
    run_splunk_search,
)

# logging:
# To avoid overriding logging destination of callers, the libs will not set on purpose any logging definition
# and rely on callers themselves


def search_kv_collection(
    service, collection_name, page=1, page_count=0, key_filter=None, object_filter=None
):
    """
    Get records from a KVstore collection using a Splunk search.

    :param service: The Splunk service object.
    :param collection_name: The name of the collection to query.
    :param page: The page number to retrieve.
    :param page_count: The number of records to retrieve per page.

    :return: A tuple containing the records, keys, a dictionary of the records, and last_page.

    """

    # run the main report, every result is a Splunk search to be executed on its own thread
    search = f'| inputlookup {collection_name.replace("kv_", "")}'

    # add filter, if any
    if key_filter:
        search += f' where keyid="{key_filter}"'
    elif object_filter:
        search += f' where object="{object_filter}"'

    # complete the search
    search = f"{search} | eval keyid=_key"

    # kwargs
    kwargs_search = {
        "earliest_time": "-5m",
        "latest_time": "now",
        "preview": "false",
        "output_mode": "json",
        "count": 0,
    }

    collection_records = []
    collection_records_keys = set()
    collection_dict = {}

    start_time = time.time()

    try:
        reader = run_splunk_search(
            service,
            search,
            kwargs_search,
            24,
            5,
        )

        for item in reader:
            if isinstance(item, dict):
                collection_records.append(item)
                collection_records_keys.add(item.get("keyid"))
                collection_dict[item.get("keyid")] = item

    except Exception as e:
        msg = f'main search failed with exception="{str(e)}"'
        logging.error(msg)
        raise Exception(msg)

    logging.info(
        f'context="perf", search_kv_collection, KVstore select terminated, no_records="{len(collection_records)}", run_time="{round((time.time() - start_time), 3)}", collection="{collection_name}"'
    )

    # if size is 0, we consider all records as one page, simply return everything
    if page_count == 0:
        last_page = 1
        return collection_records, collection_records_keys, collection_dict, last_page

    # if size is not 0, we need to paginate
    else:
        # calculate the total number of pages
        total_record_count = len(collection_records)
        last_page = (total_record_count + page_count - 1) // page_count

        # calculate the start and end index
        start_index = (page - 1) * page_count
        end_index = page * page_count

        # return the records, keys, dict and last_page
        return (
            collection_records[start_index:end_index],
            collection_records_keys,
            collection_dict,
            last_page,
        )


def get_full_kv_collection(
    collection,
    collection_name,
    limit=1000,
    total_record_count=0,
    multi_threading=False,
    max_workers=50,
):
    """
    Get all records from a KVstore collection.

    :param collection: The KVstore collection object.
    :param collection_name: The name of the collection to query.
    :param limit: The number of records to fetch in each request.
    :param total_record_count: The total number of records in the collection (if known).

    :return: A tuple containing the records, keys, and a dictionary of the records.
    """
    collection_records = []
    collection_records_keys = set()
    collection_dict = {}

    start_time = time.time()

    def fetch_page(skip):
        """Helper function to fetch a single page of data."""
        try:
            process_collection_records = collection.data.query(limit=limit, skip=skip)
            return process_collection_records
        except Exception as e:
            logging.error(f"Exception fetching records with skip {skip}: {e}")
            return []

    try:

        if total_record_count == 0 or not multi_threading:

            logging.info(
                f'calling get_full_kv_collection with no multi-threading, collection="{collection_name}", limit="{limit}", total_record_count="{total_record_count}", multi_threading="{multi_threading}"'
            )

            end = False
            skip_tracker = 0
            while end == False:
                process_collection_records = collection.data.query(skip=skip_tracker)
                if len(process_collection_records) != 0:
                    for item in process_collection_records:
                        if item.get("_key") not in collection_records_keys:
                            collection_records.append(item)
                            collection_records_keys.add(item.get("_key"))
                            collection_dict[item.get("_key")] = item
                    skip_tracker += limit
                else:
                    end = True

            return collection_records, collection_records_keys, collection_dict

        else:  # proceed with multi-threading

            logging.info(
                f'calling get_full_kv_collection with multi-threading, collection="{collection_name}", max_workers="{max_workers}"'
            )

            # Prepare to fetch all pages concurrently
            with ThreadPoolExecutor(max_workers=max_workers) as executor:
                futures = {
                    executor.submit(fetch_page, skip): skip
                    for skip in range(0, total_record_count, limit)
                }

                for future in as_completed(futures):
                    skip = futures[future]
                    try:
                        process_collection_records = future.result()
                        if process_collection_records:
                            for item in process_collection_records:
                                if item.get("_key") not in collection_records_keys:
                                    collection_records.append(item)
                                    collection_records_keys.add(item.get("_key"))
                                    collection_dict[item.get("_key")] = item
                            logging.debug(
                                f"Retrieved records with skip {skip}, total={len(process_collection_records)} records"
                            )
                    except Exception as e:
                        logging.error(
                            f"Exception processing records with skip {skip}: {e}"
                        )

            logging.info(
                f'context="perf", get_full_kv_collection, KVstore select terminated, no_records="{len(collection_records)}", run_time="{round((time.time() - start_time), 3)}", collection="{collection_name}"'
            )

            return collection_records, collection_records_keys, collection_dict

    except Exception as e:
        logging.error(
            f"Failed to call get_kv_collection, args={collection_name}, exception={str(e)}"
        )
        raise Exception(str(e))


def get_kv_collection(
    collection, collection_name, total_record_count, page=1, page_count=100
):
    """
    Get records from a KVstore collection with support for pagination.

    :param collection: The KVstore collection object.
    :param collection_name: The name of the collection to query.
    :param total_record_count: Total number of records in the collection.
    :param page: The page number to retrieve.
    :param page_count: The number of records to retrieve per page.

    :return: A tuple containing the records, keys, a dictionary of the records, and last_page.

    """

    start_time = time.time()
    collection_records = []
    collection_records_keys = set()
    collection_dict = {}

    # Initialize last_page with a default value
    last_page = 1

    try:
        if page_count == 0:

            # Retrieve all records without pagination
            end = False
            skip_tracker = 0
            while not end:
                process_collection_records = collection.data.query(skip=skip_tracker)
                if len(process_collection_records) == 0:
                    end = True
                else:
                    for item in process_collection_records:
                        if item.get("_key") not in collection_records_keys:
                            collection_records.append(item)
                            collection_records_keys.add(item.get("_key"))
                    skip_tracker += limit

            # If page_count is 0, we consider all records as one page
            last_page = 1

        else:
            # Pagination logic
            skip_tracker = (page - 1) * page_count
            limit = page_count

            fetched_records = 0
            while fetched_records < limit:
                process_collection_records = collection.data.query(
                    limit=limit, skip=skip_tracker
                )
                if process_collection_records:
                    for item in process_collection_records:
                        if item.get("_key") not in collection_records_keys:
                            collection_records.append(item)
                            collection_records_keys.add(item.get("_key"))
                            fetched_records += 1
                            if fetched_records == limit:
                                break  # Stop if we have fetched enough records for the page
                    skip_tracker += limit
                else:
                    break  # End if no more records to fetch

            # Calculate the total number of pages
            if total_record_count > 0 and page_count > 0:
                last_page = (total_record_count + page_count - 1) // page_count

        logging.info(
            f'context="perf", KVstore select terminated, no_records="{len(collection_records)}", run_time="{round((time.time() - start_time), 3)}", collection="{collection_name}", last_page="{last_page}"'
        )

        # Include last_page in the return value
        return collection_records, collection_records_keys, collection_dict, last_page

    except Exception as e:
        logging.error(
            f"failed to call get_kv_collection, args={collection_name}, exception={str(e)}"
        )
        raise Exception(str(e))


def get_target_from_kv_collection(
    filter_field, filter_value, collection, collection_name
):
    """
    Get a specific record from a KVstore collection.

    :param filter_field: The field to filter the record by.
    :param filter_value: The value to filter the record by. Can be a single value or a list of values.
    :param collection: The KVstore collection object.
    :param collection_name: The name of the collection to query.

    :return: A tuple containing the records, keys, and a dictionary of the records.

    """
    collection_records = []
    collection_records_keys = set()
    collection_dict = {}

    # Handle list of values
    if isinstance(filter_value, list):
        query_string = {filter_field: {"$in": filter_value}}
    else:
        query_string = {filter_field: filter_value}

    try:
        process_collection_records = collection.data.query(
            query=json.dumps(query_string)
        )
        if len(process_collection_records) != 0:
            for item in process_collection_records:
                if item.get("_key") not in collection_records_keys:
                    collection_records.append(item)
                    collection_records_keys.add(item.get("_key"))
                    collection_dict[item.get("_key")] = item

        return collection_records, collection_records_keys, collection_dict

    except Exception as e:
        logging.error(
            f"failed to call get_kv_collection, args={collection_name}, exception={str(e)}"
        )
        raise Exception(str(e))


def get_full_kv_collection_by_object(collection, collection_name):
    """
    Get all records from a KVstore collection.

    :param collection: The KVstore collection object.
    :param collection_name: The name of the collection to query.

    :return: A tuple containing the records, keys, and a dictionary of the records.

    """
    collection_records = []
    collection_records_keys = set()
    collection_dict = {}

    try:
        end = False
        skip_tracker = 0
        while end == False:
            process_collection_records = collection.data.query(skip=skip_tracker)
            if len(process_collection_records) != 0:
                for item in process_collection_records:
                    if item.get("_key") not in collection_records_keys:
                        collection_records.append(item)
                        collection_records_keys.add(item.get("object"))
                        collection_dict[item.get("object")] = item
                skip_tracker += 1000
            else:
                end = True

        return collection_records, collection_records_keys, collection_dict

    except Exception as e:
        logging.error(
            f"failed to call get_kv_collection, args={collection_name}, exception={str(e)}"
        )
        raise Exception(str(e))


def get_sampling_kv_collection(collection, collection_name):
    """
    Get records from the DSM sampling collection

    :param collection: The KVstore collection object.
    :param collection_name: The name of the collection to query.

    :return: A tuple containing the records, keys, and a dictionary of the records.

    """
    collection_records = []
    collection_records_keys = set()
    collection_dict = {}

    try:
        end = False
        skip_tracker = 0
        while end == False:
            process_collection_records = collection.data.query(skip=skip_tracker)
            if len(process_collection_records) != 0:
                for item in process_collection_records:
                    if item.get("_key") not in collection_records_keys:
                        collection_records.append(item)
                        collection_records_keys.add(item.get("object"))
                        # add to the dict except for raw_sample
                        collection_dict[item.get("object")] = {
                            k: v for k, v in item.items() if k != "raw_sample"
                        }
                skip_tracker += 1000
            else:
                end = True

        return collection_records, collection_records_keys, collection_dict

    except Exception as e:
        logging.error(
            f"failed to call get_kv_collection, args={collection_name}, exception={str(e)}"
        )
        raise Exception(str(e))


def get_collection_documents_count(server_rest_uri, session_key, collection_name):

    header = {
        "Authorization": f"Splunk {session_key}",
        "Content-Type": "application/json",
    }
    url = f"{server_rest_uri}/services/server/introspection/kvstore/collectionstats?output_mode=json&count=0"

    try:
        response = requests.get(
            url,
            headers=header,
            verify=False,
            timeout=300,
        )
        if response.status_code not in (
            200,
            201,
            204,
        ):
            error_msg = f'failure to retrieve the KVstore collection document count, response.status_code="{response.status_code}", response.text="{response.text}"'
            raise Exception(error_msg)

        else:
            response_json = response.json()
            collection_count = 0
            entry = response_json["entry"]
            for item in entry:
                content = item.get("content")
                data = content.get("data")
                for subdata in data:
                    subdata = json.loads(subdata)
                    ns = subdata.get("ns")
                    count = subdata.get("count")
                    if ns == f"trackme.{collection_name}":
                        collection_count = count
                        break

            return collection_count

    except Exception as e:
        logging.error(
            f'failure to retrieve the KVstore collection document count, exception="{str(e)}"'
        )
        raise Exception(str(e))


def get_wlk_apps_enablement_kv_collection(collection, collection_name):
    """
    Get records from the Wlk apps enablement collection

    :param collection: The KVstore collection object.
    :param collection_name: The name of the collection to query.

    :return: A tuple containing the records, keys, and a dictionary of the records.

    """
    collection_records = []
    collection_records_keys = set()
    collection_dict = {}

    try:
        end = False
        skip_tracker = 0
        while end == False:
            process_collection_records = collection.data.query(skip=skip_tracker)
            if len(process_collection_records) != 0:
                for item in process_collection_records:
                    if item.get("_key") not in collection_records_keys:
                        collection_records.append(item)
                        collection_records_keys.add(item.get("app"))
                        # add to the dict except for raw_sample
                        collection_dict[item.get("app")] = item
                skip_tracker += 1000
            else:
                end = True

        return collection_records, collection_records_keys, collection_dict

    except Exception as e:
        logging.error(
            f"failed to call get_kv_collection, args={collection_name}, exception={str(e)}"
        )
        raise Exception(str(e))


def get_feeds_datagen_kv_collection(collection, collection_name, component):
    """
    Get all records from a KVstore collection.

    :param collection: The KVstore collection object.
    :param collection_name: The name of the collection to query.

    :return: A tuple containing the records, keys, and a dictionary of the records.

    """
    datagen_collection_records = []
    datagen_collection_records_keys = set()
    datagen_collection_dict = {}

    datagen_collection_blocklist_not_regex_dict = {}
    datagen_collection_blocklist_regex_dict = {}

    try:
        end = False
        skip_tracker = 0
        while end == False:
            process_collection_records = collection.data.query(skip=skip_tracker)
            if len(process_collection_records) != 0:
                for item in process_collection_records:
                    if item.get("_key") not in datagen_collection_records_keys:
                        datagen_collection_records.append(item)
                        datagen_collection_records_keys.add(item.get("_key"))
                        datagen_collection_dict[item.get("_key")] = item

                        # blocklist
                        if item.get("action") == "block":

                            if item.get("is_rex") == "false":
                                datagen_collection_blocklist_not_regex_dict[
                                    item.get("_key")
                                ] = {
                                    "object": item.get("object"),
                                    "object_category": item.get("object_category"),
                                }

                            elif item.get("is_rex") == "true":
                                datagen_collection_blocklist_regex_dict[
                                    item.get("_key")
                                ] = {
                                    "object": item.get("object"),
                                    "object_category": item.get("object_category"),
                                }

                skip_tracker += 1000
            else:
                end = True

        return (
            datagen_collection_records,
            datagen_collection_records_keys,
            datagen_collection_dict,
            datagen_collection_blocklist_not_regex_dict,
            datagen_collection_blocklist_regex_dict,
        )

    except Exception as e:
        logging.error(
            f"failed to call get_kv_collection, args={collection_name}, exception={str(e)}"
        )
        raise Exception(str(e))


def execute_batch_find_in_chunks(collection, dbqueries, chunk_size=500):
    """
    Executes batch find operations in chunks to adhere to the query limit.

    :param collection: The collection to query.
    :param dbqueries: A list of query dictionaries.
    :param chunk_size: Maximum number of queries per batch operation.
    :return: A list of kvrecords.
    """
    kvrecords_nested = []

    # Process dbqueries in chunks
    for i in range(0, len(dbqueries), chunk_size):
        chunk = dbqueries[i : i + chunk_size]
        try:
            # Execute batch_find for the current chunk
            chunk_results = collection.data.batch_find(*chunk)
            kvrecords_nested.extend(chunk_results)
        except Exception as e:
            error_msg = f"Batch find failed for a chunk, exception={str(e)}"
            logging.error(error_msg)
            raise Exception(error_msg)

    return kvrecords_nested


def batch_find_records_by_object(collection, object_list):
    dbqueries = [{"query": {"object": object_value}} for object_value in object_list]

    try:
        # Execute batch_find to retrieve records in chunks
        kvrecords_nested = execute_batch_find_in_chunks(collection, dbqueries)

        # Flatten the list of lists to get a single list of kvrecords
        kvrecords = list(itertools.chain.from_iterable(kvrecords_nested))

        # Create a dictionary from kvrecords, keying by '_key'
        kvrecords_dict = {kvrecord["_key"]: kvrecord for kvrecord in kvrecords}

        # Return the dictionary and the flat list of kvrecords
        return kvrecords_dict, kvrecords

    except Exception as e:
        logging.error(
            f"Failed to call batch_find_records_by_object, args={object_list}, exception={str(e)}"
        )
        raise Exception(str(e))


def batch_find_records_by_key(collection, keys_list):
    dbqueries = [{"query": {"_key": key}} for key in keys_list]

    try:
        # Execute batch_find to retrieve records in chunks
        kvrecords_nested = execute_batch_find_in_chunks(collection, dbqueries)

        # Flatten the list of lists to get a single list of kvrecords
        kvrecords = list(itertools.chain.from_iterable(kvrecords_nested))

        # Create a dictionary from kvrecords, keying by '_key'
        kvrecords_dict = {kvrecord["_key"]: kvrecord for kvrecord in kvrecords}

        # Return the dictionary and the flat list of kvrecords
        return kvrecords_dict, kvrecords

    except Exception as e:
        logging.error(
            f"Failed to call batch_find_records_by_key, args={keys_list}, exception={str(e)}"
        )
        raise Exception(str(e))
