/*
 * This is a super-simple implementation of the functionality from wget, a Linux program whose main
 * functionality is to retrieve a file specified on the command line, and put it on the local disk.
 * It has a lot of more advanced options for crawling websites and such, but the most common use
 * case is:
 *     wget http://somewhere.com/path/to/file.ext
 * which then puts 'file.ext' in the local working directory.
 *
 * To implement this, we need to know just enough of the HTTP protocol to send the server a
 * well-formed request for the right file, and parse the file itself out of the server's response.
 */
#include <arpa/inet.h>
#include <netdb.h>
#include <stdlib.h>
#include <string.h>
#include <libgen.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <iostream>
#include <stdio.h>
#include <errno.h>
#include <assert.h>
#include <unistd.h>
#include <sys/stat.h>
#include <fcntl.h>

#define BUF_SIZE 4096

// This is the shape of the simplest HTTP request that will work on most webservers and get them to
// close the connection after sending a response.  We need to use this format string with a
// function like snprintf() to stitch in the file to request, and the hostname we think we're
// talking to.
#define GET_FMT_STRING "GET /%s HTTP1.1\r\nHost: %s\r\nConnection: close\r\n\r\n"

// By convention, web servers run on port 80.  Some run on other ports, but it's uncommon so we'll
// hard-code this assumption.
#define HTTP_PORT 80

// Webserver responses are a bunch of field names with values, with information like the date that
// a page was last modified, and even useful information like the size of the response.  But to
// keep things simple, we'll just skip to the end of the header, and since we ask the server to
// close the connection after its response, everything after the header is the file we requested
// (or a nicely formatted error message if the file doesn't exist)
#define HEADER_FINISH "\r\n\r\n"

using namespace std;

// StableWrite() is a wrapper around write(2) that does the standard looping, crashing out on error
void StableWrite(int fd, char *buf, size_t buflen);

// Usage() prints the usage information for the program
void Usage(char *progname);

// LookupName() does a DNS lookup on a server address, and initializes a sockaddr_storage with
// the IP and port to connect to one address for that DNS name.
bool LookupName(char *name,
                unsigned short port,
                struct sockaddr_storage *ret_addr,
                size_t *ret_addrlen);

// Connect() connects to the server specified, and returns a file descriptor connected to the
// specified server through the out-parameter ret_fd.
bool Connect(const struct sockaddr_storage &addr,
             const size_t &addrlen,
             int *ret_fd);

// A scratch buffer
char msgbuf[BUF_SIZE];

int main(int argc, char **argv) {
  char *hostname, *end_of_hostname, *path, *base, *mark;

  if (argc != 2) {
    Usage(argv[0]);
  }

  cout << "Fetching " << argv[1] << endl;

  /*
   * Now we need to parse the command line, to pull the server name and file path out of the full
   * url: http://www.cs.washington.edu/homes/gribble/index.html produces the following values:
   * hostname: www.cs.washington.edu
   * path: homes/gribble/index.html
   * base: index.html
   *
   * There are more elegant and error-proof ways to parse this than to simply locate the first three
   * /s in the URL, but this is the shortest without pulling in a 3rd-party library.
   *
   */
  // Find first /
  hostname = strchr(argv[1], '/');
  if (hostname == NULL) {
      cerr << "Malformed URL" << endl;
      return EXIT_FAILURE;
  }
  // Find second /
  hostname = strchr(hostname+1, '/');
  if (hostname == NULL) {
      cerr << "Malformed URL" << endl;
      return EXIT_FAILURE;
  }
  hostname++; // hostname starts just after the second slash
  // find file name, null-terminate the hostname
  end_of_hostname = strchr(hostname, '/');
  if (end_of_hostname == NULL) {
      cerr << "Malformed URL" << endl;
      return EXIT_FAILURE;
  }
  path = end_of_hostname+1;
  *end_of_hostname = '\0';

  base = basename(path);

  cout << "Connecting to host: " << hostname << endl;
  cout << "Retrieving file: " << path << endl;
  cout << "Storing results to local file: " << base << endl;

  // Now, we must connect to the server
  struct sockaddr_storage addr;
  size_t addrlen;
  if (!LookupName(hostname, HTTP_PORT, &addr, &addrlen)) {
    Usage(argv[0]);
  }

  int socket_fd;
  if (!Connect(addr, addrlen, &socket_fd)) {
    Usage(argv[0]);
  }

  // Format the request according to the (limited) HTTP protocol we're using
  int printed;
  printed = snprintf(msgbuf, BUF_SIZE, GET_FMT_STRING, path, hostname);
  if (printed < 1 || printed >= BUF_SIZE) {
      cerr << "error formatting request; url too long?" << endl;
      close(socket_fd);
      return EXIT_FAILURE;
  }

  cout << "Sending HTTP request: " << endl << msgbuf << endl;

  // Send the request to the server
  int written = 0;
  int res;
  while (written < printed) {
    res = write(socket_fd, msgbuf+written, printed-written);
    if (res == 0) {
        cerr << "socket close early in writing" << endl;
        close(socket_fd);
        return EXIT_FAILURE;
    }
    if (res == -1) {
        if (errno == EINTR)
            continue;
        cerr << "socket write failure: " << strerror(errno) << endl;
        close(socket_fd);
        return EXIT_FAILURE;
    }
    written += res;
  }
  // Flush the data from this socket out to the wire, so we're not waiting for a response to a
  // message that hasn't been sent yet
  fsync(socket_fd);

  // Read the response back from the server.  We will print the header to stdout, then open a file
  // and write the rest of the response there.  This is complicated by the fact that the
  // HEADER_FINISH string may be split across multiple reads.  To handle that, we read all header
  // bytes into a single buffer, contiguously, so eventually the header can be found through a
  // search through a consecutive byte string.  Whatever's left after that will be dumped to the
  // output file.
  cout << "Server response:" << endl;
  int bytes_read = 0;
  bool header_complete = false;
  int fd = -1;
  while (!header_complete) {
      res = read(socket_fd, msgbuf+bytes_read, BUF_SIZE-bytes_read-1);
      if (res == 0) {
          cerr << "socket closed unexpectedly" << endl;
	  close(socket_fd);
	  return EXIT_FAILURE;
      }
      if (res == -1) {
          if (errno == EINTR)
              continue;
          cerr << "socket read failure: " << strerror(errno) << endl;
          if (fd != -1)
              close(fd);
          close(socket_fd);
          return EXIT_FAILURE;
      }
      msgbuf[res]='\0';  // null-terminate
      // Look for the end-of-header tag in the (full) message buffer
      if ((mark = strstr(msgbuf,HEADER_FINISH))) {
          // We found the end of the header.  Finish displaying the header, open the file for
          // output, and write whatever this last read got that went past the header into the
          // file.  Then break out of the loop so a tighter, simpler loop can slurp up the rest
          // of the result.
          header_complete = true;
          *mark = '\0';
          StableWrite(1, msgbuf+bytes_read, strlen(msgbuf+bytes_read));
          cout << endl;
          fd = open(base, O_WRONLY | O_TRUNC | O_CREAT, S_IRUSR | S_IWUSR);
          if (fd < 0) {
              cerr << "file open failure: " << strerror(errno) << endl;
              close(socket_fd);
              return EXIT_FAILURE;
          }
          mark += strlen(HEADER_FINISH);
          StableWrite(fd, mark, res - (mark-(msgbuf+bytes_read)));
      } else {
          // We are still reading the header, print it
          StableWrite(1, msgbuf+bytes_read, strlen(msgbuf+bytes_read));
          bytes_read += res;
      }
  }
  // The file is open for writing, and this loop just needs to read from the server and write to
  // the output file.
  while(1) {
      res = read(socket_fd, msgbuf+bytes_read, BUF_SIZE-bytes_read-1);
      if (res == 0) {
          // Because we asked the server to close the connection when it was done, this is expected.
          // If we parsed the Content-length field of the header, we could do some error checking,
          // but we didn't.
          cerr << "socket closed." << endl;
          break;
      }
      if (res == -1) {
          if (errno == EINTR)
              continue;
          cerr << "socket read failure: " << strerror(errno) << endl;
          if (fd != -1)
              close(fd);
          close(socket_fd);
          return EXIT_FAILURE;
      }
      StableWrite(fd, msgbuf, res);
  }

  close(fd);
  close(socket_fd);
  return EXIT_SUCCESS;
}

void StableWrite(int fd, char *buf, size_t buflen) {
    size_t written = 0;
    int res;
    while (written < buflen) {
        res = write(fd, buf+written, buflen-written);
        assert(res > 0 || (res == -1 && errno == EINTR));
        if (res > 0)
            written += res;
    }
}

void Usage(char *progname) {
  cerr << "usage: " << progname << " url" << endl;
  exit(EXIT_FAILURE);
}

bool LookupName(char *name,
                unsigned short port,
                struct sockaddr_storage *ret_addr,
                size_t *ret_addrlen) {
  struct addrinfo hints, *results;
  int retval;

  memset(&hints, 0, sizeof(hints));
  hints.ai_family = AF_UNSPEC;
  hints.ai_socktype = SOCK_STREAM;

  // Do the lookup by invoking getaddrinfo().
  if ((retval = getaddrinfo(name, NULL, &hints, &results)) != 0) {
    cerr << "getaddrinfo failed: ";
    cerr << gai_strerror(retval) << endl;
    return false;
  }

  // Set the port in the first result.
  if (results->ai_family == AF_INET) {
    struct sockaddr_in *v4addr = (struct sockaddr_in *) results->ai_addr;
    v4addr->sin_port = htons(port);
  } else if (results->ai_family == AF_INET6) {
    struct sockaddr_in6 *v6addr = (struct sockaddr_in6 *) results->ai_addr;
    v6addr->sin6_port = htons(port);
  } else {
    cerr << "getaddrinfo failed to provide an IPv4 or IPv6 address";
    cerr << endl;
    return false;
  }

  // Return the first result.
  assert(results != NULL);
  memcpy(ret_addr, results->ai_addr, results->ai_addrlen);
  *ret_addrlen = results->ai_addrlen;

  // Clean up.
  freeaddrinfo(results);
  return true;
}

bool Connect(const struct sockaddr_storage &addr,
             const size_t &addrlen,
             int *ret_fd) {
  // Create the socket.
  int socket_fd = socket(addr.ss_family, SOCK_STREAM, 0);
  if (socket_fd == -1) {
    std::cerr << "socket() failed: " << strerror(errno) << std::endl;
    return false;
  }

  // Connect the socket to the remote host.
  int res = connect(socket_fd,
                    reinterpret_cast<const sockaddr *>(&addr),
                    addrlen);
  if (res == -1) {
    std::cerr << "connect() failed: " << strerror(errno) << std::endl;
    return false;
  }

  *ret_fd = socket_fd;
  return true;
}