diff --git a/src/websocket_frame.cpp b/src/websocket_frame.cpp index 6be794f501..7fdd35cdcf 100644 --- a/src/websocket_frame.cpp +++ b/src/websocket_frame.cpp @@ -41,6 +41,46 @@ using websocketpp::frame; + +uint8_t frame::get_state() const { + return m_state; +} + +void frame::reset() { + m_state = STATE_BASIC_HEADER; + m_bytes_needed = BASIC_HEADER_LENGTH; +} + +void frame::consume(std::istream &s) { + if (m_state == STATE_BASIC_HEADER) { + s.read(&m_header[BASIC_HEADER_LENGTH-m_bytes_needed],m_bytes_needed); + + m_bytes_needed -= s.gcount(); + + if (m_bytes_needed == 0) { + process_basic_header(); + + // basic header validation + validate_basic_header(); + + if (m_bytes_needed > 0) { + m_state = STATE_EXTENDED_HEADER; + } else { + m_state = STATE_PAYLOAD; + } + } + } else if (m_state == STATE_EXTENDED_HEADER) { + s.read(&m_header[BASIC_HEADER_LENGTH+get_header_len()-m_bytes_needed],m_bytes_needed); + + m_bytes_needed -= s.gcount(); + + if (m_bytes_needed == 0) { + process_extended_header(); + m_state = STATE_PAYLOAD; + } + } +} + char* frame::get_header() { return m_header; } @@ -283,13 +323,9 @@ std::string frame::print_frame() const { return f.str(); } -unsigned int frame::process_basic_header() { - m_extended_header_bytes_needed = 0; +void frame::process_basic_header() { m_payload.empty(); - - m_extended_header_bytes_needed = get_header_len() - BASIC_HEADER_LENGTH; - - return m_extended_header_bytes_needed; + m_bytes_needed = get_header_len() - BASIC_HEADER_LENGTH; } void frame::process_extended_header() { @@ -308,22 +344,26 @@ void frame::process_extended_header() { reinterpret_cast(&m_header[BASIC_HEADER_LENGTH]) )); + if (payload_size < s) { + throw frame_error("payload length not minimally encoded"); + } + mask_index += 2; } else if (s == BASIC_PAYLOAD_64BIT_CODE) { - // reinterpret the second eight bytes as a 16 bit integer in + // reinterpret the second eight bytes as a 64 bit integer in // network byte order. Convert to host byte order and store. payload_size = ntohll(*( reinterpret_cast(&m_header[BASIC_HEADER_LENGTH]) )); + if (payload_size <= PAYLOAD_16BIT_LIMIT) { + throw frame_error("payload length not minimally encoded"); + } + mask_index += 8; } else { // shouldn't be here - throw server_error("invalid get_basic_size in process_extended_header"); - } - - if (payload_size < s) { - throw server_error("payload size error"); + throw frame_error("invalid get_basic_size in process_extended_header"); } if (get_masked() == 0) { @@ -394,32 +434,30 @@ bool frame::validate_utf8(uint32_t* state,uint32_t* codep) const { return true; } -bool frame::validate_basic_header() const { +void frame::validate_basic_header() const { // check for control frame size if (get_basic_size() > BASIC_PAYLOAD_LIMIT && is_control()) { - return false; + throw frame_error("Control Frame is too large"); } // check for reserved opcodes if (get_rsv1() || get_rsv2() || get_rsv3()) { - return false; + throw frame_error("Reserved bit used"); } // check for reserved opcodes opcode op = get_opcode(); if (op > 0x02 && op < 0x08) { - return false; + throw frame_error("Reserved opcode used"); } if (op > 0x0A) { - return false; + throw frame_error("Reserved opcode used"); } // check for fragmented control message if (is_control() && !get_fin()) { - return false; + throw frame_error("Fragmented control message"); } - - return true; } void frame::generate_masking_key() { diff --git a/src/websocket_frame.hpp b/src/websocket_frame.hpp index 33436306a5..0d4e4c6802 100644 --- a/src/websocket_frame.hpp +++ b/src/websocket_frame.hpp @@ -54,6 +54,11 @@ public: static const uint8_t MAX_FRAME_OPCODE = 0x07; + static const uint8_t STATE_BASIC_HEADER = 1; + static const uint8_t STATE_EXTENDED_HEADER = 2; + static const uint8_t STATE_PAYLOAD = 3; + static const uint8_t STATE_READY = 4; + // basic payload byte flags static const uint8_t BPB0_OPCODE = 0x0F; static const uint8_t BPB0_RSV3 = 0x10; @@ -81,6 +86,11 @@ public: memset(m_header,0,MAX_HEADER_LENGTH); } + uint8_t get_state() const; + void reset(); + + void consume(std::istream &s); + // get pointers to underlying buffers char* get_header(); char* get_extended_header(); @@ -126,18 +136,21 @@ public: std::string print_frame() const; // reads basic header, sets and returns m_header_bits_needed - unsigned int process_basic_header(); + void process_basic_header(); void process_extended_header(); void process_payload(); void process_payload2(); // experiment with more efficient masking code. bool validate_utf8(uint32_t* state,uint32_t* codep) const; - bool validate_basic_header() const; + void validate_basic_header() const; void generate_masking_key(); void clear_masking_key(); private: + uint8_t m_state; + uint64_t m_bytes_needed; + char m_header[MAX_HEADER_LENGTH]; std::vector m_payload; @@ -150,6 +163,19 @@ private: m_gen; }; +// Exception classes +class frame_error : public std::exception { +public: + frame_error(const std::string& msg) : m_msg(msg) {} + ~frame_error() throw() {} + + virtual const char* what() const throw() { + return m_msg.c_str(); + } + + std::string m_msg; +}; + } #endif // WEBSOCKET_FRAME_HPP diff --git a/src/websocket_session.cpp b/src/websocket_session.cpp index daffeae8a9..98131b2fad 100644 --- a/src/websocket_session.cpp +++ b/src/websocket_session.cpp @@ -202,6 +202,42 @@ void session::pong(const std::string &msg) { write_frame(); } +void session::handle_read_frame(const boost::system::error_code& error) { + // while + // if there are enough bytes to do something: + // do something + // read more + + std::istream s(&m_buf); + + try { + while (m_buf.size() > 0) { + m_read_frame.consume(s); + } + } catch (const frame_error& e) { + std::stringstream err; + err << "Caught frame exception: " << e.what(); + + access_log(e.what(),ALOG_FRAME); + log(err.str(),LOG_ERROR); + + // TODO: close behavior + return; + } + // we have read everything, check if we should read more + + boost::asio::async_read( + m_socket, + m_buf, + boost::asio::transfer_at_least(1), + boost::bind( + &session::handle_read_frame, + shared_from_this(), + boost::asio::placeholders::error + ) + ); +} + void session::read_frame() { boost::asio::async_read( m_socket,