#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include using boost::shared_ptr; using std::cerr; using std::cout; using std::endl; using std::list; using std::map; using std::string; // Prints the expected usage of the program, including command line // arguments, and then exits. void Usage(char *progname); // Prints out the process ID, file descriptor number, IP(v4/v6) address and // port number to stdout. void PrintOut(int fd, struct sockaddr *addr, size_t addrlen); // Does a reverse DNS lookup of the address in addr and prints it out // to stdout. void PrintReverseDNS(int fd, struct sockaddr *addr, size_t addrlen); // Creates a listening socket on port "portnum" and returns the file // descriptor for it. int Listen(char *portnum); // Attempts to accept a new connection on the listening socket // associated with listen_fd. int Accept(int listen_fd); // Uses fcntl() to set fd to O_NONBLOCK mode. void MakeNonBlocking(int fd); // Prepares the file descriptor sets; uses the FD_ZERO and FD_SET // macros to add the file descriptors of the various connections in // connection_map to the three fd_set sets. Also adds listen_fd to // the readfds fd set. Returns the maximum file descriptor of any // that we add to any of the sets. class Connection; int PrepareFDSets(int listen_fd, map > &connection_map, fd_set *readfds, fd_set *writefds, fd_set *except_fds); // A Connection object keeps track of the state of a client // connection. We'll create an instance of the class for each client // that connects to us, and use it to stash away data we've read from // the client, the socket file descriptor, and the state-machine state // of the client. class Connection { public: // This enum enumerates the different state-machine states the // client connection could be in. // // READING_REQUEST: we're waiting for data from the client // // WRITING_REPLY: we're waiting to be able to write our reply // to the client. // // CLOSED: the client connection has shut down. enum state { READING_REQUEST, WRITING_REPLY, CLOSED }; explicit Connection(int fd); virtual ~Connection(); // Attempt to read data from the client. If successful, the // side-effect will be to place the data in the "client_data_" // string and to transition our connection_state_ to // WRITING_REPLY. If unsuccessful, we'll stay in the // READING_REQUEST state. // // Returns "true" if the socket is still alive, "false" if the // client disconnected or we experienced some kind of socket error. bool DoRead(); // Attempt to write data from the client. Once we've written to the // client all of the data that we have, we'll transition back to the // READING_REQUEST state. // // Returns "true" if the socket is still alive, "false" if the // client disconnected or we experienced some kind of socket error. bool DoWrite(); // Accessor functions. state get_state() const { return connection_state_; } int get_fd() const { return client_fd_; } string get_data() const { return client_data_; } private: state connection_state_; int client_fd_; string client_data_; unsigned amount_written_so_far_; }; // A non-blocking, event-driven server. int main(int argc, char **argv) { // Expect the port number as a command line argument. if (argc != 2) { Usage(argv[0]); } // Parse the port number or fail with an error message. unsigned short port = 0; if (sscanf(argv[1], "%hu", &port) != 1) { Usage(argv[0]); } // Create a non-blocking listening socket on port argv[1]. int listen_fd = Listen(argv[1]); if (listen_fd <= 0) { // We failed to bind/listen to a socket. Quit after printing // an error message to the console. cerr << "Couldn't bind to any addresses." << endl; return EXIT_FAILURE; } // We'll use this map to keep track of connection states, associated // with each file descriptor. map > connection_map; // Loop forever, accepting a connection from a client, reading a line // of text from it, and echoing the text back. while (1) { fd_set readfds, writefds, exceptionfds; // Prepare these sets int num_fd = PrepareFDSets(listen_fd, connection_map, &readfds, &writefds, &exceptionfds); // figure out what the next "event" that arrived is, using select int res = select(num_fd+1, &readfds, &writefds, &exceptionfds, NULL); // Is the listening socket saying it has a new connection? if (FD_ISSET(listen_fd, &readfds)) { // Yes; a new connection has arrived..accept it, add it to the // connection map. int client_fd = Accept(listen_fd); shared_ptr newconn(new Connection(client_fd)); connection_map[client_fd] = newconn; } // This list will accumulate file descriptors that have closed; // we'll close the connection and delete the filedescriptor from // the connection_map after we're done with select event // processing. list garbage_list; // Iterate through all open connections. map >::iterator it; for (it = connection_map.begin(); it != connection_map.end(); it++) { int nextfd = (*it).first; shared_ptr nextconn = (*it).second; // Is one of the sockets saying that it is readable? if (FD_ISSET(nextfd, &readfds)) { if (!(nextconn->DoRead())) { garbage_list.push_back(nextfd); } } // Is one of the sockets saying that it is writeable? if (FD_ISSET(nextfd, &writefds)) { if (!(nextconn->DoWrite())) { garbage_list.push_back(nextfd); } } // Is one of the sockets indicating that an error has happened? if (FD_ISSET(nextfd, &exceptionfds)) { garbage_list.push_back(nextfd); } } // Take out the trash... for (list::iterator it = garbage_list.begin(); it != garbage_list.end(); it++) { connection_map.erase(*it); close(*it); } } // Close up shop. close(listen_fd); return EXIT_SUCCESS; } void Usage(char *progname) { cerr << "usage: " << progname << " port" << endl; exit(EXIT_FAILURE); } void PrintOut(int fd, struct sockaddr *addr, size_t addrlen) { if (addr->sa_family == AF_INET) { // Print out the IPV4 address and port char astring[INET_ADDRSTRLEN]; struct sockaddr_in *in4 = reinterpret_cast(addr); inet_ntop(AF_INET, &(in4->sin_addr), astring, INET_ADDRSTRLEN); cout << " [" << getpid() << ":" << fd; cout << "] IPv4 address " << astring; cout << " and port " << htons(in4->sin_port) << endl; } else if (addr->sa_family == AF_INET6) { // Print out the IPV4 address and port char astring[INET6_ADDRSTRLEN]; struct sockaddr_in6 *in6 = reinterpret_cast(addr); inet_ntop(AF_INET6, &(in6->sin6_addr), astring, INET6_ADDRSTRLEN); cout << " [" << getpid() << ":" << fd; cout << "] IPv6 address " << astring; cout << " and port " << htons(in6->sin6_port) << endl; } else { cout << " [" << getpid() << ":" << fd; cout << "] ???? address and port ????" << endl; } } void PrintReverseDNS(int fd, struct sockaddr *addr, size_t addrlen) { char hostname[1024]; // ought to be big enough. if (getnameinfo(addr, addrlen, hostname, 1024, NULL, 0, 0) != 0) { sprintf(hostname, "[reverse DNS failed]"); } cout << " [" << getpid() << ":" << fd; cout << "] DNS name: " << hostname << endl; } int Listen(char *portnum) { // Populate the "hints" addrinfo structure for getaddrinfo(). // ("man addrinfo") struct addrinfo hints; memset(&hints, 0, sizeof(struct addrinfo)); hints.ai_family = AF_UNSPEC; // allow IPv4 or IPv6 hints.ai_socktype = SOCK_STREAM; // stream hints.ai_flags = AI_PASSIVE; // use wildcard "INADDR_ANY" hints.ai_protocol = IPPROTO_TCP; // tcp protocol hints.ai_canonname = NULL; hints.ai_addr = NULL; hints.ai_next = NULL; // Use argv[1] as the string representation of our portnumber to // pass in to getaddrinfo(). getaddrinfo() returns a list of // address structures via the output parameter "result". struct addrinfo *result; int res = getaddrinfo(NULL, portnum, &hints, &result); // Did addrinfo() fail? if (res != 0) { cerr << "getaddrinfo() failed: "; cerr << gai_strerror(res) << endl; return -1; } // Loop through the returned address structures until we are able // to create a socket and bind to one. The address structures are // linked in a list through the "ai_next" field of result. int listen_fd = -1; for (struct addrinfo *rp = result; rp != NULL; rp = rp->ai_next) { listen_fd = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); if (listen_fd == -1) { // Creating this socket failed. So, loop to the next returned // result and try again. cerr << "socket() failed: " << strerror(errno) << endl; listen_fd = -1; continue; } // Configure the socket; we're setting a socket "option." In // particular, we set "SO_REUSEADDR", which tells the TCP stack // so make the port we bind to available again as soon as we // exit, rather than waiting for a few tens of seconds to recycle it. int optval = 1; assert(setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(optval)) == 0); // Try binding the socket to the address and port number returned // by getaddrinfo(). if (bind(listen_fd, rp->ai_addr, rp->ai_addrlen) == 0) { // Bind worked! Print out the information about what // we bound to. PrintOut(listen_fd, rp->ai_addr, rp->ai_addrlen); break; } // The bind failed. Close the socket, then loop back around and // try the next address/port returned by getaddrinfo(). close(listen_fd); listen_fd = -1; } // Free the structure returned by getaddrinfo(). freeaddrinfo(result); // If we failed to bind, return failure. if (listen_fd <= 0) return listen_fd; // Success. Tell the OS that we want this to be a listening socket. if (listen(listen_fd, SOMAXCONN) != 0) { cerr << "Failed to mark socket as listening: "; cerr << strerror(errno) << endl; close(listen_fd); return -1; } // Set the socket to non-blocking mode. MakeNonBlocking(listen_fd); return listen_fd; } int Accept(int listen_fd) { // Loop forever, attempting to accept a connection from a client. while (1) { struct sockaddr_storage caddr; socklen_t caddr_len = sizeof(caddr); int client_fd = accept(listen_fd, reinterpret_cast(&caddr), &caddr_len); if (client_fd < 0) { if (errno == EINTR) continue; if (errno == EAGAIN) { cerr << "select didn't do what we expect!?!?!?" << endl; assert(0); } cerr << "Failure on accept: " << strerror(errno) << endl; return -1; } // We got a new client! Print out information about it. cout << endl; cout << " [" << getpid() << ":" << client_fd; cout << "] new client:" << endl; PrintOut(client_fd, (struct sockaddr *) &caddr, caddr_len); PrintReverseDNS(client_fd, (struct sockaddr *) &caddr, caddr_len); MakeNonBlocking(client_fd); return client_fd; } } Connection::Connection(int fd) { client_fd_ = fd; connection_state_ = READING_REQUEST; } Connection::~Connection() { if (client_fd_ != -1) { close(client_fd_); client_fd_ = -1; } connection_state_ = CLOSED; } bool Connection::DoRead() { assert(get_state() == READING_REQUEST); char clientbuf[1024]; ssize_t res = read(get_fd(), clientbuf, 1023); if (res == 0) { cout << " [" << getpid() << ":" << get_fd(); cout << "] The client disconnected." << endl; connection_state_ = CLOSED; return false; } if (res == -1) { if ((errno == EAGAIN) || (errno == EINTR)) { // Rely on select() to dispatch us a "is readable" event again // the next time around. return true; } // we got some sort of unhandlable error, close the connection. cout << " [" << getpid() << ":" << get_fd(); cout << "] The client experienced an error." << endl; connection_state_ = CLOSED; return false; } // We read some data!! clientbuf[res] = '\0'; client_data_ = string("You typed: ") + clientbuf; amount_written_so_far_ = 0; connection_state_ = WRITING_REPLY; return true; } bool Connection::DoWrite() { assert(get_state() == WRITING_REPLY); ssize_t res = write(get_fd(), client_data_.data() + amount_written_so_far_, client_data_.size() - amount_written_so_far_); if (res == 0) { cout << " [" << getpid() << ":" << get_fd(); cout << "] The client disconnected." << endl; connection_state_ = CLOSED; return false; } if (res == -1) { if ((errno == EAGAIN) || (errno == EINTR)) { // Rely on select() to dispatch us a "is readable" event again // the next time around. return true; } // We got some sort of unhandlable error, close the connection. cout << " [" << getpid() << ":" << get_fd(); cout << "] The client experienced an error." << endl; connection_state_ = CLOSED; return false; } // We wrote some data! amount_written_so_far_ += res; if (amount_written_so_far_ == client_data_.size()) { client_data_ = string(""); connection_state_ = READING_REQUEST; } return true; } // A utility routine to set the file descriptor "fd" to // non-blocking mode. void MakeNonBlocking(int fd) { // Get the current "flags" associated with the file descriptor. int flags = fcntl(fd, F_GETFL); // Add the O_NONBLOCK flag, and inform the OS of the new // flag settings. flags |= O_NONBLOCK; assert(fcntl(fd, F_SETFL, flags) == 0); } int PrepareFDSets(int listen_fd, map > &connection_map, fd_set *readfds, fd_set *writefds, fd_set *except_fds) { int max_fds = 0; // Zero out the sets FD_ZERO(readfds); FD_ZERO(writefds); FD_ZERO(except_fds); // Add the listening socket to the correct sets. FD_SET(listen_fd, readfds); FD_SET(listen_fd, except_fds); if (listen_fd > max_fds) max_fds = listen_fd; // Add the other connections into the appropriate sets. map >::iterator it; for (it = connection_map.begin(); it != connection_map.end(); it++) { int nextfd = (*it).first; shared_ptr conn = (*it).second; if (conn->get_state() == Connection::CLOSED) { cerr << "YOU HAVE A BUG. FIX IT." << endl; exit(EXIT_FAILURE); } if (conn->get_state() == Connection::READING_REQUEST) { // add it to the read set FD_SET(nextfd, readfds); } if (conn->get_state() == Connection::WRITING_REPLY) { // add it to the read set FD_SET(nextfd, writefds); } // Always add it to the exception set. FD_SET(nextfd, except_fds); // Make sure we return the correct max file descriptor. if (nextfd > max_fds) max_fds = nextfd; } return max_fds; }