#!/usr/bin/env python3

import argparse
from io import TextIOWrapper
import logging
from pathlib import Path

from sarge import run, shell_format, capture_both
from systemd.journal import JournalHandler

logger = logging.getLogger("letsencrypt-renew")
logger.propagate = False
logger.addHandler(JournalHandler())
logger.setLevel(logging.INFO)


def parse_domain_list(domainfile):

    domains = list()

    with open(domainfile, "r") as handle:

        for row in handle:

            if not row.rstrip():
                continue

            # Ignore everything after # (comment)
            row = row.partition("#")[0]
            row = row.rstrip()

            if not row:
                continue

            domains.append(row)

    if not domains:
        logger.warning("No domains found in configuration.")
        return

    domains = [shell_format("-d {0}", domain) for domain in domains]

    return domains


def renew_domains(letsencrypt_path, domains, dry_run=False):

    domains = " ".join(domains)
    command = " ".join([letsencrypt_path, "certonly", domains])

    logger.info("Renewing domain certificates...")

    if not dry_run:

        process = capture_both(command)

        for stdout in TextIOWrapper(process.stdout):
            logger.info(stdin)

        for stderr in TextIOWrapper(process.stderr):
            logger.info(stderr)

        if process.returncode != 0:
            logger.error("Let's Encrypt domain renewal failed.")
            return
        else:
            logger.info("Domain renewal succeeded.")

        restart_services()
    else:
        print("The following command will be performed:")
        print(command)


def get_letsencrypt_path(configfile):

    with open(configfile, "r") as handle:

        for row in handle:

            if not row.rstrip():
                continue

            # Ignore everything after # (comment)
            row = row.partition("#")[0]
            row = row.rstrip()

            if not row:
                continue

            if row.startswith("LETSENCRYPT_COMMAND"):
                break

        command_path = row.split("=")[1]
        command_path = command_path.replace('"', '')

        if not Path(command_path).exists():
            logger.error("No letsencrypt command found.")
            return

        return command_path


def restart_services():

    logger.info("Restarting web server...")
    run("/bin/systemctl restart nginx")
    logger.info("Reloading mail server configuration...")
    run("/bin/systemctl restart postfix")
    logger.info("Services restarted.")


def main():

    parser = argparse.ArgumentParser()
    parser.add_argument("-c", "--config-file",
                        default="/etc/sysconfig/letsencrypt/config",
                        help="Configuration file to use")
    parser.add_argument("-d", "--domain-file",
                        default="/etc/sysconfig/letsencrypt/domains",
                        help="File including domains (one per line)")
    parser.add_argument("--dry-run", action="store_true",
                        help="Just print out the command, don't do anything")

    options = parser.parse_args()

    letsencrypt_path = get_letsencrypt_path(options.config_file)

    if letsencrypt_path is None:
        return

    domains = parse_domain_list(options.domain_file)

    if not domains:
        return

    renew_domains(letsencrypt_path, domains, dry_run=options.dry_run)


if __name__ == "__main__":
    main()