#!/usr/bin/env python3

# SPDX-FileCopyrightText: 2021 Luca Beltrame <lbeltrame@kde.org>
#
# SPDX-License-Identifier: BSD-3-Clause

import argparse
from dataclasses import dataclass
from datetime import datetime
from itertools import zip_longest
import re
from typing import Optional, List
from urllib.parse import urlparse, urlencode, urlunparse

from bs4 import BeautifulSoup
import more_itertools as mlt
import pandas as pd
import pytz
import requests
import simplejson as json
from tabulate import tabulate

HYPERDIA_CGI = "http://www.hyperdia.com/en/cgi/search/en/hyperdia2.cgi"
HYPERDIA_SEARCH = "http://www.hyperdia.com/en/cgi/en/search.html"
GROUP_MATCHER = re.compile(r".*No\.(?P<tracknum>[0-9]{1,}).*")

HYPERDIA_PARAMS = {
    "dep_node": "",
    "arv_node": "",
    "year": "",
    "month": "",
    "day": "",
    "hour": "",
    "minute": "",
    "search_type": "0",
    "transtime": "undefined",
    "max_route": "5",
    "sort": "0",
    "faretype": "0",
    "ship": "off",
    "lmlimit": None,
    "sum_target": "7",
    "facility": "reserved",
    "search_target": "route",
    "sprexprs": "on",  # Shinkansen
    "sprnozomi": "on",  # Shinkansen plus Nozomi/Mizuho
    "slpexprs": "on",  # 特急 aka limited express
    "jr": "on",  # JR lines
    "privately": "on",  # Non-JR lines
    "search_way": ""
}

HEADERS = {
    'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64; rv:71.0) Gecko/20100101 Firefox/71.0',
    'Content-type': 'application/x-www-form-urlencoded; charset=UTF-8',
    'Host': 'www.hyperdia.com',
    'Origin': 'http://www.hyperdia.com'
}


def required_length(nmin, nmax):
    class RequiredLength(argparse.Action):
        def __call__(self, parser, args, values, option_string=None):
            if not nmin <= len(values) <= nmax:
                msg = (f'argument "{self.dest}" requires '
                       f'between {nmin} and {nmax} arguments')
                raise argparse.ArgumentTypeError(msg)
            setattr(args, self.dest, values)
    return RequiredLength


@dataclass
class HyperdiaStep:

    start_station: str
    end_station: str
    start_time: datetime
    end_time: datetime
    duration: Optional[str] = None
    train_name: Optional[str] = None
    is_transfer: Optional[bool] = False
    go_through: Optional[bool] = False
    start_track_number: Optional[int] = None
    end_track_number: Optional[int] = None


@dataclass
class HyperdiaTrip:

    steps: List[HyperdiaStep]
    total_distance: int
    total_time: int
    total_cost: int
    transfers: int
    result_number: Optional[int] = None
    # Date in format ISO (YYYY-MM-DD)
    travel_date: Optional[str] = None


def _serialize(trip: HyperdiaTrip) -> dict:

    structure = dict()
    structure["steps"] = list()

    for attrib in ("total_distance", "total_time", "total_cost", "transfers",
                   "result_number", "travel_date"):
        structure[attrib] = getattr(trip, attrib)

    for step in trip.steps:

        subdict = dict()

        subdict["start_station"] = step.start_station
        subdict["end_station"] = step.end_station
        subdict["start_time"] = step.start_time.timestamp()
        subdict["end_time"] = step.end_time.timestamp()
        subdict["duration"] = step.duration
        subdict["train_name"] = step.train_name
        subdict["is_transfer"] = step.is_transfer
        subdict["go_through"] = step.go_through
        subdict["start_track_number"] = step.start_track_number
        subdict["end_track_number"] = step.end_track_number

        structure["steps"].append(subdict)

    return structure


def get_hyperdia_data(start_station, end_station, hour, minute, day="15",
                      month="08", year="2020", max_route=5, via=None,
                      use_shinkansen=True):

    session = requests.Session()
    post_params = HYPERDIA_PARAMS.copy()
    headers = HEADERS.copy()

    post_params["dep_node"] = start_station
    post_params["arv_node"] = end_station
    post_params["year"] = year
    post_params["day"] = day
    post_params["month"] = month
    post_params["hour"] = hour
    post_params["minute"] = minute
    post_params["max_route"] = max_route

    if not use_shinkansen:
        post_params["sprexprs"] = "off"
        post_params["sprnozomi"] = "off"

    if via is None:
        for element in ("via_node01", "via_node02", "via_node03"):
            post_params[element] = ""
    else:

        if len(via) > 3:
            raise ValueError("Only up to three through stations are allowed")

        for station, node in zip_longest(
            via,
            ("via_node01", "via_node02", "via_node03"),
                fillvalue=""):

            post_params[node] = station

    referer = list(urlparse(HYPERDIA_SEARCH))
    referer[4] = urlencode(post_params)
    referer = urlunparse(referer)
    headers["Referer"] = referer

    session.headers.update(headers)

    result = session.post(HYPERDIA_CGI, data=post_params)

    return result


def parse_hyperdia_heading(soup):

    # Heading (div class="title_r") with this structure:
    # First span: total time in minutes
    # Second span: number of transfers
    # Third span: total distance in Km
    # Fourth span: total cost in JPY

    elements = soup.select("span")[0:4]

    total_time, transfers, distance, cost = [item.text.strip()
                                             for item in elements]

    cost = int(cost.replace(",", ""))

    return {"total_time": total_time, "transfers": transfers,
            "total_distance": distance, "total_cost": cost}


def parse_station_time(element, year, month, day, start=True):

    times = list(element.stripped_strings)
    # The first element if it's a transfer (arrival time; we ignore walking)
    # Otherwise we get the only item

    current_time = times[-1] if start else times[0]
    tz = pytz.timezone("Japan")

    hour, minutes = current_time.split(":")

    station_time = datetime(year, int(month), int(day),
                            int(hour),
                            int(minutes))
    # Regular datetime with tzinfo screws things up, create a native time
    # Then localize it with pytz (no DST, there's no such thing in Japan)
    station_time = tz.localize(station_time, is_dst=False)

    return station_time


def parse_train_name(element):

    # Trains are in a list with only one element, inside a span
    selected_item = element.select("td > ul > li > span")[0]

    # Long train lines (for XXX) have newlines in it, remove, along with tabs

    result = list(selected_item.stripped_strings)[0]
    result = result.replace("\n", "").replace("\t", "")

    return result


def parse_track_number(element):

    # Second span in the station name column contains the track number
    # if applicable (if not, it's empty)

    track_data = element.select("span")[1].text

    if not track_data:
        return None

    track_number = int(GROUP_MATCHER.search(track_data)["tracknum"])
    return track_number


def parse_hyperdia_table(soup, year, month, day):

    data = list()

    previous_is_direct = False
    go_through = False

    # Skip the heading and the row immediately afterwards (commuter pass)

    for group in mlt.windowed(soup.find_all("tr")[2:], n=3, step=2):

        # Groups of 3 elements:
        # First row: start station (time in first column, station in column 3)
        # Second row: train information (duration in column 1,
        # name in column 3)
        # Third row: arrival time(s) (same format as first row)
        # Times might be repeated more than once if it's a transfer

        start_info, journey_info, end_info = group
        startdata = start_info.find_all("td")[0:3]
        traindata = journey_info.find_all("td")[2]
        enddata = end_info.find_all("td")[0:3]

        # Ignore "add to favorities"
        start_station_name = list(startdata[2].stripped_strings)[0]

        direct_connection = enddata[1].next_element.get("src")

        # Second span in the station name column contains the track number
        # if applicable (if not, it's empty)
        start_track_number = parse_track_number(startdata[2])
        end_track_number = parse_track_number(enddata[2])

        start_station_time = parse_station_time(startdata[0], year, month, day,
                                                start=True)
        if previous_is_direct:
            train_name = "Line name change, train goes through"
            previous_is_direct = False
            go_through = True
        else:
            train_name = parse_train_name(traindata)

        if direct_connection is not None and "icon_choku.gif" in direct_connection:
            previous_is_direct = True

        end_station_name = list(enddata[2].stripped_strings)[0]
        end_station_time = parse_station_time(enddata[0], year, month, day,
                                              start=False)

        is_transfer = True if train_name == "Walk" else False
        duration = ((end_station_time - start_station_time).seconds // 60)

        entry = HyperdiaStep(
            start_station=start_station_name,
            end_station=end_station_name,
            start_time=start_station_time,
            end_time=end_station_time,
            train_name=train_name,
            is_transfer=is_transfer,
            duration=duration,
            start_track_number=start_track_number,
            end_track_number=end_track_number,
            go_through=go_through)

        go_through = False

        data.append(entry)

    return data


def parse_hyperdia_html(soup, *args, **kwargs):

    tables = soup.find_all("table", {"class": "table"})
    headings = soup.find_all("div", {"class": "title_r"})

    results = list()

    for heading, table in zip(headings, tables):

        parsed_heading = parse_hyperdia_heading(heading)
        parsed_table = parse_hyperdia_table(table, *args, **kwargs)

        if int(kwargs["month"]) > 9:
            # Add "0" in front of single-digit months
            month = str(kwargs["month"]).zfill(2)
        else:
            month = kwargs["month"]

        travel_date = f'{kwargs["year"]}-{month}-{kwargs["day"]}'

        trip = HyperdiaTrip(steps=parsed_table, travel_date=travel_date,
                            **parsed_heading)
        results.append(trip)

    return results


def convert_trip_to_table(trip: HyperdiaTrip) -> pd.DataFrame:

    columns = ["From", "Departure time", "Departure track",
               "To", "Arrival time", "Arrival track", "Duration",
               "Train / Transfer"]

    rows = list()

    for element in trip.steps:

        start_track_number = ("-" if not element.start_track_number
                              else f"{element.start_track_number:.0f}")
        end_track_number = ("-" if not element.end_track_number
                            else f"{element.end_track_number:.0f}")

        row = (element.start_station,
               f"{element.start_time: %H:%M}",
               start_track_number,
               element.end_station,
               f"{element.end_time: %H:%M}",
               end_track_number,
               f"{element.duration:.0f} minutes",
               element.train_name)

        rows.append(row)

    df = pd.DataFrame.from_records(rows, columns=columns)
    df = df.fillna("-")

    return df


def trip_summary(trip: HyperdiaTrip) -> str:

    table = convert_trip_to_table(trip)
    table = tabulate(table, tablefmt="github", headers="keys", showindex=False)

    summary = (f"Total time: {trip.total_time} minutes,"
               f" Total distance: {trip.total_distance} Km,"
               f" Total cost: {trip.total_cost} JPY")

    return table + "\n\n" + summary + "\n\n"


def hyperdia_search(start_station: str, end_station: str, hour: int,
                    minute: int, day: int = "15", month: str = "08",
                    year: int = 2020, max_route: int = 5,
                    via: List[str] = None, output_type: str = "md",
                    use_shinkansen: bool = True):

    # TODO: Error checking
    raw_result = get_hyperdia_data(start_station, end_station,
                                   hour, minute, day, month, year, max_route,
                                   via, use_shinkansen)
    soup = BeautifulSoup(raw_result.text, "html.parser")
    results = parse_hyperdia_html(soup, year=year, month=month, day=day)

    json_data = dict()
    json_data["result"] = list()

    for index, trip in enumerate(results, start=1):

        trip.result_number = index

        if output_type == "md":
            print(f"##### Route {index}", end="\n\n")
            print(trip_summary(trip))
        elif output_type == "json":
            json_data["result"].append(_serialize(trip))

    if output_type == "json":
        print(json.dumps(json_data, indent=2))


def main():

    parser = argparse.ArgumentParser()
    parser.add_argument("-t", "--time", help="Hour of travel",
                        type=lambda d: datetime.strptime(d, '%H.%M').time())
    parser.add_argument("-d", "--date", help="Date of travel",
                        type=lambda d: datetime.strptime(d, "%Y-%m-%d").date())
    parser.add_argument("--max-routes", help="Maximum number of routes",
                        type=int)
    parser.add_argument("--no-shinkansen", action="store_false",
                        help="Do not use shinkansen routes")
    parser.add_argument("--via", nargs='+', action=required_length(1, 3),
                        help="Stations to force route through (min 1, max 3)")
    parser.add_argument("--output-type", choices=("md", "json"), default="md",
                        help="Output type (markdown or JSON)")
    parser.add_argument("start_station", help="Start station")
    parser.add_argument("end_station", help="End station")

    options = parser.parse_args()

    hour, minute = options.time.hour, options.time.minute
    day, month, year = options.date.day, options.date.month, options.date.year

    if month > 9:
        # Add "0" in front of single-digit months
        month = str(month).zfill(2)
    else:
        month = str(month)

    hyperdia_search(options.start_station, options.end_station, hour, minute,
                    day, month, year, options.max_routes, via=options.via,
                    output_type=options.output_type,
                    use_shinkansen=options.no_shinkansen)


if __name__ == "__main__":
    main()