#include <arpa/inet.h>
#include <assert.h>
#include <errno.h>
#include <netdb.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>
#include <wait.h>
#include <iostream>

void Usage(char *progname);
void PrintOut(int fd, struct sockaddr *addr, size_t addrlen);
void PrintReverseDNS(int fd, struct sockaddr *addr, size_t addrlen);
void HandleClient(int c_fd);
int  Listen(char *portnum);
int  Accept(int listen_fd);

// A multiprocessed server.
int main(int argc, char **argv) {
  // Expect the port number as a command line argument.
  if (argc != 2) {
    Usage(argv[0]);
  }

  // Parse the port number or fail.
  unsigned short port = 0;
  if (sscanf(argv[1], "%hu", &port) != 1) {
    Usage(argv[0]);
  }

  // Create a listening socket on port argv[1].
  int listen_fd = Listen(argv[1]);
  if (listen_fd <= 0) {
    // We failed to bind/listen to a socket.  Quit with failure.
    std::cerr << "Couldn't bind to any addresses." << std::endl;
    return EXIT_FAILURE;
  }

  // Loop forever, accepting a connection from a client and forking
  // a grandchild to handle the client.
  while (1) {
    int client_fd = Accept(listen_fd);

    pid_t pid = fork();
    if (pid > 0) {
      // I'm the parent, pid is the child.  Wait for the child to exit.
      while (1) {
        int stat_loc;
        pid_t res = wait(&stat_loc);
        if ((res == -1) && (errno == EINTR))
          continue;
        if (res == -1) {
          std::cerr << "main processs failed on wait(): ";
          std::cerr << strerror(errno) << std::endl;
          exit(EXIT_FAILURE);
        }
        // The child exited and our wait succeeded.  Break back
        // out to the main accept loop after closing the client fd.
        assert(res == pid);
        close(client_fd);
        break;
      }
    } else if (pid == 0) {
      // I'm the child.  Fork a grandchild and exit.
      pid = fork();
      if (pid > 0) {
        // I'm the child.  Exit!
        exit(EXIT_SUCCESS);
      } else if (pid == 0) {
        // I'm the grandchild.  Handle the client connection.
        HandleClient(client_fd);
      } else {
        // Error in child's fork.
        std::cerr << "child process coudln't fork(): ";
        std::cerr << strerror(errno) << std::endl;
        exit(EXIT_FAILURE);
      }
    } else {
      // Error in parent's fork.
      std::cerr << "main process couldn't fork(): ";
      std::cerr << strerror(errno) << std::endl;
      exit(EXIT_FAILURE);
    }
  }

  // Close up shop.
  close(listen_fd);
  return EXIT_SUCCESS;
}

void Usage(char *progname) {
  std::cerr << "usage: " << progname << " port" << std::endl;
  exit(EXIT_FAILURE);
}

void PrintOut(int fd, struct sockaddr *addr, size_t addrlen) {
  if (addr->sa_family == AF_INET) {
    // Print out the IPV4 address and port

    char astring[INET_ADDRSTRLEN];
    struct sockaddr_in *in4 = reinterpret_cast<struct sockaddr_in *>(addr);
    inet_ntop(AF_INET, &(in4->sin_addr), astring, INET_ADDRSTRLEN);
    std::cout << "  [" << getpid() << ":" << fd;
    std::cout << "] IPv4 address " << astring;
    std::cout << " and port " << htons(in4->sin_port) << std::endl;

  } else if (addr->sa_family == AF_INET6) {
    // Print out the IPV4 address and port

    char astring[INET6_ADDRSTRLEN];
    struct sockaddr_in6 *in6 = reinterpret_cast<struct sockaddr_in6 *>(addr);
    inet_ntop(AF_INET6, &(in6->sin6_addr), astring, INET6_ADDRSTRLEN);
    std::cout << "  [" << getpid() << ":" << fd;
    std::cout << "] IPv6 address " << astring;
    std::cout << " and port " << htons(in6->sin6_port) << std::endl;

  } else {
    std::cout << "  [" << getpid() << ":" << fd;
    std::cout << "] ???? address and port ????" << std::endl;
  }
}

void PrintReverseDNS(int fd, struct sockaddr *addr, size_t addrlen) {
  char hostname[1024];  // ought to be big enough.
  if (getnameinfo(addr, addrlen, hostname, 1024, NULL, 0, 0) != 0) {
    sprintf(hostname, "[reverse DNS failed]");
  }
  std::cout << "  [" << getpid() << ":" << fd;
  std::cout << "] DNS name: " << hostname << std::endl;
}

int Listen(char *portnum) {
  // Populate the "hints" addrinfo structure for getaddrinfo().
  // ("man addrinfo")
  struct addrinfo hints;
  memset(&hints, 0, sizeof(struct addrinfo));
  hints.ai_family = AF_UNSPEC;      // allow IPv4 or IPv6
  hints.ai_socktype = SOCK_STREAM;  // stream
  hints.ai_flags = AI_PASSIVE;      // use wildcard "INADDR_ANY"
  hints.ai_protocol = IPPROTO_TCP;  // tcp protocol
  hints.ai_canonname = NULL;
  hints.ai_addr = NULL;
  hints.ai_next = NULL;

  // Use argv[1] as the string representation of our portnumber to
  // pass in to getaddrinfo().  getaddrinfo() returns a list of
  // address structures via the output parameter "result".
  struct addrinfo *result;
  int res = getaddrinfo(NULL, portnum, &hints, &result);

  // Did addrinfo() fail?
  if (res != 0) {
    std::cerr << "getaddrinfo() failed: ";
    std::cerr << gai_strerror(res) << std::endl;
    return -1;
  }

  // Loop through the returned address structures until we are able
  // to create a socket and bind to one.  The address structures are
  // linked in a list through the "ai_next" field of result.
  int listen_fd = -1;
  for (struct addrinfo *rp = result; rp != NULL; rp = rp->ai_next) {
    listen_fd = socket(rp->ai_family,
                       rp->ai_socktype,
                       rp->ai_protocol);
    if (listen_fd == -1) {
      // Creating this socket failed.  So, loop to the next returned
      // result and try again.
      std::cerr << "socket() failed: " << strerror(errno) << std::endl;
      listen_fd = -1;
      continue;
    }

    // Configure the socket; we're setting a socket "option."  In
    // particular, we set "SO_REUSEADDR", which tells the TCP stack
    // so make the port we bind to available again as soon as we
    // exit, rather than waiting for a few tens of seconds to recycle it.
    int optval = 1;
    assert(setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR,
                      &optval, sizeof(optval)) == 0);

    // Try binding the socket to the address and port number returned
    // by getaddrinfo().
    if (bind(listen_fd, rp->ai_addr, rp->ai_addrlen) == 0) {
      // Bind worked!  Print out the information about what
      // we bound to.
      PrintOut(listen_fd, rp->ai_addr, rp->ai_addrlen);
      break;
    }

    // The bind failed.  Close the socket, then loop back around and
    // try the next address/port returned by getaddrinfo().
    close(listen_fd);
    listen_fd = -1;
  }

  // Free the structure returned by getaddrinfo().
  freeaddrinfo(result);

  // If we failed to bind, return failure.
  if (listen_fd <= 0)
    return listen_fd;

  // Success. Tell the OS that we want this to be a listening socket.
  if (listen(listen_fd, SOMAXCONN) != 0) {
    std::cerr << "Failed to mark socket as listening: ";
    std::cerr << strerror(errno) << std::endl;
    close(listen_fd);
    return -1;
  }

  return listen_fd;
}

int Accept(int listen_fd) {
  // Loop forever, attempting to accept a connection from a client.
  while (1) {
    struct sockaddr_storage caddr;
    socklen_t caddr_len = sizeof(caddr);
    int client_fd = accept(listen_fd,
                           reinterpret_cast<struct sockaddr *>(&caddr),
                           &caddr_len);
    if (client_fd < 0) {
      if ((errno == EAGAIN) || (errno == EINTR))
        continue;
      std::cerr << "Failure on accept: " << strerror(errno) << std::endl;
      return -1;
    }

    // We got a new client! Print out information about it.
    std::cout << std::endl;
    std::cout << "  [" << getpid() << ":" << client_fd;
    std::cout << "] new client" << std::endl;
    PrintOut(client_fd, (struct sockaddr *) &caddr, caddr_len);
    PrintReverseDNS(client_fd, (struct sockaddr *) &caddr, caddr_len);
    return client_fd;
  }
}

void HandleClient(int c_fd) {
  // Loop, reading data and echo'ing it back, until the client
  // closes the connection.
  std::cout << "  [" << getpid() << ":" << c_fd;
  std::cout << "] Grandchild entering client read/write loop";
  std::cout << std::endl;
  while (1) {
    char clientbuf[1024];
    ssize_t res = read(c_fd, clientbuf, 1023);
    if (res == 0) {
      std::cout << "  [" << getpid() << ":" << c_fd;
      std::cout << "] The client disconnected." << std::endl;
      break;
    }

    if (res == -1) {
      if ((errno == EAGAIN) || (errno == EINTR))
        continue;

      std::cout << "  [" << getpid() << ":" << c_fd;
      std::cout << "] Error on client socket: ";
      std::cout << strerror(errno) << "." << std::endl;
      break;
    }
    clientbuf[res] = '\0';
    std::cout << "  [" << getpid() << ":" << c_fd;
    std::cout << "] : " << clientbuf;

    // Generate and write the reply.
    std::string req(clientbuf, strlen(clientbuf));
    std::string reply = "You typed: ";
    reply += req;

    int written_so_far = 0, writelen = reply.size();
    while (written_so_far < writelen) {
      res = write(c_fd,
                  reply.c_str() + written_so_far,
                  writelen - written_so_far);
      if (res == -1) {
        if ((errno == EAGAIN) || (errno == EINTR))
          continue;
        break;
      }
      if (res == 0)
        break;
      written_so_far += res;
    }
  }

  close(c_fd);
}