diff --git a/examples/chat_server/chat.cpp b/examples/chat_server/chat.cpp index cb6e6ffe69..85b3b35ba7 100644 --- a/examples/chat_server/chat.cpp +++ b/examples/chat_server/chat.cpp @@ -30,20 +30,21 @@ using websocketchat::chat_handler; -bool chat_handler::validate(websocketpp::session_ptr client) { +void chat_handler::validate(websocketpp::session_ptr client) { + std::stringstream err; + // We only know about the chat resource if (client->get_request() != "/chat") { - client->set_http_error(404); - return false; + err << "Request for unknown resource " << client->get_request(); + throw(handshake_error(err.str(),404)); } // Require specific origin example if (client->get_header("Sec-WebSocket-Origin") != "http://zaphoyd.com") { - client->set_http_error(403); - return false; + err << "Request from unrecognized origin: " + << client->get_header("Sec-WebSocket-Origin"); + throw(handshake_error(err.str(),403)); } - - return true; } diff --git a/examples/chat_server/chat.hpp b/examples/chat_server/chat.hpp index 74e0f4b66b..7dd97f9a29 100644 --- a/examples/chat_server/chat.hpp +++ b/examples/chat_server/chat.hpp @@ -53,7 +53,7 @@ public: chat_handler() {} virtual ~chat_handler() {} - bool validate(websocketpp::session_ptr client); + void validate(websocketpp::session_ptr client); // add new connection to the lobby void connect(session_ptr client) { diff --git a/examples/echo_server/echo.cpp b/examples/echo_server/echo.cpp index 7ea9faaadf..313b5d100f 100644 --- a/examples/echo_server/echo.cpp +++ b/examples/echo_server/echo.cpp @@ -29,9 +29,7 @@ using websocketecho::echo_handler; -bool echo_handler::validate(websocketpp::session_ptr client) { - return true; -} +void echo_handler::validate(websocketpp::session_ptr client) {} void echo_handler::message(websocketpp::session_ptr client, const std::string &msg) { client->send(msg); diff --git a/examples/echo_server/echo.hpp b/examples/echo_server/echo.hpp index da0f02b062..01e0c6869d 100644 --- a/examples/echo_server/echo.hpp +++ b/examples/echo_server/echo.hpp @@ -42,7 +42,7 @@ public: virtual ~echo_handler() {} // The echo server allows all domains is protocol free. - bool validate(websocketpp::session_ptr client); + void validate(websocketpp::session_ptr client); // an echo server is stateless. The handler has no need to keep track of connected // clients. diff --git a/examples/echo_server/echo_server b/examples/echo_server/echo_server index 0b341d7092..60540ce967 100755 Binary files a/examples/echo_server/echo_server and b/examples/echo_server/echo_server differ diff --git a/src/websocket_connection_handler.hpp b/src/websocket_connection_handler.hpp index 297f0a8ec2..4ee93aeb97 100644 --- a/src/websocket_connection_handler.hpp +++ b/src/websocket_connection_handler.hpp @@ -44,18 +44,17 @@ namespace websocketpp { class connection_handler { public: - // validate will be called after a websocket handshake has been received and + // validate will be called after a websocket handshake has been received and // before it is accepted. It provides a handler the ability to refuse a // connection based on application specific logic (ex: restrict domains or - // negotiate subprotocols) + // negotiate subprotocols). To reject the connection throw a handshake_error // - // resource is the resource requested by the connection - // headers is a map containing the HTTP headers included in the client - // handshake as key/value strings - // - // if validate returns true the connection is allowed, otherwise the - // connection is closed. - virtual bool validate(session_ptr client) = 0; + // handshake_error parameters: + // log_message - error message to send to server log + // http_error_code - numeric HTTP error code to return to the client + // http_error_msg - (optional) string HTTP error code to return to the + // client (useful for returning non-standard error codes) + virtual void validate(session_ptr client) = 0; // this will be called once the connected websocket is avaliable for // writing messages. client may be a new websocket session or an existing diff --git a/src/websocket_session.cpp b/src/websocket_session.cpp index 5ccc1e1a4b..7f2f720ef6 100644 --- a/src/websocket_session.cpp +++ b/src/websocket_session.cpp @@ -33,9 +33,10 @@ #include #include -#include +#include #include #include +#include using websocketpp::session; @@ -108,6 +109,14 @@ std::string session::get_request() const { return m_request; } +std::string session::get_origin() const { + if (m_version < 13) { + return get_header("Sec-WebSocket-Origin"); + } else { + return get_header("Origin"); + } +} + void session::set_http_error(int code, std::string msg) { m_http_error_code = code; m_http_error_string = (msg != "" ? msg : lookup_http_error_string(code)); @@ -250,78 +259,80 @@ void session::handle_read_handshake(const boost::system::error_code& e, } } - // error checking - - // check the method - if (m_request.substr(0,4) != "GET ") { - // invalid method - std::cout << "Websocket handshake has invalid method: " << m_request.substr(0,4) << ", killing connection." << std::endl; - // TODO: exception - this->set_http_error(400); - this->write_http_error(); - return; - } - - // check the HTTP version - // TODO: allow versions greater than 1.1 - end = m_request.find(" HTTP/1.1",4); - if (end == std::string::npos) { - std::cout << "Websocket handshake has invalid HTTP version, killing connection." << std::endl; - this->set_http_error(400); // or error 505 HTTP Version Not Supported - this->write_http_error(); - return; - } + // handshake error checking + try { + std::stringstream err; + std::string h; - m_request = m_request.substr(4,end-4); - - // TODO: use exceptions or a helper function or something better here - - // verify the presence of required headers - if (m_hosts.find(get_header("Host")) == m_hosts.end()) { - std::cerr << "Invalid or missing Host header: " - << get_header("Host") << std::endl; - this->set_http_error(400); - } - if (!boost::iequals(get_header("Upgrade"),"websocket")) { - std::cerr << "Invalid or missing Upgrade header." << std::endl; - this->set_http_error(400); - } - if (get_header("Connection").find("Upgrade") == std::string::npos) { - // TODO: case insensitive? - std::cerr << "Invalid or missing Connection header." << std::endl; - this->set_http_error(400); - } - if (get_header("Sec-WebSocket-Key") == "") { - std::cerr << "Invalid or missing Sec-Websocket-Key header." << std::endl; - this->set_http_error(400); - } - - if (get_header("Sec-WebSocket-Version") == "" ) { - std::cerr << "Missing Sec-Websocket-Version header." << std::endl; - this->set_http_error(400); - } else { - m_version = strtoul(get_header("Sec-WebSocket-Version").c_str(),NULL,10); - - if (m_version != 7 && m_version != 8 && m_version != 13) { - this->set_http_error(400); + // check the method + if (m_request.substr(0,4) != "GET ") { + err << "Websocket handshake has invalid method: " + << m_request.substr(0,4); + + throw(handshake_error(err.str(),400)); } - } - - if (m_http_error_code != 0) { - this->write_http_error(); - return; - } - - // optional headers (delegated to the local interface) - if (m_local_interface && !m_local_interface->validate(shared_from_this())) { - std::cerr << "Local interface rejected the connection." << std::endl; - if (m_http_error_code == 0) { - this->set_http_error(400); - } - } - if (m_http_error_code != 0) { - this->write_http_error(); + // check the HTTP version + // TODO: allow versions greater than 1.1 + end = m_request.find(" HTTP/1.1",4); + if (end == std::string::npos) { + err << "Websocket handshake has invalid HTTP version"; + throw(handshake_error(err.str(),400)); + } + + m_request = m_request.substr(4,end-4); + + // verify the presence of required headers + h = get_header("Host"); + if (h == "") { + throw(handshake_error("Required Host header is missing",400)); + } else if (m_hosts.find(h) == m_hosts.end()) { + err << "Host " << h << " is not one of this server's names."; + throw(handshake_error(err.str(),400)); + } + + h = get_header("Upgrade"); + if (h == "") { + throw(handshake_error("Required Upgrade header is missing",400)); + } else if (!boost::iequals(h,"websocket")) { + err << "Upgrade header was " << h << " instead of \"websocket\""; + throw(handshake_error(err.str(),400)); + } + + h = get_header("Connection"); + if (h == "") { + throw(handshake_error("Required Connection header is missing",400)); + } else if (!boost::ifind_first(h,"upgrade")) { + err << "Connection header, \"" << h + << "\", does not contain required token \"upgrade\""; + throw(handshake_error(err.str(),400)); + } + + if (get_header("Sec-WebSocket-Key") == "") { + throw(handshake_error("Required Sec-WebSocket-Key header is missing",400)); + } + + h = get_header("Sec-WebSocket-Version"); + if (h == "") { + throw(handshake_error("Required Sec-WebSocket-Version header is missing",400)); + } else { + m_version = atoi(h.c_str()); + + if (m_version != 7 && m_version != 8 && m_version != 13) { + err << "This server doesn't support WebSocket protocol version " + << m_version; + throw(handshake_error(err.str(),400)); + } + } + + // optional headers (delegated to the local interface) + if (m_local_interface) { + m_local_interface->validate(shared_from_this()); + } + + } catch (const handshake_error& e) { + std::cerr << "Caught handshake exception: " << e.what() << std::endl; + e.write(shared_from_this()); return; } @@ -381,7 +392,7 @@ void session::handle_write_handshake(const boost::system::error_code& error) { return; } - //std::cout << "WebSocket Version 8 connection opened." << std::endl; + std::cout << "WebSocket Version " << m_version << " connection opened." << std::endl; m_status = OPEN; @@ -393,11 +404,12 @@ void session::handle_write_handshake(const boost::system::error_code& error) { this->read_frame(); } -void session::write_http_error() { +void session::write_http_error(int code,const std::string &msg) { std::stringstream server_handshake; - - server_handshake << "HTTP/1.1 " << m_http_error_code << " " - << m_http_error_string << "\r\n"; + + server_handshake << "HTTP/1.1 " << code << " " + << (msg != "" ? msg : lookup_http_error_string(code)) + << "\r\n"; // additional headers? diff --git a/src/websocket_session.hpp b/src/websocket_session.hpp index 750e726c34..71b291a458 100644 --- a/src/websocket_session.hpp +++ b/src/websocket_session.hpp @@ -37,16 +37,19 @@ #include #include +#include #include #include #include -#include #include +#include #include namespace websocketpp { class session; typedef boost::shared_ptr session_ptr; + + class handshake_error; } #include "websocket_frame.hpp" @@ -69,7 +72,9 @@ public: }; typedef enum ws_status status_code; - + + friend class handshake_error; + session (boost::asio::io_service& io_service, connection_handler_ptr defc); tcp::socket& socket(); @@ -103,19 +108,13 @@ public: // memory buffers. Maximum value is 2^63, default value is ?? void set_max_message_size(uint64_t); - // By default a failed handshake validation will return an HTTP 400 Bad - // Request error. If your application wants to reject the connection for - // another reason it can be set here. Example: 404 if the resource request - // is not recognized or 403 forbidden if the origin does not match. If only - // a code is supplied a msg will be generated automatically. - void set_http_error(int code,std::string msg = ""); - // gets the value of a header or the empty string if not present. std::string get_header(const std::string &key) const; // adds an arbitrary header to the server handshake HTTP response. void add_header(const std::string &key,const std::string &value); std::string get_request() const; + std::string get_origin() const; // sets the subprotocol being used. This will result in the appropriate // Sec-WebSocket-Protocol header being sent back to the client. The value @@ -153,7 +152,7 @@ private: void handle_write_handshake(const boost::system::error_code& error); // construct and write an HTTP error in the case the handshake goes poorly - void write_http_error(); + void write_http_error(int http_code,const std::string &http_err_str); void handle_write_http_error(const boost::system::error_code& error); // start async read for a websocket frame (2 bytes) to handle_frame_header @@ -236,6 +235,30 @@ private: frame::opcode m_current_opcode; }; +// Exception classes + +class handshake_error : public std::exception { +public: + handshake_error(const std::string& msg, + int http_error, + const std::string& http_msg = "") + : m_msg(msg),m_http_error_code(http_error),m_http_error_msg(http_msg) {} + ~handshake_error() throw() {} + + virtual const char* what() const throw() { + return m_msg.c_str(); + } + + void write(session_ptr s) const { + s->write_http_error(m_http_error_code,m_http_error_msg); + } + +private: + std::string m_msg; + int m_http_error_code; + std::string m_http_error_msg; +}; + } #endif // WEBSOCKET_SESSION_HPP