diff --git a/examples/chat_server/chat.cpp b/examples/chat_server/chat.cpp index da25180b1b..52d3e72eb6 100644 --- a/examples/chat_server/chat.cpp +++ b/examples/chat_server/chat.cpp @@ -32,35 +32,36 @@ using websocketchat::chat_server_handler; using websocketpp::session::server_session_ptr; + void chat_server_handler::validate(server_session_ptr session) { std::stringstream err; // We only know about the chat resource - if (session->get_uri().resource != "/chat") { - err << "Request for unknown resource " << client->get_resource(); - throw(websocketpp::handshake_error(err.str(),404)); + if (session->get_resource() != "/chat") { + err << "Request for unknown resource " << session->get_resource(); + throw(websocketpp::http::exception(err.str(),websocketpp::http::status_code::NOT_FOUND)); } // Require specific origin example - if (client->get_origin() != "http://zaphoyd.com") { - err << "Request from unrecognized origin: " << client->get_origin(); - throw(websocketpp::handshake_error(err.str(),403)); + if (session->get_origin() != "http://zaphoyd.com") { + err << "Request from unrecognized origin: " << session->get_origin(); + throw(websocketpp::http::exception(err.str(),websocketpp::http::status_code::FORBIDDEN)); } } -void chat_server_handler::on_open(session_ptr client) { - std::cout << "client " << client << " joined the lobby." << std::endl; - m_connections.insert(std::pair(client,get_con_id(client))); +void chat_server_handler::on_open(server_session_ptr session) { + std::cout << "client " << session << " joined the lobby." << std::endl; + m_connections.insert(std::pair(session,get_con_id(session))); // send user list and signon message to all clients send_to_all(serialize_state()); - client->send(encode_message("server","Welcome, use the /alias command to set a name, /help for a list of other commands.")); - send_to_all(encode_message("server",m_connections[client]+" has joined the chat.")); + session->send(encode_message("server","Welcome, use the /alias command to set a name, /help for a list of other commands.")); + send_to_all(encode_message("server",m_connections[session]+" has joined the chat.")); } -void chat_server_handler::on_close(session_ptr client) { - std::map::iterator it = m_connections.find(client); +void chat_server_handler::on_close(server_session_ptr session) { + std::map::iterator it = m_connections.find(session); if (it == m_connections.end()) { // this client has already disconnected, we can ignore this. @@ -70,7 +71,7 @@ void chat_server_handler::on_close(session_ptr client) { return; } - std::cout << "client " << client << " left the lobby." << std::endl; + std::cout << "client " << session << " left the lobby." << std::endl; const std::string alias = it->second; m_connections.erase(it); @@ -80,31 +81,31 @@ void chat_server_handler::on_close(session_ptr client) { send_to_all(encode_message("server",alias+" has left the chat.")); } -void chat_server_handler::on_message(session_ptr client,const std::string &msg) { - std::cout << "message from client " << client << ": " << msg << std::endl; +void chat_server_handler::on_message(server_session_ptr session,websocketpp::utf8_string_ptr msg) { + std::cout << "message from client " << session << ": " << *msg << std::endl; // check for special command messages - if (msg == "/help") { + if (*msg == "/help") { // print command list - client->send(encode_message("server","avaliable commands:
    /help - show this help
    /alias foo - set alias to foo",false)); + session->send(encode_message("server","avaliable commands:
    /help - show this help
    /alias foo - set alias to foo",false)); return; } - if (msg.substr(0,7) == "/alias ") { + if (msg->substr(0,7) == "/alias ") { std::string response; std::string alias; - if (msg.size() == 7) { + if (msg->size() == 7) { response = "You must enter an alias."; - client->send(encode_message("server",response)); + session->send(encode_message("server",response)); return; } else { - alias = msg.substr(7); + alias = msg->substr(7); } - response = m_connections[client] + " is now known as "+alias; + response = m_connections[session] + " is now known as "+alias; // store alias pre-escaped so we don't have to do this replacing every time this // user sends a message @@ -118,7 +119,7 @@ void chat_server_handler::on_message(session_ptr client,const std::string &msg) boost::algorithm::replace_all(alias,"<","<"); boost::algorithm::replace_all(alias,">",">"); - m_connections[client] = alias; + m_connections[session] = alias; // set alias send_to_all(serialize_state()); @@ -127,13 +128,13 @@ void chat_server_handler::on_message(session_ptr client,const std::string &msg) } // catch other slash commands - if (msg[0] == '/') { - client->send(encode_message("server","unrecognized command")); + if ((*msg)[0] == '/') { + session->send(encode_message("server","unrecognized command")); return; } // create JSON message to send based on msg - send_to_all(encode_message(m_connections[client],msg)); + send_to_all(encode_message(m_connections[session],*msg)); } // {"type":"participants","value":[,...]} @@ -142,7 +143,7 @@ std::string chat_server_handler::serialize_state() { s << "{\"type\":\"participants\",\"value\":["; - std::map::iterator it; + std::map::iterator it; for (it = m_connections.begin(); it != m_connections.end(); it++) { s << "\"" << (*it).second << "\""; @@ -178,14 +179,14 @@ std::string chat_server_handler::encode_message(std::string sender,std::string m return s.str(); } -std::string chat_server_handler::get_con_id(session_ptr s) { +std::string chat_server_handler::get_con_id(server_session_ptr s) { std::stringstream endpoint; - endpoint << s->socket().remote_endpoint(); + endpoint << s->get_endpoint(); return endpoint.str(); } void chat_server_handler::send_to_all(std::string data) { - std::map::iterator it; + std::map::iterator it; for (it = m_connections.begin(); it != m_connections.end(); it++) { (*it).first->send(data); } diff --git a/examples/chat_server/chat_server.cpp b/examples/chat_server/chat_server.cpp index 80d65d9fab..474197bda5 100644 --- a/examples/chat_server/chat_server.cpp +++ b/examples/chat_server/chat_server.cpp @@ -36,46 +36,33 @@ using boost::asio::ip::tcp; using namespace websocketchat; int main(int argc, char* argv[]) { - std::string host = "localhost"; short port = 9003; - std::string full_host; - if (argc == 3) { + if (argc == 2) { // TODO: input validation? - host = argv[1]; port = atoi(argv[2]); } - - std::stringstream temp; - - temp << host << ":" << port; - full_host = temp.str(); - - - chat_server_handler_ptr chat_handler(new chat_server_handler()); try { - boost::asio::io_service io_service; - tcp::endpoint endpoint(tcp::v6(), port); + // create an instance of our handler + server_handler_ptr default_handler(new chat_server_handler()); - websocketpp::server_ptr server( - new websocketpp::server(io_service,endpoint,chat_handler) - ); + // create a server that listens on port `port` and uses our handler + websocketpp::basic_server_ptr server(new websocketpp::basic_server(port,default_handler)); + + server->elog().set_levels(websocketpp::log::elevel::DEVEL,websocketpp::log::elevel::FATAL); + + server->alog().set_level(websocketpp::log::alevel::ALL); // setup server settings - server->add_host(host); - server->add_host(full_host); // Chat server should only be receiving small text messages, reduce max // message size limit slightly to save memory, improve performance, and // guard against DoS attacks. - server->set_max_message_size(0xFFFF); // 64KiB + //server->set_max_message_size(0xFFFF); // 64KiB - // start the server - server->start_accept(); + std::cout << "Starting chat server on port " << port << std::endl; - std::cout << "Starting chat server on " << full_host << std::endl; - - io_service.run(); + server->run(); } catch (std::exception& e) { std::cerr << "Exception: " << e.what() << std::endl; } diff --git a/examples/echo_server/echo_server.cpp b/examples/echo_server/echo_server.cpp index 97e3431136..b16cf88329 100644 --- a/examples/echo_server/echo_server.cpp +++ b/examples/echo_server/echo_server.cpp @@ -35,31 +35,19 @@ using boost::asio::ip::tcp; int main(int argc, char* argv[]) { - std::string host = "localhost"; short port = 9002; - std::string full_host; - if (argc == 3) { + if (argc == 2) { // TODO: input validation? - host = argv[1]; port = atoi(argv[2]); } - std::stringstream temp; - - temp << host << ":" << port; - full_host = temp.str(); - - - try { // create an instance of our handler server_handler_ptr default_handler(new websocketecho::echo_server_handler()); // create a server that listens on port `port` and uses our handler - typedef boost::shared_ptr< websocketpp::server::server<> > server_ptr; - - server_ptr server(new websocketpp::server::server<>(port,default_handler)); + websocketpp::basic_server_ptr server(new websocketpp::basic_server(port,default_handler)); server->elog().set_levels(websocketpp::log::elevel::DEVEL,websocketpp::log::elevel::FATAL); @@ -75,10 +63,10 @@ int main(int argc, char* argv[]) { // server to test performance and protocol extremes. //server->set_max_message_size(websocketpp::frame::limits::PAYLOAD_SIZE_JUMBO); + std::cout << "Starting echo server on port " << port << std::endl; + // start the server server->run(); - - std::cout << "Starting echo server on " << full_host << std::endl; } catch (std::exception& e) { std::cerr << "Exception: " << e.what() << std::endl; } diff --git a/src/http/constants.hpp b/src/http/constants.hpp index 283e57e1b1..bbdd76e793 100644 --- a/src/http/constants.hpp +++ b/src/http/constants.hpp @@ -85,7 +85,8 @@ namespace http { NETWORK_AUTHENTICATION_REQUIRED = 511 }; - std::string get_string(value c) { + // TODO: should this be inline? + inline std::string get_string(value c) { switch (c) { case CONTINUE: return "Continue"; diff --git a/src/hybi_legacy_processor.hpp b/src/hybi_legacy_processor.hpp index 3dba3c6b55..e3adc73293 100644 --- a/src/hybi_legacy_processor.hpp +++ b/src/hybi_legacy_processor.hpp @@ -45,7 +45,7 @@ namespace hybi_legacy_state { class hybi_legacy_processor : public processor { public: - hybi_legacy_processor() : m_state(hybi_legacy_state::INIT) { + hybi_legacy_processor(bool secure) : m_secure(secure),m_state(hybi_legacy_state::INIT) { reset(); } @@ -86,14 +86,26 @@ public: // set a different one. if (response.header("Sec-WebSocket-Location") == "") { // TODO: extract from host header rather than hard code - ws_uri uri; - uri.secure = false; - uri.host = "localhost"; - uri.port = 9002; + ws_uri uri = get_uri(request); response.add_header("Sec-WebSocket-Location",uri.base()); } } + std::string get_origin(const http::parser::request& request) const { + return request.header("Origin"); + } + + ws_uri get_uri(const http::parser::request& request) const { + ws_uri uri; + + uri.secure = m_secure; + // TODO: check if get_uri is a full uri + // TODO: host and port + uri.resource = request.uri(); + + return uri; + } + void consume(std::istream& s) { unsigned char c; while (s.good() && m_state != hybi_legacy_state::DONE) { @@ -251,6 +263,8 @@ private: } } + bool m_secure; + hybi_legacy_state::value m_state; frame::opcode::value m_opcode; diff --git a/src/hybi_processor.hpp b/src/hybi_processor.hpp index 138170f81d..2a61c5fb04 100644 --- a/src/hybi_processor.hpp +++ b/src/hybi_processor.hpp @@ -52,7 +52,7 @@ namespace hybi_state { template class hybi_processor : public processor { public: - hybi_processor(rng_policy &rng) : m_fragmented_opcode(frame::opcode::CONTINUATION),m_utf8_payload(new utf8_string()),m_binary_payload(new binary_string()),m_read_frame(rng),m_write_frame(rng) { + hybi_processor(bool secure,rng_policy &rng) : m_secure(secure),m_fragmented_opcode(frame::opcode::CONTINUATION),m_utf8_payload(new utf8_string()),m_binary_payload(new binary_string()),m_read_frame(rng),m_write_frame(rng) { reset(); } @@ -115,6 +115,30 @@ public: } } + std::string get_origin(const http::parser::request& request) const { + std::string h = request.header("Sec-WebSocket-Version"); + int version = atoi(h.c_str()); + + if (version == 13) { + return request.header("Origin"); + } else if (version == 7 || version == 8) { + return request.header("Sec-WebSocket-Origin"); + } else { + throw(http::exception("Could not determine origin header.",http::status_code::BAD_REQUEST)); + } + } + + ws_uri get_uri(const http::parser::request& request) const { + ws_uri uri; + + uri.secure = m_secure; + // TODO: check if get_uri is a full uri + // TODO: host and port + uri.resource = request.uri(); + + return uri; + } + void handshake_response(const http::parser::request& request,http::parser::response& response) { // TODO: @@ -416,6 +440,7 @@ public: } private: + bool m_secure; int m_state; frame::opcode::value m_opcode; frame::opcode::value m_fragmented_opcode; diff --git a/src/interfaces/protocol.hpp b/src/interfaces/protocol.hpp index d1749c1009..1b2f6892f4 100644 --- a/src/interfaces/protocol.hpp +++ b/src/interfaces/protocol.hpp @@ -49,7 +49,12 @@ public: virtual void validate_handshake(const http::parser::request& headers) const = 0; virtual void handshake_response(const http::parser::request& request,http::parser::response& response) = 0; - + + // Extracts client origin from a handshake request + virtual std::string get_origin(const http::parser::request& request) const = 0; + // Extracts client uri from a handshake request + virtual ws_uri get_uri(const http::parser::request& request) const = 0; + // consume bytes, throw on exception virtual void consume(std::istream& s) = 0; diff --git a/src/interfaces/session.hpp b/src/interfaces/session.hpp index 36b0452101..b6d931707b 100644 --- a/src/interfaces/session.hpp +++ b/src/interfaces/session.hpp @@ -87,10 +87,17 @@ public: virtual session::state::value get_state() const = 0; virtual unsigned int get_version() const = 0; - virtual std::string get_origin() const = 0; virtual std::string get_request_header(const std::string& key) const = 0; - virtual const ws_uri& get_uri() const = 0; + virtual std::string get_origin() const = 0; + + // Information about the requested URI virtual bool get_secure() const = 0; + virtual std::string get_host() const = 0; + virtual std::string get_resource() const = 0; + virtual uint16_t get_port() const = 0; + + // Information about the connected endpoint + /* Tentative API member function */ virtual boost::asio::ip::tcp::endpoint get_endpoint() const = 0; // Valid for CONNECTING state virtual void add_response_header(const std::string& key, const std::string& value) = 0; diff --git a/src/utf8_validator/utf8_validator.hpp b/src/utf8_validator/utf8_validator.hpp index 7df583c3a7..9595cef745 100644 --- a/src/utf8_validator/utf8_validator.hpp +++ b/src/utf8_validator/utf8_validator.hpp @@ -78,7 +78,8 @@ private: // convenience function that creates a validator, validates a complete string // and returns the result. -bool validate(const std::string& s) { +// TODO: should this be inline? +inline bool validate(const std::string& s) { validator v; if (!v.decode(s.begin(),s.end())) { return false; diff --git a/src/websocket_server.hpp b/src/websocket_server.hpp index 304c109fa6..94173424e7 100644 --- a/src/websocket_server.hpp +++ b/src/websocket_server.hpp @@ -58,7 +58,7 @@ namespace po = boost::program_options; #include "http/parser.hpp" #include "logger/logger.hpp" -using boost::asio ::ip::tcp; +using boost::asio::ip::tcp; using websocketpp::session::server_handler_ptr; using websocketpp::protocol::processor_ptr; @@ -114,15 +114,27 @@ public: return m_request.header(key); } - const ws_uri& get_uri() const { - return m_uri; - } - bool get_secure() const { // TODO return false; } + std::string get_host() const { + return m_uri.host; + } + + uint16_t get_port() const { + return m_uri.port; + } + + std::string get_resource() const { + return m_uri.resource; + } + + tcp::endpoint get_endpoint() const { + return m_socket.remote_endpoint(); + } + // Valid for CONNECTING state void add_response_header(const std::string& key, const std::string& value) { m_response.add_header(key,value); @@ -325,16 +337,18 @@ public: } m_request.add_header("Sec-WebSocket-Key3",std::string(foo)); - m_processor = processor_ptr(new protocol::hybi_legacy_processor()); + m_processor = processor_ptr(new protocol::hybi_legacy_processor(false)); } else if (m_version == 7 || m_version == 8 || m_version == 13) { // create hybi 17 processor - m_processor = processor_ptr(new protocol::hybi_processor(m_rng)); + m_processor = processor_ptr(new protocol::hybi_processor(false,m_rng)); } else { // TODO: respond with unknown version message per spec } // ask new protocol whether this set of headers is valid m_processor->validate_handshake(m_request); + m_origin = m_processor->get_origin(m_request); + m_uri = m_processor->get_uri(m_request); // ask local application to confirm that it wants to accept m_handler->validate(boost::static_pointer_cast(connection_type::shared_from_this())); @@ -1029,8 +1043,20 @@ private: po::options_description m_desc; po::variables_map m_vm; }; - + } + +// convenience type interface + +// websocketpp::server_ptr represents a basic non-secure websocket server +typedef server::server<> basic_server; +typedef basic_server::ptr basic_server_ptr; + +// websocketpp::secure_server_ptr represents a basic secure websocket server +// TODO: +typedef server::server<> secure_server; +typedef secure_server::ptr secure_server_ptr; + } #endif // WEBSOCKET_SERVER_HPP diff --git a/src/websocketpp.hpp b/src/websocketpp.hpp index 356148b43b..125401a81f 100644 --- a/src/websocketpp.hpp +++ b/src/websocketpp.hpp @@ -32,9 +32,6 @@ #include "websocket_constants.hpp" -//#include "websocket_session.hpp" -//#include "websocket_server_session.hpp" -//#include "websocket_client_session.hpp" #include "websocket_server.hpp" //#include "websocket_client.hpp"