From 86da9f503c0ff7179cee853a7ec2d1a5a4b30238 Mon Sep 17 00:00:00 2001 From: Peter Thorson Date: Fri, 28 Oct 2011 17:09:36 -0500 Subject: [PATCH] started policy-refactor branch --- src/policy/rng/blank_rng.cpp | 38 ++ src/policy/rng/blank_rng.hpp | 43 ++ src/policy/rng/boost_rng.cpp | 39 ++ src/policy/rng/boost_rng.hpp | 52 ++ src/websocket_constants.hpp | 156 ++++++ src/websocket_frame.cpp | 559 -------------------- src/websocket_frame.hpp | 614 ++++++++++++++++++--- src/websocket_server.cpp | 174 ------ src/websocket_server.hpp | 193 ++++++- src/websocket_session.cpp | 712 ------------------------- src/websocket_session.hpp | 734 +++++++++++++++++++++++--- src/websocketpp.hpp | 65 +-- websocketpp.xcodeproj/project.pbxproj | 30 ++ 13 files changed, 1726 insertions(+), 1683 deletions(-) create mode 100644 src/policy/rng/blank_rng.cpp create mode 100644 src/policy/rng/blank_rng.hpp create mode 100644 src/policy/rng/boost_rng.cpp create mode 100644 src/policy/rng/boost_rng.hpp create mode 100644 src/websocket_constants.hpp diff --git a/src/policy/rng/blank_rng.cpp b/src/policy/rng/blank_rng.cpp new file mode 100644 index 0000000000..dcfcb81971 --- /dev/null +++ b/src/policy/rng/blank_rng.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2011, Peter Thorson. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the WebSocket++ Project nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL PETER THORSON BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * This Makefile was derived from a similar one included in the libjson project + * It's authors were Jonathan Wallace and Bernhard Fluehmann. + */ + +#include "blank_rng.hpp" + +using websocketpp::blank_rng; + +blank_rng::blank_rng() {} + +int32_t blank_rng::gen() { + throw "Random Number generation not supported"; +} \ No newline at end of file diff --git a/src/policy/rng/blank_rng.hpp b/src/policy/rng/blank_rng.hpp new file mode 100644 index 0000000000..5d06283036 --- /dev/null +++ b/src/policy/rng/blank_rng.hpp @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2011, Peter Thorson. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the WebSocket++ Project nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL PETER THORSON BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * This Makefile was derived from a similar one included in the libjson project + * It's authors were Jonathan Wallace and Bernhard Fluehmann. + */ + +#ifndef BLANK_RNG_HPP +#define BLANK_RNG_HPP + +#include + +namespace websocketpp { + class blank_rng { + public: + blank_rng(); + int32_t gen(); + } +} + +#endif // BLANK_RNG_HPP diff --git a/src/policy/rng/boost_rng.cpp b/src/policy/rng/boost_rng.cpp new file mode 100644 index 0000000000..a20212dca5 --- /dev/null +++ b/src/policy/rng/boost_rng.cpp @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2011, Peter Thorson. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the WebSocket++ Project nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL PETER THORSON BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * This Makefile was derived from a similar one included in the libjson project + * It's authors were Jonathan Wallace and Bernhard Fluehmann. + */ + +#include "boost_rng.hpp" + +using websocketpp::boost_rng; + +boost_rng::boost_rng() : m_gen(m_rng, + boost::random::uniform_int_distribution<>(INT32_MIN,INT32_MAX)); {} + +int32_t boost_rng::gen() { + return m_gen(); +} \ No newline at end of file diff --git a/src/policy/rng/boost_rng.hpp b/src/policy/rng/boost_rng.hpp new file mode 100644 index 0000000000..87d368a48d --- /dev/null +++ b/src/policy/rng/boost_rng.hpp @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2011, Peter Thorson. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the WebSocket++ Project nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL PETER THORSON BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * This Makefile was derived from a similar one included in the libjson project + * It's authors were Jonathan Wallace and Bernhard Fluehmann. + */ + +#ifndef BOOST_RNG_HPP +#define BOOST_RNG_HPP + +#include +#include +#include + +namespace websocketpp { + class boost_rng { + public: + boost_rng(); + int32_t gen(); + private: + boost::random::random_device m_rng; + boost::random::variate_generator + < + boost::random::random_device&, + boost::random::uniform_int_distribution<> + > m_gen; + } +} + +#endif // BOOST_RNG_HPP diff --git a/src/websocket_constants.hpp b/src/websocket_constants.hpp new file mode 100644 index 0000000000..84f57717b1 --- /dev/null +++ b/src/websocket_constants.hpp @@ -0,0 +1,156 @@ +/* + * Copyright (c) 2011, Peter Thorson. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the WebSocket++ Project nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL PETER THORSON BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * This Makefile was derived from a similar one included in the libjson project + * It's authors were Jonathan Wallace and Bernhard Fluehmann. + */ + +#ifndef WEBSOCKET_CONSTANTS_HPP +#define WEBSOCKET_CONSTANTS_HPP + +#include + +// Defaults +namespace websocketpp { + const uint64_t DEFAULT_MAX_MESSAGE_SIZE = 0xFFFFFF; // ~16MB + + // System logging levels + static const uint16_t LOG_ALL = 0; + static const uint16_t LOG_DEBUG = 1; + static const uint16_t LOG_INFO = 2; + static const uint16_t LOG_WARN = 3; + static const uint16_t LOG_ERROR = 4; + static const uint16_t LOG_FATAL = 5; + static const uint16_t LOG_OFF = 6; + + // Access logging controls + // Individual bits + static const uint16_t ALOG_CONNECT = 0x1; + static const uint16_t ALOG_DISCONNECT = 0x2; + static const uint16_t ALOG_MISC_CONTROL = 0x4; + static const uint16_t ALOG_FRAME = 0x8; + static const uint16_t ALOG_MESSAGE = 0x10; + static const uint16_t ALOG_INFO = 0x20; + static const uint16_t ALOG_HANDSHAKE = 0x40; + // Useful groups + static const uint16_t ALOG_OFF = 0x0; + static const uint16_t ALOG_CONTROL = ALOG_CONNECT + & ALOG_DISCONNECT + & ALOG_MISC_CONTROL; + static const uint16_t ALOG_ALL = 0xFFFF; + + + namespace close { + namespace status { + enum value { + INVALID_END = 999, + NORMAL = 1000, + GOING_AWAY = 1001, + PROTOCOL_ERROR = 1002, + UNSUPPORTED_DATA = 1003, + RSV_ADHOC_1 = 1004, + NO_STATUS = 1005, + ABNORMAL_CLOSE = 1006, + INVALID_PAYLOAD = 1007, + POLICY_VIOLATION = 1008, + MESSAGE_TOO_BIG = 1009, + EXTENSION_REQUIRE = 1010, + RSV_START = 1011, + RSV_END = 2999, + INVALID_START = 5000 + }; + + inline bool reserved(value s) { + return ((s >= RSV_START && s <= RSV_END) || + s == RSV_ADHOC_1); + } + + inline bool invalid(value s) { + return ((s <= INVALID_END || s >= INVALID_START) || + s == NO_STATUS || + s == ABNORMAL_CLOSE); + } + + // TODO functions for application ranges? + } + } + + namespace frame { + namespace error { + enum value { + FATAL_SESSION_ERROR = 0, // force session end + SOFT_SESSION_ERROR = 1, // should log and ignore + PROTOCOL_VIOLATION = 2, // must end session + PAYLOAD_VIOLATION = 3, // should end session + INTERNAL_SERVER_ERROR = 4, // cleanly end session + MSG_TOO_BIG = 5 // ??? + }; + } + + // Opcodes are 4 bits + // See spec section 5.2 + namespace opcode { + enum value { + CONTINUATION = 0x0, + TEXT = 0x1, + BINARY = 0x2, + RSV3 = 0x3, + RSV4 = 0x4, + RSV5 = 0x5, + RSV6 = 0x6, + RSV7 = 0x7, + CLOSE = 0x8, + PING = 0x9, + PONG = 0xA, + CONTROL_RSVB = 0xB, + CONTROL_RSVC = 0xC, + CONTROL_RSVD = 0xD, + CONTROL_RSVE = 0xE, + CONTROL_RSVF = 0xF, + }; + + inline bool reserved(value v) { + return (v >= RSV3 && v <= RSV7) || + (v >= CONTROL_RSVB && v <= CONTROL_RSVF); + } + + inline bool invalid(value v) { + return (v > 0xF || v < 0); + } + + inline bool is_control(value v) { + return v >= 0x8; + } + } + + namespace limits { + static const uint8_t PAYLOAD_SIZE_BASIC = 125; + static const uint16_t PAYLOAD_SIZE_EXTENDED = 0xFFFF; // 2^16, 65535 + static const uint64_t PAYLOAD_SIZE_JUMBO = 0x7FFFFFFFFFFFFFFF;//2^63 + } + } +} + +#endif // WEBSOCKET_CONSTANTS_HPP diff --git a/src/websocket_frame.cpp b/src/websocket_frame.cpp index c630f05007..d1001affb6 100644 --- a/src/websocket_frame.cpp +++ b/src/websocket_frame.cpp @@ -25,562 +25,3 @@ * */ -#include "websocket_frame.hpp" - -#include "websocket_server.hpp" -#include "utf8_validator/utf8_validator.hpp" - -#include -#include - -#if defined(WIN32) -#include -#else -#include -#endif - -using websocketpp::frame; - - -uint8_t frame::get_state() const { - return m_state; -} - -uint64_t frame::get_bytes_needed() const { - return m_bytes_needed; -} - -void frame::reset() { - m_state = STATE_BASIC_HEADER; - m_bytes_needed = BASIC_HEADER_LENGTH; - m_degraded = false; - m_payload.empty(); - memset(m_header,0,MAX_HEADER_LENGTH); -} - -// Method invariant: One of the following must always be true even in the case -// of exceptions. -// - m_bytes_needed > 0 -// - m-state = STATE_READY -void frame::consume(std::istream &s) { - try { - switch (m_state) { - case STATE_BASIC_HEADER: - s.read(&m_header[BASIC_HEADER_LENGTH-m_bytes_needed],m_bytes_needed); - - m_bytes_needed -= s.gcount(); - - if (m_bytes_needed == 0) { - process_basic_header(); - - validate_basic_header(); - - if (m_bytes_needed > 0) { - m_state = STATE_EXTENDED_HEADER; - } else { - process_extended_header(); - - if (m_bytes_needed == 0) { - m_state = STATE_READY; - process_payload(); - - } else { - m_state = STATE_PAYLOAD; - } - } - } - break; - case STATE_EXTENDED_HEADER: - s.read(&m_header[get_header_len()-m_bytes_needed],m_bytes_needed); - - m_bytes_needed -= s.gcount(); - - if (m_bytes_needed == 0) { - process_extended_header(); - if (m_bytes_needed == 0) { - m_state = STATE_READY; - process_payload(); - } else { - m_state = STATE_PAYLOAD; - } - } - break; - case STATE_PAYLOAD: - s.read(reinterpret_cast(&m_payload[m_payload.size()-m_bytes_needed]), - m_bytes_needed); - - m_bytes_needed -= s.gcount(); - - if (m_bytes_needed == 0) { - m_state = STATE_READY; - process_payload(); - } - break; - case STATE_RECOVERY: - // Recovery state discards all bytes that are not the first byte - // of a close frame. - do { - s.read(reinterpret_cast(&m_header[0]),1); - - //std::cout << std::hex << int(static_cast(m_header[0])) << " "; - - if (int(static_cast(m_header[0])) == 0x88) { - //(BPB0_FIN && CONNECTION_CLOSE) - m_bytes_needed--; - m_state = STATE_BASIC_HEADER; - break; - } - } while (s.gcount() > 0); - - //std::cout << std::endl; - - break; - default: - break; - } - - /*if (s.gcount() == 0) { - throw frame_error("consume read zero bytes",FERR_FATAL_SESSION_ERROR); - }*/ - } catch (const frame_error& e) { - // After this point all non-close frames must be considered garbage, - // including the current one. Reset it and put the reading frame into - // a recovery state. - if (m_degraded == true) { - throw frame_error("An error occurred while trying to gracefully recover from a less serious frame error.",FERR_FATAL_SESSION_ERROR); - } else { - reset(); - m_state = STATE_RECOVERY; - m_degraded = true; - - throw e; - } - } -} - -char* frame::get_header() { - return m_header; -} - -char* frame::get_extended_header() { - return m_header+BASIC_HEADER_LENGTH; -} - -unsigned int frame::get_header_len() const { - unsigned int temp = 2; - - if (get_masked()) { - temp += 4; - } - - if (get_basic_size() == 126) { - temp += 2; - } else if (get_basic_size() == 127) { - temp += 8; - } - - return temp; -} - -char* frame::get_masking_key() { - if (m_state != STATE_READY) { - throw frame_error("attempted to get masking_key before reading full header"); - } - return m_masking_key; -} - -// get and set header bits -bool frame::get_fin() const { - return ((m_header[0] & BPB0_FIN) == BPB0_FIN); -} - -void frame::set_fin(bool fin) { - if (fin) { - m_header[0] |= BPB0_FIN; - } else { - m_header[0] &= (0xFF ^ BPB0_FIN); - } -} - -// get and set reserved bits -bool frame::get_rsv1() const { - return ((m_header[0] & BPB0_RSV1) == BPB0_RSV1); -} - -void frame::set_rsv1(bool b) { - if (b) { - m_header[0] |= BPB0_RSV1; - } else { - m_header[0] &= (0xFF ^ BPB0_RSV1); - } -} - -bool frame::get_rsv2() const { - return ((m_header[0] & BPB0_RSV2) == BPB0_RSV2); -} - -void frame::set_rsv2(bool b) { - if (b) { - m_header[0] |= BPB0_RSV2; - } else { - m_header[0] &= (0xFF ^ BPB0_RSV2); - } -} - -bool frame::get_rsv3() const { - return ((m_header[0] & BPB0_RSV3) == BPB0_RSV3); -} - -void frame::set_rsv3(bool b) { - if (b) { - m_header[0] |= BPB0_RSV3; - } else { - m_header[0] &= (0xFF ^ BPB0_RSV3); - } -} - -frame::opcode frame::get_opcode() const { - return frame::opcode(m_header[0] & BPB0_OPCODE); -} - -void frame::set_opcode(frame::opcode op) { - if (op > 0x0F) { - throw frame_error("invalid opcode",FERR_PROTOCOL_VIOLATION); - } - - if (get_basic_size() > BASIC_PAYLOAD_LIMIT && - is_control()) { - throw frame_error("control frames can't have large payloads",FERR_PROTOCOL_VIOLATION); - } - - m_header[0] &= (0xFF ^ BPB0_OPCODE); // clear op bits - m_header[0] |= op; // set op bits -} - -bool frame::get_masked() const { - return ((m_header[1] & BPB1_MASK) == BPB1_MASK); -} - -void frame::set_masked(bool masked) { - if (masked) { - m_header[1] |= BPB1_MASK; - generate_masking_key(); - } else { - m_header[1] &= (0xFF ^ BPB1_MASK); - clear_masking_key(); - } -} - -uint8_t frame::get_basic_size() const { - return m_header[1] & BPB1_PAYLOAD; -} - -size_t frame::get_payload_size() const { - if (m_state != STATE_READY && m_state != STATE_PAYLOAD) { - // problem - throw frame_error("attempted to get payload size before reading full header"); - } - - return m_payload.size(); -} - -uint16_t frame::get_close_status() const { - if (get_payload_size() == 0) { - return close::status::NO_STATUS; - } else if (get_payload_size() >= 2) { - char val[2]; - - val[0] = m_payload[0]; - val[1] = m_payload[1]; - - uint16_t code = ntohs(*( - reinterpret_cast(&val[0]) - )); - - return code; - } else { - return close::status::PROTOCOL_ERROR; - } -} - -std::string frame::get_close_msg() const { - if (get_payload_size() > 2) { - uint32_t state = utf8_validator::UTF8_ACCEPT; - uint32_t codep = 0; - validate_utf8(&state,&codep,2); - if (state != utf8_validator::UTF8_ACCEPT) { - throw frame_error("Invalid UTF-8 Data", - frame::FERR_PAYLOAD_VIOLATION); - } - return std::string(m_payload.begin()+2,m_payload.end()); - } else { - return std::string(); - } -} - -std::vector &frame::get_payload() { - return m_payload; -} - -void frame::set_payload(const std::vector source) { - set_payload_helper(source.size()); - - std::copy(source.begin(),source.end(),m_payload.begin()); -} -void frame::set_payload(const std::string source) { - set_payload_helper(source.size()); - - std::copy(source.begin(),source.end(),m_payload.begin()); -} - -bool frame::is_control() const { - return (get_opcode() > MAX_FRAME_OPCODE); -} - -void frame::set_payload_helper(size_t s) { - if (s > max_payload_size) { - throw frame_error("requested payload is over implimentation defined limit",FERR_MSG_TOO_BIG); - } - - // limits imposed by the websocket spec - if (s > BASIC_PAYLOAD_LIMIT && - get_opcode() > MAX_FRAME_OPCODE) { - throw frame_error("control frames can't have large payloads",FERR_PROTOCOL_VIOLATION); - } - - if (s <= BASIC_PAYLOAD_LIMIT) { - m_header[1] = s; - } else if (s <= PAYLOAD_16BIT_LIMIT) { - m_header[1] = BASIC_PAYLOAD_16BIT_CODE; - - // this reinterprets the second pair of bytes in m_header as a - // 16 bit int and writes the payload size there as an integer - // in network byte order - *reinterpret_cast(&m_header[BASIC_HEADER_LENGTH]) = htons(s); - } else if (s <= PAYLOAD_64BIT_LIMIT) { - m_header[1] = BASIC_PAYLOAD_64BIT_CODE; - *reinterpret_cast(&m_header[BASIC_HEADER_LENGTH]) = htonll(s); - } else { - throw frame_error("payload size limit is 63 bits",FERR_PROTOCOL_VIOLATION); - } - - m_payload.resize(s); -} - -void frame::set_status(uint16_t status,const std::string message) { - // check for valid statuses - if (close::status::invalid(status)) { - std::stringstream err; - err << "Status code " << status << " is invalid"; - throw frame_error(err.str()); - } - - if (close::status::reserved(status)) { - std::stringstream err; - err << "Status code " << status << " is reserved"; - throw frame_error(err.str()); - } - - m_payload.resize(2+message.size()); - - char val[2]; - - *reinterpret_cast(&val[0]) = htons(status); - - m_header[1] = message.size()+2; - - m_payload[0] = val[0]; - m_payload[1] = val[1]; - - std::copy(message.begin(),message.end(),m_payload.begin()+2); -} - -std::string frame::print_frame() const { - std::stringstream f; - - unsigned int len = get_header_len(); - - f << "frame: "; - // print header - for (unsigned int i = 0; i < len; i++) { - f << std::hex << (unsigned short)m_header[i] << " "; - } - // print message - if (m_payload.size() > 50) { - f << "[payload of " << m_payload.size() << " bytes]"; - } else { - std::vector::const_iterator it; - for (it = m_payload.begin(); it != m_payload.end(); it++) { - f << *it; - } - } - return f.str(); -} - -void frame::process_basic_header() { - m_bytes_needed = get_header_len() - BASIC_HEADER_LENGTH; -} - -void frame::process_extended_header() { - uint8_t s = get_basic_size(); - uint64_t payload_size; - int mask_index = BASIC_HEADER_LENGTH; - - if (s <= BASIC_PAYLOAD_LIMIT) { - payload_size = s; - } else if (s == BASIC_PAYLOAD_16BIT_CODE) { - // reinterpret the second two bytes as a 16 bit integer in network - // byte order. Convert to host byte order and store locally. - payload_size = ntohs(*( - reinterpret_cast(&m_header[BASIC_HEADER_LENGTH]) - )); - - if (payload_size < s) { - std::stringstream err; - err << "payload length not minimally encoded. Using 16 bit form for payload size: " << payload_size; - m_bytes_needed = payload_size; - throw frame_error(err.str(), - FERR_PROTOCOL_VIOLATION); - } - - mask_index += 2; - } else if (s == BASIC_PAYLOAD_64BIT_CODE) { - // reinterpret the second eight bytes as a 64 bit integer in - // network byte order. Convert to host byte order and store. - payload_size = ntohll(*( - reinterpret_cast(&m_header[BASIC_HEADER_LENGTH]) - )); - - if (payload_size <= PAYLOAD_16BIT_LIMIT) { - m_bytes_needed = payload_size; - throw frame_error("payload length not minimally encoded", - FERR_PROTOCOL_VIOLATION); - } - - mask_index += 8; - } else { - // shouldn't be here - throw frame_error("invalid get_basic_size in process_extended_header"); - } - - if (get_masked() == 0) { - clear_masking_key(); - } else { - // TODO: use this copy line (needs testing) - // std::copy(m_header[mask_index],m_header[mask_index+4],m_masking_key); - m_masking_key[0] = m_header[mask_index+0]; - m_masking_key[1] = m_header[mask_index+1]; - m_masking_key[2] = m_header[mask_index+2]; - m_masking_key[3] = m_header[mask_index+3]; - } - - if (payload_size > max_payload_size) { - // TODO: frame/message size limits - throw server_error("got frame with payload greater than maximum frame buffer size."); - } - m_payload.resize(payload_size); - m_bytes_needed = payload_size; -} - -void frame::process_payload() { - if (get_masked()) { - char *masking_key = &m_header[get_header_len()-4]; - - for (uint64_t i = 0; i < m_payload.size(); i++) { - m_payload[i] = (m_payload[i] ^ masking_key[i%4]); - } - } -} - -void frame::process_payload2() { - // unmask payload one byte at a time - - //uint64_t key = (*((uint32_t*)m_masking_key;)) << 32; - //key += *((uint32_t*)m_masking_key); - - // might need to switch byte order - uint32_t key = *((uint32_t*)m_masking_key); - - // 4 - - uint64_t i = 0; - uint64_t s = (m_payload.size() / 4); - - std::cout << "s: " << s << std::endl; - - // chunks of 4 - for (i = 0; i < s; i+=4) { - ((uint32_t*)(&m_payload[0]))[i] = (((uint32_t*)(&m_payload[0]))[i] ^ key); - } - - // finish the last few - for (i = s; i < m_payload.size(); i++) { - m_payload[i] = (m_payload[i] ^ m_masking_key[i%4]); - } -} - -void frame::validate_utf8(uint32_t* state,uint32_t* codep, size_t offset) const { - for (size_t i = offset; i < m_payload.size(); i++) { - using utf8_validator::decode; - - if (decode(state,codep,m_payload[i]) == utf8_validator::UTF8_REJECT) { - throw frame_error("Invalid UTF-8 Data",FERR_PAYLOAD_VIOLATION); - } - } -} - -void frame::validate_basic_header() const { - // check for control frame size - if (get_basic_size() > BASIC_PAYLOAD_LIMIT && is_control()) { - throw frame_error("Control Frame is too large",FERR_PROTOCOL_VIOLATION); - } - - // check for reserved opcodes - if (get_rsv1() || get_rsv2() || get_rsv3()) { - throw frame_error("Reserved bit used",FERR_PROTOCOL_VIOLATION); - } - - // check for reserved opcodes - opcode op = get_opcode(); - if (op > 0x02 && op < 0x08) { - throw frame_error("Reserved opcode used",FERR_PROTOCOL_VIOLATION); - } - if (op > 0x0A) { - throw frame_error("Reserved opcode used",FERR_PROTOCOL_VIOLATION); - } - - // check for fragmented control message - if (is_control() && !get_fin()) { - throw frame_error("Fragmented control message",FERR_PROTOCOL_VIOLATION); - } -} - -void frame::generate_masking_key() { - //throw "masking key generation not implimented"; - - int32_t key = m_gen(); - - //m_masking_key[0] = reinterpret_cast(&key)[0]; - //m_masking_key[1] = reinterpret_cast(&key)[1]; - //m_masking_key[2] = reinterpret_cast(&key)[2]; - //m_masking_key[3] = reinterpret_cast(&key)[3]; - - *(reinterpret_cast(&m_header[get_header_len()-4])) = key; - - //std::cout << "maskkey: " << m_masking_key << std::endl; - - /* TODO: test and tune - - boost::random::random_device rng; - boost::random::uniform_int_distribution<> mask(0,UINT32_MAX); - - *(reinterpret_cast(m_masking_key)) = mask(rng); - - */ -} - -void frame::clear_masking_key() { - // this is a no-op as clearing the mask bit also changes the get_header_len - // method to not include these byte ranges. Whenever the masking bit is re- - // set a new key is generated anyways. -} diff --git a/src/websocket_frame.hpp b/src/websocket_frame.hpp index 9c719f38b2..221f296eab 100644 --- a/src/websocket_frame.hpp +++ b/src/websocket_frame.hpp @@ -29,50 +29,50 @@ #define WEBSOCKET_FRAME_HPP #include "network_utilities.hpp" +#include "websocket_constants.hpp" + +//#include "websocket_server.hpp" +#include "utf8_validator/utf8_validator.hpp" + +#if defined(WIN32) +#include +#else +#include +#endif #include #include #include - -#include -#include +#include +#include namespace websocketpp { - +namespace frame { + /* policies to abstract out - random number generation - utf8 validation +rng + int32_t gen() + + +class boost_random { +public: + boost_random() + int32_t gen(); +private: + boost::random::random_device m_rng; + boost::random::variate_generator > m_gen; + } + */ -class frame { +template +class parser { public: - enum opcode_s { - CONTINUATION_FRAME = 0x00, - TEXT_FRAME = 0x01, - BINARY_FRAME = 0x02, - CONNECTION_CLOSE = 0x08, - PING = 0x09, - PONG = 0x0A - }; - typedef enum opcode_s opcode; - - static const uint8_t MAX_FRAME_OPCODE = 0x07; - - static const uint8_t STATE_BASIC_HEADER = 1; - static const uint8_t STATE_EXTENDED_HEADER = 2; - static const uint8_t STATE_PAYLOAD = 3; - static const uint8_t STATE_READY = 4; - static const uint8_t STATE_RECOVERY = 5; - - static const uint16_t FERR_FATAL_SESSION_ERROR = 0; // force session end - static const uint16_t FERR_SOFT_SESSION_ERROR = 1; // should log and ignore - static const uint16_t FERR_PROTOCOL_VIOLATION = 2; // must end session - static const uint16_t FERR_PAYLOAD_VIOLATION = 3; // should end session - static const uint16_t FERR_INTERNAL_SERVER_ERROR = 4; // cleanly end session - static const uint16_t FERR_MSG_TOO_BIG = 5; // basic payload byte flags static const uint8_t BPB0_OPCODE = 0x0F; @@ -83,11 +83,8 @@ public: static const uint8_t BPB1_PAYLOAD = 0x7F; static const uint8_t BPB1_MASK = 0x80; - static const uint8_t BASIC_PAYLOAD_LIMIT = 0x7D; // 125 static const uint8_t BASIC_PAYLOAD_16BIT_CODE = 0x7E; // 126 - static const uint16_t PAYLOAD_16BIT_LIMIT = 0xFFFF; // 2^16, 65535 static const uint8_t BASIC_PAYLOAD_64BIT_CODE = 0x7F; // 127 - static const uint64_t PAYLOAD_64BIT_LIMIT = 0x7FFFFFFFFFFFFFFF; // 2^63 static const unsigned int BASIC_HEADER_LENGTH = 2; static const unsigned int MAX_HEADER_LENGTH = 14; @@ -95,109 +92,552 @@ public: static const uint64_t max_payload_size = 100000000; // 100MB // create an empty frame for writing into - frame() : m_gen(m_rng, - boost::random::uniform_int_distribution<>(INT32_MIN,INT32_MAX)),m_degraded(false) { + parser(rng_policy& rng) : m_rng(rng) { reset(); } - uint8_t get_state() const; - uint64_t get_bytes_needed() const; - void reset(); + bool ready() const { + return (m_state == STATE_READY); + } + uint64_t get_bytes_needed() const { + return m_bytes_needed; + } + void reset() { + m_state = STATE_BASIC_HEADER; + m_bytes_needed = BASIC_HEADER_LENGTH; + m_degraded = false; + m_payload.empty(); + memset(m_header,0,MAX_HEADER_LENGTH); + } - void consume(std::istream &s); + // Method invariant: One of the following must always be true even in the case + // of exceptions. + // - m_bytes_needed > 0 + // - m-state = STATE_READY + void consume(std::istream &s) { + try { + switch (m_state) { + case STATE_BASIC_HEADER: + s.read(&m_header[BASIC_HEADER_LENGTH-m_bytes_needed],m_bytes_needed); + + m_bytes_needed -= s.gcount(); + + if (m_bytes_needed == 0) { + process_basic_header(); + + validate_basic_header(); + + if (m_bytes_needed > 0) { + m_state = STATE_EXTENDED_HEADER; + } else { + process_extended_header(); + + if (m_bytes_needed == 0) { + m_state = STATE_READY; + process_payload(); + + } else { + m_state = STATE_PAYLOAD; + } + } + } + break; + case STATE_EXTENDED_HEADER: + s.read(&m_header[get_header_len()-m_bytes_needed],m_bytes_needed); + + m_bytes_needed -= s.gcount(); + + if (m_bytes_needed == 0) { + process_extended_header(); + if (m_bytes_needed == 0) { + m_state = STATE_READY; + process_payload(); + } else { + m_state = STATE_PAYLOAD; + } + } + break; + case STATE_PAYLOAD: + s.read(reinterpret_cast(&m_payload[m_payload.size()-m_bytes_needed]), + m_bytes_needed); + + m_bytes_needed -= s.gcount(); + + if (m_bytes_needed == 0) { + m_state = STATE_READY; + process_payload(); + } + break; + case STATE_RECOVERY: + // Recovery state discards all bytes that are not the first byte + // of a close frame. + do { + s.read(reinterpret_cast(&m_header[0]),1); + + //std::cout << std::hex << int(static_cast(m_header[0])) << " "; + + if (int(static_cast(m_header[0])) == 0x88) { + //(BPB0_FIN && CONNECTION_CLOSE) + m_bytes_needed--; + m_state = STATE_BASIC_HEADER; + break; + } + } while (s.gcount() > 0); + + //std::cout << std::endl; + + break; + default: + break; + } + + /*if (s.gcount() == 0) { + throw frame_error("consume read zero bytes",FERR_FATAL_SESSION_ERROR); + }*/ + } catch (const exception& e) { + // After this point all non-close frames must be considered garbage, + // including the current one. Reset it and put the reading frame into + // a recovery state. + if (m_degraded == true) { + throw exception("An error occurred while trying to gracefully recover from a less serious frame error.",FERR_FATAL_SESSION_ERROR); + } else { + reset(); + m_state = STATE_RECOVERY; + m_degraded = true; + + throw e; + } + } + } // get pointers to underlying buffers - char* get_header(); - char* get_extended_header(); - unsigned int get_header_len() const; + char* get_header() { + return m_header; + } + char* get_extended_header() { + return m_header+BASIC_HEADER_LENGTH; + } + unsigned int get_header_len() const { + unsigned int temp = 2; + + if (get_masked()) { + temp += 4; + } + + if (get_basic_size() == 126) { + temp += 2; + } else if (get_basic_size() == 127) { + temp += 8; + } + + return temp; + } - char* get_masking_key(); + char* get_masking_key() { + if (m_state != STATE_READY) { + throw exception("attempted to get masking_key before reading full header"); + } + return m_header[get_header_len()-4]; + } // get and set header bits - bool get_fin() const; - void set_fin(bool fin); + bool get_fin() const { + return ((m_header[0] & BPB0_FIN) == BPB0_FIN); + } + void set_fin(bool fin) { + if (fin) { + m_header[0] |= BPB0_FIN; + } else { + m_header[0] &= (0xFF ^ BPB0_FIN); + } + } - bool get_rsv1() const; - void set_rsv1(bool b); + bool get_rsv1() const { + return ((m_header[0] & BPB0_RSV1) == BPB0_RSV1); + } + void set_rsv1(bool b) { + if (b) { + m_header[0] |= BPB0_RSV1; + } else { + m_header[0] &= (0xFF ^ BPB0_RSV1); + } + } - bool get_rsv2() const; - void set_rsv2(bool b); + bool get_rsv2() const { + return ((m_header[0] & BPB0_RSV2) == BPB0_RSV2); + } + void set_rsv2(bool b) { + if (b) { + m_header[0] |= BPB0_RSV2; + } else { + m_header[0] &= (0xFF ^ BPB0_RSV2); + } + } - bool get_rsv3() const; - void set_rsv3(bool b); + bool get_rsv3() const { + return ((m_header[0] & BPB0_RSV3) == BPB0_RSV3); + } + void set_rsv3(bool b) { + if (b) { + m_header[0] |= BPB0_RSV3; + } else { + m_header[0] &= (0xFF ^ BPB0_RSV3); + } + } - opcode get_opcode() const; - void set_opcode(opcode op); + opcode::value get_opcode() const { + return frame::opcode::value(m_header[0] & BPB0_OPCODE); + } + void set_opcode(opcode::value op) { + if (frame::opcode::reserved(op)) { + throw exception("reserved opcode",FERR_PROTOCOL_VIOLATION); + } + + if (frame::opcode::invalid(op)) { + throw exception("invalid opcode",FERR_PROTOCOL_VIOLATION); + } + + if (get_basic_size() > BASIC_PAYLOAD_LIMIT && is_control()) { + throw exception("control frames can't have large payloads",FERR_PROTOCOL_VIOLATION); + } + + m_header[0] &= (0xFF ^ BPB0_OPCODE); // clear op bits + m_header[0] |= op; // set op bits + } - bool get_masked() const; - void set_masked(bool masked); + bool get_masked() const { + return ((m_header[1] & BPB1_MASK) == BPB1_MASK); + } + void set_masked(bool masked) { + if (masked) { + m_header[1] |= BPB1_MASK; + generate_masking_key(); + } else { + m_header[1] &= (0xFF ^ BPB1_MASK); + clear_masking_key(); + } + } - uint8_t get_basic_size() const; - size_t get_payload_size() const; + uint8_t get_basic_size() const { + return m_header[1] & BPB1_PAYLOAD; + } + size_t get_payload_size() const { + if (m_state != STATE_READY && m_state != STATE_PAYLOAD) { + // problem + throw exception("attempted to get payload size before reading full header"); + } + + return m_payload.size(); + } - uint16_t get_close_status() const; - std::string get_close_msg() const; + close::status::value get_close_status() const { + if (get_payload_size() == 0) { + return close::status::NO_STATUS; + } else if (get_payload_size() >= 2) { + char val[2]; + + val[0] = m_payload[0]; + val[1] = m_payload[1]; + + uint16_t code = ntohs(*(reinterpret_cast(&val[0]))); + + return close::status::value(code); + } else { + return close::status::PROTOCOL_ERROR; + } + } + std::string get_close_msg() const { + if (get_payload_size() > 2) { + uint32_t state = utf8_validator::UTF8_ACCEPT; + uint32_t codep = 0; + validate_utf8(&state,&codep,2); + if (state != utf8_validator::UTF8_ACCEPT) { + throw exception("Invalid UTF-8 Data", + frame::FERR_PAYLOAD_VIOLATION); + } + return std::string(m_payload.begin()+2,m_payload.end()); + } else { + return std::string(); + } + } - std::vector &get_payload(); + std::vector &get_payload() { + return m_payload; + } - void set_payload(const std::vector source); - void set_payload(const std::string source); - void set_payload_helper(size_t s); + void set_payload(const std::vector source) { + set_payload_helper(source.size()); + + std::copy(source.begin(),source.end(),m_payload.begin()); + } + void set_payload(const std::string source) { + set_payload_helper(source.size()); + + std::copy(source.begin(),source.end(),m_payload.begin()); + } + void set_payload_helper(size_t s) { + if (s > max_payload_size) { + throw exception("requested payload is over implimentation defined limit",FERR_MSG_TOO_BIG); + } + + // limits imposed by the websocket spec + if (s > limits::BASIC_PAYLOAD_SIZE && + is_control()) { + throw exception("control frames can't have large payloads",FERR_PROTOCOL_VIOLATION); + } + + if (s <= limits::PAYLOAD_SIZE_BASIC) { + m_header[1] = s; + } else if (s <= limits::PAYLOAD_SIZE_EXTENDED) { + m_header[1] = BASIC_PAYLOAD_16BIT_CODE; + + // this reinterprets the second pair of bytes in m_header as a + // 16 bit int and writes the payload size there as an integer + // in network byte order + *reinterpret_cast(&m_header[BASIC_HEADER_LENGTH]) = htons(s); + } else if (s <= limits::PAYLOAD_SIZE_JUMBO) { + m_header[1] = BASIC_PAYLOAD_64BIT_CODE; + *reinterpret_cast(&m_header[BASIC_HEADER_LENGTH]) = htonll(s); + } else { + throw exception("payload size limit is 63 bits",FERR_PROTOCOL_VIOLATION); + } + + m_payload.resize(s); + } - void set_status(uint16_t status,const std::string message = ""); + void set_status(uint16_t status,const std::string message = "") { + // check for valid statuses + if (close::status::invalid(status)) { + std::stringstream err; + err << "Status code " << status << " is invalid"; + throw exception(err.str()); + } + + if (close::status::reserved(status)) { + std::stringstream err; + err << "Status code " << status << " is reserved"; + throw exception(err.str()); + } + + m_payload.resize(2+message.size()); + + char val[2]; + + *reinterpret_cast(&val[0]) = htons(status); + + m_header[1] = message.size()+2; + + m_payload[0] = val[0]; + m_payload[1] = val[1]; + + std::copy(message.begin(),message.end(),m_payload.begin()+2); + } - bool is_control() const; + bool is_control() const { + return (frame::opcode::is_control(get_opcode())); + } - std::string print_frame() const; + std::string print_frame() const { + std::stringstream f; + + unsigned int len = get_header_len(); + + f << "frame: "; + // print header + for (unsigned int i = 0; i < len; i++) { + f << std::hex << (unsigned short)m_header[i] << " "; + } + // print message + if (m_payload.size() > 50) { + f << "[payload of " << m_payload.size() << " bytes]"; + } else { + std::vector::const_iterator it; + for (it = m_payload.begin(); it != m_payload.end(); it++) { + f << *it; + } + } + return f.str(); + } // reads basic header, sets and returns m_header_bits_needed - void process_basic_header(); - void process_extended_header(); - void process_payload(); - void process_payload2(); // experiment with more efficient masking code. + void process_basic_header() { + m_bytes_needed = get_header_len() - BASIC_HEADER_LENGTH; + } + void process_extended_header() { + uint8_t s = get_basic_size(); + uint64_t payload_size; + int mask_index = BASIC_HEADER_LENGTH; + + if (s <= BASIC_PAYLOAD_LIMIT) { + payload_size = s; + } else if (s == BASIC_PAYLOAD_16BIT_CODE) { + // reinterpret the second two bytes as a 16 bit integer in network + // byte order. Convert to host byte order and store locally. + payload_size = ntohs(*( + reinterpret_cast(&m_header[BASIC_HEADER_LENGTH]) + )); + + if (payload_size < s) { + std::stringstream err; + err << "payload length not minimally encoded. Using 16 bit form for payload size: " << payload_size; + m_bytes_needed = payload_size; + throw exception(err.str(),FERR_PROTOCOL_VIOLATION); + } + + mask_index += 2; + } else if (s == BASIC_PAYLOAD_64BIT_CODE) { + // reinterpret the second eight bytes as a 64 bit integer in + // network byte order. Convert to host byte order and store. + payload_size = ntohll(*( + reinterpret_cast(&m_header[BASIC_HEADER_LENGTH]) + )); + + if (payload_size <= PAYLOAD_16BIT_LIMIT) { + m_bytes_needed = payload_size; + throw exception("payload length not minimally encoded", + FERR_PROTOCOL_VIOLATION); + } + + mask_index += 8; + } else { + // shouldn't be here + throw exception("invalid get_basic_size in process_extended_header"); + } + + if (get_masked() == 0) { + clear_masking_key(); + } else { + // TODO: use this copy line (needs testing) + // std::copy(m_header[mask_index],m_header[mask_index+4],m_masking_key); + m_masking_key[0] = m_header[mask_index+0]; + m_masking_key[1] = m_header[mask_index+1]; + m_masking_key[2] = m_header[mask_index+2]; + m_masking_key[3] = m_header[mask_index+3]; + } + + if (payload_size > max_payload_size) { + // TODO: frame/message size limits + throw server_error("got frame with payload greater than maximum frame buffer size."); + } + m_payload.resize(payload_size); + m_bytes_needed = payload_size; + } + void process_payload() { + if (get_masked()) { + char *masking_key = &m_header[get_header_len()-4]; + + for (uint64_t i = 0; i < m_payload.size(); i++) { + m_payload[i] = (m_payload[i] ^ masking_key[i%4]); + } + } + } - void validate_utf8(uint32_t* state,uint32_t* codep,size_t offset = 0) const; - void validate_basic_header() const; + // experiment with more efficient masking code. + void process_payload2() { + // unmask payload one byte at a time + + //uint64_t key = (*((uint32_t*)m_masking_key;)) << 32; + //key += *((uint32_t*)m_masking_key); + + // might need to switch byte order + uint32_t key = *((uint32_t*)m_masking_key); + + // 4 + + uint64_t i = 0; + uint64_t s = (m_payload.size() / 4); + + std::cout << "s: " << s << std::endl; + + // chunks of 4 + for (i = 0; i < s; i+=4) { + ((uint32_t*)(&m_payload[0]))[i] = (((uint32_t*)(&m_payload[0]))[i] ^ key); + } + + // finish the last few + for (i = s; i < m_payload.size(); i++) { + m_payload[i] = (m_payload[i] ^ m_masking_key[i%4]); + } + } - void generate_masking_key(); - void clear_masking_key(); + void validate_utf8(uint32_t* state,uint32_t* codep,size_t offset = 0) const { + for (size_t i = offset; i < m_payload.size(); i++) { + using utf8_validator::decode; + + if (decode(state,codep,m_payload[i]) == utf8_validator::UTF8_REJECT) { + throw frame_error("Invalid UTF-8 Data",FERR_PAYLOAD_VIOLATION); + } + } + } + void validate_basic_header() const { + // check for control frame size + if (is_control() && get_basic_size() > BASIC_PAYLOAD_LIMIT) { + throw frame_error("Control Frame is too large",FERR_PROTOCOL_VIOLATION); + } + + // check for reserved bits + if (get_rsv1() || get_rsv2() || get_rsv3()) { + throw frame_error("Reserved bit used",FERR_PROTOCOL_VIOLATION); + } + + // check for reserved opcodes + if (opcode::reserved(get_opcode())) { + throw frame_error("Reserved opcode used",FERR_PROTOCOL_VIOLATION); + } + + // check for fragmented control message + if (is_control() && !get_fin()) { + throw frame_error("Fragmented control message",FERR_PROTOCOL_VIOLATION); + } + } + + void generate_masking_key() { + *(reinterpret_cast(&m_header[get_header_len()-4])) = m_rng.gen(); + } + void clear_masking_key() { + // this is a no-op as clearing the mask bit also changes the get_header_len + // method to not include these byte ranges. Whenever the masking bit is re- + // set a new key is generated anyways. + } private: + static const uint8_t STATE_BASIC_HEADER = 1; + static const uint8_t STATE_EXTENDED_HEADER = 2; + static const uint8_t STATE_PAYLOAD = 3; + static const uint8_t STATE_READY = 4; + static const uint8_t STATE_RECOVERY = 5; + uint8_t m_state; uint64_t m_bytes_needed; bool m_degraded; char m_header[MAX_HEADER_LENGTH]; std::vector m_payload; - - char m_masking_key[4]; - - boost::random::random_device m_rng; - boost::random::variate_generator > - m_gen; + + rng_policy& m_rng; }; // Exception classes -class frame_error : public std::exception { +class exception : public std::exception { public: - frame_error(const std::string& msg, - uint16_t code = frame::FERR_FATAL_SESSION_ERROR) + exception(const std::string& msg, + frame::error::value code = frame::error::FATAL_SESSION_ERROR) : m_msg(msg),m_code(code) {} - ~frame_error() throw() {} + exception() throw() {} virtual const char* what() const throw() { return m_msg.c_str(); } - uint16_t code() const throw() { + frame::error::value code() const throw() { return m_code; } std::string m_msg; - uint16_t m_code; + frame::error::value m_code; }; - + +} } #endif // WEBSOCKET_FRAME_HPP diff --git a/src/websocket_server.cpp b/src/websocket_server.cpp index 52ecce784b..d1001affb6 100644 --- a/src/websocket_server.cpp +++ b/src/websocket_server.cpp @@ -25,177 +25,3 @@ * */ -#include "websocket_server.hpp" - -#include -#include - -#include - -#include - -using websocketpp::server; - -server::server(boost::asio::io_service& io_service, - const tcp::endpoint& endpoint, - websocketpp::connection_handler_ptr defc) - : m_elog_level(LOG_ALL), - m_alog_level(ALOG_ALL), - m_max_message_size(DEFAULT_MAX_MESSAGE_SIZE), - m_io_service(io_service), - m_acceptor(io_service, endpoint), - m_def_con_handler(defc), - m_desc("websocketpp::server") { - m_desc.add_options() - ("help", "produce help message") - ("host,h",po::value >()->multitoken()->composing(), "hostnames to listen on") - ("port,p",po::value(), "port to listen on") - ; - -} - -void server::parse_command_line(int ac, char* av[]) { - po::store(po::parse_command_line(ac,av, m_desc),m_vm); - po::notify(m_vm); - - if (m_vm.count("help") ) { - std::cout << m_desc << std::endl; - } - - //m_vm["host"].as(); - - const std::vector< std::string > &foo = m_vm["host"].as< std::vector >(); - - for (int i = 0; i < foo.size(); i++) { - std::cout << foo[i] << std::endl; - } - - //std::cout << m_vm["host"].as< std::vector >() << std::endl; -} - -void server::add_host(std::string host) { - m_hosts.insert(host); -} - -void server::remove_host(std::string host) { - m_hosts.erase(host); -} - - -void server::set_max_message_size(uint64_t val) { - if (val > frame::PAYLOAD_64BIT_LIMIT) { - std::stringstream err; - err << "Invalid maximum message size: " << val; - - // TODO: Figure out what the ideal error behavior for this method. - // Options: - // Throw exception - // Log error and set value to maximum allowed - // Log error and leave value at whatever it was before - log(err.str(),LOG_WARN); - //throw server_error(err.str()); - } - m_max_message_size = val; -} - -bool server::test_elog_level(uint16_t level) { - return (level >= m_elog_level); -} -void server::set_elog_level(uint16_t level) { - std::stringstream msg; - msg << "Error logging level changing from " - << m_elog_level << " to " << level; - log(msg.str(),LOG_INFO); - - m_elog_level = level; -} -bool server::test_alog_level(uint16_t level) { - return ((level & m_alog_level) != 0); -} -void server::set_alog_level(uint16_t level) { - if (test_alog_level(level)) { - return; - } - std::stringstream msg; - msg << "Access logging level " << level << " being set"; - access_log(msg.str(),ALOG_INFO); - - m_alog_level |= level; -} -void server::unset_alog_level(uint16_t level) { - if (!test_alog_level(level)) { - return; - } - std::stringstream msg; - msg << "Access logging level " << level << " being unset"; - access_log(msg.str(),ALOG_INFO); - - m_alog_level &= ~level; -} - -bool server::validate_host(std::string host) { - if (m_hosts.find(host) == m_hosts.end()) { - return false; - } - return true; -} - -bool server::validate_message_size(uint64_t val) { - if (val > m_max_message_size) { - return false; - } - return true; -} - -void server::log(std::string msg,uint16_t level) { - if (!test_elog_level(level)) { - return; - } - std::cerr << "[Error Log] " - << boost::posix_time::to_iso_extended_string( - boost::posix_time::second_clock::local_time()) - << " " << msg << std::endl; -} -void server::access_log(std::string msg,uint16_t level) { - if (!test_alog_level(level)) { - return; - } - std::cout << "[Access Log] " - << boost::posix_time::to_iso_extended_string( - boost::posix_time::second_clock::local_time()) - << " " << msg << std::endl; -} - -void server::start_accept() { - // TODO: sanity check whether the session buffer size bound could be reduced - server_session_ptr new_session(new server_session(shared_from_this(), - m_io_service, - m_def_con_handler, - m_max_message_size*2)); - - m_acceptor.async_accept( - new_session->socket(), - boost::bind( - &server::handle_accept, - this, - new_session, - boost::asio::placeholders::error - ) - ); -} - -void server::handle_accept(websocketpp::server_session_ptr session, - const boost::system::error_code& error) { - - if (!error) { - session->on_connect(); - } else { - std::stringstream err; - err << "Error accepting socket connection: " << error; - - log(err.str(),LOG_ERROR); - throw server_error(err.str()); - } - - this->start_accept(); -} diff --git a/src/websocket_server.hpp b/src/websocket_server.hpp index dc9b4f03cb..5aad034cd6 100644 --- a/src/websocket_server.hpp +++ b/src/websocket_server.hpp @@ -30,20 +30,25 @@ #include #include +#include #include namespace po = boost::program_options; #include namespace websocketpp { - class server; - typedef boost::shared_ptr server_ptr; + //class server; + //typedef boost::shared_ptr server_ptr; } #include "websocketpp.hpp" #include "websocket_server_session.hpp" #include "websocket_connection_handler.hpp" +#include + +#include "policy/rng/blank_rng.hpp" + using boost::asio::ip::tcp; namespace websocketpp { @@ -61,54 +66,198 @@ private: std::string m_msg; }; -class server : public boost::enable_shared_from_this { +template +class server : public boost::enable_shared_from_this< server > { public: - server(boost::asio::io_service& io_service, + typedef boost::shared_ptr< server > ptr; + + server(boost::asio::io_service& io_service, const tcp::endpoint& endpoint, - connection_handler_ptr defc); + connection_handler_ptr defc) + : m_elog_level(LOG_ALL), + m_alog_level(ALOG_ALL), + m_max_message_size(DEFAULT_MAX_MESSAGE_SIZE), + m_io_service(io_service), + m_acceptor(io_service, endpoint), + m_def_con_handler(defc), + m_desc("websocketpp::server") + { + m_desc.add_options() + ("help", "produce help message") + ("host,h",po::value >()->multitoken()->composing(), "hostnames to listen on") + ("port,p",po::value(), "port to listen on") + ; + } // creates a new session object and connects the next websocket // connection to it. - void start_accept(); + void start_accept() { + // TODO: sanity check whether the session buffer size bound could be reduced + server_session_ptr new_session( + new server_session( + shared_from_this(), + m_io_service, + m_def_con_handler, + m_max_message_size*2 + ) + ); + + m_acceptor.async_accept( + new_session->socket(), + boost::bind( + &server::handle_accept, + shared_from_this(), + new_session, + boost::asio::placeholders::error + ) + ); + } // INTERFACE FOR LOCAL APPLICATIONS // Add or remove a host string (host:port) to the list of acceptable // hosts to accept websocket connections from. Additions/deletions here // only affect new connections. - void add_host(std::string host); - void remove_host(std::string host); + void add_host(std::string host) { + m_hosts.insert(host); + } + void remove_host(std::string host) { + m_hosts.erase(host); + } - void set_max_message_size(uint64_t val); + void set_max_message_size(uint64_t val) { + if (val > frame::limits::JUMBO_PAYLOAD_SIZE) { + std::stringstream err; + err << "Invalid maximum message size: " << val; + + // TODO: Figure out what the ideal error behavior for this method. + // Options: + // Throw exception + // Log error and set value to maximum allowed + // Log error and leave value at whatever it was before + log(err.str(),LOG_WARN); + //throw server_error(err.str()); + } + m_max_message_size = val; + } // Test methods determine if a message of the given level should be // written. elog shows all values above the level set. alog shows only // the values explicitly set. - bool test_elog_level(uint16_t level); - void set_elog_level(uint16_t level); + bool test_elog_level(uint16_t level) { + return (level >= m_elog_level); + } + void set_elog_level(uint16_t level) { + std::stringstream msg; + msg << "Error logging level changing from " + << m_elog_level << " to " << level; + log(msg.str(),LOG_INFO); + + m_elog_level = level; + } - bool test_alog_level(uint16_t level); - void set_alog_level(uint16_t level); - void unset_alog_level(uint16_t level); + bool test_alog_level(uint16_t level) { + return ((level & m_alog_level) != 0); + } + void set_alog_level(uint16_t level) { + if (test_alog_level(level)) { + return; + } + std::stringstream msg; + msg << "Access logging level " << level << " being set"; + access_log(msg.str(),ALOG_INFO); + + m_alog_level |= level; + } + void unset_alog_level(uint16_t level) { + if (!test_alog_level(level)) { + return; + } + std::stringstream msg; + msg << "Access logging level " << level << " being unset"; + access_log(msg.str(),ALOG_INFO); + + m_alog_level &= ~level; + } - void parse_command_line(int ac, char* av[]); + void parse_command_line(int ac, char* av[]) { + po::store(po::parse_command_line(ac,av, m_desc),m_vm); + po::notify(m_vm); + + if (m_vm.count("help") ) { + std::cout << m_desc << std::endl; + } + + //m_vm["host"].as(); + + const std::vector< std::string > &foo = m_vm["host"].as< std::vector >(); + + for (int i = 0; i < foo.size(); i++) { + std::cout << foo[i] << std::endl; + } + + //std::cout << m_vm["host"].as< std::vector >() << std::endl; + } // INTERFACE FOR SESSIONS - + + rng_policy& get_rng() { + return &m_rng; + } + // Check if this server will respond to this host. - bool validate_host(std::string host); + bool validate_host(std::string host) { + if (m_hosts.find(host) == m_hosts.end()) { + return false; + } + return true; + } // Check if message size is within server's acceptable parameters - bool validate_message_size(uint64_t val); + bool validate_message_size(uint64_t val) { + if (val > m_max_message_size) { + return false; + } + return true; + } // write to the server's logs - void log(std::string msg,uint16_t level = LOG_ERROR); - void access_log(std::string msg,uint16_t level); + void log(std::string msg,uint16_t level = LOG_ERROR) { + if (!test_elog_level(level)) { + return; + } + std::cerr << "[Error Log] " + << boost::posix_time::to_iso_extended_string( + boost::posix_time::second_clock::local_time()) + << " " << msg << std::endl; + } + void access_log(std::string msg,uint16_t level) { + if (!test_alog_level(level)) { + return; + } + std::cout << "[Access Log] " + << boost::posix_time::to_iso_extended_string( + boost::posix_time::second_clock::local_time()) + << " " << msg << std::endl; + } private: // if no errors starts the session's read loop and returns to the // start_accept phase. void handle_accept(server_session_ptr session, - const boost::system::error_code& error); + const boost::system::error_code& error) + { + if (!error) { + session->on_connect(); + } else { + std::stringstream err; + err << "Error accepting socket connection: " << error; + + log(err.str(),LOG_ERROR); + throw server_error(err.str()); + } + + this->start_accept(); + } private: uint16_t m_elog_level; @@ -120,6 +269,8 @@ private: tcp::acceptor m_acceptor; connection_handler_ptr m_def_con_handler; + rng_policy m_rng; + po::options_description m_desc; po::variables_map m_vm; }; diff --git a/src/websocket_session.cpp b/src/websocket_session.cpp index a8119b4728..65d401e974 100644 --- a/src/websocket_session.cpp +++ b/src/websocket_session.cpp @@ -24,715 +24,3 @@ * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * */ - -#include "websocketpp.hpp" -#include "websocket_session.hpp" - -#include "websocket_frame.hpp" -#include "utf8_validator/utf8_validator.hpp" - -#include -#include -#include - - -#include -#include -#include -#include - -using websocketpp::session; - -session::session (boost::asio::io_service& io_service, - websocketpp::connection_handler_ptr defc, - uint64_t buf_size) - : m_state(STATE_CONNECTING), - m_writing(false), - m_local_close_code(CLOSE_STATUS_NO_STATUS), - m_remote_close_code(CLOSE_STATUS_NO_STATUS), - m_was_clean(false), - m_closed_by_me(false), - m_dropped_by_me(false), - m_socket(io_service), - m_io_service(io_service), - m_local_interface(defc), - m_timer(io_service,boost::posix_time::seconds(0)), - m_buf(buf_size), // maximum buffered (unconsumed) bytes from network - m_utf8_state(utf8_validator::UTF8_ACCEPT), - m_utf8_codepoint(0) {} - -tcp::socket& session::socket() { - return m_socket; -} - -boost::asio::io_service& session::io_service() { - return m_io_service; -} - -void session::set_handler(websocketpp::connection_handler_ptr new_con) { - if (m_local_interface) { - // TODO: this should be another method and not reusing onclose - //m_local_interface->disconnect(shared_from_this(),4000,"Setting new connection handler"); - } - m_local_interface = new_con; - m_local_interface->on_open(shared_from_this()); -} - -const std::string& session::get_subprotocol() const { - if (m_state == STATE_CONNECTING) { - log("Subprotocol is not avaliable before the handshake has completed.",LOG_WARN); - throw server_error("Subprotocol is not avaliable before the handshake has completed."); - } - return m_server_subprotocol; -} - -const std::string& session::get_resource() const { - return m_resource; -} - -const std::string& session::get_origin() const { - return m_client_origin; -} - -std::string session::get_client_header(const std::string& key) const { - return get_header(key,m_client_headers); -} - -std::string session::get_server_header(const std::string& key) const { - return get_header(key,m_server_headers); -} - -std::string session::get_header(const std::string& key, - const websocketpp::header_list& list) const { - header_list::const_iterator h = list.find(key); - - if (h == list.end()) { - return ""; - } else { - return h->second; - } -} - -const std::vector& session::get_extensions() const { - return m_server_extensions; -} - -unsigned int session::get_version() const { - return m_version; -} - -void session::send(const std::string &msg) { - if (m_state != STATE_OPEN) { - log("Tried to send a message from a session that wasn't open",LOG_WARN); - return; - } - m_write_frame.set_fin(true); - m_write_frame.set_opcode(frame::TEXT_FRAME); - m_write_frame.set_payload(msg); - - write_frame(); -} - -void session::send(const std::vector &data) { - if (m_state != STATE_OPEN) { - log("Tried to send a message from a session that wasn't open",LOG_WARN); - return; - } - m_write_frame.set_fin(true); - m_write_frame.set_opcode(frame::BINARY_FRAME); - m_write_frame.set_payload(data); - - write_frame(); -} - -// end user interface to close the connection -void session::close(uint16_t status,const std::string& msg) { - validate_app_close_status(status); - - send_close(status,msg); -} - -// TODO: clean this up, needs to be broken out into more specific methods - -// This method initiates a clean disconnect with the given status code and reason -// it logs an error and is ignored if it is called from a state other than OPEN - -// called by process_close when an initiate close method is received. - -void session::send_close(uint16_t status,const std::string &message) { - if (m_state != STATE_OPEN) { - log("Tried to disconnect a session that wasn't open",LOG_WARN); - return; - } - - m_state = STATE_CLOSING; - - m_timer.expires_from_now(boost::posix_time::milliseconds(1000)); - - m_timer.async_wait( - boost::bind( - &session::handle_close_expired, - shared_from_this(), - boost::asio::placeholders::error - ) - ); - - m_local_close_code = status; - m_local_close_msg = message; - - m_write_frame.set_fin(true); - m_write_frame.set_opcode(frame::CONNECTION_CLOSE); - - // echo close value unless there is a good reason not to. - if (status == CLOSE_STATUS_NO_STATUS) { - m_write_frame.set_status(CLOSE_STATUS_NORMAL,""); - } else if (status == CLOSE_STATUS_ABNORMAL_CLOSE) { - // Internal implimentation error. There is no good close code for this. - m_write_frame.set_status(CLOSE_STATUS_POLICY_VIOLATION,message); - } else if (close::status::invalid(status)) { - m_write_frame.set_status(close::status::PROTOCOL_ERROR,"Status code is invalid"); - } else if (close::status::reserved(status)) { - m_write_frame.set_status(close::status::PROTOCOL_ERROR,"Status code is reserved"); - } else { - m_write_frame.set_status(status,message); - } - - write_frame(); -} - -void session::ping(const std::string &msg) { - if (m_state != STATE_OPEN) { - log("Tried to send a ping from a session that wasn't open",LOG_WARN); - return; - } - m_write_frame.set_fin(true); - m_write_frame.set_opcode(frame::PING); - m_write_frame.set_payload(msg); - - write_frame(); -} - -void session::pong(const std::string &msg) { - if (m_state != STATE_OPEN) { - log("Tried to send a pong from a session that wasn't open",LOG_WARN); - return; - } - m_write_frame.set_fin(true); - m_write_frame.set_opcode(frame::PONG); - m_write_frame.set_payload(msg); - - write_frame(); -} - -void session::read_frame() { - // the initial read in the handshake may have read in the first frame. - // handle it (if it exists) before we read anything else. - handle_read_frame(boost::system::error_code()); -} - -// handle_read_frame reads and processes all socket read commands for the -// session by consuming the read buffer and then starting an async read with -// itself as the callback. The connection is over when this method returns. -void session::handle_read_frame(const boost::system::error_code& error) { - if (m_state != STATE_OPEN && m_state != STATE_CLOSING) { - log("handle_read_frame called in invalid state",LOG_ERROR); - return; - } - - if (error) { - if (error == boost::asio::error::eof) { - // if this is a case where we are expecting eof, return, else log & drop - - log_error("Recieved EOF",error); - //drop_tcp(false); - //m_state = STATE_CLOSED; - } else if (error == boost::asio::error::operation_aborted) { - // some other part of our client called shutdown on our socket. - // This is usually due to a write error. Everything should have - // already been logged and dropped so we just return here - return; - } else { - log_error("Error reading frame",error); - //drop_tcp(false); - m_state = STATE_CLOSED; - } - } - - std::istream s(&m_buf); - - while (m_buf.size() > 0 && m_state != STATE_CLOSED) { - try { - if (m_read_frame.get_bytes_needed() == 0) { - throw frame_error("have bytes that no frame needs",frame::FERR_FATAL_SESSION_ERROR); - } - - // Consume will read bytes from s - // will throw a frame_error on error. - - std::stringstream err; - - err << "consuming. have: " << m_buf.size() << " bytes. Need: " << m_read_frame.get_bytes_needed() << " state: " << (int)m_read_frame.get_state(); - log(err.str(),LOG_DEBUG); - m_read_frame.consume(s); - - err.str(""); - err << "consume complete, " << m_buf.size() << " bytes left, " << m_read_frame.get_bytes_needed() << " still needed, state: " << (int)m_read_frame.get_state(); - log(err.str(),LOG_DEBUG); - - if (m_read_frame.get_state() == frame::STATE_READY) { - // process frame and reset frame state for the next frame. - // will throw a frame_error on error. May set m_state to CLOSED, - // if so no more frames should be processed. - err.str(""); - err << "processing frame " << m_buf.size(); - log(err.str(),LOG_DEBUG); - m_timer.cancel(); - process_frame(); - } - } catch (const frame_error& e) { - std::stringstream err; - err << "Caught frame exception: " << e.what(); - - access_log(e.what(),ALOG_FRAME); - log(err.str(),LOG_ERROR); - - // if the exception happened while processing. - // TODO: this is not elegant, perhaps separate frame read vs process - // exceptions need to be used. - if (m_read_frame.get_state() == frame::STATE_READY) { - m_read_frame.reset(); - } - - // process different types of frame errors - // - if (e.code() == frame::FERR_PROTOCOL_VIOLATION) { - send_close(CLOSE_STATUS_PROTOCOL_ERROR, e.what()); - } else if (e.code() == frame::FERR_PAYLOAD_VIOLATION) { - send_close(CLOSE_STATUS_INVALID_PAYLOAD, e.what()); - } else if (e.code() == frame::FERR_INTERNAL_SERVER_ERROR) { - send_close(CLOSE_STATUS_ABNORMAL_CLOSE, e.what()); - } else if (e.code() == frame::FERR_SOFT_SESSION_ERROR) { - // ignore and continue processing frames - continue; - } else { - // Fatal error, forcibly end connection immediately. - log("Dropping TCP due to unrecoverable exception",LOG_DEBUG); - drop_tcp(true); - } - - break; - } - } - - if (error == boost::asio::error::eof) { - m_state = STATE_CLOSED; - } - - // we have read everything, check if we should read more - - if ((m_state == STATE_OPEN || m_state == STATE_CLOSING) && m_read_frame.get_bytes_needed() > 0) { - std::stringstream msg; - msg << "starting async read for " << m_read_frame.get_bytes_needed() << " bytes."; - - log(msg.str(),LOG_DEBUG); - - // TODO: set a timer here in case we don't want to read forever. - // Ex: when the frame is in a degraded state. - - boost::asio::async_read( - m_socket, - m_buf, - boost::asio::transfer_at_least(m_read_frame.get_bytes_needed()), - boost::bind( - &session::handle_read_frame, - shared_from_this(), - boost::asio::placeholders::error - ) - ); - } else if (m_state == STATE_CLOSED) { - log_close_result(); - - if (m_local_interface) { - // TODO: make sure close code/msg are properly set. - m_local_interface->on_close(shared_from_this()); - } - - m_timer.cancel(); - } else { - log("handle_read_frame called in invalid state",LOG_ERROR); - } -} - -void session::process_frame () { - log("process_frame",LOG_DEBUG); - - if (m_state == STATE_OPEN) { - switch (m_read_frame.get_opcode()) { - case frame::CONTINUATION_FRAME: - process_continuation(); - break; - case frame::TEXT_FRAME: - process_text(); - break; - case frame::BINARY_FRAME: - process_binary(); - break; - case frame::CONNECTION_CLOSE: - log("process_close",LOG_DEBUG); - process_close(); - break; - case frame::PING: - process_ping(); - break; - case frame::PONG: - process_pong(); - break; - default: - throw frame_error("Invalid Opcode", - frame::FERR_PROTOCOL_VIOLATION); - break; - } - } else if (m_state == STATE_CLOSING) { - if (m_read_frame.get_opcode() == frame::CONNECTION_CLOSE) { - process_close(); - } else { - // Ignore all other frames in closing state - log("ignoring this frame",LOG_DEBUG); - } - } else { - // Recieved message before or after connection was opened/closed - throw frame_error("process_frame called from invalid state"); - } - - m_read_frame.reset(); -} - -void session::handle_write_frame (const boost::system::error_code& error) { - if (error) { - log_error("Error writing frame data",error); - drop_tcp(false); - } - - access_log("handle_write_frame complete",ALOG_FRAME); - m_writing = false; -} - - -void session::handle_timer_expired (const boost::system::error_code& error) { - if (error) { - if (error == boost::asio::error::operation_aborted) { - log("timer was aborted",LOG_DEBUG); - //drop_tcp(false); - } else { - log("timer ended with error",LOG_DEBUG); - } - return; - } - - log("timer ended without error",LOG_DEBUG); - - -} - -void session::handle_handshake_expired (const boost::system::error_code& error) { - if (error) { - if (error != boost::asio::error::operation_aborted) { - log("Unexpected handshake timer error.",LOG_DEBUG); - drop_tcp(true); - } - return; - } - - log("Handshake timed out",LOG_DEBUG); - drop_tcp(true); -} - -// The error timer is set when we want to give the other endpoint some time to -// do something but don't want to wait forever. There is a special error code -// that represents the timer being canceled by us (because the other endpoint -// responded in time. All other cases should assume that the other endpoint is -// irrepairibly broken and drop the TCP connection. -void session::handle_error_timer_expired (const boost::system::error_code& error) { - if (error) { - if (error == boost::asio::error::operation_aborted) { - log("error timer was aborted",LOG_DEBUG); - //drop_tcp(false); - } else { - log("error timer ended with error",LOG_DEBUG); - drop_tcp(true); - } - return; - } - - log("error timer ended without error",LOG_DEBUG); - drop_tcp(true); -} - -void session::handle_close_expired (const boost::system::error_code& error) { - if (error) { - if (error == boost::asio::error::operation_aborted) { - log("timer was aborted",LOG_DEBUG); - //drop_tcp(false); - } else { - log("Unexpected close timer error.",LOG_DEBUG); - drop_tcp(false); - } - return; - } - - if (m_state != STATE_CLOSED) { - log("close timed out",LOG_DEBUG); - drop_tcp(false); - } -} - -void session::process_ping() { - access_log("Ping",ALOG_MISC_CONTROL); - // TODO: on_ping - - // send pong - m_write_frame.set_fin(true); - m_write_frame.set_opcode(frame::PONG); - m_write_frame.set_payload(m_read_frame.get_payload()); - - write_frame(); -} - -void session::process_pong() { - access_log("Pong",ALOG_MISC_CONTROL); - // TODO: on_pong -} - -void session::process_text() { - // this will throw an exception if validation fails at any point - m_read_frame.validate_utf8(&m_utf8_state,&m_utf8_codepoint); - - // otherwise, treat as binary - process_binary(); -} - -void session::process_binary() { - if (m_fragmented) { - throw frame_error("Got a new message before the previous was finished.", - frame::FERR_PROTOCOL_VIOLATION); - } - - m_current_opcode = m_read_frame.get_opcode(); - - if (m_read_frame.get_fin()) { - deliver_message(); - reset_message(); - } else { - m_fragmented = true; - extract_payload(); - } -} - -void session::process_continuation() { - if (!m_fragmented) { - throw frame_error("Got a continuation frame without an outstanding message.", - frame::FERR_PROTOCOL_VIOLATION); - } - - if (m_current_opcode == frame::TEXT_FRAME) { - // this will throw an exception if validation fails at any point - m_read_frame.validate_utf8(&m_utf8_state,&m_utf8_codepoint); - } - - extract_payload(); - - // check if we are done - if (m_read_frame.get_fin()) { - deliver_message(); - reset_message(); - } -} - -void session::process_close() { - m_remote_close_code = m_read_frame.get_close_status(); - m_remote_close_msg = m_read_frame.get_close_msg(); - - if (m_state == STATE_OPEN) { - log("process_close sending ack",LOG_DEBUG); - // This is the case where the remote initiated the close. - m_closed_by_me = false; - // send acknowledgement - - // check if the remote close code - if (m_remote_close_code >= close::status::RSV_START) { - - } - - send_close(m_remote_close_code,m_remote_close_msg); - } else if (m_state == STATE_CLOSING) { - log("process_close got ack",LOG_DEBUG); - // this is an ack of our close message - m_closed_by_me = true; - } else { - throw frame_error("process_closed called from wrong state"); - } - - m_was_clean = true; - m_state = STATE_CLOSED; -} - -void session::deliver_message() { - if (!m_local_interface) { - return; - } - - if (m_current_opcode == frame::BINARY_FRAME) { - //log("Dispatching Binary Message",LOG_DEBUG); - if (m_fragmented) { - m_local_interface->on_message(shared_from_this(),m_current_message); - } else { - m_local_interface->on_message(shared_from_this(), - m_read_frame.get_payload()); - } - } else if (m_current_opcode == frame::TEXT_FRAME) { - std::string msg; - - // make sure the finished frame is valid utf8 - // the streaming validator checks for bad codepoints as it goes. It - // doesn't know where the end of the message is though, so we need to - // check here to make sure the final message ends on a valid codepoint. - if (m_utf8_state != utf8_validator::UTF8_ACCEPT) { - throw frame_error("Invalid UTF-8 Data", - frame::FERR_PAYLOAD_VIOLATION); - } - - if (m_fragmented) { - msg.append(m_current_message.begin(),m_current_message.end()); - } else { - msg.append( - m_read_frame.get_payload().begin(), - m_read_frame.get_payload().end() - ); - } - - //log("Dispatching Text Message",LOG_DEBUG); - m_local_interface->on_message(shared_from_this(),msg); - } else { - // Not sure if this should be a fatal error or not - std::stringstream err; - err << "Attempted to deliver a message of unsupported opcode " << m_current_opcode; - throw frame_error(err.str(),frame::FERR_SOFT_SESSION_ERROR); - } - -} - -void session::extract_payload() { - std::vector &msg = m_read_frame.get_payload(); - m_current_message.resize(m_current_message.size()+msg.size()); - std::copy(msg.begin(),msg.end(),m_current_message.end()-msg.size()); -} - -void session::write_frame() { - if (!is_server()) { - m_write_frame.set_masked(true); // client must mask frames - } - - m_write_frame.process_payload(); - - std::vector data; - - data.push_back( - boost::asio::buffer( - m_write_frame.get_header(), - m_write_frame.get_header_len() - ) - ); - data.push_back( - boost::asio::buffer(m_write_frame.get_payload()) - ); - - log("Write Frame: "+m_write_frame.print_frame(),LOG_DEBUG); - - m_writing = true; - - boost::asio::async_write( - m_socket, - data, - boost::bind( - &session::handle_write_frame, - shared_from_this(), - boost::asio::placeholders::error - ) - ); -} - -void session::reset_message() { - m_error = false; - m_fragmented = false; - m_current_message.clear(); - - m_utf8_state = utf8_validator::UTF8_ACCEPT; - m_utf8_codepoint = 0; -} - -void session::log_close_result() { - std::stringstream msg; - - msg << "[Connection " << this << "] " - << (m_was_clean ? "Clean " : "Unclean ") - << "close local:[" << m_local_close_code - << (m_local_close_msg == "" ? "" : ","+m_local_close_msg) - << "] remote:[" << m_remote_close_code - << (m_remote_close_msg == "" ? "" : ","+m_remote_close_msg) << "]"; - - access_log(msg.str(),ALOG_DISCONNECT); -} - -void session::log_open_result() { - std::stringstream msg; - - msg << "[Connection " << this << "] " - << m_socket.remote_endpoint() - << " v" << m_version << " " - << (get_client_header("User-Agent") == "" ? "NULL" : get_client_header("User-Agent")) - << " " << m_resource << " " << m_server_http_code; - - access_log(msg.str(),ALOG_HANDSHAKE); -} - -// this is called when an async asio call encounters an error -void session::log_error(std::string msg,const boost::system::error_code& e) { - std::stringstream err; - - err << "[Connection " << this << "] " << msg << " (" << e << ")"; - - log(err.str(),LOG_ERROR); -} - -// validates status codes that the end application is allowed to use -bool session::validate_app_close_status(uint16_t status) { - if (status == CLOSE_STATUS_NORMAL) { - return true; - } - - if (status >= 4000 && status < 5000) { - return true; - } - - return false; -} - -void session::drop_tcp(bool dropped_by_me) { - m_timer.cancel(); - try { - if (m_socket.is_open()) { - m_socket.shutdown(tcp::socket::shutdown_both); - m_socket.close(); - } - } catch (boost::system::system_error& e) { - if (e.code() == boost::asio::error::not_connected) { - // this means the socket was disconnected by the other side before - // we had a chance to. Ignore and continue. - } else { - throw e; - } - } - m_dropped_by_me = dropped_by_me; - m_state = STATE_CLOSED; - -} diff --git a/src/websocket_session.hpp b/src/websocket_session.hpp index d1c7833d8c..87e11b3457 100644 --- a/src/websocket_session.hpp +++ b/src/websocket_session.hpp @@ -77,6 +77,7 @@ #include #include +#include #include #include @@ -87,6 +88,7 @@ #include #endif +#include #include #include #include @@ -109,6 +111,7 @@ namespace websocketpp { #include "base64/base64.h" #include "sha1/sha1.h" +#include "utf8_validator/utf8_validator.hpp" using boost::asio::ip::tcp; @@ -116,33 +119,54 @@ namespace websocketpp { typedef std::map header_list; -class session : public boost::enable_shared_from_this { +template