From 9180e521037e9b37f189d64cbd225b967b677977 Mon Sep 17 00:00:00 2001 From: Peter Thorson Date: Fri, 7 Oct 2011 08:26:41 -0500 Subject: [PATCH] continues frame reading changes fixes a lot of underspecified close behavior. NOTE: this revision won't compile --- src/websocket_frame.cpp | 39 ++++++----- src/websocket_frame.hpp | 13 +++- src/websocket_session.cpp | 132 ++++++++++++++++++++++---------------- src/websocket_session.hpp | 20 +++--- 4 files changed, 122 insertions(+), 82 deletions(-) diff --git a/src/websocket_frame.cpp b/src/websocket_frame.cpp index 7fdd35cdcf..1496ea06a8 100644 --- a/src/websocket_frame.cpp +++ b/src/websocket_frame.cpp @@ -49,6 +49,7 @@ uint8_t frame::get_state() const { void frame::reset() { m_state = STATE_BASIC_HEADER; m_bytes_needed = BASIC_HEADER_LENGTH; + m_payload.empty(); } void frame::consume(std::istream &s) { @@ -78,6 +79,15 @@ void frame::consume(std::istream &s) { process_extended_header(); m_state = STATE_PAYLOAD; } + } else if (m_state == STATE_PAYLOAD) { + s.read(reinterpret_cast(&m_payload[0]),m_bytes_needed); + + m_bytes_needed -= s.gcount(); + + if (m_bytes_needed == 0) { + process_payload(); + m_state = STATE_READY; + } } } @@ -324,13 +334,10 @@ std::string frame::print_frame() const { } void frame::process_basic_header() { - m_payload.empty(); m_bytes_needed = get_header_len() - BASIC_HEADER_LENGTH; } void frame::process_extended_header() { - m_extended_header_bytes_needed = 0; - uint8_t s = get_basic_size(); uint64_t payload_size; int mask_index = BASIC_HEADER_LENGTH; @@ -345,7 +352,8 @@ void frame::process_extended_header() { )); if (payload_size < s) { - throw frame_error("payload length not minimally encoded"); + throw frame_error("payload length not minimally encoded", + FERR_PROTOCOL_VIOLATION); } mask_index += 2; @@ -357,7 +365,8 @@ void frame::process_extended_header() { )); if (payload_size <= PAYLOAD_16BIT_LIMIT) { - throw frame_error("payload length not minimally encoded"); + throw frame_error("payload length not minimally encoded", + FERR_PROTOCOL_VIOLATION); } mask_index += 8; @@ -378,9 +387,11 @@ void frame::process_extended_header() { } if (payload_size > max_payload_size) { + // TODO: frame/message size limits throw server_error("got frame with payload greater than maximum frame buffer size."); } m_payload.resize(payload_size); + m_bytes_needed = payload_size; } void frame::process_payload() { @@ -420,43 +431,39 @@ void frame::process_payload2() { } } -bool frame::validate_utf8(uint32_t* state,uint32_t* codep) const { +void frame::validate_utf8(uint32_t* state,uint32_t* codep) const { for (size_t i = 0; i < m_payload.size(); i++) { using utf8_validator::decode; - //std::cout << "decoding: " << std::hex << m_payload[i] << std::endl; if (decode(state,codep,m_payload[i]) == utf8_validator::UTF8_REJECT) { - // std::cout << "bad byte" << std::endl; - return false; + throw frame_error("Invalid UTF-8 Data",FERR_PAYLOAD_VIOLATION); } } - - return true; } void frame::validate_basic_header() const { // check for control frame size if (get_basic_size() > BASIC_PAYLOAD_LIMIT && is_control()) { - throw frame_error("Control Frame is too large"); + throw frame_error("Control Frame is too large",FERR_PROTOCOL_VIOLATION); } // check for reserved opcodes if (get_rsv1() || get_rsv2() || get_rsv3()) { - throw frame_error("Reserved bit used"); + throw frame_error("Reserved bit used",FERR_PROTOCOL_VIOLATION); } // check for reserved opcodes opcode op = get_opcode(); if (op > 0x02 && op < 0x08) { - throw frame_error("Reserved opcode used"); + throw frame_error("Reserved opcode used",FERR_PROTOCOL_VIOLATION); } if (op > 0x0A) { - throw frame_error("Reserved opcode used"); + throw frame_error("Reserved opcode used",FERR_PROTOCOL_VIOLATION); } // check for fragmented control message if (is_control() && !get_fin()) { - throw frame_error("Fragmented control message"); + throw frame_error("Fragmented control message",FERR_PROTOCOL_VIOLATION); } } diff --git a/src/websocket_frame.hpp b/src/websocket_frame.hpp index 0d4e4c6802..abd88ce30d 100644 --- a/src/websocket_frame.hpp +++ b/src/websocket_frame.hpp @@ -59,6 +59,11 @@ public: static const uint8_t STATE_PAYLOAD = 3; static const uint8_t STATE_READY = 4; + static const uint16_t FERR_FATAL_SESSION_ERROR = 0; // must end session + static const uint16_t FERR_SOFT_SESSION_ERROR = 1; // should log and ignore + static const uint16_t FERR_PROTOCOL_VIOLATION = 2; // must end session + static const uint16_t FERR_PAYLOAD_VIOLATION = 3; // should end session + // basic payload byte flags static const uint8_t BPB0_OPCODE = 0x0F; static const uint8_t BPB0_RSV3 = 0x10; @@ -87,6 +92,7 @@ public: } uint8_t get_state() const; + uint64_t get_bytes_needed() const; void reset(); void consume(std::istream &s); @@ -141,7 +147,7 @@ public: void process_payload(); void process_payload2(); // experiment with more efficient masking code. - bool validate_utf8(uint32_t* state,uint32_t* codep) const; + void validate_utf8(uint32_t* state,uint32_t* codep) const; void validate_basic_header() const; void generate_masking_key(); @@ -166,7 +172,9 @@ private: // Exception classes class frame_error : public std::exception { public: - frame_error(const std::string& msg) : m_msg(msg) {} + frame_error(const std::string& msg, + uint16_t code = frame::FERR_FATAL_SESSION_ERROR) + : m_msg(msg),m_code(code) {} ~frame_error() throw() {} virtual const char* what() const throw() { @@ -174,6 +182,7 @@ public: } std::string m_msg; + uint16_t m_code; }; } diff --git a/src/websocket_session.cpp b/src/websocket_session.cpp index 98131b2fad..e1f363377d 100644 --- a/src/websocket_session.cpp +++ b/src/websocket_session.cpp @@ -120,7 +120,7 @@ unsigned int session::get_version() const { } void session::send(const std::string &msg) { - if (m_status != OPEN) { + if (m_state != STATE_OPEN) { log("Tried to send a message from a session that wasn't open",LOG_WARN); return; } @@ -132,7 +132,7 @@ void session::send(const std::string &msg) { } void session::send(const std::vector &data) { - if (m_status != OPEN) { + if (m_state != STATE_OPEN) { log("Tried to send a message from a session that wasn't open",LOG_WARN); return; } @@ -143,19 +143,28 @@ void session::send(const std::vector &data) { write_frame(); } +// end user interface to close the connection void session::close(uint16_t status,const std::string& msg) { + validate_app_close_status(status); + disconnect(status,msg); // TODO: close behavior } // TODO: clean this up, needs to be broken out into more specific methods + +// This method initiates a clean disconnect with the given status code and reason +// it logs an error and is ignored if it is called from a state other than OPEN + +// called by process_close when an initiate close method is received. + void session::disconnect(uint16_t status,const std::string &message) { - if (m_status != OPEN) { + if (m_state != STATE_OPEN) { log("Tried to disconnect a session that wasn't open",LOG_WARN); return; } - m_status = CLOSING; + m_state = STATE_CLOSING; m_close_code = status; m_close_message = message; @@ -179,7 +188,7 @@ void session::disconnect(uint16_t status,const std::string &message) { } void session::ping(const std::string &msg) { - if (m_status != OPEN) { + if (m_state != STATE_OPEN) { log("Tried to send a ping from a session that wasn't open",LOG_WARN); return; } @@ -191,7 +200,7 @@ void session::ping(const std::string &msg) { } void session::pong(const std::string &msg) { - if (m_status != OPEN) { + if (m_state != STATE_OPEN) { log("Tried to send a pong from a session that wasn't open",LOG_WARN); return; } @@ -203,16 +212,28 @@ void session::pong(const std::string &msg) { } void session::handle_read_frame(const boost::system::error_code& error) { - // while - // if there are enough bytes to do something: - // do something - // read more - + if (error) { + handle_error("Error reading extended frame header",error); + // TODO: close behavior + return; + } + + if (m_state != STATE_OPEN && m_state != STATE_CLOSING) { + // stop processing frames. + log("handle_read_frame called in invalid state",LOG_ERROR); + return; + } + std::istream s(&m_buf); try { while (m_buf.size() > 0) { m_read_frame.consume(s); + if (m_read_frame.get_state() == frame::STATE_READY) { + process_frame(); + + // should we break? + } } } catch (const frame_error& e) { std::stringstream err; @@ -221,15 +242,24 @@ void session::handle_read_frame(const boost::system::error_code& error) { access_log(e.what(),ALOG_FRAME); log(err.str(),LOG_ERROR); + disconnect(CLOSE_STATUS_PROTOCOL_ERROR,""); + // TODO: close behavior return; } // we have read everything, check if we should read more + if (m_state != STATE_OPEN && m_state != STATE_CLOSING) { + // stop processing frames. + log("handle_read_frame called in invalid state"); + return; + } + + // read more boost::asio::async_read( m_socket, m_buf, - boost::asio::transfer_at_least(1), + boost::asio::transfer_at_least(m_read_frame.get_bytes_needed()), boost::bind( &session::handle_read_frame, shared_from_this(), @@ -313,16 +343,8 @@ void session::read_payload() { ); } -void session::handle_read_payload (const boost::system::error_code& error) { - if (error) { - handle_error("Error reading payload data frame header",error); - // TODO: close behavior - return; - } - - m_read_frame.process_payload(); - - if (m_status == OPEN) { +void session::process_frame () { + if (m_state == STATE_OPEN) { switch (m_read_frame.get_opcode()) { case frame::CONTINUATION_FRAME: process_continuation(); @@ -347,7 +369,7 @@ void session::handle_read_payload (const boost::system::error_code& error) { // TODO: close behavior break; } - } else if (m_status == CLOSING) { + } else if (m_state == STATE_CLOSING) { if (m_read_frame.get_opcode() == frame::CONNECTION_CLOSE) { process_close(); } else { @@ -377,8 +399,8 @@ void session::handle_read_payload (const boost::system::error_code& error) { // TODO: close behavior return; } - - this->read_frame(); + + m_read_frame.reset(); } void session::handle_write_frame (const boost::system::error_code& error) { @@ -408,22 +430,17 @@ void session::process_pong() { } void session::process_text() { - if (!m_read_frame.validate_utf8(&m_utf8_state,&m_utf8_codepoint)) { - disconnect(CLOSE_STATUS_INVALID_PAYLOAD,"Invalid UTF8 Data"); - // TODO: close behavior - return; - } - + // this will throw an exception if validation fails at any point + m_read_frame.validate_utf8(&m_utf8_state,&m_utf8_codepoint); + + // otherwise, treat as binary process_binary(); } void session::process_binary() { if (m_fragmented) { - handle_error("Got a new message before the previous was finished.", - boost::system::error_code()); - disconnect(CLOSE_STATUS_PROTOCOL_ERROR,""); - // TODO: close behavior - return; + throw frame_error("Got a new message before the previous was finished.", + frame::FERR_PROTOCOL_VIOLATION); } m_current_opcode = m_read_frame.get_opcode(); @@ -439,19 +456,13 @@ void session::process_binary() { void session::process_continuation() { if (!m_fragmented) { - handle_error("Got a continuation frame without an outstanding message.", - boost::system::error_code()); - disconnect(CLOSE_STATUS_PROTOCOL_ERROR,""); - // TODO: close behavior - return; + throw frame_error("Got a continuation frame without an outstanding message.", + frame::FERR_PROTOCOL_VIOLATION); } if (m_current_opcode == frame::TEXT_FRAME) { - if (!m_read_frame.validate_utf8(&m_utf8_state,&m_utf8_codepoint)) { - disconnect(CLOSE_STATUS_INVALID_PAYLOAD,"Invalid UTF8 Data"); - // TODO: close behavior - return; - } + // this will throw an exception if validation fails at any point + m_read_frame.validate_utf8(&m_utf8_state,&m_utf8_codepoint); } extract_payload(); @@ -470,17 +481,16 @@ void session::process_close() { m_remote_close_code = status; m_remote_close_msg = message; - - if (m_status == OPEN) { + if (m_state == STATE_OPEN) { // This is the case where the remote initiated the close. m_closed_by_me = false; // TODO: close behavior disconnect(status,message); - } else if (m_status == CLOSING) { + } else if (m_state == STATE_CLOSING) { // this is an ack of our close message m_closed_by_me = true; } else { - throw "fixme"; + throw frame_error("process_closed called from wrong state"); } m_was_clean = true; @@ -503,10 +513,11 @@ void session::deliver_message() { std::string msg; // make sure the finished frame is valid utf8 + // the streaming validator checks for bad codepoints as it goes. It + // doesn't know where the end of the message is though, so we need to + // check here to make sure the final message ends on a valid codepoint. if (m_utf8_state != utf8_validator::UTF8_ACCEPT) { - disconnect(CLOSE_STATUS_INVALID_PAYLOAD,"Invalid UTF8 Data"); - // TODO: close behavior - return; + throw frame_error("Invalid UTF-8 Data",FERR_PAYLOAD_VIOLATION); } if (m_fragmented) { @@ -523,7 +534,7 @@ void session::deliver_message() { // Not sure if this should be a fatal error or not std::stringstream err; err << "Attempted to deliver a message of unsupported opcode " << m_current_opcode; - log(err.str(),LOG_ERROR); + throw frame_error(err.str(),frame::FERR_SOFT_SESSION_ERROR); } } @@ -614,3 +625,16 @@ void session::handle_error(std::string msg, m_error = true; } + +// validates status codes that the end application is allowed to use +bool session::validate_app_close_status(uint16_t status) { + if (status == CLOSE_STATUS_NORMAL) { + return true; + } + + if (status >= 4000 && status < 5000) { + return true; + } + + return false; +} diff --git a/src/websocket_session.hpp b/src/websocket_session.hpp index c9d270a265..6c1ad3ced3 100644 --- a/src/websocket_session.hpp +++ b/src/websocket_session.hpp @@ -73,14 +73,10 @@ class session : public boost::enable_shared_from_this { public: friend class handshake_error; - enum ws_status { - CONNECTING, - OPEN, - CLOSING, - CLOSED - }; - - typedef enum ws_status status_code; + static const uint8_t STATE_CONNECTING = 0; + static const uint8_t STATE_OPEN = 1; + static const uint8_t STATE_CLOSING = 2; + static const uint8_t STATE_CLOSED = 3; static const uint16_t CLOSE_STATUS_NORMAL = 1000; static const uint16_t CLOSE_STATUS_GOING_AWAY = 1001; @@ -160,13 +156,14 @@ protected: void handle_frame_header(const boost::system::error_code& error); void handle_extended_frame_header(const boost::system::error_code& error); void read_payload(); - void handle_read_payload (const boost::system::error_code& error); + void handle_read_frame (const boost::system::error_code& error); // write m_write_frame out to the socket. void write_frame(); void handle_write_frame (const boost::system::error_code& error); // helper functions for processing each opcode + void process_frame(); void process_ping(); void process_pong(); void process_text(); @@ -194,6 +191,9 @@ protected: // prints a diagnostic message and disconnects the local interface void handle_error(std::string msg,const boost::system::error_code& error); + + // misc helpers + bool validate_app_close_status(uint16_t status); private: std::string get_header(const std::string& key, const header_list& list) const; @@ -220,7 +220,7 @@ protected: std::string m_server_http_string; // Mutable connection state; - status_code m_status; + uint8_t m_state; uint16_t m_close_code; std::string m_close_message;