#!/usr/bin/python3

import sys

# Inject the configured Python stdlib path
sys.path.insert(0, "/usr/lib64/python3.9")

from azfilesauth import azfiles_set_oauth, get_oauth_token
import subprocess
import os
import json
import time
import signal
import logging
import re

# Constants
SECONDS = 1
MINUTES = 60 * SECONDS

SLEEP_TIME = 60 * SECONDS
REFRESH_BEFORE_EXPIRY = 5 * MINUTES 
RUNNING = True

# Configure Logging
logging.basicConfig(
    filename="/var/log/azfilesrefresh.log",
    filemode='a',
    format="%(asctime)s [%(levelname)s] %(message)s",
    level=logging.INFO
)

def log_status(message):
    logging.info(message)

def log_error(message):
    logging.error(message)

def log_verbose(message):
    logging.debug(message)

# Graceful shutdown handler
def handle_shutdown(signum, frame):
    global RUNNING
    log_status(f"Received signal {signum}. Shutting down gracefully.")
    RUNNING = False
    exit(0)

# Register signal handlers
signal.signal(signal.SIGTERM, handle_shutdown)
signal.signal(signal.SIGINT, handle_shutdown)

def get_tickets():
    try:
        result = subprocess.run(["azfilesauthmanager", "list", "--json"],
                                check=True, capture_output=True)
        result_raw = result.stdout.decode('utf-8')
        tickets = json.loads(result_raw)
        return tickets
    except subprocess.CalledProcessError as e:
        stderr_output = e.stderr.decode('utf-8') if e.stderr else "No stderr output"
        log_error(f"Error getting tickets: {e}. stderr: {stderr_output}")
        return []
    except Exception as e:
        log_error(f"Error getting tickets: {e}")
        return []

EPOCH_RE = re.compile(r"\(epoch:\s*(\d+)\)")

def _parse_epoch(field_value: str) -> int:
    """Extract epoch from a string of the form '... (epoch: NNN)'.
    Returns 0 if pattern not found."""
    if not field_value:
        return 0
    m = EPOCH_RE.search(field_value)
    if not m:
        return 0
    try:
        return int(m.group(1))
    except Exception:
        return 0

def is_expiring(ticket):
    try:
        end_raw = ticket.get('ticket_end_time', '')
        valid_till = _parse_epoch(end_raw)
        if valid_till == 0:
            # Log as error and force refresh (treat as expired) to avoid stale credentials
            log_error(f"Unparsable or missing ticket_end_time '{end_raw}'; forcing refresh.")
            return True
        current_time = int(time.time())
        return (current_time + REFRESH_BEFORE_EXPIRY + SLEEP_TIME) >= valid_till
    except Exception as e:
        log_error(f"Failed to check ticket expiration: {e}")
        return False

def get_endpoint_from_principal(principal_string: str):
    index = principal_string.find('@')
    prefixes = ["cifs/", "https://"]
    for prefix in prefixes:
        if principal_string.startswith(prefix):
            return principal_string[len(prefix): index]
    log_error(f"Invalid principal string: {principal_string}")
    return ""

def get_mount_options():
    result = subprocess.run(['mount', '-t', 'cifs'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    mount_output = result.stdout.decode('utf-8')

    id_map = {}
    for line in mount_output.splitlines():
        # Only consider CIFS mounts that explicitly use Kerberos auth
        if 'sec=krb5' not in line:
            continue
        # Extract endpoint and username from this line
        m_ep = re.search(r'//([^/\s]+)', line)
        m_user = re.search(r'username=([^,\s]+)', line)
        if m_ep and m_user:
            endpoint = m_ep.group(1)
            client_id = m_user.group(1)
            id_map[endpoint] = client_id

    return id_map

def get_client_id(file_endpoint_uri):
    id_map = None
    try:
        id_map = get_mount_options()
        return id_map.get(file_endpoint_uri)
    except Exception as e:
        log_error(f"Error getting mount options {e}")

    return

def refresh_ticket(ticket):
    try:
        endpoint = get_endpoint_from_principal(ticket['server'])
        client_id = get_client_id(endpoint)
        # If the mount used username=root we treat this as system-assigned identity
        # and do not require a client_id. Otherwise we use client_id for user-assigned MI.
        if client_id is None:
            log_error(f"Skipping refresh. Username not found for endpoint: {endpoint}")
            return

        log_status(f"Refresh token for {endpoint} initiated (mount username={client_id})")

        if client_id == "root":
            log_status(f"username=root for {endpoint}; using system-assigned managed identity")
            oauth_token = get_oauth_token()  # system MI (no client_id)
        else:
            log_status(f"Using client_id '{client_id}' for user-assigned managed identity")
            oauth_token = get_oauth_token(client_id)  # user-assigned MI

        if oauth_token is None:
            raise Exception(f"Unable to obtain OAuth token for {endpoint}")
        result = azfiles_set_oauth("https://" + endpoint, oauth_token)
        log_status(f"Ticket for {endpoint} refreshed")
    except Exception as e:
        log_error(f"Error refreshing ticket: {e}")

def start_daemon():
    global RUNNING
    log_status("AZ Files Refresh Daemon started.")
    while RUNNING:
        tickets = get_tickets()
        log_status("Checking for expiring tickets...")
        for ticket in tickets:
            if is_expiring(ticket):
                refresh_ticket(ticket)
        log_status("Done checking. Sleeping...")
        time.sleep(SLEEP_TIME)
    log_status("Daemon exited.")

if __name__ == "__main__":
    if os.geteuid() != 0:
        log_error(f"Script is not running as root. Please run as root")
        exit(1)
    start_daemon()
    # print(get_mount_options())
