/* 
   socket handling routines
   Copyright (C) 1998, 1999, 2000, Joe Orton <joe@orton.demon.co.uk>, 
   except where otherwise indicated.

   Portions are:
   Copyright (C) 1999-2000 Tommi Komulainen <Tommi.Komulainen@iki.fi>
   Originally under GPL in Mutt, http://www.mutt.org/
   Relicensed under LGPL for neon, http://www.webdav.org/neon/

   This library is free software; you can redistribute it and/or
   modify it under the terms of the GNU Library General Public
   License as published by the Free Software Foundation; either
   version 2 of the License, or (at your option) any later version.
   
   This library is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
   Library General Public License for more details.

   You should have received a copy of the GNU Library General Public
   License along with this library; if not, write to the Free
   Software Foundation, Inc., 59 Temple Place - Suite 330, Boston,
   MA 02111-1307, USA

   The sock_readline() function is:

   Copyright (c) 1999 Eric S. Raymond

   Permission is hereby granted, free of charge, to any person
   obtaining a copy of this software and associated documentation
   files (the "Software"), to deal in the Software without
   restriction, including without limitation the rights to use, copy,
   modify, merge, publish, distribute, sublicense, and/or sell copies
   of the Software, and to permit persons to whom the Software is
   furnished to do so, subject to the following conditions:

   The above copyright notice and this permission notice shall be
   included in all copies or substantial portions of the Software.

   THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
   EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
   MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
   NONINFRINGEMENT.  IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
   HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
   WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
   OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
   DEALINGS IN THE SOFTWARE.

   $Id: socket.c,v 1.25.2.4 2000/10/26 21:28:45 joe Exp $ 
*/

#include <config.h>

#include <sys/types.h>
#ifdef HAVE_SYS_TIME_H
#include <sys/time.h>
#endif
#include <sys/socket.h>
#include <sys/stat.h>

#ifdef HAVE_SYS_SELECT_H
#include <sys/select.h>
#endif

#include <netinet/in.h>
#include <arpa/inet.h>
#include <netdb.h>

#include <errno.h>

#include <fcntl.h>
#include <stdio.h>
#ifdef HAVE_STRING_H
#include <string.h>
#endif
#ifdef HAVE_STRINGS_H
#include <strings.h>
#endif 
#ifdef HAVE_STDLIB_H
#include <stdlib.h>
#endif /* HAVE_STDLIB_H */
#ifdef HAVE_UNISTD_H
#include <unistd.h>
#endif /* HAVE_UNISTD_H */

#include "neon_i18n.h"
#include "string_utils.h"
#include "http_utils.h"
#include "nsocket.h"
#include "ne_alloc.h"

static sock_progress progress_cb = NULL;
static sock_notify notify_cb = NULL;
static void *progress_ud, *notify_ud;

#ifdef ENABLE_SSL
#include <openssl/ssl.h>
#include <openssl/err.h>

/* Whilst the OpenSSL interface *looks* like it is not thread-safe, it
 * appears to do horrible gymnastics to be thread-safe internally. */
#define ERROR_SSL_STRING (ERR_reason_error_string(ERR_get_error()))

#endif

struct nsocket_s {
    int fd;
    const char *error; /* Store error string here */
#ifdef ENABLE_SSL
    SSL *ssl;
    SSL_CTX *ssl_ctx;
    nssl_context *ctx;
#endif
};

struct nssl_context_s {
    nssl_accept cert_accept;
    void *accept_ud; /* userdata for callback */
    unsigned int disable_tlsv1:1;
    unsigned int disable_sslv2:1;
    unsigned int disable_sslv3:1;
    const char *cert_file;
};
    
void sock_register_progress(sock_progress cb, void *userdata)
{
    progress_cb = cb;
    progress_ud = userdata;
}

void sock_register_notify(sock_notify cb, void *userdata)
{
    notify_cb = cb;
    notify_ud = userdata;
}

void sock_call_progress(off_t progress, off_t total)
{
    if (progress_cb) {
	(*progress_cb)(progress_ud, progress, total);
    }
}

int sock_init(void)
{
#ifdef ENABLE_SSL

    SSL_load_error_strings();
    SSL_library_init();

    DEBUG(DEBUG_SOCKET, "Initialized SSL.\n");
#endif

    return 0;
}

/* sock_read is read() with a timeout of SOCKET_READ_TIMEOUT. */
int sock_read(nsocket *sock, char *buffer, size_t count) 
{
    int ret;
    ret = sock_block(sock, SOCKET_READ_TIMEOUT);
    if (ret == 0) {
	/* Got data */
	do {
#ifdef ENABLE_SSL
	    if (sock->ssl) {
		ret = SSL_read(sock->ssl, buffer, count);
	    } else {
#endif
	    ret = read(sock->fd, buffer, count);
#ifdef ENABLE_SSL
	    }
#endif
	} while (ret == -1 && errno == EINTR);
	if (ret < 0) {
	    sock->error = strerror(errno);
	    ret = SOCK_ERROR;
	}
    }
    return ret;
}

/* sock_peek is recv(...,MSG_PEEK) with a timeout of SOCKET_TIMEOUT.
 * Returns length of data read or SOCK_* on error */
int sock_peek(nsocket *sock, char *buffer, size_t count) 
{
    int ret;
    ret = sock_block(sock, SOCKET_READ_TIMEOUT);
    if (ret < 0) {
	return ret;
    }
    /* Got data */
#ifdef ENABLE_SSL
    if (sock->ssl) {
	ret = SSL_peek(sock->ssl, buffer, count);
	/* TODO: This is the fetchmail fix as in sock_readline.
	 * Do we really need it? */
	if (ret == 0) {
	    if (sock->ssl->shutdown) {
		return SOCK_CLOSED;
	    }
	    if (0 != ERR_get_error()) {
		sock->error = ERROR_SSL_STRING;
		return SOCK_ERROR;
	    }
	}
    } else {
#endif
    do {
	ret = recv(sock->fd, buffer, count, MSG_PEEK);
    } while (ret == -1 && errno == EINTR);
#ifdef ENABLE_SSL
    }
#endif
    /* According to the Single Unix Spec, recv() will return
     * zero if the socket has been closed the other end. */
    if (ret == 0) {
	ret = SOCK_CLOSED;
    } else if (ret < 0) {
	sock->error = strerror(errno);
	ret = SOCK_ERROR;
    } 
    return ret;
}

/* Blocks waiting for read input on the given socket for the given time.
 * Returns:
 *    0 if data arrived
 *    SOCK_TIMEOUT if data did not arrive before timeout
 *    SOCK_ERROR on error
 */
int sock_block(nsocket *sock, int timeout) 
{
    struct timeval tv;
    fd_set fds;
    int ret;

#ifdef ENABLE_SSL
    if (sock->ssl) {
	/* There may be data already available in the 
	 * SSL buffers */
	if (SSL_pending(sock->ssl)) {
	    return 0;
	}
	/* Otherwise, we should be able to select on
	 * the socket as per normal. Probably? */
    }
#endif

    /* Init the fd set */
    FD_ZERO(&fds);
    FD_SET(sock->fd, &fds);
    /* Set the timeout */
    tv.tv_sec = timeout;
    tv.tv_usec = 0;
    do {
	ret = select(sock->fd+1, &fds, NULL, NULL, &tv);
    } while (ret == -1 && errno == EINTR);

    switch(ret) {
    case 0:
	return SOCK_TIMEOUT;
    case -1:
	sock->error = strerror(errno);
	return SOCK_ERROR;
    default:
	return 0;
    }
}

/* Send the given line down the socket with CRLF appended. 
 * Returns 0 on success or SOCK_* on failure. */
int sock_sendline(nsocket *sock, const char *line) 
{
    char *buffer;
    int ret;
    CONCAT2(buffer, line, "\r\n");
    ret = sock_send_string(sock, buffer);
    free(buffer);
    return ret;
}

/* Reads from fd, passing blocks to reader, also calling
 * fe_t_p. 
 * Returns 0 on success or SOCK_* on error. */
int sock_readfile_blocked(nsocket *sock, off_t length,
			  sock_block_reader reader, void *userdata) 
{
    char buffer[BUFSIZ];
    int ret;
    off_t done = 0;
    do {
	ret = sock_read(sock, buffer, BUFSIZ);
	if (ret < 0) {
	    return ret;
	} 
	done += ret;
	sock_call_progress(done, length);
	(*reader)(userdata, buffer, ret);
    } while ((done < length) && ret);
    return 0;
}


/* Send a block of data down the given fd.
 * Returns 0 on success or SOCK_* on failure */
int sock_fullwrite(nsocket *sock, const char *data, size_t length) 
{
    ssize_t wrote;

#ifdef ENABLE_SSL
    if (sock->ssl) {
	/* joe: ssl.h says SSL_MODE_ENABLE_PARTIAL_WRITE must 
	 * be enabled to have SSL_write return < length... 
	 * so, SSL_write should never return < length. */
	wrote = SSL_write(sock->ssl, data, length);
	if (wrote >= 0 && wrote < length) {
	    DEBUG(DEBUG_SOCKET, "SSL_write returned less than length!\n");
	    sock->error = ERROR_SSL_STRING;
	    return SOCK_ERROR;
	}
    } else {
#endif
	const char *pnt = data;
	ssize_t sent = 0;

	while (sent < length) {
	    wrote = write(sock->fd, pnt, length-sent);
	    if (wrote < 0) {
		if (errno == EINTR) {
		    continue;
		} else if (errno == EPIPE) {
		    return SOCK_CLOSED;
		} else {
		    sock->error = strerror(errno);
		    return SOCK_ERROR;
		}
	    }
	    sent += wrote;
	    pnt += wrote;
#ifdef ENABLE_SSL
	}
#endif
    }
    return 0;
}

/* Sends the given string down the given socket.
 * Returns 0 on success or -1 on failure. */
int sock_send_string(nsocket *sock, const char *data) 
{
    return sock_fullwrite(sock, data, strlen(data));
}

/* This is from from Eric Raymond's fetchmail (SockRead() in socket.c)
 * since I wouldn't have a clue how to do it properly.
 * This function is Copyright 1999 (C) Eric Raymond.
 * Modifications Copyright 2000 (C) Joe Orton
 */
int sock_readline(nsocket *sock, char *buf, int len)
{
    char *newline, *bp = buf;
    int n;

    do {
	/* 
	 * The reason for these gymnastics is that we want two things:
	 * (1) to read \n-terminated lines,
	 * (2) to return the true length of data read, even if the
	 *     data coming in has embedded NULS.
	 */
#ifdef	ENABLE_SSL

	if (sock->ssl) {
	    /* Hack alert! */
	    /* OK...  SSL_peek works a little different from MSG_PEEK
	       Problem is that SSL_peek can return 0 if there is no
	       data currently available.  If, on the other hand, we
	       loose the socket, we also get a zero, but the SSL_read
	       then SEGFAULTS!  To deal with this, we'll check the
	       error code any time we get a return of zero from
	       SSL_peek.  If we have an error, we bail.  If we don't,
	       we read one character in SSL_read and loop.  This
	       should continue to work even if they later change the
	       behavior of SSL_peek to "fix" this problem...  :-(*/
	    DEBUG(DEBUG_SOCKET, "SSL readline... \n");
	    if ((n = SSL_peek(sock->ssl, bp, len)) < 0) {
		sock->error = ERROR_SSL_STRING;
		return(-1);
	    }
	    if (0 == n) {
		/* SSL_peek says no data...  Does he mean no data
		   or did the connection blow up?  If we got an error
		   then bail! */
		DEBUG(DEBUG_SOCKET, "SSL_Peek says no data!\n");
		/* Check properly to see if the connection has closed */
		if (sock->ssl->shutdown) {
		    DEBUG(DEBUG_SOCKET, "SSL says shutdown.");
		    return SOCK_CLOSED;
		} else if (0 != (n = ERR_get_error())) {
		    DEBUG(DEBUG_SOCKET, "SSL error occured.\n");
		    sock->error = ERROR_SSL_STRING;
		    return -1;
		}
		    
		/* We didn't get an error so read at least one
		   character at this point and loop */
		n = 1;
		/* Make sure newline start out NULL!  We don't have a
		 * string to pass through the strchr at this point yet
		 * */
		newline = NULL;
	    } else if ((newline = memchr(bp, '\n', n)) != NULL)
		n = newline - bp + 1;
	    n = SSL_read(sock->ssl, bp, n);
	    DEBUG(DEBUG_SOCKET, "SSL_read returned %d\n", n);
	    if (n == -1) {
		sock->error = ERROR_SSL_STRING;
		return(-1);
	    }
	    /* Check for case where our single character turned out to
	     * be a newline...  (It wasn't going to get caught by
	     * the strchr above if it came from the hack...). */
	    if (NULL == newline && 1 == n && '\n' == *bp) {
		/* Got our newline - this will break
				out of the loop now */
		newline = bp;
	    }
	} else {
#endif
	    if ((n = sock_peek(sock, bp, len)) <= 0)
		return n;
	    if ((newline = memchr(bp, '\n', n)) != NULL)
		n = newline - bp + 1;
	    if ((n = sock_read(sock, bp, n)) < 0)
		return n;
#ifdef ENABLE_SSL
	}
#endif
	bp += n;
	len -= n;
	if (len < 1) {
	    sock->error = _("Line too long");
	    return SOCK_FULL;
	}
    } while (!newline && len);
    *bp = '\0';
    return bp - buf;
}

/*** End of ESR-copyrighted section ***/

/* Reads readlen bytes from fd and write to sock.
 * If readlen == -1, then it reads from srcfd until EOF.
 * Returns number of bytes written to destfd, or -1 on error.
 */
int sock_transfer(int fd, nsocket *sock, off_t readlen) 
{
    char buffer[BUFSIZ];
    size_t curlen; /* total bytes yet to read from srcfd */
    off_t sumwrlen; /* total bytes written to destfd */

    if (readlen == -1) {
	curlen = BUFSIZ; /* so the buffer size test works */
    } else {
	curlen = readlen; /* everything to do */
    }
    sumwrlen = 0; /* nowt done yet */

    while (curlen > 0) {
	int rdlen, wrlen;

	/* Get a chunk... if the number of bytes that are left to read
	 * is less than the buffer size, only read that many bytes. */
	rdlen = read(fd, buffer, (readlen==-1)?BUFSIZ:(min(BUFSIZ, curlen)));
	sock_call_progress(sumwrlen, readlen);
	if (rdlen < 0) { 
	    if (errno == EPIPE) {
		return SOCK_CLOSED;
	    } else {
		sock->error = strerror(errno);
		return SOCK_ERROR;
	    }
	} else if (rdlen == 0) { 
	    /* End of file... get out of here */
	    break;
	}
	if (readlen != -1)
	    curlen -= rdlen;

	/* Otherwise, we have bytes!  Write them to destfd */
	
	wrlen = sock_fullwrite(sock, buffer, rdlen);
	if (wrlen < 0) { 
	    return wrlen;
	}

	sumwrlen += rdlen;
    }
    sock_call_progress(sumwrlen, readlen);
    return sumwrlen;
}

/* Reads buflen bytes into buffer until it's full.
 * Returns 0 on success, -1 on error */
int sock_fullread(nsocket *sock, char *buffer, int buflen) 
{
    char *pnt; /* current position within buffer */
    int len;
    pnt = buffer;
    while (buflen > 0) {
	len = sock_read(sock, pnt, buflen);
	if (len < 0) return len;
	buflen -= len;
	pnt += len;
    }
    return 0;
}

/* Do a name lookup on given hostname, writes the address into
 * given address buffer. Return -1 on failure.
 */
int sock_name_lookup(const char *hostname, struct in_addr *addr) 
{
    struct hostent *hp;
    unsigned long laddr;
    
    if (notify_cb)
	(*notify_cb)(notify_ud, sock_namelookup, hostname);
    
    /* TODO?: a possible problem here, is that if we are passed an
     * invalid IP address e.g. "500.500.500.500", then this gets
     * passed to gethostbyname and returned as "Host not found".
     * Arguably wrong, but maybe difficult to detect correctly what is
     * an invalid IP address and what is a hostname... can hostnames
     * begin with a numeric character? */
    laddr = (unsigned long)inet_addr(hostname);
    if ((int)laddr == -1) {
	/* inet_addr failed. */
	hp = gethostbyname(hostname);
	if (hp == NULL) {
#if 0
	    /* Need to get this back somehow, but we don't have 
	     * an nsocket * yet... */
	    switch(h_errno) {
	    case HOST_NOT_FOUND:
		sock->error = _("Host not found");
		break;
	    case TRY_AGAIN:
		sock->error = _("Host not found (try again later?)");
		break;
	    case NO_ADDRESS:
		sock->error = _("Host exists but has no address.");
		break;
	    case NO_RECOVERY:
	    default:
		sock->error = _("Non-recoverable error in resolver library.");
		break;
	    }
#endif
	    return SOCK_ERROR;
	}
	memcpy(addr, hp->h_addr, hp->h_length);
    } else {
	addr->s_addr = laddr;
    }
    return 0;
}

static nsocket *create_sock(int fd)
{
    nsocket *sock = ne_calloc(sizeof *sock);
    sock->fd = fd;
    return sock;
}

/* Opens a socket to the given port at the given address.
 * Returns -1 on failure, or the socket on success. 
 * portnum must be in HOST byte order */
nsocket *sock_connect_u(const struct in_addr addr, int portnum, int call_fe) 
{
    struct sockaddr_in sa;
    int fd;

    /* Create the socket */
    fd = socket(AF_INET, SOCK_STREAM, 0);
    if (fd < 0)
	return NULL;
    /* Connect the nsocket */
    sa.sin_family = AF_INET;
    sa.sin_port = htons(portnum); /* host -> net byte orders */
    sa.sin_addr = addr;
    if (call_fe && notify_cb) (*notify_cb)(notify_ud, sock_connecting, NULL);
    if (connect(fd, (struct sockaddr *)&sa, sizeof(struct sockaddr_in)) < 0) {
	(void) close(fd);
	return NULL;
    }
    if (call_fe && notify_cb) (*notify_cb)(notify_ud, sock_connected, NULL);
    /* Success - return the nsocket */
    return create_sock(fd);
}

nsocket *sock_accept(int listener) 
{
    int fd = accept(listener, NULL, NULL);
    if (fd > 0) {
	return create_sock(fd);
    } else {
	return NULL;
    }
}

int sock_get_fd(nsocket *sock)
{
    return sock->fd;
}

nsocket *sock_connect(const struct in_addr addr, int portnum) {
    return sock_connect_u(addr, portnum, 1);
}

nssl_context *sock_create_ssl_context(void)
{
    nssl_context *ctx = ne_calloc(sizeof *ctx);
    return ctx;
}

void sock_destroy_ssl_context(nssl_context *ctx)
{
    free(ctx);
}

void sock_disable_tlsv1(nssl_context *c)
{
    c->disable_tlsv1 = 1;
}
void sock_disable_sslv2(nssl_context *c)
{
    c->disable_sslv2 = 1;
}
void sock_disable_sslv3(nssl_context *c)
{
    c->disable_sslv3 = 1;
}

int sock_make_secure(nsocket *sock, nssl_context *ctx)
{
#ifdef ENABLE_SSL
    int ret;

    sock->ssl_ctx = SSL_CTX_new(SSLv23_client_method());
    if (ctx) {
	if (ctx->disable_tlsv1) {
	    SSL_CTX_set_options(sock->ssl_ctx, SSL_OP_NO_TLSv1);
	}
	if (ctx->disable_sslv2) {
	    SSL_CTX_set_options(sock->ssl_ctx, SSL_OP_NO_SSLv2);
	}
	if (ctx->disable_sslv3) {
	    SSL_CTX_set_options(sock->ssl_ctx, SSL_OP_NO_SSLv3);
	}
    }

    sock->ssl = SSL_new(sock->ssl_ctx);
    if (!sock->ssl) {
	sock->error = ERROR_SSL_STRING;
	/* Usually goes wrong because: */
	fprintf(stderr, "Have you called sock_init()!?\n");
	SSL_CTX_free(sock->ssl_ctx);
	return SOCK_ERROR;
    }
    
    SSL_set_fd(sock->ssl, sock->fd);
    
    ret = SSL_connect(sock->ssl);
    if (ret == -1) {
	sock->error = ERROR_SSL_STRING;
	SSL_free(sock->ssl);
	SSL_CTX_free(sock->ssl_ctx);
	return SOCK_ERROR;
    }

#if 0
    /* Tommi Komulainen <Tommi.Komulainen@iki.fi> has donated his SSL
     * cert verification from the mutt IMAP/SSL code under the
     * LGPL... it will plug in here */
    ret = sock_check_certicate(sock);
    if (ret) {
	SSL_free(sock->ssl);
	SSL_CTX_free(sock->ssl_ctx);
	return ret;
    }
#endif

    if (notify_cb) (*notify_cb)(notify_ud, sock_secure_details, 
				  SSL_get_version(sock->ssl));
    DEBUG(DEBUG_SOCKET, "SSL connected: version %s\n", 
	   SSL_get_version(sock->ssl));
    return 0;
#else
    sock->error = _("This application does not have SSL support.");
    return SOCK_ERROR;
#endif
}

const char *sock_get_error(nsocket *sock)
{
    return sock->error;
}

/* Closes given nsocket */
int sock_close(nsocket *sock) {
    int ret;
#ifdef ENABLE_SSL
    if (sock->ssl) {
	SSL_free(sock->ssl);
	SSL_CTX_free(sock->ssl_ctx);
    }
#endif
    ret = close(sock->fd);
    free(sock);
    return ret;
}

/* Returns HOST byte order port of given name */
int sock_service_lookup(const char *name) {
    struct servent *ent;
    ent = getservbyname(name, "tcp");
    if (ent == NULL) {
	return 0;
    } else {
	return ntohs(ent->s_port);
    }
}
