#!/usr/local/bin/python3.6
# -*- coding: utf-8 -*-

""" Python code to start a RIPE Atlas UDM (User-Defined
Measurement). This one is to test X.509/PKIX certificates in TLS servers.

You'll need an API key in ~/.atlas/auth.

After launching the measurement, it downloads the results and analyzes
them, displaying the name ("subject" in X.509 parlance) or issuer.

Stéphane Bortzmeyer <stephane+framabortzmeyer.org>
"""

import json
import time
import os
import string
import re
import sys
import time
import socket
import collections

import Blaeu

# https://github.com/pyca/pyopenssl https://pyopenssl.readthedocs.org/en/stable/
import OpenSSL.crypto

config = Blaeu.Config()
# Default values
config.display = "n" #Name
config.sni = True
# Override what's in the Blaeu package
config.port = 443

class Set():
    def __init__(self):
        self.total = 0

def usage(msg=None):
    print("Usage: %s target-name-or-IP" % sys.argv[0], file=sys.stderr)
    config.usage(msg)
    print("""Also:
    --issuer or -I : displays the issuer (default is to display the name)
    --key or -k : displays the public key (default is to display the name)
    --serial or -S : displays the serial number (default is to display the name)
    --expiration or -E : displays the expiration datetime (default is to display the name)
    --no-sni : do not send the SNI (Server Name Indication) (default is to send it)""", file=sys.stderr)

def specificParse(config, option, value):
    result = True
    if option == "--issuer" or option == "-I":
        config.display = "i"
    elif option == "--key" or option == "-k":
        config.display = "k"
    elif option == "--serial" or option == "-S":
        config.display = "s"
    elif option == "--expiration" or option == "-E":
        config.display = "e"
    elif option == "--no-sni":
        config.sni = False
    else:
        result = False
    return result

(args, data) = config.parse("IkSE", ["issuer", "serial", "expiration", "key", "no-sni"], specificParse, usage)

if len(args) != 1:
    usage("Not the good number of arguments")
    sys.exit(1)
target = args[0]

if config.measurement_id is None:
    data["definitions"][0]["target"] = target
    data["definitions"][0]["type"] = "sslcert"
    data["definitions"][0]["description"] = "X.509 cert of %s" % target
    del data["definitions"][0]["size"] # Meaningless argument
    if target.find(':') > -1: # TODO: or use is_ip_address(str) from blaeu-reach?
        config.ipv4 = False
        af = 6
        if config.include is not None:
            data["probes"][0]["tags"]["include"] = copy.copy(config.include)
            data["probes"][0]["tags"]["include"].append("system-ipv6-works")
        else:
            data["probes"][0]["tags"]["include"] = ["system-ipv6-works",]
    elif re.match("^[0-9.]+$", target):
        config.ipv4 = True
        af = 4
        if config.include is not None:
            data["probes"][0]["tags"]["include"] = copy.copy(config.include)
            data["probes"][0]["tags"]["include"].append("system-ipv4-works")
        else:
            data["probes"][0]["tags"]["include"] = ["system-ipv4-works",]
    else:
        # Hostname
        if config.ipv4:
            af = 4
        else:
            af = 6
    data["definitions"][0]['af'] = af
    if config.sni:
        data["definitions"][0]['hostname'] = target

    if config.verbose:
        print(data)

    measurement = Blaeu.Measurement(data)
    if config.verbose:
            print("Measurement #%s to %s uses %i probes" % (measurement.id, target,
                                                        measurement.num_probes))
    rdata = measurement.results(wait=True, percentage_required=config.percentage_required)
else:
    measurement = Blaeu.Measurement(data=None, id=config.measurement_id)
    rdata = measurement.results(wait=False)

sets = collections.defaultdict(Set)
if config.display_probes:
    probes_sets = collections.defaultdict(Set)
print(("%s probes reported" % len(rdata)))
for result in rdata:
        if config.display_probes:
            probe_id = result["prb_id"]
        if 'cert' in result:
                # TODO: handle chains of certificates
                x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, str(result['cert'][0]))
                detail = ""
                content = x509.get_subject()
                if config.display == "i":
                    content = x509.get_issuer()
                elif config.display == "k":
                    key = x509.get_pubkey()
                    # TODO #2
                    content = "%s, type %s, %s bits" % (key, key.type(), key.bits())
                elif config.display == "s":
                    content = format(x509.get_serial_number(), '05x')
                elif config.display == "e":
                    if x509.has_expired():
                        detail = " (EXPIRED)"
                    # TODO: better format of the date?
                    content = "%s%s" % (x509.get_notAfter(), detail)
                value = "%s%s" % (content, detail) # TODO better display of the name? https://pyopenssl.readthedocs.org/en/stable/api/crypto.html#x509name-objects See also bug #2, which is related.
        else:
            if 'err' in result:
                error = result['err']
            else:
                error = result['alert']
            value = "FAILED TO GET A CERT: %s" % error
        sets[value].total += 1
        if config.display_probes:
            if value in probes_sets:
                probes_sets[value].append(probe_id)
            else:
                probes_sets[value] = [probe_id,]

sets_data = sorted(sets, key=lambda s: sets[s].total, reverse=False)
for myset in sets_data:
    detail = ""
    if config.display_probes:
        detail = "(probes %s)" % probes_sets[myset]
    print("[%s] : %i occurrences %s" % (myset, sets[myset].total, detail))

print(("Test #%s done at %s" % (measurement.id,
                                time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()))))
