#!/usr/local/bin/python2.7

from __future__ import print_function
from json import loads
import re
import os
import time
import sys
import socket
import signal
import posix
import errno
import optparse
import random
import subprocess

VERSION = "0.88" # Automatically filled in.

random.seed()

import SubnetTree

class IntervalUpdate:
    pass

class IntervalList:
    def __init__(self):
        self.ints = []
        self.start = -1

    def finish(self):
        for i in self.ints:
            if i:
                i.start += self.start
                i.end += self.start
                i.applySampleFactor()

    def writeR(self, file, top_ports):
        file = open(file, "w")
        Interval.makeRHeader(file, top_ports)
        next_start = self.start

        for i in self.ints:
            if i:
                i.formatForR(file, top_ports)
                next_start = i.end
            else:
                empty = Interval()
                empty.start = next_start
                empty.end = next_start + Options.ilen
                empty.formatForR(file, top_ports)
                next_start = empty.end

        file.close()

class Interval:
    def __init__(self):
        self.start = 1e20
        self.end = 0
        self.bytes = 0
        self.payload = 0
        self.pkts = 0
        self.frags = 0
        self.updates = 0
        self.ports = {}
        self.prots = {}
        self.servs = {}
        self.srcs = {}
        self.dsts = {}
        self.states = {}

    def update(self, iupdate, adjusttime=True):

        self.updates += 1
        self.pkts += iupdate.pkts
        self.bytes += iupdate.bytes
        self.payload += iupdate.payload
        self.frags += iupdate.frags

        if Options.bytes:
            incr = iupdate.bytes
        else:
            incr = 1

        # For packets, we need to look at the source port, too.
        if not Options.conns:
            if ( iupdate.src_port < 1024 ) or \
                ( not Ports and not Options.save_mem ) or \
                ( Ports and iupdate.src_port in Ports ) or \
                Options.storeports:
                try:
                    self.ports[iupdate.src_port] += incr
                except KeyError:
                    self.ports[iupdate.src_port] = incr

        if ( iupdate.dst_port < 1024 ) or \
            ( not Ports and not Options.save_mem ) or \
            ( Ports and iupdate.dst_port in Ports ) or \
            Options.storeports:
            try:
                self.ports[iupdate.dst_port] += incr
            except KeyError:
                self.ports[iupdate.dst_port] = incr

        try:
            self.prots[iupdate.prot] += incr
        except KeyError:
            self.prots[iupdate.prot] = incr

        try:
            self.servs[iupdate.service] += incr
        except KeyError:
            self.servs[iupdate.service] = incr

        try:
            self.states[iupdate.state] += incr
        except KeyError:
            self.states[iupdate.state] = incr

        if adjusttime:
            if iupdate.start < self.start:
                self.start = iupdate.start

            if iupdate.end > self.end:
                self.end = iupdate.end

        if not Options.save_mem and not Options.R:
            try:
                self.srcs[iupdate.src_ip] += incr
            except KeyError:
                self.srcs[iupdate.src_ip] = incr

            try:
                self.dsts[iupdate.dst_ip] += incr
            except KeyError:
                self.dsts[iupdate.dst_ip] = incr

    def applySampleFactor(self):
        if Options.factor == 1:
            return

        self.bytes *= Options.factor
        self.payload *= Options.factor
        self.pkts *= Options.factor
        self.frags *= Options.factor
        self.updates *= Options.factor

        for i in self.ports.keys():
             self.ports[i] *= Options.factor
        for i in self.prots.keys():
             self.prots[i] *= Options.factor
        for i in self.servs.keys():
             self.servs[i] *= Options.factor
        for i in self.srcs.keys():
             self.srcs[i] *= Options.factor
        for i in self.dsts.keys():
             self.dsts[i] *= Options.factor
        for i in self.states.keys():
             self.states[i] *= Options.factor

    def format(self, conns=False, title=""):
        def fmt(tag, count, total=-1, sep=" - "):
            if total >= 0:
                try:
                    return "%s %5.1f%%%s" % (tag, (float(count) / total) * 100, sep)
                except ZeroDivisionError:
                    return "%s (??%%)%s" % (tag, sep)

            return "%s %s%s" % (tag, formatVal(count), sep)

        s = "\n>== %s === %s - %s\n   - " % (title, isoTime(self.start), isoTime(self.end))

        if not conns:
            # Information for packet traces.
            s += fmt("Bytes", self.bytes) + \
                 fmt("Payload", self.payload) + \
                 fmt("Pkts", self.pkts) + \
                 fmt("Frags", self.frags, self.pkts)

            try:
                mbit = self.bytes * 8 / 1024.0 / 1024.0 / (self.end - self.start)
            except ZeroDivisionError:
                mbit = 0

            s += "MBit/s %8.1f - " % mbit

        else:
            # Information for connection summaries.
            s += fmt("Connections", self.pkts) + \
                 fmt("Payload", self.payload)

        if Options.factor != 1:
            s += "Sampling %.2f%% -" % ( 100.0 / Options.factor )

        if Options.verbose:
            ports = topx(self.ports)
            srcs = topx(self.srcs)
            dsts = topx(self.dsts)
            prots = topx(self.prots)
            servs = topx(self.servs)

            servs = [ (count, svc.replace("icmp-", "i-").replace("netbios", "nb")) for count, svc in servs ]

            # Default column widths for IP addresses.
            srcwidth = 18
            dstwidth = 18

            # Check all IP addrs to see if column widths need to be increased
            # (due to the presence of long IPv6 addresses).
            src_over = 0
            dst_over = 0
            for i in range(Options.topx):
                for dict in (srcs, dsts):
                    try:
                        item = inet_ntox(dict[i][1])
                    except IndexError:
                        continue

                    # Note: 15 is longest possible IPv4 address.
                    oversize = len(item) - 15
                    if oversize > 0:
                        if dict == srcs:
                            src_over = max(src_over, oversize)
                        elif dict == dsts:
                            dst_over = max(dst_over, oversize)

            # Increase column widths, if necessary.
            srcwidth += src_over
            dstwidth += dst_over

            s += "\n     %-5s        | %-*s        | %-*s        | %-18s | %1s |" \
                % ("Ports", srcwidth, "Sources", dstwidth, "Destinations", "Services", "Protocols")

            if conns:
                s += " States        |"
                states = (topx(self.states), 6)
            else:
                states = ({}, 0)

            s += "\n"

            addrs = []

            for i in range(Options.topx):

                s += "     "

                for (dict, length) in ((ports, 5), (srcs, srcwidth), (dsts, dstwidth), (servs, 11), (prots, 2), states):
                    try:
                        item = None
                        if dict == srcs or dict == dsts:
                            item = inet_ntox(dict[i][1])
                            if Options.resolve:
                                addrs += [dict[i][1]]
                                item += "#%d" % len(addrs)
                        else:
                            item = str(dict[i][1])

                        s += fmt("%-*s" % (length, item), dict[i][0], (Options.bytes and self.bytes or self.pkts), sep=" | ")
                    except:
                        s += " " * length + "        | "

                s += "\n"

            if Options.resolve:
                s += "\n        "
                for i in range(1, len(addrs)+1):
                    s +=  "#%d=%s  " % (i, gethostbyaddr(inet_ntox(addrs[i-1])))
                    if i % 3 == 0:
                        s += "\n        "

            s += "\n"


        return s

    def makeRHeader(f, top_ports):
        f.write("start end count bytes payload frags srcs dsts prot.tcp prot.udp prot.icmp ")
        f.write("%s " % " ".join(["state.%s" % s.lower() for s in States]))
        f.write("%s " % " ".join(["top.port.%d" % (i+1) for i in range(0,Options.topx)]))
        f.write(" ".join(["port.%d" % i for i in range(0,1024)]))
        if not Options.save_mem:
            f.write(" %s" % " ".join(["port.%d" % p[1] for p in top_ports if p[1] >= 1024]))
        f.write("\n")

    makeRHeader = staticmethod(makeRHeader)

    def formatForR(self, f, top_ports):
        f.write("%.16g %.16g %s %s %s %s " % (self.start, self.end, self.pkts, self.bytes, self.payload, self.frags))
        f.write("%s %s " % (len(self.srcs), len(self.dsts)))
        f.write("%s %s %s " % (self.prots.get(6, 0), self.prots.get(17, 0), self.prots.get(1, 0)))
        f.write("%s " % " ".join([str(self.states.get(i, 0)) for i in States]))
        f.write("%s " % " ".join([str(p[1]) for p in topx(self.ports, True)]))
        f.write(" ".join([str(self.ports.get(i, 0)) for i in range(0,1024)]))
        if not Options.save_mem:
            f.write(" %s" % " ".join([str(self.ports.get(p[1], 0)) for p in top_ports if p[1] >= 1024]))
        f.write("\n")


    def __str__(self):
        return self.format(True)

def topx(dict, fill_if_empty=False):
    top = sorted([ (count, val) for val, count in dict.items() ], reverse=True)

    # Filter out zero vals.
    top = [(val, count) for (val, count) in top if count != 0]

    if fill_if_empty and len(top) < Options.topx:
        top += [(0,0)] * Options.topx

    return top[:Options.topx]

def findInterval(time, intervals):

    if intervals.start < 0:
        intervals.start = int(time / Options.ilen) * Options.ilen

    i = (time - intervals.start) / Options.ilen
    idx = int(i)

    # Interval may be earlier than current start
    if i < 0:

        if float(idx) != i:
            # minus 1 since we will multiply by -1
            idx -= 1

        idx *= -1

        for j in intervals.ints:
            if j:
                j.start += Options.ilen * idx
                j.end += Options.ilen * idx

        intervals.ints = ([None] * idx) + intervals.ints
        intervals.start = int(time / Options.ilen) * Options.ilen
        first = time
        idx = 0

    # Interval may be later than current end
    while idx >= len(intervals.ints):
        intervals.ints += [None]

    if not intervals.ints[idx]:
        interv = Interval()
        interv.start = float(idx * Options.ilen)
        interv.end =  float((idx+1) * Options.ilen)
        intervals.ints[idx] = interv
        return interv

    return intervals.ints[idx]

def isoTime(t):
    if t == 1e20 or t == 0:
        return "N/A"
    return time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime(t))

def readPcap(file):
    global Total
    global Incoming
    global Outgoing

    proc = subprocess.Popen("ipsumdump -r %s --timestamp --src --dst --sport --dport --length --protocol --fragment --payload-length -Q" % file, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)

    for line in proc.stdout:
        if using_py3:
            line = line.decode()

        if line.startswith("!"):
            continue

        if Options.sample > 0 and random.random() > Options.sample:
            continue

        f = line.split()

        if len(f) < 10:
            print("Ignoring corrupt line: '%s'" % line.strip(), file=sys.stderr)
            continue

        try:
            time = float(f[0])
        except ValueError:
            print("Ignoring corrupt line: '%s'" % line.strip(), file=sys.stderr)
            continue

        if time < Options.mintime or time > Options.maxtime:
            continue

        if Options.chema:
            if f[6] != "T" or f[9] == "0":
                continue

        if Options.tcp and f[6] != "T":
            continue

        if Options.udp and f[6] != "U":
            continue

        iupdate = IntervalUpdate()
        iupdate.pkts = 1
        iupdate.bytes = int(f[5])
        iupdate.payload = int(f[8])
        iupdate.service = ""

        try:
            iupdate.src_ip = inet_xton(f[1])
        except socket.error:
            iupdate.src_ip = unspecified_addr(f[1])

        if iupdate.src_ip in ExcludeNets:
            continue

        try:
            iupdate.src_port = int(f[3])
        except:
            iupdate.src_port = 0

        try:
            iupdate.dst_ip = inet_xton(f[2])
        except socket.error:
            iupdate.dst_ip = unspecified_addr(f[2])

        if iupdate.dst_ip in ExcludeNets:
            continue

        try:
            iupdate.dst_port = int(f[4])
        except:
            iupdate.dst_port = 0

        try:
            iupdate.prot = Protocols[f[6]]
        except KeyError:
            iupdate.prot = 0
        iupdate.state = 0
        iupdate.start = time
        iupdate.end = time

        if f[7] != ".":
            iupdate.frags = 1
        else:
            iupdate.frags = 0

        if Options.external:
            if iupdate.src_ip in LocalNetsIntervals and iupdate.dst_ip in LocalNetsIntervals:
                continue

        Total.update(iupdate)

        if Options.ilen > 0:
            interval = findInterval(time, TotalIntervals)
            interval.update(iupdate, adjusttime=False)

        if Options.localnets:
            try:
                LocalNetsIntervals[iupdate.src_ip].update(iupdate)
                Outgoing.update(iupdate)
                if Options.ilen > 0:
                    interval = findInterval(time, OutgoingIntervals)
                    interval.update(iupdate, adjusttime=False)
            except KeyError:
                try:
                    LocalNetsIntervals[iupdate.dst_ip].update(iupdate)
                    Incoming.update(iupdate)
                    if Options.ilen > 0:
                        interval = findInterval(time, IncomingIntervals)
                        interval.update(iupdate, adjusttime=False)
                except KeyError:
                    global NonLocalCount
                    NonLocalCount += 1
                    if NonLocalCount < Options.topx:
                        NonLocalConns[(iupdate.src_ip, iupdate.dst_ip)] = 1

    status = proc.wait()
    if status != 0:
        print("ipsumdump returned exit status of %d" % status, file=sys.stderr)


Protocols = { "T": 6, "tcp": 6, "U": 17, "udp": 17, "I": 1, "icmp": 1 }
States = ["OTH", "REJ", "RSTO", "RSTOS0", "RSTR", "RSTRH", "S0", "S1", "S2", "S3", "SF", "SH", "SHR"]

def readConnSummaries(file):
    # Determine the field separator, unset field string, and field indices
    # for the specified conn.log file.
    (field_sep, unset_field, idx, max_idx_1, is_json, scope_separator) = getLogInfo(file)

    while True:
        try:
            for line in open(file):
                # Skip log metadata lines.
                if line[0] != "#":
                    parseConnLine(line, field_sep, unset_field, idx, max_idx_1, is_json, scope_separator)

        except IOError as e:
            if e.errno == errno.EINTR or e.errno == errno.EAGAIN:
                continue

            print(e, file=sys.stderr)

        return

def getLogInfo(file):
    is_json = False
    scope_separator = "."

    if Options.conn_version == 0:
        with open(file, "r") as fin:
            line = fin.readline()

            # Guess the conn.log version by checking for a metadata line.
            if line[0] == "#":
                Options.conn_version = 2
            elif line[0] == "{":
                Options.conn_version = 2
                is_json = True
                f = loads(line)
                if "id.orig_h" not in f:
                    pattern = re.compile("id(.)orig_h$")
                    for field in f:
                        m = pattern.match(field)
                        if m:
                            scope_separator = m.group(1)
            else:
                # Guess the conn.log version by looking at the number of
                # fields we have.
                m = line.split()

                if len(m) < 15:
                    Options.conn_version = 1
                else:
                    Options.conn_version = 2

    if Options.conn_version == 1:
        # Field names needed by this script, listed here in same order as
        # found in bro version 1.x conn.log.
        field_names = ("ts", "duration", "id.orig_h", "id.resp_h", "service", "id.orig_p", "id.resp_p", "proto", "orig_bytes", "resp_bytes", "conn_state")

        idx = {}
        for field in field_names:
            idx[field] = len(idx)

        # max_idx_1 is max. index value plus 1
        max_idx_1 = len(field_names)

        field_sep = " "
        unset_field = "?"

        return (field_sep, unset_field, idx, max_idx_1, is_json, scope_separator)

    # Field names needed by this script, listed here in same order as
    # found in conn.log.
    field_names = ("ts", "uid", "id.orig_h", "id.orig_p", "id.resp_h", "id.resp_p", "proto", "service", "duration", "orig_bytes", "resp_bytes", "conn_state")

    field_sep = "\t"
    unset_field = "-"
    idx = {}

    with open(file, "r") as fin:
        firstline = True

        for line in fin:
            if firstline:
                firstline = False
                if line[0] == "{":
                    is_json = True
                    f = loads(line)
                    if "id.orig_h" not in f:
                        pattern = re.compile("id(.)orig_h$")
                        for field in f:
                            m = pattern.match(field)
                            if m:
                                scope_separator = m.group(1)
                    break

            if line[0] != "#":
                break

            # Remove trailing '\n' so that it's not included in last item of
            # results from split().
            if line[-1] == "\n":
                line = line[:-1]

            if line.startswith("#separator"):
                try:
                    field_sep = line.split()[1]
                    if field_sep.startswith("\\x"):
                        field_sep = chr(int(field_sep[2:], 16))
                except (IndexError, ValueError):
                    # If no value found, then just use default.
                    print("Ignoring bad '#separator' line", file=sys.stderr)

            elif line.startswith("#unset_field"):
                try:
                    unset_field = line.split(field_sep)[1]
                except IndexError:
                    # If no value found, then just use default.
                    print("Ignoring bad '#unset_field' line", file=sys.stderr)

            elif line.startswith("#fields"):
                fields = line.split(field_sep)[1:]

                if "id.orig_h" not in fields:
                    # Either the "#fields" line is corrupt, or we're using a
                    # non-default field scope separator.
                    pattern = re.compile("id(.)orig_h$")
                    for field in fields:
                        m = pattern.match(field)
                        if m:
                            scope_separator = m.group(1)

                max_idx_1 = 0
                idx = {}
                for field in field_names:
                    try:
                        # Use original field name in "idx" (even if there is a
                        # non-default field scope separator).
                        idx[field] = fields.index(field.replace(".", scope_separator))
                    except ValueError as err:
                        # If any field is missing, then just use defaults.
                        idx = {}
                        print("Ignoring bad '#fields' line: %s" % err, file=sys.stderr)
                        break

                    max_idx_1 = max(max_idx_1, idx[field])

                max_idx_1 += 1

    # If no fields metadata was found, then just use default values.
    if not idx:
        # max_idx_1 is max. index value plus 1
        max_idx_1 = len(field_names)

        for field in field_names:
            idx[field] = len(idx)

    return (field_sep, unset_field, idx, max_idx_1, is_json, scope_separator)

def parseConnLine(line, field_sep, unset_field, idx, max_idx_1, is_json, scope_separator):
    global Total, Incoming, Outgoing, LastOutputTime, BaseTime

    if Options.sample > 0 and random.random() > Options.sample:
        return

    # Remove trailing '\n' so that it's not included in last item of
    # results from split().
    if line[-1] == "\n":
        line = line[:-1]

    if is_json:
        f = loads(line)
    else:
        f = line.split(field_sep, max_idx_1)
        if len(f) < max_idx_1:
            print("Ignoring corrupt line: %s" % line, file=sys.stderr)
            return

    if is_json:
        proto_val = f["proto"]
    else:
        proto_val = f[idx["proto"]]

    if Options.tcp and proto_val != "tcp":
        return

    if Options.udp and proto_val != "udp":
        return

    try:
        if is_json:
            time = float(f["ts"])
        else:
            time = float(f[idx["ts"]])
    except ValueError:
        print("Invalid starting time on line: %s" % line, file=sys.stderr)
        return

    if is_json:
        try:
            duration_str = f["duration"]
        except KeyError:
            duration_str = unset_field
    else:
        duration_str = f[idx["duration"]]

    if duration_str != unset_field:
        try:
            duration = float(duration_str)
        except ValueError:
            # The default unset/empty field string can be changed from "-"
            # and in that case, it's hard to know if this exception is due
            # to that or because we're looking at the wrong column entirely,
            # so just print an error and continue with the assumption of
            # an unset/empty duration column.
            print("Invalid duration on line: %s" % line, file=sys.stderr)
            duration = 0
    else:
        duration = 0

    if time < Options.mintime or (time + duration) > Options.maxtime:
        return

    if not BaseTime:
        BaseTime = time
        LastOutputTime = time

    if time - LastOutputTime > 3600:
        # print("%d hours processed" % int((time - BaseTime) / 3600), file=sys.stderr)
        LastOutputTime = time

    if is_json:
        try:
            orig_bytes_str = f["orig_bytes"]
        except KeyError:
            orig_bytes_str = unset_field

        try:
            resp_bytes_str = f["resp_bytes"]
        except KeyError:
            resp_bytes_str = unset_field
    else:
        orig_bytes_str = f[idx["orig_bytes"]]
        resp_bytes_str = f[idx["resp_bytes"]]

    try:
        bytes_orig = int(orig_bytes_str)
    except ValueError:
        bytes_orig = 0

    try:
        bytes_resp = int(resp_bytes_str)
    except ValueError:
        bytes_resp = 0

    iupdate = IntervalUpdate()
    iupdate.pkts = 1 # no. connections
    iupdate.bytes = bytes_orig + bytes_resp

    try:
        if is_json:
            iupdate.src_ip = inet_xton(f["id" + scope_separator + "orig_h"])
            iupdate.src_port = int(f["id" + scope_separator + "orig_p"])
            iupdate.dst_ip = inet_xton(f["id" + scope_separator + "resp_h"])
            iupdate.dst_port = int(f["id" + scope_separator + "resp_p"])
        else:
            iupdate.src_ip = inet_xton(f[idx["id.orig_h"]])
            iupdate.src_port = int(f[idx["id.orig_p"]])
            iupdate.dst_ip = inet_xton(f[idx["id.resp_h"]])
            iupdate.dst_port = int(f[idx["id.resp_p"]])

        if iupdate.src_ip in ExcludeNets:
            return
        if iupdate.dst_ip in ExcludeNets:
            return
        iupdate.prot = Protocols[proto_val]

    except (KeyError, ValueError):
        print("Ignoring corrupt line: %s" % line, file=sys.stderr)
        return

    try:
        if is_json:
            iupdate.service = f["service"]
        else:
            iupdate.service = f[idx["service"]]

        if iupdate.service[-1] == "?":
            iupdate.service = iupdate.service[:-1]
    except (KeyError, IndexError):
        iupdate.service = unset_field

    iupdate.frags = 0
    if is_json:
        iupdate.state = f["conn_state"]
    else:
        iupdate.state = f[idx["conn_state"]]
    iupdate.start = time
    iupdate.end = time + duration

    payload_orig = bytes_orig
    payload_resp = bytes_resp

    if duration:
        bytes_to_mbps = 8 / (1024 * 1024 * duration)

        if payload_orig * bytes_to_mbps > 700:
            # Bandwidth exceed due to Bro bug.
            if Options.conn_version == 2:
                if is_json:
                    uid = f["uid"]
                else:
                    uid = f[idx["uid"]]
                print("UID %s originator exceeds bandwidth" % uid, file=sys.stderr)
            else:
                print("%.6f originator exceeds bandwidth" % time, file=sys.stderr)
            payload_orig = 0

        if payload_resp * bytes_to_mbps > 700:
            # Bandwidth exceed due to Bro bug.
            if Options.conn_version == 2:
                if is_json:
                    uid = f["uid"]
                else:
                    uid = f[idx["uid"]]
                print("UID %s originator exceeds bandwidth" % uid, file=sys.stderr)
            else:
                print("%.6f originator exceeds bandwidth" % time, file=sys.stderr)
            payload_resp = 0

    iupdate.payload = payload_orig + payload_resp

    if Options.external:
        if iupdate.src_ip in LocalNetsIntervals and iupdate.dst_ip in LocalNetsIntervals:
            return

    Total.update(iupdate)

    if Options.ilen > 0:
        interval = findInterval(time, TotalIntervals)
        interval.update(iupdate, adjusttime=False)

    if Options.localnets:

        try:
            LocalNetsIntervals[iupdate.src_ip].update(iupdate)
            Outgoing.update(iupdate)
            if Options.ilen > 0:
                interval = findInterval(time, OutgoingIntervals)
                interval.update(iupdate, adjusttime=False)
        except KeyError:
            try:
                LocalNetsIntervals[iupdate.dst_ip].update(iupdate)
                Incoming.update(iupdate)
                if Options.ilen > 0:
                    interval = findInterval(time, IncomingIntervals)
                    interval.update(iupdate, adjusttime=False)
            except KeyError:
                global NonLocalCount
                NonLocalCount += 1
                if NonLocalCount < Options.topx:
                    NonLocalConns[(iupdate.src_ip, iupdate.dst_ip)] = 1

Cache = {}

def gethostbyaddr( ip, timeout = 5, default = "<???>" ):

    try:
        return Cache[ip]
    except LookupError:
        pass

    host = default
    ( pin, pout ) = os.pipe()

    pid = os.fork()

    if not pid:
        # Child
        os.close( pin )
        try:
            host = socket.gethostbyaddr( ip )[0]
        except socket.herror:
            pass

        if using_py3:
            host = host.encode()

        os.write( pout, host )
        posix._exit(127)

    #Parent
    os.close( pout )

    signal.signal( signal.SIGALRM, lambda sig, frame: os.kill( pid, signal.SIGKILL ) )
    signal.alarm( timeout )

    try:
        childpid, status = os.waitpid(pid, 0)

        if os.WIFEXITED(status) and os.WEXITSTATUS(status) == 127:
            host = os.read(pin, 8192)
            if using_py3:
                host = host.decode()
    except OSError:
        # If the child process is killed while waitpid() is waiting, then
        # only Python 2 (not Python 3) raises OSError.
        pass

    signal.alarm( 0 )

    os.close( pin )

    Cache[ip] = host

    return host

def formatVal(val):
    for (prefix, unit, factor) in (("", "g", 1e9), ("", "m", 1e6), ("", "k", 1e3), (" ", "", 1e0)):
        if val >= factor:
            return "%s%3.1f%s" % (prefix, val / factor, unit)
    return val # Should not happen

def readNetworks(file):

    nets = []

    for line in open(file):
        line = line.strip()
        if not line or line.startswith("#"):
            continue

        fields = line.split()
        nets += [(fields[0], " ".join(fields[1:]))]

    return nets

def inet_xton(ipstr):
    family = socket.AF_INET

    if ':' in ipstr:
        family = socket.AF_INET6

    return socket.inet_pton(family, ipstr)

def inet_ntox(ipaddr):
    family = socket.AF_INET

    if len(ipaddr) != 4:
        family = socket.AF_INET6

    return socket.inet_ntop(family, ipaddr)

def unspecified_addr(ipstr):
    if ':' in ipstr:
        ipaddr = inet_xton("::")
    else:
        ipaddr = inet_xton("0.0.0.0")

    return ipaddr


####### Main

# Check which version of Python is running
if sys.version_info[0] >= 3:
    using_py3 = True
else:
    using_py3 = False

Total = Interval()
Incoming = Interval()
Outgoing = Interval()

TotalIntervals = IntervalList()
IncomingIntervals = IntervalList()
OutgoingIntervals = IntervalList()

BaseTime = None
LastOutputTime = None

LocalNets = {}
LocalNetsIntervals = SubnetTree.SubnetTree(True)
NonLocalConns = {}
NonLocalCount = 0

Ports = None

ExcludeNets = SubnetTree.SubnetTree(True)

optparser = optparse.OptionParser(usage="%prog [options] <pcap-file>|<conn-summaries>", version=VERSION)
optparser.add_option("-b", "--bytes", action="store_true", dest="bytes", default=False,
                     help="count fractions in terms of bytes rather than packets/connections")
optparser.add_option("-c", "--conn-summaries", action="store_true", dest="conns", default=False,
                     help="input file contains Bro connection summaries")
optparser.add_option("--conn-version", action="store", type="int", dest="conn_version", default=0,
                     help="when used with -c, specify '1' for use with Bro version 1.x connection logs, or '2' for use with Bro 2.x format. '0' tries to guess the format")
optparser.add_option("-C", "--chema", action="store_true", dest="chema", default=False,
                     help="for packets: include only TCP, ignore when seq==0")
optparser.add_option("-e", "--external", action="store_true", dest="external", default=False,
                     help="ignore strictly internal traffic")
optparser.add_option("-E", "--exclude-nets", action="store", type="string", dest="excludenets", default=None,
                     help="excludes CIDRs in file from analysis")
optparser.add_option("-i", "--intervals", action="store", type="string", dest="ilen", default="0",
                     help="create summaries for time intervals of given length (seconds, or use suffix of 'h' for hours, or 'm' for minutes)")
optparser.add_option("-l", "--local-nets", action="store", type="string", dest="localnets", default=None,
                     help="differentiate in/out based on CIDRs in file")
optparser.add_option("-n", "--topn", action="store", type="int", dest="topx", default=10,
                     help="show top <n>")
optparser.add_option("-p", "--ports", action="store", type="string", dest="ports", default=None,
                     help="include only ports listed in file")
optparser.add_option("-P", "--write-ports", action="store", type="string", dest="storeports", default=None,
                     help="write top total/incoming/outgoing ports into file")
optparser.add_option("-r", "--resolve-host-names", action="store_true", dest="resolve", default=False,
                     help="resolve host names")
optparser.add_option("-R", "--R", action="store", type="string", dest="R", default=None, metavar="tag",
                     help="write output suitable for R into files <tag.*>")
optparser.add_option("-s", "--sample-factor", action="store", type="int", dest="factor", default=1,
                     help="sample factor of input")
optparser.add_option("-S", "--do-sample", action="store", type="float", dest="sample", default=-1.0,
                     help="sample input with probability (0.0 < prob < 1.0)")
optparser.add_option("-m", "--save-mem", action="store_true", dest="save_mem", default=False,
                     help="do not make memory-expensive statistics")
optparser.add_option("-t", "--tcp", action="store_true", dest="tcp", default=False,
                     help="include only TCP")
optparser.add_option("-u", "--udp", action="store_true", dest="udp", default=False,
                     help="include only UDP")
optparser.add_option("-U", "--min-time", action="store", type="string", dest="mintime", default=None,
                     help="minimum time in ISO format (e.g. 2005-12-31-23-59-00)")
optparser.add_option("-v", "--verbose", action="store_true", dest="verbose", default=False,
                     help="show top-n for every interval")
optparser.add_option("-V", "--max-time", action="store", type="string", dest="maxtime", default=None,
                     help="maximum time in ISO format")

(Options, args) = optparser.parse_args()

if len(args) > 2:
    optparser.error("Wrong number of arguments")

file = "-"

if len(args) > 0:
    file = args[0]

if Options.external and not Options.localnets:
    print("Need -l for -e.", file=sys.stderr)
    sys.exit(1)

if Options.topx < 0:
    print("Top-n value cannot be negative", file=sys.stderr)
    sys.exit(1)

# If reading pcap traces, then ipsumdump is required.
if not Options.conns:
    proc = subprocess.Popen("ipsumdump -v", shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
    if proc.wait() != 0:
        print("Can't read pcap trace: 'ipsumdump' is required.", file=sys.stderr)
        sys.exit(1)

# Make per-interval summaries.
if Options.ilen:
    if Options.ilen.endswith("m"):
        Options.ilen = int(Options.ilen[:-1]) * 60
    elif Options.ilen.endswith("h"):
        Options.ilen = int(Options.ilen[:-1]) * 60 * 60
    else:
        Options.ilen = int(Options.ilen)

    if Options.ilen < 0:
        print("Interval length cannot be negative", file=sys.stderr)
        sys.exit(1)


# Read local networks.
if Options.localnets:

    for (net, txt) in readNetworks(Options.localnets):
        try:
            i = Interval()
            LocalNetsIntervals[net] = i
            LocalNets[net] = (txt, i)
        except KeyError:
            print("Can't parse local network '%s'" % net, file=sys.stderr)

# Read networks to exclude.
if Options.excludenets:
    for (net, txt) in readNetworks(Options.excludenets):
        try:
            ExcludeNets[net] = txt
        except KeyError:
            print("Can't parse exclude network '%s'" % net, file=sys.stderr)

# Read ports file.
if Options.ports:
    Ports = {}
    for line in open(Options.ports):
        Ports[int(line.strip())] = 1

# Parse time-range if given.
if Options.mintime:
    Options.mintime = time.mktime(time.strptime(Options.mintime, "%Y-%m-%d-%H-%M-%S"))
else:
    Options.mintime = 0

if Options.maxtime:
    Options.maxtime = time.mktime(time.strptime(Options.maxtime, "%Y-%m-%d-%H-%M-%S"))
else:
    Options.maxtime = 1e20

if Options.factor <= 0:
    print("Sample factor must be > 0", file=sys.stderr)
    sys.exit(1)

if Options.sample > 0:
    if Options.sample > 1.0:
        print("Sample probability cannot be > 1", file=sys.stderr)
        sys.exit(1)
    Options.factor = 1.0 / Options.sample

if file == "-":
    file = "/dev/stdin"

try:
    if Options.conns:
        readConnSummaries(file)
    else:
        readPcap(file)
except KeyboardInterrupt:
    pass

TotalIntervals.finish()
IncomingIntervals.finish()
OutgoingIntervals.finish()

Total.applySampleFactor()
Incoming.applySampleFactor()
Outgoing.applySampleFactor()

unique = {}
for (count, port) in topx(Total.ports) + topx(Incoming.ports) + topx(Outgoing.ports):
    unique[port] = (count, port)

top_ports = sorted(unique.values(), key=lambda x: x[1])

if Options.storeports:
    f = open(Options.storeports, "w")
    for p in top_ports:
        f.write("%s\n" % p[1])
    f.close()

if Options.R:
    file = open(Options.R + ".dat", "w")

    file.write("tag ")
    Interval.makeRHeader(file, top_ports)
    file.write("total ")
    Total.formatForR(file, top_ports)

    file.write("incoming ")
    Incoming.formatForR(file, top_ports)
    file.write("outgoing ")
    Outgoing.formatForR(file, top_ports)

    for (net, data) in LocalNets.items():

        (txt, i) = data

        if i.updates:
            file.write("%s " % net.replace(" ", "_"))
            i.start += TotalIntervals.start
            i.end += TotalIntervals.start
            i.applySampleFactor()
            i.formatForR(file, top_ports)

    file.close()

    TotalIntervals.writeR(Options.R + ".total.dat", top_ports)
    IncomingIntervals.writeR(Options.R + ".incoming.dat", top_ports)
    OutgoingIntervals.writeR(Options.R + ".outgoing.dat", top_ports)

    sys.exit(0)

for i in TotalIntervals.ints:
    if i:
        print(i.format(conns=Options.conns))

Options.verbose = True

print(Total.format(conns=Options.conns, title="Total"))

locals = list(LocalNets.keys())

for net in locals:
    (txt, i) = LocalNets[net]
    if i.updates:
        i.applySampleFactor()

if locals:

    type = "packets"
    if Options.conns:
        type = "connections"

    locals.sort(key=lambda x: LocalNets[x][1].pkts, reverse=True)

    print("\n>== Top %d local networks by number of %s\n" % (Options.topx, type))

    for i in range(min(len(locals), Options.topx)):
        print("    %2d %5s  %-16s %s " % (i+1, formatVal(LocalNets[locals[i]][1].pkts), locals[i], LocalNets[locals[i]][0]))
    print()

    if len(NonLocalConns):
        print("\n>== %d %s did not have any local address. Here are the first %d:\n" % (NonLocalCount, type, Options.topx))

        for (src,dst) in sorted(NonLocalConns.keys()):
            print("    %s <-> %s" % (inet_ntox(src), inet_ntox(dst)))

if Options.localnets:
    print(Incoming.format(conns=Options.conns, title="Incoming"))
    print(Outgoing.format(conns=Options.conns, title="Outgoing"))

for net in locals:
    (txt, i) = LocalNets[net]

    if i.updates:
        print(i.format(conns=Options.conns, title=net + " " + txt))

print("First: %16s (%.6f) Last: %s %.6f" % (isoTime(Total.start), Total.start, isoTime(Total.end), Total.end))
