#!/usr/bin/python3

# SPDX-FileCopyrightText: 2021 Luca Beltrame <lbeltrame@kde.org>
#
# SPDX-License-Identifier: BSD-3-Clause

import argparse
from collections import defaultdict
from datetime import datetime
from typing import List, Tuple
import sys

import lxml.etree as etree
import lxml.builder as builder
import more_itertools as mlt
from pytz import timezone
from requests_oauthlib import OAuth1Session
import simplejson as json

E = builder.ElementMaker()

Request = E.Request
Segment = E.Segment
StartDateTime = E.StartDateTime
EndDateTime = E.EndDateTime
RailObject = E.RailObject
date_ = E.date
time_ = E.time

TRIPIT_CREATE_URL = "https://api.tripit.com/v1/create"
TOKEN = "tripit.credentials.json"


def get_tokens(filename: str) -> Tuple[str]:

    with open(filename) as handle:
        tokens = json.load(handle)

    return (tokens["client_key"], tokens["client_secret"],
            tokens["resource_owner_key"], tokens["resource_owner_secret"])


def add_train_segment(start_station: str,
                      end_station: str,
                      start_time: str,
                      end_time: str,
                      trip_date: str,
                      train: str = None,
                      number: int = None,
                      carrier: str = None) -> Segment:

    start_station_name = E.start_station_name
    end_station_name = E.end_station_name
    train_type = E.train_type
    train_number = E.train_number
    carrier_name = E.carrier_name

    data = []

    data.append(StartDateTime(date_(trip_date), time_(start_time)))
    data.append(EndDateTime(date_(trip_date), time_(end_time)))
    data.append(start_station_name(start_station))
    data.append(end_station_name(end_station))

    if number:
        data.append(train_number(number))

    if carrier:
        data.append(carrier_name(carrier))

    if train_type:
        data.append(train_type(train))

    doc = Segment(*data)

    return doc


def _convert_datetimes(segment: int) -> List[str]:

    date_time = datetime.fromtimestamp(segment,
                                       timezone("Japan"))

    date = date_time.strftime("%Y-%m-%d")
    time = date_time.strftime("%H:%M:00")

    return date, time


def _parse_record(segment: dict) -> Tuple[str]:

    start_station = segment["start_station"].title()
    start_date, start_time = _convert_datetimes(segment["start_time"])

    end_station = segment["end_station"].title()
    end_date, end_time = _convert_datetimes(segment["end_time"])
    train_type = segment["train_name"]
    train_number = ("" if segment.get("train_number") is None
                    else segment["train_number"])

    return (start_station, start_date, start_time, end_station, end_date,
            end_time, train_type, train_number)


def parse_json(json_data: dict, carrier: str = None,
               result_number: int = None) -> str:

    jdoc = json_data

    result = jdoc["result"]
    record = result[result_number - 1]
    segments = list()

    iter_records = mlt.peekable(record["steps"])

    previous_record = None

    for segment in iter_records:

        parsed = _parse_record(segment)

        start_station, start_date, start_time, *rest = parsed
        end_station, end_date, end_time, *rest = rest
        train_type, train_number = rest

        # FIXME: Impossible to extract it from the current Hyperdia data
        carrier = "" if carrier is None else carrier
        next_record = iter_records.peek(None)

        if next_record is not None:

            if segment["go_through"] and next_record["go_through"]:
                # Nothing to be done here
                continue

            if not segment["go_through"] and next_record["go_through"]:
                # Next one is going through: set the start,
                # but don't add anything
                previous_record = segment
                continue

            elif not next_record["go_through"] and segment["go_through"]:
                # The next one is a "real" one, set the end, and add
                # use the existing record to extract the data
                parsed = _parse_record(previous_record)
                start_station, start_date, start_time, *_ = parsed
                *_, train_type, train_number = parsed
                previous_record = None

        subdoc = add_train_segment(start_station,
                                   end_station,
                                   start_time,
                                   end_time,
                                   start_date,
                                   train=train_type,
                                   number=train_number,
                                   carrier=carrier)
        segments.append(subdoc)

    rail_object = RailObject(*segments)
    request_object = Request(rail_object)

    return etree.tounicode(request_object)


def main():

    parser = argparse.ArgumentParser()

    parser.add_argument("-c", "--carrier", default="Japan Rail",
                        help="Carrier name")
    parser.add_argument("-n", "--number", default=1, type=int,
                        help="Route number to insert into Tripit")
    parser.add_argument("source", type=argparse.FileType("r"),
                        default=sys.stdin, nargs="?",
                        help="JSON file produced by hyperdia.py")

    options = parser.parse_args()

    json_data = json.load(options.source)

    client_key, client_secret, resource_key, resource_secret = get_tokens(
        TOKEN)

    session = OAuth1Session(client_key, client_secret, resource_key,
                            resource_secret)

    result = parse_json(json_data, carrier=options.carrier,
                        result_number=options.number)

    response = session.post(TRIPIT_CREATE_URL, data={"xml": result})

    if response.ok:
        print("Trip added to Tripit")


if __name__ == "__main__":
    main()