diff --git a/examples/chat_server/chat.cpp b/examples/chat_server/chat.cpp index 9b629b45fb..da25180b1b 100644 --- a/examples/chat_server/chat.cpp +++ b/examples/chat_server/chat.cpp @@ -30,13 +30,13 @@ #include using websocketchat::chat_server_handler; -using websocketpp::session_ptr; +using websocketpp::session::server_session_ptr; -void chat_server_handler::validate(session_ptr client) { +void chat_server_handler::validate(server_session_ptr session) { std::stringstream err; // We only know about the chat resource - if (client->get_resource() != "/chat") { + if (session->get_uri().resource != "/chat") { err << "Request for unknown resource " << client->get_resource(); throw(websocketpp::handshake_error(err.str(),404)); } diff --git a/examples/chat_server/chat.hpp b/examples/chat_server/chat.hpp index 0ef21c325c..7c32020d4e 100644 --- a/examples/chat_server/chat.hpp +++ b/examples/chat_server/chat.hpp @@ -38,43 +38,42 @@ // {"type":"msg","sender":"","value":"" } // {"type":"participants","value":[,...]} -#include #include "../../src/websocketpp.hpp" -#include "../../src/websocket_connection_handler.hpp" +#include "../../src/interfaces/session.hpp" +#include #include #include #include +using websocketpp::session::server_session_ptr; +using websocketpp::session::server_handler; + namespace websocketchat { -class chat_server_handler : public websocketpp::connection_handler { +class chat_server_handler : public server_handler { public: - chat_server_handler() {} - virtual ~chat_server_handler() {} - - void validate(websocketpp::session_ptr client); + void validate(server_session_ptr session); // add new connection to the lobby - void on_open(websocketpp::session_ptr client); + void on_open(server_session_ptr session); // someone disconnected from the lobby, remove them - void on_close(websocketpp::session_ptr client); + void on_close(server_session_ptr session); - void on_message(websocketpp::session_ptr client,const std::string &msg); + void on_message(server_session_ptr session,websocketpp::utf8_string_ptr msg); // lobby will ignore binary messages - void on_message(websocketpp::session_ptr client, - const std::vector &data) {} + void on_message(server_session_ptr session,websocketpp::binary_string_ptr data) {} private: std::string serialize_state(); std::string encode_message(std::string sender,std::string msg,bool escape = true); - std::string get_con_id(websocketpp::session_ptr s); + std::string get_con_id(server_session_ptr s); void send_to_all(std::string data); // list of outstanding connections - std::map m_connections; + std::map m_connections; }; typedef boost::shared_ptr chat_server_handler_ptr; diff --git a/examples/echo_server/echo.hpp b/examples/echo_server/echo.hpp index 1e3c00e706..48e8948d6c 100644 --- a/examples/echo_server/echo.hpp +++ b/examples/echo_server/echo.hpp @@ -40,7 +40,7 @@ using websocketpp::session::server_handler; namespace websocketecho { - class echo_server_handler : public server_handler { +class echo_server_handler : public server_handler { public: // The echo server allows all domains is protocol free. void validate(server_ptr session) {} diff --git a/src/hybi_legacy_processor.hpp b/src/hybi_legacy_processor.hpp index 398039db25..3dba3c6b55 100644 --- a/src/hybi_legacy_processor.hpp +++ b/src/hybi_legacy_processor.hpp @@ -219,7 +219,7 @@ public: binary_string_ptr prepare_close_frame(close::status::value code, bool mask, - const std::string& reason) const { + const std::string& reason) { binary_string_ptr response(new binary_string(2)); (*response)[0] = 0xFF; diff --git a/src/hybi_processor.hpp b/src/hybi_processor.hpp index c5534cfc74..138170f81d 100644 --- a/src/hybi_processor.hpp +++ b/src/hybi_processor.hpp @@ -52,9 +52,9 @@ namespace hybi_state { template class hybi_processor : public processor { public: - hybi_processor(rng_policy &rng) : m_read_frame(rng),m_write_frame(rng) { + hybi_processor(rng_policy &rng) : m_fragmented_opcode(frame::opcode::CONTINUATION),m_utf8_payload(new utf8_string()),m_binary_payload(new binary_string()),m_read_frame(rng),m_write_frame(rng) { reset(); - } + } void validate_handshake(const http::parser::request& request) const { std::stringstream err; @@ -159,31 +159,28 @@ public: if (m_read_frame.ready()) { switch (m_read_frame.get_opcode()) { case frame::opcode::CONTINUATION: - if (m_opcode == frame::opcode::BINARY) { - extract_binary(); - } else if (m_opcode == frame::opcode::TEXT) { - extract_utf8(); - } else { - // can't be here - } + process_continuation(); break; case frame::opcode::TEXT: - m_opcode = frame::opcode::TEXT; - extract_utf8(); + process_text(); break; case frame::opcode::BINARY: - m_opcode = frame::opcode::BINARY; - extract_binary(); + process_binary(); break; case frame::opcode::CLOSE: + if (!utf8_validator::validate(m_read_frame.get_close_msg())) { + throw session::exception("Invalid UTF8",session::error::PAYLOAD_VIOLATION); + } + m_opcode = frame::opcode::CLOSE; m_close_code = m_read_frame.get_close_status(); m_close_reason = m_read_frame.get_close_msg(); + break; case frame::opcode::PING: case frame::opcode::PONG: m_opcode = m_read_frame.get_opcode(); - extract_binary(); + extract_binary(m_control_payload); break; default: throw session::exception("Invalid Opcode",session::error::PROTOCOL_VIOLATION); @@ -191,6 +188,13 @@ public: } if (m_read_frame.get_fin()) { m_state = hybi_state::DONE; + if (m_opcode == frame::opcode::TEXT) { + if (!m_validator.complete()) { + m_validator.reset(); + throw session::exception("Invalid UTF8",session::error::PAYLOAD_VIOLATION); + } + m_validator.reset(); + } } m_read_frame.reset(); } @@ -204,17 +208,57 @@ public: } } - void extract_binary() { - binary_string &msg = m_read_frame.get_payload(); - m_binary_payload->resize(m_binary_payload->size()+msg.size()); - std::copy(msg.begin(),msg.end(),m_binary_payload->end()-msg.size()); + // frame type handlers: + void process_continuation() { + if (m_fragmented_opcode == frame::opcode::BINARY) { + extract_binary(m_binary_payload); + } else if (m_fragmented_opcode == frame::opcode::TEXT) { + extract_utf8(m_utf8_payload); + } else if (m_fragmented_opcode == frame::opcode::CONTINUATION) { + // got continuation frame without a message to continue. + throw session::exception("No message to continue.",session::error::PROTOCOL_VIOLATION); + } else { + + // can't be here + } + if (m_read_frame.get_fin()) { + m_opcode = m_fragmented_opcode; + } } - void extract_utf8() { + void process_text() { + if (m_fragmented_opcode != frame::opcode::CONTINUATION) { + throw session::exception("New message started without closing previous.",session::error::PROTOCOL_VIOLATION); + } + extract_utf8(m_utf8_payload); + m_opcode = frame::opcode::TEXT; + m_fragmented_opcode = frame::opcode::TEXT; + } + + void process_binary() { + if (m_fragmented_opcode != frame::opcode::CONTINUATION) { + throw session::exception("New message started without closing previous.",session::error::PROTOCOL_VIOLATION); + } + m_opcode = frame::opcode::BINARY; + m_fragmented_opcode = frame::opcode::BINARY; + extract_binary(m_binary_payload); + } + + void extract_binary(binary_string_ptr dest) { + binary_string &msg = m_read_frame.get_payload(); + dest->resize(dest->size() + msg.size()); + std::copy(msg.begin(),msg.end(),dest->end() - msg.size()); + } + + void extract_utf8(utf8_string_ptr dest) { binary_string &msg = m_read_frame.get_payload(); - m_utf8_payload->reserve(m_utf8_payload->size()+msg.size()); - m_utf8_payload->append(msg.begin(),msg.end()); + if (!m_validator.decode(msg.begin(),msg.end())) { + throw session::exception("Invalid UTF8",session::error::PAYLOAD_VIOLATION); + } + + dest->reserve(dest->size() + msg.size()); + dest->append(msg.begin(),msg.end()); } bool ready() const { @@ -223,8 +267,13 @@ public: void reset() { m_state = m_state = hybi_state::INIT; - m_utf8_payload = utf8_string_ptr(new utf8_string()); - m_binary_payload = binary_string_ptr(new binary_string()); + m_control_payload = binary_string_ptr(new binary_string()); + + if (m_fragmented_opcode == m_opcode) { + m_utf8_payload = utf8_string_ptr(new utf8_string()); + m_binary_payload = binary_string_ptr(new binary_string()); + m_fragmented_opcode = frame::opcode::CONTINUATION; + } } uint64_t get_bytes_needed() const { @@ -251,18 +300,18 @@ public: } binary_string_ptr get_binary_payload() const { - // TODO: opcode doesn't match - if (get_opcode() != frame::opcode::BINARY && - get_opcode() != frame::opcode::PING && - get_opcode() != frame::opcode::PONG) { - throw "opcode doesn't have a binary payload"; - } - if (!ready()) { throw "not ready"; } - return m_binary_payload; + if (get_opcode() == frame::opcode::BINARY) { + return m_binary_payload; + } else if (get_opcode() != frame::opcode::PING || + get_opcode() != frame::opcode::PONG) { + return m_control_payload; + } else { + throw "opcode doesn't have a binary payload"; + } } // legacy hybi doesn't have close codes @@ -345,10 +394,23 @@ public: binary_string_ptr prepare_close_frame(close::status::value code, bool mask, - const std::string& reason) const { + const std::string& reason) { binary_string_ptr response(new binary_string(0)); + m_write_frame.reset(); + m_write_frame.set_opcode(frame::opcode::CLOSE); + m_write_frame.set_masked(mask); + m_write_frame.set_fin(true); + m_write_frame.set_status(code,reason); + // TODO + response->resize(m_write_frame.get_header_len()+m_write_frame.get_payload().size()); + + // copy header + std::copy(m_write_frame.get_header(),m_write_frame.get_header()+m_write_frame.get_header_len(),response->begin()); + + // copy payload + std::copy(m_write_frame.get_payload().begin(),m_write_frame.get_payload().end(),response->begin()+m_write_frame.get_header_len()); return response; } @@ -356,13 +418,17 @@ public: private: int m_state; frame::opcode::value m_opcode; + frame::opcode::value m_fragmented_opcode; utf8_string_ptr m_utf8_payload; binary_string_ptr m_binary_payload; + binary_string_ptr m_control_payload; close::status::value m_close_code; std::string m_close_reason; + utf8_validator::validator m_validator; + frame::parser m_read_frame; frame::parser m_write_frame; }; diff --git a/src/interfaces/protocol.hpp b/src/interfaces/protocol.hpp index eca5ab8870..d1749c1009 100644 --- a/src/interfaces/protocol.hpp +++ b/src/interfaces/protocol.hpp @@ -77,7 +77,7 @@ public: virtual binary_string_ptr prepare_close_frame(close::status::value code, bool mask, - const std::string& reason) const = 0; + const std::string& reason) = 0; diff --git a/src/interfaces/session.hpp b/src/interfaces/session.hpp index d7d918e8e2..36b0452101 100644 --- a/src/interfaces/session.hpp +++ b/src/interfaces/session.hpp @@ -117,7 +117,8 @@ public: virtual bool get_closed_by_me() const = 0; }; typedef boost::shared_ptr server_ptr; - +typedef server_ptr server_session_ptr; + /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * Server Handler API * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ diff --git a/src/utf8_validator/utf8_validator.hpp b/src/utf8_validator/utf8_validator.hpp index 2e1b40c0a7..7df583c3a7 100644 --- a/src/utf8_validator/utf8_validator.hpp +++ b/src/utf8_validator/utf8_validator.hpp @@ -39,7 +39,53 @@ decode(uint32_t* state, uint32_t* codep, uint32_t byte) { *state = utf8d[256 + *state*16 + type]; return *state; } + +class validator { +public: + validator() : m_state(UTF8_ACCEPT),m_codepoint(0) {} + + + + bool consume (uint32_t byte) { + if (utf8_validator::decode(&m_state,&m_codepoint,byte) == UTF8_REJECT) { + return false; + } + return true; + } + + template + bool decode (iterator_type b, iterator_type e) { + for (iterator_type i = b; i != e; i++) { + if (utf8_validator::decode(&m_state,&m_codepoint,*i) == UTF8_REJECT) { + return false; + } + } + return true; + } + + bool complete() { + return m_state == UTF8_ACCEPT; + } + + void reset() { + m_state = UTF8_ACCEPT; + m_codepoint = 0; + } +private: + uint32_t m_state; + uint32_t m_codepoint; +}; +// convenience function that creates a validator, validates a complete string +// and returns the result. +bool validate(const std::string& s) { + validator v; + if (!v.decode(s.begin(),s.end())) { + return false; + } + return v.complete(); +} + } // namespace utf8_validator #endif // UTF8_VALIDATOR_HPP \ No newline at end of file