#include <arpa/inet.h>
#include <assert.h>
#include <boost/shared_ptr.hpp>
#include <errno.h>
#include <fcntl.h>
#include <netdb.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>

#include <iostream>
#include <list>
#include <map>
#include <string>

using boost::shared_ptr;
using std::cerr;
using std::cout;
using std::endl;
using std::list;
using std::map;
using std::string;

// Prints the expected usage of the program, including command line
// arguments, and then exits.
void Usage(char *progname);

// Prints out the process ID, file descriptor number, IP(v4/v6) address and
// port number to stdout.
void PrintOut(int fd, struct sockaddr *addr, size_t addrlen);

// Does a reverse DNS lookup of the address in addr and prints it out
// to stdout.
void PrintReverseDNS(int fd, struct sockaddr *addr, size_t addrlen);

// Creates a listening socket on port "portnum" and returns the file
// descriptor for it.
int  Listen(char *portnum);

// Attempts to accept a new connection on the listening socket
// associated with listen_fd.
int  Accept(int listen_fd);

// Uses fcntl() to set fd to O_NONBLOCK mode.
void MakeNonBlocking(int fd);

// Prepares the file descriptor sets; uses the FD_ZERO and FD_SET
// macros to add the file descriptors of the various connections in
// connection_map to the three fd_set sets.  Also adds listen_fd to
// the readfds fd set.  Returns the maximum file descriptor of any
// that we add to any of the sets.
class Connection;
int PrepareFDSets(int listen_fd,
                  map<int, shared_ptr<Connection> > &connection_map,
                  fd_set *readfds, fd_set *writefds, fd_set *except_fds);

// A Connection object keeps track of the state of a client
// connection.  We'll create an instance of the class for each client
// that connects to us, and use it to stash away data we've read from
// the client, the socket file descriptor, and the state-machine state
// of the client.
class Connection {
 public:
  // This enum enumerates the different state-machine states the
  // client connection could be in.
  //
  //   READING_REQUEST: we're waiting for data from the client
  //
  //   WRITING_REPLY: we're waiting to be able to write our reply
  //   to the client.
  //
  //   CLOSED: the client connection has shut down.
  enum state {
    READING_REQUEST, WRITING_REPLY, CLOSED
  };

  explicit Connection(int fd);
  virtual ~Connection();

  // Attempt to read data from the client.  If successful, the
  // side-effect will be to place the data in the "client_data_"
  // string and to transition our connection_state_ to
  // WRITING_REPLY.  If unsuccessful, we'll stay in the
  // READING_REQUEST state.
  //
  // Returns "true" if the socket is still alive, "false" if the
  // client disconnected or we experienced some kind of socket error.
  bool DoRead();

  // Attempt to write data from the client.  Once we've written to the
  // client all of the data that we have, we'll transition back to the
  // READING_REQUEST state.
  //
  // Returns "true" if the socket is still alive, "false" if the
  // client disconnected or we experienced some kind of socket error.
  bool DoWrite();

  // Accessor functions.
  state get_state() const { return connection_state_; }
  int get_fd() const { return client_fd_; }
  string get_data() const { return client_data_; }

 private:
  state connection_state_;
  int client_fd_;
  string client_data_;
  unsigned amount_written_so_far_;
};

// A non-blocking, event-driven server.
int main(int argc, char **argv) {
  // Expect the port number as a command line argument.
  if (argc != 2) {
    Usage(argv[0]);
  }

  // Parse the port number or fail with an error message.
  unsigned short port = 0;
  if (sscanf(argv[1], "%hu", &port) != 1) {
    Usage(argv[0]);
  }

  // Create a non-blocking listening socket on port argv[1].
  int listen_fd = Listen(argv[1]);
  if (listen_fd <= 0) {
    // We failed to bind/listen to a socket.  Quit after printing
    // an error message to the console.
    cerr << "Couldn't bind to any addresses." << endl;
    return EXIT_FAILURE;
  }

  // We'll use this map to keep track of connection states, associated
  // with each file descriptor.
  map<int, shared_ptr<Connection> > connection_map;

  // Loop forever, accepting a connection from a client, reading a line
  // of text from it, and echoing the text back.
  while (1) {
    fd_set readfds, writefds, exceptionfds;

    // Prepare these sets
    int num_fd = PrepareFDSets(listen_fd, connection_map,
                               &readfds, &writefds, &exceptionfds);

    // figure out what the next "event" that arrived is, using select
    int res = select(num_fd+1, &readfds, &writefds, &exceptionfds, NULL);

    // Is the listening socket saying it has a new connection?
    if (FD_ISSET(listen_fd, &readfds)) {
      // Yes; a new connection has arrived..accept it, add it to the
      // connection map.
      int client_fd = Accept(listen_fd);
      shared_ptr<Connection> newconn(new Connection(client_fd));
      connection_map[client_fd] = newconn;
    }

    // This list will accumulate file descriptors that have closed;
    // we'll close the connection and delete the filedescriptor from
    // the connection_map after we're done with select event
    // processing.
    list<int> garbage_list;

    // Iterate through all open connections.
    map<int, shared_ptr<Connection> >::iterator it;
    for (it = connection_map.begin(); it != connection_map.end(); it++) {
      int nextfd = (*it).first;
      shared_ptr<Connection> nextconn = (*it).second;

      // Is one of the sockets saying that it is readable?
      if (FD_ISSET(nextfd, &readfds)) {
        if (!(nextconn->DoRead())) {
          garbage_list.push_back(nextfd);
        }
      }

      // Is one of the sockets saying that it is writeable?
      if (FD_ISSET(nextfd, &writefds)) {
        if (!(nextconn->DoWrite())) {
          garbage_list.push_back(nextfd);
        }
      }

      // Is one of the sockets indicating that an error has happened?
      if (FD_ISSET(nextfd, &exceptionfds)) {
        garbage_list.push_back(nextfd);
      }
    }

    // Take out the trash...
    for (list<int>::iterator it = garbage_list.begin();
         it != garbage_list.end(); it++) {
      connection_map.erase(*it);
      close(*it);
    }
  }

  // Close up shop.
  close(listen_fd);
  return EXIT_SUCCESS;
}

void Usage(char *progname) {
  cerr << "usage: " << progname << " port" << endl;
  exit(EXIT_FAILURE);
}

void PrintOut(int fd, struct sockaddr *addr, size_t addrlen) {
  if (addr->sa_family == AF_INET) {
    // Print out the IPV4 address and port

    char astring[INET_ADDRSTRLEN];
    struct sockaddr_in *in4 = reinterpret_cast<struct sockaddr_in *>(addr);
    inet_ntop(AF_INET, &(in4->sin_addr), astring, INET_ADDRSTRLEN);
    cout << "  [" << getpid() << ":" << fd;
    cout << "] IPv4 address " << astring;
    cout << "  and port " << htons(in4->sin_port) << endl;

  } else if (addr->sa_family == AF_INET6) {
    // Print out the IPV4 address and port

    char astring[INET6_ADDRSTRLEN];
    struct sockaddr_in6 *in6 = reinterpret_cast<struct sockaddr_in6 *>(addr);
    inet_ntop(AF_INET6, &(in6->sin6_addr), astring, INET6_ADDRSTRLEN);
    cout << "  [" << getpid() << ":" << fd;
    cout << "] IPv6 address " << astring;
    cout << "  and port " << htons(in6->sin6_port) << endl;

  } else {
    cout << "  [" << getpid() << ":" << fd;
    cout << "] ???? address and port ????" << endl;
  }
}

void PrintReverseDNS(int fd, struct sockaddr *addr, size_t addrlen) {
  char hostname[1024];  // ought to be big enough.
  if (getnameinfo(addr, addrlen, hostname, 1024, NULL, 0, 0) != 0) {
    sprintf(hostname, "[reverse DNS failed]");
  }
  cout << "  [" << getpid() << ":" << fd;
  cout << "] DNS name: " << hostname << endl;
}

int Listen(char *portnum) {
  // Populate the "hints" addrinfo structure for getaddrinfo().
  // ("man addrinfo")
  struct addrinfo hints;
  memset(&hints, 0, sizeof(struct addrinfo));
  hints.ai_family = AF_UNSPEC;      // allow IPv4 or IPv6
  hints.ai_socktype = SOCK_STREAM;  // stream
  hints.ai_flags = AI_PASSIVE;      // use wildcard "INADDR_ANY"
  hints.ai_protocol = IPPROTO_TCP;  // tcp protocol
  hints.ai_canonname = NULL;
  hints.ai_addr = NULL;
  hints.ai_next = NULL;

  // Use argv[1] as the string representation of our portnumber to
  // pass in to getaddrinfo().  getaddrinfo() returns a list of
  // address structures via the output parameter "result".
  struct addrinfo *result;
  int res = getaddrinfo(NULL, portnum, &hints, &result);

  // Did addrinfo() fail?
  if (res != 0) {
    cerr << "getaddrinfo() failed: ";
    cerr << gai_strerror(res) << endl;
    return -1;
  }

  // Loop through the returned address structures until we are able
  // to create a socket and bind to one.  The address structures are
  // linked in a list through the "ai_next" field of result.
  int listen_fd = -1;
  for (struct addrinfo *rp = result; rp != NULL; rp = rp->ai_next) {
    listen_fd = socket(rp->ai_family,
                       rp->ai_socktype,
                       rp->ai_protocol);
    if (listen_fd == -1) {
      // Creating this socket failed.  So, loop to the next returned
      // result and try again.
      cerr << "socket() failed: " << strerror(errno) << endl;
      listen_fd = -1;
      continue;
    }

    // Configure the socket; we're setting a socket "option."  In
    // particular, we set "SO_REUSEADDR", which tells the TCP stack
    // so make the port we bind to available again as soon as we
    // exit, rather than waiting for a few tens of seconds to recycle it.
    int optval = 1;
    assert(setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR,
                      &optval, sizeof(optval)) == 0);

    // Try binding the socket to the address and port number returned
    // by getaddrinfo().
    if (bind(listen_fd, rp->ai_addr, rp->ai_addrlen) == 0) {
      // Bind worked!  Print out the information about what
      // we bound to.
      PrintOut(listen_fd, rp->ai_addr, rp->ai_addrlen);
      break;
    }

    // The bind failed.  Close the socket, then loop back around and
    // try the next address/port returned by getaddrinfo().
    close(listen_fd);
    listen_fd = -1;
  }

  // Free the structure returned by getaddrinfo().
  freeaddrinfo(result);

  // If we failed to bind, return failure.
  if (listen_fd <= 0)
    return listen_fd;

  // Success. Tell the OS that we want this to be a listening socket.
  if (listen(listen_fd, SOMAXCONN) != 0) {
    cerr << "Failed to mark socket as listening: ";
    cerr << strerror(errno) << endl;
    close(listen_fd);
    return -1;
  }

  // Set the socket to non-blocking mode.
  MakeNonBlocking(listen_fd);
  return listen_fd;
}

int Accept(int listen_fd) {
  // Loop forever, attempting to accept a connection from a client.
  while (1) {
    struct sockaddr_storage caddr;
    socklen_t caddr_len = sizeof(caddr);
    int client_fd = accept(listen_fd,
                           reinterpret_cast<struct sockaddr *>(&caddr),
                           &caddr_len);
    if (client_fd < 0) {
      if (errno == EINTR)
        continue;
      if (errno == EAGAIN) {
        cerr << "select didn't do what we expect!?!?!?" << endl;
        assert(0);
      }
      cerr << "Failure on accept: " << strerror(errno) << endl;
      return -1;
    }

    // We got a new client! Print out information about it.
    cout << endl;
    cout << "  [" << getpid() << ":" << client_fd;
    cout << "] new client:" << endl;
    PrintOut(client_fd, (struct sockaddr *) &caddr, caddr_len);
    PrintReverseDNS(client_fd, (struct sockaddr *) &caddr, caddr_len);

    MakeNonBlocking(client_fd);
    return client_fd;
  }
}

Connection::Connection(int fd) {
  client_fd_ = fd;
  connection_state_ = READING_REQUEST;
}

Connection::~Connection() {
  if (client_fd_ != -1) {
    close(client_fd_);
    client_fd_ = -1;
  }
  connection_state_ = CLOSED;
}

bool Connection::DoRead() {
  assert(get_state() == READING_REQUEST);
  char clientbuf[1024];
  ssize_t res = read(get_fd(), clientbuf, 1023);

  if (res == 0) {
    cout << "  [" << getpid() << ":" << get_fd();
    cout << "] The client disconnected." << endl;
    connection_state_ = CLOSED;
    return false;
  }
  if (res == -1) {
    if ((errno == EAGAIN) || (errno == EINTR)) {
      // Rely on select() to dispatch us a "is readable" event again
      // the next time around.
      return true;
    }
    // we got some sort of unhandlable error, close the connection.
    cout << "  [" << getpid() << ":" << get_fd();
    cout << "] The client experienced an error." << endl;
    connection_state_ = CLOSED;
    return false;
  }

  // We read some data!!
  clientbuf[res] = '\0';
  client_data_ = string("You typed: ") + clientbuf;
  amount_written_so_far_ = 0;
  connection_state_ = WRITING_REPLY;
  return true;
}

bool Connection::DoWrite() {
  assert(get_state() == WRITING_REPLY);
  ssize_t res = write(get_fd(),
                      client_data_.data() + amount_written_so_far_,
                      client_data_.size() - amount_written_so_far_);

  if (res == 0) {
    cout << "  [" << getpid() << ":" << get_fd();
    cout << "] The client disconnected." << endl;
    connection_state_ = CLOSED;
    return false;
  }
  if (res == -1) {
    if ((errno == EAGAIN) || (errno == EINTR)) {
      // Rely on select() to dispatch us a "is readable" event again
      // the next time around.
      return true;
    }
    // We got some sort of unhandlable error, close the connection.
    cout << "  [" << getpid() << ":" << get_fd();
    cout << "] The client experienced an error." << endl;
    connection_state_ = CLOSED;
    return false;
  }

  // We wrote some data!
  amount_written_so_far_ += res;
  if (amount_written_so_far_ == client_data_.size()) {
    client_data_ = string("");
    connection_state_ = READING_REQUEST;
  }
  return true;
}

// A utility routine to set the file descriptor "fd" to
// non-blocking mode.
void MakeNonBlocking(int fd) {
  // Get the current "flags" associated with the file descriptor.
  int flags = fcntl(fd, F_GETFL);

  // Add the O_NONBLOCK flag, and inform the OS of the new
  // flag settings.
  flags |= O_NONBLOCK;
  assert(fcntl(fd, F_SETFL, flags) == 0);
}


int PrepareFDSets(int listen_fd,
                  map<int, shared_ptr<Connection> > &connection_map,
                  fd_set *readfds, fd_set *writefds, fd_set *except_fds) {
  int max_fds = 0;
  // Zero out the sets
  FD_ZERO(readfds);
  FD_ZERO(writefds);
  FD_ZERO(except_fds);

  // Add the listening socket to the correct sets.
  FD_SET(listen_fd, readfds);
  FD_SET(listen_fd, except_fds);

  if (listen_fd > max_fds)
    max_fds = listen_fd;

  // Add the other connections into the appropriate sets.
  map<int, shared_ptr<Connection> >::iterator it;
  for (it = connection_map.begin(); it != connection_map.end(); it++) {
    int nextfd = (*it).first;
    shared_ptr<Connection> conn = (*it).second;

    if (conn->get_state() == Connection::CLOSED) {
      cerr << "YOU HAVE A BUG.  FIX IT." << endl;
      assert(0);
    }

    if (conn->get_state() == Connection::READING_REQUEST) {
      // add it to the read set
      FD_SET(nextfd, readfds);
    }

    if (conn->get_state() == Connection::WRITING_REPLY) {
      // add it to the read set
      FD_SET(nextfd, writefds);
    }

    // Always add it to the exception set.
    FD_SET(nextfd, except_fds);
      if (nextfd > max_fds)
        max_fds = nextfd;

  }

  return max_fds;
}