started policy-refactor branch

This commit is contained in:
Peter Thorson
2011-10-28 17:09:36 -05:00
parent b48733747e
commit 86da9f503c
13 changed files with 1726 additions and 1683 deletions

View File

@@ -0,0 +1,38 @@
/*
* Copyright (c) 2011, Peter Thorson. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the WebSocket++ Project nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL PETER THORSON BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
* This Makefile was derived from a similar one included in the libjson project
* It's authors were Jonathan Wallace and Bernhard Fluehmann.
*/
#include "blank_rng.hpp"
using websocketpp::blank_rng;
blank_rng::blank_rng() {}
int32_t blank_rng::gen() {
throw "Random Number generation not supported";
}

View File

@@ -0,0 +1,43 @@
/*
* Copyright (c) 2011, Peter Thorson. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the WebSocket++ Project nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL PETER THORSON BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
* This Makefile was derived from a similar one included in the libjson project
* It's authors were Jonathan Wallace and Bernhard Fluehmann.
*/
#ifndef BLANK_RNG_HPP
#define BLANK_RNG_HPP
#include <stdint.h>
namespace websocketpp {
class blank_rng {
public:
blank_rng();
int32_t gen();
}
}
#endif // BLANK_RNG_HPP

View File

@@ -0,0 +1,39 @@
/*
* Copyright (c) 2011, Peter Thorson. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the WebSocket++ Project nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL PETER THORSON BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
* This Makefile was derived from a similar one included in the libjson project
* It's authors were Jonathan Wallace and Bernhard Fluehmann.
*/
#include "boost_rng.hpp"
using websocketpp::boost_rng;
boost_rng::boost_rng() : m_gen(m_rng,
boost::random::uniform_int_distribution<>(INT32_MIN,INT32_MAX)); {}
int32_t boost_rng::gen() {
return m_gen();
}

View File

@@ -0,0 +1,52 @@
/*
* Copyright (c) 2011, Peter Thorson. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the WebSocket++ Project nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL PETER THORSON BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
* This Makefile was derived from a similar one included in the libjson project
* It's authors were Jonathan Wallace and Bernhard Fluehmann.
*/
#ifndef BOOST_RNG_HPP
#define BOOST_RNG_HPP
#include <stdint.h>
#include <boost/random.hpp>
#include <boost/random/random_device.hpp>
namespace websocketpp {
class boost_rng {
public:
boost_rng();
int32_t gen();
private:
boost::random::random_device m_rng;
boost::random::variate_generator
<
boost::random::random_device&,
boost::random::uniform_int_distribution<>
> m_gen;
}
}
#endif // BOOST_RNG_HPP

156
src/websocket_constants.hpp Normal file
View File

@@ -0,0 +1,156 @@
/*
* Copyright (c) 2011, Peter Thorson. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the WebSocket++ Project nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL PETER THORSON BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
* This Makefile was derived from a similar one included in the libjson project
* It's authors were Jonathan Wallace and Bernhard Fluehmann.
*/
#ifndef WEBSOCKET_CONSTANTS_HPP
#define WEBSOCKET_CONSTANTS_HPP
#include <stdint.h>
// Defaults
namespace websocketpp {
const uint64_t DEFAULT_MAX_MESSAGE_SIZE = 0xFFFFFF; // ~16MB
// System logging levels
static const uint16_t LOG_ALL = 0;
static const uint16_t LOG_DEBUG = 1;
static const uint16_t LOG_INFO = 2;
static const uint16_t LOG_WARN = 3;
static const uint16_t LOG_ERROR = 4;
static const uint16_t LOG_FATAL = 5;
static const uint16_t LOG_OFF = 6;
// Access logging controls
// Individual bits
static const uint16_t ALOG_CONNECT = 0x1;
static const uint16_t ALOG_DISCONNECT = 0x2;
static const uint16_t ALOG_MISC_CONTROL = 0x4;
static const uint16_t ALOG_FRAME = 0x8;
static const uint16_t ALOG_MESSAGE = 0x10;
static const uint16_t ALOG_INFO = 0x20;
static const uint16_t ALOG_HANDSHAKE = 0x40;
// Useful groups
static const uint16_t ALOG_OFF = 0x0;
static const uint16_t ALOG_CONTROL = ALOG_CONNECT
& ALOG_DISCONNECT
& ALOG_MISC_CONTROL;
static const uint16_t ALOG_ALL = 0xFFFF;
namespace close {
namespace status {
enum value {
INVALID_END = 999,
NORMAL = 1000,
GOING_AWAY = 1001,
PROTOCOL_ERROR = 1002,
UNSUPPORTED_DATA = 1003,
RSV_ADHOC_1 = 1004,
NO_STATUS = 1005,
ABNORMAL_CLOSE = 1006,
INVALID_PAYLOAD = 1007,
POLICY_VIOLATION = 1008,
MESSAGE_TOO_BIG = 1009,
EXTENSION_REQUIRE = 1010,
RSV_START = 1011,
RSV_END = 2999,
INVALID_START = 5000
};
inline bool reserved(value s) {
return ((s >= RSV_START && s <= RSV_END) ||
s == RSV_ADHOC_1);
}
inline bool invalid(value s) {
return ((s <= INVALID_END || s >= INVALID_START) ||
s == NO_STATUS ||
s == ABNORMAL_CLOSE);
}
// TODO functions for application ranges?
}
}
namespace frame {
namespace error {
enum value {
FATAL_SESSION_ERROR = 0, // force session end
SOFT_SESSION_ERROR = 1, // should log and ignore
PROTOCOL_VIOLATION = 2, // must end session
PAYLOAD_VIOLATION = 3, // should end session
INTERNAL_SERVER_ERROR = 4, // cleanly end session
MSG_TOO_BIG = 5 // ???
};
}
// Opcodes are 4 bits
// See spec section 5.2
namespace opcode {
enum value {
CONTINUATION = 0x0,
TEXT = 0x1,
BINARY = 0x2,
RSV3 = 0x3,
RSV4 = 0x4,
RSV5 = 0x5,
RSV6 = 0x6,
RSV7 = 0x7,
CLOSE = 0x8,
PING = 0x9,
PONG = 0xA,
CONTROL_RSVB = 0xB,
CONTROL_RSVC = 0xC,
CONTROL_RSVD = 0xD,
CONTROL_RSVE = 0xE,
CONTROL_RSVF = 0xF,
};
inline bool reserved(value v) {
return (v >= RSV3 && v <= RSV7) ||
(v >= CONTROL_RSVB && v <= CONTROL_RSVF);
}
inline bool invalid(value v) {
return (v > 0xF || v < 0);
}
inline bool is_control(value v) {
return v >= 0x8;
}
}
namespace limits {
static const uint8_t PAYLOAD_SIZE_BASIC = 125;
static const uint16_t PAYLOAD_SIZE_EXTENDED = 0xFFFF; // 2^16, 65535
static const uint64_t PAYLOAD_SIZE_JUMBO = 0x7FFFFFFFFFFFFFFF;//2^63
}
}
}
#endif // WEBSOCKET_CONSTANTS_HPP

View File

@@ -25,562 +25,3 @@
*
*/
#include "websocket_frame.hpp"
#include "websocket_server.hpp"
#include "utf8_validator/utf8_validator.hpp"
#include <iostream>
#include <algorithm>
#if defined(WIN32)
#include <winsock2.h>
#else
#include <arpa/inet.h>
#endif
using websocketpp::frame;
uint8_t frame::get_state() const {
return m_state;
}
uint64_t frame::get_bytes_needed() const {
return m_bytes_needed;
}
void frame::reset() {
m_state = STATE_BASIC_HEADER;
m_bytes_needed = BASIC_HEADER_LENGTH;
m_degraded = false;
m_payload.empty();
memset(m_header,0,MAX_HEADER_LENGTH);
}
// Method invariant: One of the following must always be true even in the case
// of exceptions.
// - m_bytes_needed > 0
// - m-state = STATE_READY
void frame::consume(std::istream &s) {
try {
switch (m_state) {
case STATE_BASIC_HEADER:
s.read(&m_header[BASIC_HEADER_LENGTH-m_bytes_needed],m_bytes_needed);
m_bytes_needed -= s.gcount();
if (m_bytes_needed == 0) {
process_basic_header();
validate_basic_header();
if (m_bytes_needed > 0) {
m_state = STATE_EXTENDED_HEADER;
} else {
process_extended_header();
if (m_bytes_needed == 0) {
m_state = STATE_READY;
process_payload();
} else {
m_state = STATE_PAYLOAD;
}
}
}
break;
case STATE_EXTENDED_HEADER:
s.read(&m_header[get_header_len()-m_bytes_needed],m_bytes_needed);
m_bytes_needed -= s.gcount();
if (m_bytes_needed == 0) {
process_extended_header();
if (m_bytes_needed == 0) {
m_state = STATE_READY;
process_payload();
} else {
m_state = STATE_PAYLOAD;
}
}
break;
case STATE_PAYLOAD:
s.read(reinterpret_cast<char *>(&m_payload[m_payload.size()-m_bytes_needed]),
m_bytes_needed);
m_bytes_needed -= s.gcount();
if (m_bytes_needed == 0) {
m_state = STATE_READY;
process_payload();
}
break;
case STATE_RECOVERY:
// Recovery state discards all bytes that are not the first byte
// of a close frame.
do {
s.read(reinterpret_cast<char *>(&m_header[0]),1);
//std::cout << std::hex << int(static_cast<unsigned char>(m_header[0])) << " ";
if (int(static_cast<unsigned char>(m_header[0])) == 0x88) {
//(BPB0_FIN && CONNECTION_CLOSE)
m_bytes_needed--;
m_state = STATE_BASIC_HEADER;
break;
}
} while (s.gcount() > 0);
//std::cout << std::endl;
break;
default:
break;
}
/*if (s.gcount() == 0) {
throw frame_error("consume read zero bytes",FERR_FATAL_SESSION_ERROR);
}*/
} catch (const frame_error& e) {
// After this point all non-close frames must be considered garbage,
// including the current one. Reset it and put the reading frame into
// a recovery state.
if (m_degraded == true) {
throw frame_error("An error occurred while trying to gracefully recover from a less serious frame error.",FERR_FATAL_SESSION_ERROR);
} else {
reset();
m_state = STATE_RECOVERY;
m_degraded = true;
throw e;
}
}
}
char* frame::get_header() {
return m_header;
}
char* frame::get_extended_header() {
return m_header+BASIC_HEADER_LENGTH;
}
unsigned int frame::get_header_len() const {
unsigned int temp = 2;
if (get_masked()) {
temp += 4;
}
if (get_basic_size() == 126) {
temp += 2;
} else if (get_basic_size() == 127) {
temp += 8;
}
return temp;
}
char* frame::get_masking_key() {
if (m_state != STATE_READY) {
throw frame_error("attempted to get masking_key before reading full header");
}
return m_masking_key;
}
// get and set header bits
bool frame::get_fin() const {
return ((m_header[0] & BPB0_FIN) == BPB0_FIN);
}
void frame::set_fin(bool fin) {
if (fin) {
m_header[0] |= BPB0_FIN;
} else {
m_header[0] &= (0xFF ^ BPB0_FIN);
}
}
// get and set reserved bits
bool frame::get_rsv1() const {
return ((m_header[0] & BPB0_RSV1) == BPB0_RSV1);
}
void frame::set_rsv1(bool b) {
if (b) {
m_header[0] |= BPB0_RSV1;
} else {
m_header[0] &= (0xFF ^ BPB0_RSV1);
}
}
bool frame::get_rsv2() const {
return ((m_header[0] & BPB0_RSV2) == BPB0_RSV2);
}
void frame::set_rsv2(bool b) {
if (b) {
m_header[0] |= BPB0_RSV2;
} else {
m_header[0] &= (0xFF ^ BPB0_RSV2);
}
}
bool frame::get_rsv3() const {
return ((m_header[0] & BPB0_RSV3) == BPB0_RSV3);
}
void frame::set_rsv3(bool b) {
if (b) {
m_header[0] |= BPB0_RSV3;
} else {
m_header[0] &= (0xFF ^ BPB0_RSV3);
}
}
frame::opcode frame::get_opcode() const {
return frame::opcode(m_header[0] & BPB0_OPCODE);
}
void frame::set_opcode(frame::opcode op) {
if (op > 0x0F) {
throw frame_error("invalid opcode",FERR_PROTOCOL_VIOLATION);
}
if (get_basic_size() > BASIC_PAYLOAD_LIMIT &&
is_control()) {
throw frame_error("control frames can't have large payloads",FERR_PROTOCOL_VIOLATION);
}
m_header[0] &= (0xFF ^ BPB0_OPCODE); // clear op bits
m_header[0] |= op; // set op bits
}
bool frame::get_masked() const {
return ((m_header[1] & BPB1_MASK) == BPB1_MASK);
}
void frame::set_masked(bool masked) {
if (masked) {
m_header[1] |= BPB1_MASK;
generate_masking_key();
} else {
m_header[1] &= (0xFF ^ BPB1_MASK);
clear_masking_key();
}
}
uint8_t frame::get_basic_size() const {
return m_header[1] & BPB1_PAYLOAD;
}
size_t frame::get_payload_size() const {
if (m_state != STATE_READY && m_state != STATE_PAYLOAD) {
// problem
throw frame_error("attempted to get payload size before reading full header");
}
return m_payload.size();
}
uint16_t frame::get_close_status() const {
if (get_payload_size() == 0) {
return close::status::NO_STATUS;
} else if (get_payload_size() >= 2) {
char val[2];
val[0] = m_payload[0];
val[1] = m_payload[1];
uint16_t code = ntohs(*(
reinterpret_cast<uint16_t*>(&val[0])
));
return code;
} else {
return close::status::PROTOCOL_ERROR;
}
}
std::string frame::get_close_msg() const {
if (get_payload_size() > 2) {
uint32_t state = utf8_validator::UTF8_ACCEPT;
uint32_t codep = 0;
validate_utf8(&state,&codep,2);
if (state != utf8_validator::UTF8_ACCEPT) {
throw frame_error("Invalid UTF-8 Data",
frame::FERR_PAYLOAD_VIOLATION);
}
return std::string(m_payload.begin()+2,m_payload.end());
} else {
return std::string();
}
}
std::vector<unsigned char> &frame::get_payload() {
return m_payload;
}
void frame::set_payload(const std::vector<unsigned char> source) {
set_payload_helper(source.size());
std::copy(source.begin(),source.end(),m_payload.begin());
}
void frame::set_payload(const std::string source) {
set_payload_helper(source.size());
std::copy(source.begin(),source.end(),m_payload.begin());
}
bool frame::is_control() const {
return (get_opcode() > MAX_FRAME_OPCODE);
}
void frame::set_payload_helper(size_t s) {
if (s > max_payload_size) {
throw frame_error("requested payload is over implimentation defined limit",FERR_MSG_TOO_BIG);
}
// limits imposed by the websocket spec
if (s > BASIC_PAYLOAD_LIMIT &&
get_opcode() > MAX_FRAME_OPCODE) {
throw frame_error("control frames can't have large payloads",FERR_PROTOCOL_VIOLATION);
}
if (s <= BASIC_PAYLOAD_LIMIT) {
m_header[1] = s;
} else if (s <= PAYLOAD_16BIT_LIMIT) {
m_header[1] = BASIC_PAYLOAD_16BIT_CODE;
// this reinterprets the second pair of bytes in m_header as a
// 16 bit int and writes the payload size there as an integer
// in network byte order
*reinterpret_cast<uint16_t*>(&m_header[BASIC_HEADER_LENGTH]) = htons(s);
} else if (s <= PAYLOAD_64BIT_LIMIT) {
m_header[1] = BASIC_PAYLOAD_64BIT_CODE;
*reinterpret_cast<uint64_t*>(&m_header[BASIC_HEADER_LENGTH]) = htonll(s);
} else {
throw frame_error("payload size limit is 63 bits",FERR_PROTOCOL_VIOLATION);
}
m_payload.resize(s);
}
void frame::set_status(uint16_t status,const std::string message) {
// check for valid statuses
if (close::status::invalid(status)) {
std::stringstream err;
err << "Status code " << status << " is invalid";
throw frame_error(err.str());
}
if (close::status::reserved(status)) {
std::stringstream err;
err << "Status code " << status << " is reserved";
throw frame_error(err.str());
}
m_payload.resize(2+message.size());
char val[2];
*reinterpret_cast<uint16_t*>(&val[0]) = htons(status);
m_header[1] = message.size()+2;
m_payload[0] = val[0];
m_payload[1] = val[1];
std::copy(message.begin(),message.end(),m_payload.begin()+2);
}
std::string frame::print_frame() const {
std::stringstream f;
unsigned int len = get_header_len();
f << "frame: ";
// print header
for (unsigned int i = 0; i < len; i++) {
f << std::hex << (unsigned short)m_header[i] << " ";
}
// print message
if (m_payload.size() > 50) {
f << "[payload of " << m_payload.size() << " bytes]";
} else {
std::vector<unsigned char>::const_iterator it;
for (it = m_payload.begin(); it != m_payload.end(); it++) {
f << *it;
}
}
return f.str();
}
void frame::process_basic_header() {
m_bytes_needed = get_header_len() - BASIC_HEADER_LENGTH;
}
void frame::process_extended_header() {
uint8_t s = get_basic_size();
uint64_t payload_size;
int mask_index = BASIC_HEADER_LENGTH;
if (s <= BASIC_PAYLOAD_LIMIT) {
payload_size = s;
} else if (s == BASIC_PAYLOAD_16BIT_CODE) {
// reinterpret the second two bytes as a 16 bit integer in network
// byte order. Convert to host byte order and store locally.
payload_size = ntohs(*(
reinterpret_cast<uint16_t*>(&m_header[BASIC_HEADER_LENGTH])
));
if (payload_size < s) {
std::stringstream err;
err << "payload length not minimally encoded. Using 16 bit form for payload size: " << payload_size;
m_bytes_needed = payload_size;
throw frame_error(err.str(),
FERR_PROTOCOL_VIOLATION);
}
mask_index += 2;
} else if (s == BASIC_PAYLOAD_64BIT_CODE) {
// reinterpret the second eight bytes as a 64 bit integer in
// network byte order. Convert to host byte order and store.
payload_size = ntohll(*(
reinterpret_cast<uint64_t*>(&m_header[BASIC_HEADER_LENGTH])
));
if (payload_size <= PAYLOAD_16BIT_LIMIT) {
m_bytes_needed = payload_size;
throw frame_error("payload length not minimally encoded",
FERR_PROTOCOL_VIOLATION);
}
mask_index += 8;
} else {
// shouldn't be here
throw frame_error("invalid get_basic_size in process_extended_header");
}
if (get_masked() == 0) {
clear_masking_key();
} else {
// TODO: use this copy line (needs testing)
// std::copy(m_header[mask_index],m_header[mask_index+4],m_masking_key);
m_masking_key[0] = m_header[mask_index+0];
m_masking_key[1] = m_header[mask_index+1];
m_masking_key[2] = m_header[mask_index+2];
m_masking_key[3] = m_header[mask_index+3];
}
if (payload_size > max_payload_size) {
// TODO: frame/message size limits
throw server_error("got frame with payload greater than maximum frame buffer size.");
}
m_payload.resize(payload_size);
m_bytes_needed = payload_size;
}
void frame::process_payload() {
if (get_masked()) {
char *masking_key = &m_header[get_header_len()-4];
for (uint64_t i = 0; i < m_payload.size(); i++) {
m_payload[i] = (m_payload[i] ^ masking_key[i%4]);
}
}
}
void frame::process_payload2() {
// unmask payload one byte at a time
//uint64_t key = (*((uint32_t*)m_masking_key;)) << 32;
//key += *((uint32_t*)m_masking_key);
// might need to switch byte order
uint32_t key = *((uint32_t*)m_masking_key);
// 4
uint64_t i = 0;
uint64_t s = (m_payload.size() / 4);
std::cout << "s: " << s << std::endl;
// chunks of 4
for (i = 0; i < s; i+=4) {
((uint32_t*)(&m_payload[0]))[i] = (((uint32_t*)(&m_payload[0]))[i] ^ key);
}
// finish the last few
for (i = s; i < m_payload.size(); i++) {
m_payload[i] = (m_payload[i] ^ m_masking_key[i%4]);
}
}
void frame::validate_utf8(uint32_t* state,uint32_t* codep, size_t offset) const {
for (size_t i = offset; i < m_payload.size(); i++) {
using utf8_validator::decode;
if (decode(state,codep,m_payload[i]) == utf8_validator::UTF8_REJECT) {
throw frame_error("Invalid UTF-8 Data",FERR_PAYLOAD_VIOLATION);
}
}
}
void frame::validate_basic_header() const {
// check for control frame size
if (get_basic_size() > BASIC_PAYLOAD_LIMIT && is_control()) {
throw frame_error("Control Frame is too large",FERR_PROTOCOL_VIOLATION);
}
// check for reserved opcodes
if (get_rsv1() || get_rsv2() || get_rsv3()) {
throw frame_error("Reserved bit used",FERR_PROTOCOL_VIOLATION);
}
// check for reserved opcodes
opcode op = get_opcode();
if (op > 0x02 && op < 0x08) {
throw frame_error("Reserved opcode used",FERR_PROTOCOL_VIOLATION);
}
if (op > 0x0A) {
throw frame_error("Reserved opcode used",FERR_PROTOCOL_VIOLATION);
}
// check for fragmented control message
if (is_control() && !get_fin()) {
throw frame_error("Fragmented control message",FERR_PROTOCOL_VIOLATION);
}
}
void frame::generate_masking_key() {
//throw "masking key generation not implimented";
int32_t key = m_gen();
//m_masking_key[0] = reinterpret_cast<char*>(&key)[0];
//m_masking_key[1] = reinterpret_cast<char*>(&key)[1];
//m_masking_key[2] = reinterpret_cast<char*>(&key)[2];
//m_masking_key[3] = reinterpret_cast<char*>(&key)[3];
*(reinterpret_cast<int32_t *>(&m_header[get_header_len()-4])) = key;
//std::cout << "maskkey: " << m_masking_key << std::endl;
/* TODO: test and tune
boost::random::random_device rng;
boost::random::uniform_int_distribution<> mask(0,UINT32_MAX);
*(reinterpret_cast<uint32_t *>(m_masking_key)) = mask(rng);
*/
}
void frame::clear_masking_key() {
// this is a no-op as clearing the mask bit also changes the get_header_len
// method to not include these byte ranges. Whenever the masking bit is re-
// set a new key is generated anyways.
}

View File

@@ -29,50 +29,50 @@
#define WEBSOCKET_FRAME_HPP
#include "network_utilities.hpp"
#include "websocket_constants.hpp"
//#include "websocket_server.hpp"
#include "utf8_validator/utf8_validator.hpp"
#if defined(WIN32)
#include <winsock2.h>
#else
#include <arpa/inet.h>
#endif
#include <string>
#include <vector>
#include <cstring>
#include <boost/random.hpp>
#include <boost/random/random_device.hpp>
#include <iostream>
#include <algorithm>
namespace websocketpp {
namespace frame {
/* policies to abstract out
- random number generation
- utf8 validation
rng
int32_t gen()
class boost_random {
public:
boost_random()
int32_t gen();
private:
boost::random::random_device m_rng;
boost::random::variate_generator<boost::random::random_device&,boost::random::uniform_int_distribution<> > m_gen;
}
*/
class frame {
template <class rng_policy>
class parser {
public:
enum opcode_s {
CONTINUATION_FRAME = 0x00,
TEXT_FRAME = 0x01,
BINARY_FRAME = 0x02,
CONNECTION_CLOSE = 0x08,
PING = 0x09,
PONG = 0x0A
};
typedef enum opcode_s opcode;
static const uint8_t MAX_FRAME_OPCODE = 0x07;
static const uint8_t STATE_BASIC_HEADER = 1;
static const uint8_t STATE_EXTENDED_HEADER = 2;
static const uint8_t STATE_PAYLOAD = 3;
static const uint8_t STATE_READY = 4;
static const uint8_t STATE_RECOVERY = 5;
static const uint16_t FERR_FATAL_SESSION_ERROR = 0; // force session end
static const uint16_t FERR_SOFT_SESSION_ERROR = 1; // should log and ignore
static const uint16_t FERR_PROTOCOL_VIOLATION = 2; // must end session
static const uint16_t FERR_PAYLOAD_VIOLATION = 3; // should end session
static const uint16_t FERR_INTERNAL_SERVER_ERROR = 4; // cleanly end session
static const uint16_t FERR_MSG_TOO_BIG = 5;
// basic payload byte flags
static const uint8_t BPB0_OPCODE = 0x0F;
@@ -83,11 +83,8 @@ public:
static const uint8_t BPB1_PAYLOAD = 0x7F;
static const uint8_t BPB1_MASK = 0x80;
static const uint8_t BASIC_PAYLOAD_LIMIT = 0x7D; // 125
static const uint8_t BASIC_PAYLOAD_16BIT_CODE = 0x7E; // 126
static const uint16_t PAYLOAD_16BIT_LIMIT = 0xFFFF; // 2^16, 65535
static const uint8_t BASIC_PAYLOAD_64BIT_CODE = 0x7F; // 127
static const uint64_t PAYLOAD_64BIT_LIMIT = 0x7FFFFFFFFFFFFFFF; // 2^63
static const unsigned int BASIC_HEADER_LENGTH = 2;
static const unsigned int MAX_HEADER_LENGTH = 14;
@@ -95,109 +92,552 @@ public:
static const uint64_t max_payload_size = 100000000; // 100MB
// create an empty frame for writing into
frame() : m_gen(m_rng,
boost::random::uniform_int_distribution<>(INT32_MIN,INT32_MAX)),m_degraded(false) {
parser(rng_policy& rng) : m_rng(rng) {
reset();
}
uint8_t get_state() const;
uint64_t get_bytes_needed() const;
void reset();
bool ready() const {
return (m_state == STATE_READY);
}
uint64_t get_bytes_needed() const {
return m_bytes_needed;
}
void reset() {
m_state = STATE_BASIC_HEADER;
m_bytes_needed = BASIC_HEADER_LENGTH;
m_degraded = false;
m_payload.empty();
memset(m_header,0,MAX_HEADER_LENGTH);
}
void consume(std::istream &s);
// Method invariant: One of the following must always be true even in the case
// of exceptions.
// - m_bytes_needed > 0
// - m-state = STATE_READY
void consume(std::istream &s) {
try {
switch (m_state) {
case STATE_BASIC_HEADER:
s.read(&m_header[BASIC_HEADER_LENGTH-m_bytes_needed],m_bytes_needed);
m_bytes_needed -= s.gcount();
if (m_bytes_needed == 0) {
process_basic_header();
validate_basic_header();
if (m_bytes_needed > 0) {
m_state = STATE_EXTENDED_HEADER;
} else {
process_extended_header();
if (m_bytes_needed == 0) {
m_state = STATE_READY;
process_payload();
} else {
m_state = STATE_PAYLOAD;
}
}
}
break;
case STATE_EXTENDED_HEADER:
s.read(&m_header[get_header_len()-m_bytes_needed],m_bytes_needed);
m_bytes_needed -= s.gcount();
if (m_bytes_needed == 0) {
process_extended_header();
if (m_bytes_needed == 0) {
m_state = STATE_READY;
process_payload();
} else {
m_state = STATE_PAYLOAD;
}
}
break;
case STATE_PAYLOAD:
s.read(reinterpret_cast<char *>(&m_payload[m_payload.size()-m_bytes_needed]),
m_bytes_needed);
m_bytes_needed -= s.gcount();
if (m_bytes_needed == 0) {
m_state = STATE_READY;
process_payload();
}
break;
case STATE_RECOVERY:
// Recovery state discards all bytes that are not the first byte
// of a close frame.
do {
s.read(reinterpret_cast<char *>(&m_header[0]),1);
//std::cout << std::hex << int(static_cast<unsigned char>(m_header[0])) << " ";
if (int(static_cast<unsigned char>(m_header[0])) == 0x88) {
//(BPB0_FIN && CONNECTION_CLOSE)
m_bytes_needed--;
m_state = STATE_BASIC_HEADER;
break;
}
} while (s.gcount() > 0);
//std::cout << std::endl;
break;
default:
break;
}
/*if (s.gcount() == 0) {
throw frame_error("consume read zero bytes",FERR_FATAL_SESSION_ERROR);
}*/
} catch (const exception& e) {
// After this point all non-close frames must be considered garbage,
// including the current one. Reset it and put the reading frame into
// a recovery state.
if (m_degraded == true) {
throw exception("An error occurred while trying to gracefully recover from a less serious frame error.",FERR_FATAL_SESSION_ERROR);
} else {
reset();
m_state = STATE_RECOVERY;
m_degraded = true;
throw e;
}
}
}
// get pointers to underlying buffers
char* get_header();
char* get_extended_header();
unsigned int get_header_len() const;
char* get_header() {
return m_header;
}
char* get_extended_header() {
return m_header+BASIC_HEADER_LENGTH;
}
unsigned int get_header_len() const {
unsigned int temp = 2;
if (get_masked()) {
temp += 4;
}
if (get_basic_size() == 126) {
temp += 2;
} else if (get_basic_size() == 127) {
temp += 8;
}
return temp;
}
char* get_masking_key();
char* get_masking_key() {
if (m_state != STATE_READY) {
throw exception("attempted to get masking_key before reading full header");
}
return m_header[get_header_len()-4];
}
// get and set header bits
bool get_fin() const;
void set_fin(bool fin);
bool get_fin() const {
return ((m_header[0] & BPB0_FIN) == BPB0_FIN);
}
void set_fin(bool fin) {
if (fin) {
m_header[0] |= BPB0_FIN;
} else {
m_header[0] &= (0xFF ^ BPB0_FIN);
}
}
bool get_rsv1() const;
void set_rsv1(bool b);
bool get_rsv1() const {
return ((m_header[0] & BPB0_RSV1) == BPB0_RSV1);
}
void set_rsv1(bool b) {
if (b) {
m_header[0] |= BPB0_RSV1;
} else {
m_header[0] &= (0xFF ^ BPB0_RSV1);
}
}
bool get_rsv2() const;
void set_rsv2(bool b);
bool get_rsv2() const {
return ((m_header[0] & BPB0_RSV2) == BPB0_RSV2);
}
void set_rsv2(bool b) {
if (b) {
m_header[0] |= BPB0_RSV2;
} else {
m_header[0] &= (0xFF ^ BPB0_RSV2);
}
}
bool get_rsv3() const;
void set_rsv3(bool b);
bool get_rsv3() const {
return ((m_header[0] & BPB0_RSV3) == BPB0_RSV3);
}
void set_rsv3(bool b) {
if (b) {
m_header[0] |= BPB0_RSV3;
} else {
m_header[0] &= (0xFF ^ BPB0_RSV3);
}
}
opcode get_opcode() const;
void set_opcode(opcode op);
opcode::value get_opcode() const {
return frame::opcode::value(m_header[0] & BPB0_OPCODE);
}
void set_opcode(opcode::value op) {
if (frame::opcode::reserved(op)) {
throw exception("reserved opcode",FERR_PROTOCOL_VIOLATION);
}
if (frame::opcode::invalid(op)) {
throw exception("invalid opcode",FERR_PROTOCOL_VIOLATION);
}
if (get_basic_size() > BASIC_PAYLOAD_LIMIT && is_control()) {
throw exception("control frames can't have large payloads",FERR_PROTOCOL_VIOLATION);
}
m_header[0] &= (0xFF ^ BPB0_OPCODE); // clear op bits
m_header[0] |= op; // set op bits
}
bool get_masked() const;
void set_masked(bool masked);
bool get_masked() const {
return ((m_header[1] & BPB1_MASK) == BPB1_MASK);
}
void set_masked(bool masked) {
if (masked) {
m_header[1] |= BPB1_MASK;
generate_masking_key();
} else {
m_header[1] &= (0xFF ^ BPB1_MASK);
clear_masking_key();
}
}
uint8_t get_basic_size() const;
size_t get_payload_size() const;
uint8_t get_basic_size() const {
return m_header[1] & BPB1_PAYLOAD;
}
size_t get_payload_size() const {
if (m_state != STATE_READY && m_state != STATE_PAYLOAD) {
// problem
throw exception("attempted to get payload size before reading full header");
}
return m_payload.size();
}
uint16_t get_close_status() const;
std::string get_close_msg() const;
close::status::value get_close_status() const {
if (get_payload_size() == 0) {
return close::status::NO_STATUS;
} else if (get_payload_size() >= 2) {
char val[2];
val[0] = m_payload[0];
val[1] = m_payload[1];
uint16_t code = ntohs(*(reinterpret_cast<uint16_t*>(&val[0])));
return close::status::value(code);
} else {
return close::status::PROTOCOL_ERROR;
}
}
std::string get_close_msg() const {
if (get_payload_size() > 2) {
uint32_t state = utf8_validator::UTF8_ACCEPT;
uint32_t codep = 0;
validate_utf8(&state,&codep,2);
if (state != utf8_validator::UTF8_ACCEPT) {
throw exception("Invalid UTF-8 Data",
frame::FERR_PAYLOAD_VIOLATION);
}
return std::string(m_payload.begin()+2,m_payload.end());
} else {
return std::string();
}
}
std::vector<unsigned char> &get_payload();
std::vector<unsigned char> &get_payload() {
return m_payload;
}
void set_payload(const std::vector<unsigned char> source);
void set_payload(const std::string source);
void set_payload_helper(size_t s);
void set_payload(const std::vector<unsigned char> source) {
set_payload_helper(source.size());
std::copy(source.begin(),source.end(),m_payload.begin());
}
void set_payload(const std::string source) {
set_payload_helper(source.size());
std::copy(source.begin(),source.end(),m_payload.begin());
}
void set_payload_helper(size_t s) {
if (s > max_payload_size) {
throw exception("requested payload is over implimentation defined limit",FERR_MSG_TOO_BIG);
}
// limits imposed by the websocket spec
if (s > limits::BASIC_PAYLOAD_SIZE &&
is_control()) {
throw exception("control frames can't have large payloads",FERR_PROTOCOL_VIOLATION);
}
if (s <= limits::PAYLOAD_SIZE_BASIC) {
m_header[1] = s;
} else if (s <= limits::PAYLOAD_SIZE_EXTENDED) {
m_header[1] = BASIC_PAYLOAD_16BIT_CODE;
// this reinterprets the second pair of bytes in m_header as a
// 16 bit int and writes the payload size there as an integer
// in network byte order
*reinterpret_cast<uint16_t*>(&m_header[BASIC_HEADER_LENGTH]) = htons(s);
} else if (s <= limits::PAYLOAD_SIZE_JUMBO) {
m_header[1] = BASIC_PAYLOAD_64BIT_CODE;
*reinterpret_cast<uint64_t*>(&m_header[BASIC_HEADER_LENGTH]) = htonll(s);
} else {
throw exception("payload size limit is 63 bits",FERR_PROTOCOL_VIOLATION);
}
m_payload.resize(s);
}
void set_status(uint16_t status,const std::string message = "");
void set_status(uint16_t status,const std::string message = "") {
// check for valid statuses
if (close::status::invalid(status)) {
std::stringstream err;
err << "Status code " << status << " is invalid";
throw exception(err.str());
}
if (close::status::reserved(status)) {
std::stringstream err;
err << "Status code " << status << " is reserved";
throw exception(err.str());
}
m_payload.resize(2+message.size());
char val[2];
*reinterpret_cast<uint16_t*>(&val[0]) = htons(status);
m_header[1] = message.size()+2;
m_payload[0] = val[0];
m_payload[1] = val[1];
std::copy(message.begin(),message.end(),m_payload.begin()+2);
}
bool is_control() const;
bool is_control() const {
return (frame::opcode::is_control(get_opcode()));
}
std::string print_frame() const;
std::string print_frame() const {
std::stringstream f;
unsigned int len = get_header_len();
f << "frame: ";
// print header
for (unsigned int i = 0; i < len; i++) {
f << std::hex << (unsigned short)m_header[i] << " ";
}
// print message
if (m_payload.size() > 50) {
f << "[payload of " << m_payload.size() << " bytes]";
} else {
std::vector<unsigned char>::const_iterator it;
for (it = m_payload.begin(); it != m_payload.end(); it++) {
f << *it;
}
}
return f.str();
}
// reads basic header, sets and returns m_header_bits_needed
void process_basic_header();
void process_extended_header();
void process_payload();
void process_payload2(); // experiment with more efficient masking code.
void process_basic_header() {
m_bytes_needed = get_header_len() - BASIC_HEADER_LENGTH;
}
void process_extended_header() {
uint8_t s = get_basic_size();
uint64_t payload_size;
int mask_index = BASIC_HEADER_LENGTH;
if (s <= BASIC_PAYLOAD_LIMIT) {
payload_size = s;
} else if (s == BASIC_PAYLOAD_16BIT_CODE) {
// reinterpret the second two bytes as a 16 bit integer in network
// byte order. Convert to host byte order and store locally.
payload_size = ntohs(*(
reinterpret_cast<uint16_t*>(&m_header[BASIC_HEADER_LENGTH])
));
if (payload_size < s) {
std::stringstream err;
err << "payload length not minimally encoded. Using 16 bit form for payload size: " << payload_size;
m_bytes_needed = payload_size;
throw exception(err.str(),FERR_PROTOCOL_VIOLATION);
}
mask_index += 2;
} else if (s == BASIC_PAYLOAD_64BIT_CODE) {
// reinterpret the second eight bytes as a 64 bit integer in
// network byte order. Convert to host byte order and store.
payload_size = ntohll(*(
reinterpret_cast<uint64_t*>(&m_header[BASIC_HEADER_LENGTH])
));
if (payload_size <= PAYLOAD_16BIT_LIMIT) {
m_bytes_needed = payload_size;
throw exception("payload length not minimally encoded",
FERR_PROTOCOL_VIOLATION);
}
mask_index += 8;
} else {
// shouldn't be here
throw exception("invalid get_basic_size in process_extended_header");
}
if (get_masked() == 0) {
clear_masking_key();
} else {
// TODO: use this copy line (needs testing)
// std::copy(m_header[mask_index],m_header[mask_index+4],m_masking_key);
m_masking_key[0] = m_header[mask_index+0];
m_masking_key[1] = m_header[mask_index+1];
m_masking_key[2] = m_header[mask_index+2];
m_masking_key[3] = m_header[mask_index+3];
}
if (payload_size > max_payload_size) {
// TODO: frame/message size limits
throw server_error("got frame with payload greater than maximum frame buffer size.");
}
m_payload.resize(payload_size);
m_bytes_needed = payload_size;
}
void process_payload() {
if (get_masked()) {
char *masking_key = &m_header[get_header_len()-4];
for (uint64_t i = 0; i < m_payload.size(); i++) {
m_payload[i] = (m_payload[i] ^ masking_key[i%4]);
}
}
}
void validate_utf8(uint32_t* state,uint32_t* codep,size_t offset = 0) const;
void validate_basic_header() const;
// experiment with more efficient masking code.
void process_payload2() {
// unmask payload one byte at a time
//uint64_t key = (*((uint32_t*)m_masking_key;)) << 32;
//key += *((uint32_t*)m_masking_key);
// might need to switch byte order
uint32_t key = *((uint32_t*)m_masking_key);
// 4
uint64_t i = 0;
uint64_t s = (m_payload.size() / 4);
std::cout << "s: " << s << std::endl;
// chunks of 4
for (i = 0; i < s; i+=4) {
((uint32_t*)(&m_payload[0]))[i] = (((uint32_t*)(&m_payload[0]))[i] ^ key);
}
// finish the last few
for (i = s; i < m_payload.size(); i++) {
m_payload[i] = (m_payload[i] ^ m_masking_key[i%4]);
}
}
void generate_masking_key();
void clear_masking_key();
void validate_utf8(uint32_t* state,uint32_t* codep,size_t offset = 0) const {
for (size_t i = offset; i < m_payload.size(); i++) {
using utf8_validator::decode;
if (decode(state,codep,m_payload[i]) == utf8_validator::UTF8_REJECT) {
throw frame_error("Invalid UTF-8 Data",FERR_PAYLOAD_VIOLATION);
}
}
}
void validate_basic_header() const {
// check for control frame size
if (is_control() && get_basic_size() > BASIC_PAYLOAD_LIMIT) {
throw frame_error("Control Frame is too large",FERR_PROTOCOL_VIOLATION);
}
// check for reserved bits
if (get_rsv1() || get_rsv2() || get_rsv3()) {
throw frame_error("Reserved bit used",FERR_PROTOCOL_VIOLATION);
}
// check for reserved opcodes
if (opcode::reserved(get_opcode())) {
throw frame_error("Reserved opcode used",FERR_PROTOCOL_VIOLATION);
}
// check for fragmented control message
if (is_control() && !get_fin()) {
throw frame_error("Fragmented control message",FERR_PROTOCOL_VIOLATION);
}
}
void generate_masking_key() {
*(reinterpret_cast<int32_t *>(&m_header[get_header_len()-4])) = m_rng.gen();
}
void clear_masking_key() {
// this is a no-op as clearing the mask bit also changes the get_header_len
// method to not include these byte ranges. Whenever the masking bit is re-
// set a new key is generated anyways.
}
private:
static const uint8_t STATE_BASIC_HEADER = 1;
static const uint8_t STATE_EXTENDED_HEADER = 2;
static const uint8_t STATE_PAYLOAD = 3;
static const uint8_t STATE_READY = 4;
static const uint8_t STATE_RECOVERY = 5;
uint8_t m_state;
uint64_t m_bytes_needed;
bool m_degraded;
char m_header[MAX_HEADER_LENGTH];
std::vector<unsigned char> m_payload;
char m_masking_key[4];
boost::random::random_device m_rng;
boost::random::variate_generator<boost::random::random_device&,
boost::random::uniform_int_distribution<> >
m_gen;
rng_policy& m_rng;
};
// Exception classes
class frame_error : public std::exception {
class exception : public std::exception {
public:
frame_error(const std::string& msg,
uint16_t code = frame::FERR_FATAL_SESSION_ERROR)
exception(const std::string& msg,
frame::error::value code = frame::error::FATAL_SESSION_ERROR)
: m_msg(msg),m_code(code) {}
~frame_error() throw() {}
exception() throw() {}
virtual const char* what() const throw() {
return m_msg.c_str();
}
uint16_t code() const throw() {
frame::error::value code() const throw() {
return m_code;
}
std::string m_msg;
uint16_t m_code;
frame::error::value m_code;
};
}
}
#endif // WEBSOCKET_FRAME_HPP

View File

@@ -25,177 +25,3 @@
*
*/
#include "websocket_server.hpp"
#include <boost/bind.hpp>
#include <boost/date_time/posix_time/posix_time.hpp>
#include <boost/date_time/posix_time/posix_time.hpp>
#include <iostream>
using websocketpp::server;
server::server(boost::asio::io_service& io_service,
const tcp::endpoint& endpoint,
websocketpp::connection_handler_ptr defc)
: m_elog_level(LOG_ALL),
m_alog_level(ALOG_ALL),
m_max_message_size(DEFAULT_MAX_MESSAGE_SIZE),
m_io_service(io_service),
m_acceptor(io_service, endpoint),
m_def_con_handler(defc),
m_desc("websocketpp::server") {
m_desc.add_options()
("help", "produce help message")
("host,h",po::value<std::vector<std::string> >()->multitoken()->composing(), "hostnames to listen on")
("port,p",po::value<int>(), "port to listen on")
;
}
void server::parse_command_line(int ac, char* av[]) {
po::store(po::parse_command_line(ac,av, m_desc),m_vm);
po::notify(m_vm);
if (m_vm.count("help") ) {
std::cout << m_desc << std::endl;
}
//m_vm["host"].as<std::string>();
const std::vector< std::string > &foo = m_vm["host"].as< std::vector<std::string> >();
for (int i = 0; i < foo.size(); i++) {
std::cout << foo[i] << std::endl;
}
//std::cout << m_vm["host"].as< std::vector<std::string> >() << std::endl;
}
void server::add_host(std::string host) {
m_hosts.insert(host);
}
void server::remove_host(std::string host) {
m_hosts.erase(host);
}
void server::set_max_message_size(uint64_t val) {
if (val > frame::PAYLOAD_64BIT_LIMIT) {
std::stringstream err;
err << "Invalid maximum message size: " << val;
// TODO: Figure out what the ideal error behavior for this method.
// Options:
// Throw exception
// Log error and set value to maximum allowed
// Log error and leave value at whatever it was before
log(err.str(),LOG_WARN);
//throw server_error(err.str());
}
m_max_message_size = val;
}
bool server::test_elog_level(uint16_t level) {
return (level >= m_elog_level);
}
void server::set_elog_level(uint16_t level) {
std::stringstream msg;
msg << "Error logging level changing from "
<< m_elog_level << " to " << level;
log(msg.str(),LOG_INFO);
m_elog_level = level;
}
bool server::test_alog_level(uint16_t level) {
return ((level & m_alog_level) != 0);
}
void server::set_alog_level(uint16_t level) {
if (test_alog_level(level)) {
return;
}
std::stringstream msg;
msg << "Access logging level " << level << " being set";
access_log(msg.str(),ALOG_INFO);
m_alog_level |= level;
}
void server::unset_alog_level(uint16_t level) {
if (!test_alog_level(level)) {
return;
}
std::stringstream msg;
msg << "Access logging level " << level << " being unset";
access_log(msg.str(),ALOG_INFO);
m_alog_level &= ~level;
}
bool server::validate_host(std::string host) {
if (m_hosts.find(host) == m_hosts.end()) {
return false;
}
return true;
}
bool server::validate_message_size(uint64_t val) {
if (val > m_max_message_size) {
return false;
}
return true;
}
void server::log(std::string msg,uint16_t level) {
if (!test_elog_level(level)) {
return;
}
std::cerr << "[Error Log] "
<< boost::posix_time::to_iso_extended_string(
boost::posix_time::second_clock::local_time())
<< " " << msg << std::endl;
}
void server::access_log(std::string msg,uint16_t level) {
if (!test_alog_level(level)) {
return;
}
std::cout << "[Access Log] "
<< boost::posix_time::to_iso_extended_string(
boost::posix_time::second_clock::local_time())
<< " " << msg << std::endl;
}
void server::start_accept() {
// TODO: sanity check whether the session buffer size bound could be reduced
server_session_ptr new_session(new server_session(shared_from_this(),
m_io_service,
m_def_con_handler,
m_max_message_size*2));
m_acceptor.async_accept(
new_session->socket(),
boost::bind(
&server::handle_accept,
this,
new_session,
boost::asio::placeholders::error
)
);
}
void server::handle_accept(websocketpp::server_session_ptr session,
const boost::system::error_code& error) {
if (!error) {
session->on_connect();
} else {
std::stringstream err;
err << "Error accepting socket connection: " << error;
log(err.str(),LOG_ERROR);
throw server_error(err.str());
}
this->start_accept();
}

View File

@@ -30,20 +30,25 @@
#include <boost/asio.hpp>
#include <boost/shared_ptr.hpp>
#include <boost/enable_shared_from_this.hpp>
#include <boost/program_options.hpp>
namespace po = boost::program_options;
#include <set>
namespace websocketpp {
class server;
typedef boost::shared_ptr<server> server_ptr;
//class server;
//typedef boost::shared_ptr<server> server_ptr;
}
#include "websocketpp.hpp"
#include "websocket_server_session.hpp"
#include "websocket_connection_handler.hpp"
#include <boost/date_time/posix_time/posix_time.hpp>
#include "policy/rng/blank_rng.hpp"
using boost::asio::ip::tcp;
namespace websocketpp {
@@ -61,54 +66,198 @@ private:
std::string m_msg;
};
class server : public boost::enable_shared_from_this<server> {
template <class rng_policy = blank_rng>
class server : public boost::enable_shared_from_this< server<rng_policy> > {
public:
server(boost::asio::io_service& io_service,
typedef boost::shared_ptr< server<rng_policy> > ptr;
server<rng_policy>(boost::asio::io_service& io_service,
const tcp::endpoint& endpoint,
connection_handler_ptr defc);
connection_handler_ptr defc)
: m_elog_level(LOG_ALL),
m_alog_level(ALOG_ALL),
m_max_message_size(DEFAULT_MAX_MESSAGE_SIZE),
m_io_service(io_service),
m_acceptor(io_service, endpoint),
m_def_con_handler(defc),
m_desc("websocketpp::server")
{
m_desc.add_options()
("help", "produce help message")
("host,h",po::value<std::vector<std::string> >()->multitoken()->composing(), "hostnames to listen on")
("port,p",po::value<int>(), "port to listen on")
;
}
// creates a new session object and connects the next websocket
// connection to it.
void start_accept();
void start_accept() {
// TODO: sanity check whether the session buffer size bound could be reduced
server_session_ptr new_session(
new server_session(
shared_from_this(),
m_io_service,
m_def_con_handler,
m_max_message_size*2
)
);
m_acceptor.async_accept(
new_session->socket(),
boost::bind(
&server::handle_accept,
shared_from_this(),
new_session,
boost::asio::placeholders::error
)
);
}
// INTERFACE FOR LOCAL APPLICATIONS
// Add or remove a host string (host:port) to the list of acceptable
// hosts to accept websocket connections from. Additions/deletions here
// only affect new connections.
void add_host(std::string host);
void remove_host(std::string host);
void add_host(std::string host) {
m_hosts.insert(host);
}
void remove_host(std::string host) {
m_hosts.erase(host);
}
void set_max_message_size(uint64_t val);
void set_max_message_size(uint64_t val) {
if (val > frame::limits::JUMBO_PAYLOAD_SIZE) {
std::stringstream err;
err << "Invalid maximum message size: " << val;
// TODO: Figure out what the ideal error behavior for this method.
// Options:
// Throw exception
// Log error and set value to maximum allowed
// Log error and leave value at whatever it was before
log(err.str(),LOG_WARN);
//throw server_error(err.str());
}
m_max_message_size = val;
}
// Test methods determine if a message of the given level should be
// written. elog shows all values above the level set. alog shows only
// the values explicitly set.
bool test_elog_level(uint16_t level);
void set_elog_level(uint16_t level);
bool test_elog_level(uint16_t level) {
return (level >= m_elog_level);
}
void set_elog_level(uint16_t level) {
std::stringstream msg;
msg << "Error logging level changing from "
<< m_elog_level << " to " << level;
log(msg.str(),LOG_INFO);
m_elog_level = level;
}
bool test_alog_level(uint16_t level);
void set_alog_level(uint16_t level);
void unset_alog_level(uint16_t level);
bool test_alog_level(uint16_t level) {
return ((level & m_alog_level) != 0);
}
void set_alog_level(uint16_t level) {
if (test_alog_level(level)) {
return;
}
std::stringstream msg;
msg << "Access logging level " << level << " being set";
access_log(msg.str(),ALOG_INFO);
m_alog_level |= level;
}
void unset_alog_level(uint16_t level) {
if (!test_alog_level(level)) {
return;
}
std::stringstream msg;
msg << "Access logging level " << level << " being unset";
access_log(msg.str(),ALOG_INFO);
m_alog_level &= ~level;
}
void parse_command_line(int ac, char* av[]);
void parse_command_line(int ac, char* av[]) {
po::store(po::parse_command_line(ac,av, m_desc),m_vm);
po::notify(m_vm);
if (m_vm.count("help") ) {
std::cout << m_desc << std::endl;
}
//m_vm["host"].as<std::string>();
const std::vector< std::string > &foo = m_vm["host"].as< std::vector<std::string> >();
for (int i = 0; i < foo.size(); i++) {
std::cout << foo[i] << std::endl;
}
//std::cout << m_vm["host"].as< std::vector<std::string> >() << std::endl;
}
// INTERFACE FOR SESSIONS
rng_policy& get_rng() {
return &m_rng;
}
// Check if this server will respond to this host.
bool validate_host(std::string host);
bool validate_host(std::string host) {
if (m_hosts.find(host) == m_hosts.end()) {
return false;
}
return true;
}
// Check if message size is within server's acceptable parameters
bool validate_message_size(uint64_t val);
bool validate_message_size(uint64_t val) {
if (val > m_max_message_size) {
return false;
}
return true;
}
// write to the server's logs
void log(std::string msg,uint16_t level = LOG_ERROR);
void access_log(std::string msg,uint16_t level);
void log(std::string msg,uint16_t level = LOG_ERROR) {
if (!test_elog_level(level)) {
return;
}
std::cerr << "[Error Log] "
<< boost::posix_time::to_iso_extended_string(
boost::posix_time::second_clock::local_time())
<< " " << msg << std::endl;
}
void access_log(std::string msg,uint16_t level) {
if (!test_alog_level(level)) {
return;
}
std::cout << "[Access Log] "
<< boost::posix_time::to_iso_extended_string(
boost::posix_time::second_clock::local_time())
<< " " << msg << std::endl;
}
private:
// if no errors starts the session's read loop and returns to the
// start_accept phase.
void handle_accept(server_session_ptr session,
const boost::system::error_code& error);
const boost::system::error_code& error)
{
if (!error) {
session->on_connect();
} else {
std::stringstream err;
err << "Error accepting socket connection: " << error;
log(err.str(),LOG_ERROR);
throw server_error(err.str());
}
this->start_accept();
}
private:
uint16_t m_elog_level;
@@ -120,6 +269,8 @@ private:
tcp::acceptor m_acceptor;
connection_handler_ptr m_def_con_handler;
rng_policy m_rng;
po::options_description m_desc;
po::variables_map m_vm;
};

View File

@@ -24,715 +24,3 @@
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
*/
#include "websocketpp.hpp"
#include "websocket_session.hpp"
#include "websocket_frame.hpp"
#include "utf8_validator/utf8_validator.hpp"
#include <boost/asio.hpp>
#include <boost/bind.hpp>
#include <boost/algorithm/string.hpp>
#include <cstdlib>
#include <iostream>
#include <sstream>
#include <string>
using websocketpp::session;
session::session (boost::asio::io_service& io_service,
websocketpp::connection_handler_ptr defc,
uint64_t buf_size)
: m_state(STATE_CONNECTING),
m_writing(false),
m_local_close_code(CLOSE_STATUS_NO_STATUS),
m_remote_close_code(CLOSE_STATUS_NO_STATUS),
m_was_clean(false),
m_closed_by_me(false),
m_dropped_by_me(false),
m_socket(io_service),
m_io_service(io_service),
m_local_interface(defc),
m_timer(io_service,boost::posix_time::seconds(0)),
m_buf(buf_size), // maximum buffered (unconsumed) bytes from network
m_utf8_state(utf8_validator::UTF8_ACCEPT),
m_utf8_codepoint(0) {}
tcp::socket& session::socket() {
return m_socket;
}
boost::asio::io_service& session::io_service() {
return m_io_service;
}
void session::set_handler(websocketpp::connection_handler_ptr new_con) {
if (m_local_interface) {
// TODO: this should be another method and not reusing onclose
//m_local_interface->disconnect(shared_from_this(),4000,"Setting new connection handler");
}
m_local_interface = new_con;
m_local_interface->on_open(shared_from_this());
}
const std::string& session::get_subprotocol() const {
if (m_state == STATE_CONNECTING) {
log("Subprotocol is not avaliable before the handshake has completed.",LOG_WARN);
throw server_error("Subprotocol is not avaliable before the handshake has completed.");
}
return m_server_subprotocol;
}
const std::string& session::get_resource() const {
return m_resource;
}
const std::string& session::get_origin() const {
return m_client_origin;
}
std::string session::get_client_header(const std::string& key) const {
return get_header(key,m_client_headers);
}
std::string session::get_server_header(const std::string& key) const {
return get_header(key,m_server_headers);
}
std::string session::get_header(const std::string& key,
const websocketpp::header_list& list) const {
header_list::const_iterator h = list.find(key);
if (h == list.end()) {
return "";
} else {
return h->second;
}
}
const std::vector<std::string>& session::get_extensions() const {
return m_server_extensions;
}
unsigned int session::get_version() const {
return m_version;
}
void session::send(const std::string &msg) {
if (m_state != STATE_OPEN) {
log("Tried to send a message from a session that wasn't open",LOG_WARN);
return;
}
m_write_frame.set_fin(true);
m_write_frame.set_opcode(frame::TEXT_FRAME);
m_write_frame.set_payload(msg);
write_frame();
}
void session::send(const std::vector<unsigned char> &data) {
if (m_state != STATE_OPEN) {
log("Tried to send a message from a session that wasn't open",LOG_WARN);
return;
}
m_write_frame.set_fin(true);
m_write_frame.set_opcode(frame::BINARY_FRAME);
m_write_frame.set_payload(data);
write_frame();
}
// end user interface to close the connection
void session::close(uint16_t status,const std::string& msg) {
validate_app_close_status(status);
send_close(status,msg);
}
// TODO: clean this up, needs to be broken out into more specific methods
// This method initiates a clean disconnect with the given status code and reason
// it logs an error and is ignored if it is called from a state other than OPEN
// called by process_close when an initiate close method is received.
void session::send_close(uint16_t status,const std::string &message) {
if (m_state != STATE_OPEN) {
log("Tried to disconnect a session that wasn't open",LOG_WARN);
return;
}
m_state = STATE_CLOSING;
m_timer.expires_from_now(boost::posix_time::milliseconds(1000));
m_timer.async_wait(
boost::bind(
&session::handle_close_expired,
shared_from_this(),
boost::asio::placeholders::error
)
);
m_local_close_code = status;
m_local_close_msg = message;
m_write_frame.set_fin(true);
m_write_frame.set_opcode(frame::CONNECTION_CLOSE);
// echo close value unless there is a good reason not to.
if (status == CLOSE_STATUS_NO_STATUS) {
m_write_frame.set_status(CLOSE_STATUS_NORMAL,"");
} else if (status == CLOSE_STATUS_ABNORMAL_CLOSE) {
// Internal implimentation error. There is no good close code for this.
m_write_frame.set_status(CLOSE_STATUS_POLICY_VIOLATION,message);
} else if (close::status::invalid(status)) {
m_write_frame.set_status(close::status::PROTOCOL_ERROR,"Status code is invalid");
} else if (close::status::reserved(status)) {
m_write_frame.set_status(close::status::PROTOCOL_ERROR,"Status code is reserved");
} else {
m_write_frame.set_status(status,message);
}
write_frame();
}
void session::ping(const std::string &msg) {
if (m_state != STATE_OPEN) {
log("Tried to send a ping from a session that wasn't open",LOG_WARN);
return;
}
m_write_frame.set_fin(true);
m_write_frame.set_opcode(frame::PING);
m_write_frame.set_payload(msg);
write_frame();
}
void session::pong(const std::string &msg) {
if (m_state != STATE_OPEN) {
log("Tried to send a pong from a session that wasn't open",LOG_WARN);
return;
}
m_write_frame.set_fin(true);
m_write_frame.set_opcode(frame::PONG);
m_write_frame.set_payload(msg);
write_frame();
}
void session::read_frame() {
// the initial read in the handshake may have read in the first frame.
// handle it (if it exists) before we read anything else.
handle_read_frame(boost::system::error_code());
}
// handle_read_frame reads and processes all socket read commands for the
// session by consuming the read buffer and then starting an async read with
// itself as the callback. The connection is over when this method returns.
void session::handle_read_frame(const boost::system::error_code& error) {
if (m_state != STATE_OPEN && m_state != STATE_CLOSING) {
log("handle_read_frame called in invalid state",LOG_ERROR);
return;
}
if (error) {
if (error == boost::asio::error::eof) {
// if this is a case where we are expecting eof, return, else log & drop
log_error("Recieved EOF",error);
//drop_tcp(false);
//m_state = STATE_CLOSED;
} else if (error == boost::asio::error::operation_aborted) {
// some other part of our client called shutdown on our socket.
// This is usually due to a write error. Everything should have
// already been logged and dropped so we just return here
return;
} else {
log_error("Error reading frame",error);
//drop_tcp(false);
m_state = STATE_CLOSED;
}
}
std::istream s(&m_buf);
while (m_buf.size() > 0 && m_state != STATE_CLOSED) {
try {
if (m_read_frame.get_bytes_needed() == 0) {
throw frame_error("have bytes that no frame needs",frame::FERR_FATAL_SESSION_ERROR);
}
// Consume will read bytes from s
// will throw a frame_error on error.
std::stringstream err;
err << "consuming. have: " << m_buf.size() << " bytes. Need: " << m_read_frame.get_bytes_needed() << " state: " << (int)m_read_frame.get_state();
log(err.str(),LOG_DEBUG);
m_read_frame.consume(s);
err.str("");
err << "consume complete, " << m_buf.size() << " bytes left, " << m_read_frame.get_bytes_needed() << " still needed, state: " << (int)m_read_frame.get_state();
log(err.str(),LOG_DEBUG);
if (m_read_frame.get_state() == frame::STATE_READY) {
// process frame and reset frame state for the next frame.
// will throw a frame_error on error. May set m_state to CLOSED,
// if so no more frames should be processed.
err.str("");
err << "processing frame " << m_buf.size();
log(err.str(),LOG_DEBUG);
m_timer.cancel();
process_frame();
}
} catch (const frame_error& e) {
std::stringstream err;
err << "Caught frame exception: " << e.what();
access_log(e.what(),ALOG_FRAME);
log(err.str(),LOG_ERROR);
// if the exception happened while processing.
// TODO: this is not elegant, perhaps separate frame read vs process
// exceptions need to be used.
if (m_read_frame.get_state() == frame::STATE_READY) {
m_read_frame.reset();
}
// process different types of frame errors
//
if (e.code() == frame::FERR_PROTOCOL_VIOLATION) {
send_close(CLOSE_STATUS_PROTOCOL_ERROR, e.what());
} else if (e.code() == frame::FERR_PAYLOAD_VIOLATION) {
send_close(CLOSE_STATUS_INVALID_PAYLOAD, e.what());
} else if (e.code() == frame::FERR_INTERNAL_SERVER_ERROR) {
send_close(CLOSE_STATUS_ABNORMAL_CLOSE, e.what());
} else if (e.code() == frame::FERR_SOFT_SESSION_ERROR) {
// ignore and continue processing frames
continue;
} else {
// Fatal error, forcibly end connection immediately.
log("Dropping TCP due to unrecoverable exception",LOG_DEBUG);
drop_tcp(true);
}
break;
}
}
if (error == boost::asio::error::eof) {
m_state = STATE_CLOSED;
}
// we have read everything, check if we should read more
if ((m_state == STATE_OPEN || m_state == STATE_CLOSING) && m_read_frame.get_bytes_needed() > 0) {
std::stringstream msg;
msg << "starting async read for " << m_read_frame.get_bytes_needed() << " bytes.";
log(msg.str(),LOG_DEBUG);
// TODO: set a timer here in case we don't want to read forever.
// Ex: when the frame is in a degraded state.
boost::asio::async_read(
m_socket,
m_buf,
boost::asio::transfer_at_least(m_read_frame.get_bytes_needed()),
boost::bind(
&session::handle_read_frame,
shared_from_this(),
boost::asio::placeholders::error
)
);
} else if (m_state == STATE_CLOSED) {
log_close_result();
if (m_local_interface) {
// TODO: make sure close code/msg are properly set.
m_local_interface->on_close(shared_from_this());
}
m_timer.cancel();
} else {
log("handle_read_frame called in invalid state",LOG_ERROR);
}
}
void session::process_frame () {
log("process_frame",LOG_DEBUG);
if (m_state == STATE_OPEN) {
switch (m_read_frame.get_opcode()) {
case frame::CONTINUATION_FRAME:
process_continuation();
break;
case frame::TEXT_FRAME:
process_text();
break;
case frame::BINARY_FRAME:
process_binary();
break;
case frame::CONNECTION_CLOSE:
log("process_close",LOG_DEBUG);
process_close();
break;
case frame::PING:
process_ping();
break;
case frame::PONG:
process_pong();
break;
default:
throw frame_error("Invalid Opcode",
frame::FERR_PROTOCOL_VIOLATION);
break;
}
} else if (m_state == STATE_CLOSING) {
if (m_read_frame.get_opcode() == frame::CONNECTION_CLOSE) {
process_close();
} else {
// Ignore all other frames in closing state
log("ignoring this frame",LOG_DEBUG);
}
} else {
// Recieved message before or after connection was opened/closed
throw frame_error("process_frame called from invalid state");
}
m_read_frame.reset();
}
void session::handle_write_frame (const boost::system::error_code& error) {
if (error) {
log_error("Error writing frame data",error);
drop_tcp(false);
}
access_log("handle_write_frame complete",ALOG_FRAME);
m_writing = false;
}
void session::handle_timer_expired (const boost::system::error_code& error) {
if (error) {
if (error == boost::asio::error::operation_aborted) {
log("timer was aborted",LOG_DEBUG);
//drop_tcp(false);
} else {
log("timer ended with error",LOG_DEBUG);
}
return;
}
log("timer ended without error",LOG_DEBUG);
}
void session::handle_handshake_expired (const boost::system::error_code& error) {
if (error) {
if (error != boost::asio::error::operation_aborted) {
log("Unexpected handshake timer error.",LOG_DEBUG);
drop_tcp(true);
}
return;
}
log("Handshake timed out",LOG_DEBUG);
drop_tcp(true);
}
// The error timer is set when we want to give the other endpoint some time to
// do something but don't want to wait forever. There is a special error code
// that represents the timer being canceled by us (because the other endpoint
// responded in time. All other cases should assume that the other endpoint is
// irrepairibly broken and drop the TCP connection.
void session::handle_error_timer_expired (const boost::system::error_code& error) {
if (error) {
if (error == boost::asio::error::operation_aborted) {
log("error timer was aborted",LOG_DEBUG);
//drop_tcp(false);
} else {
log("error timer ended with error",LOG_DEBUG);
drop_tcp(true);
}
return;
}
log("error timer ended without error",LOG_DEBUG);
drop_tcp(true);
}
void session::handle_close_expired (const boost::system::error_code& error) {
if (error) {
if (error == boost::asio::error::operation_aborted) {
log("timer was aborted",LOG_DEBUG);
//drop_tcp(false);
} else {
log("Unexpected close timer error.",LOG_DEBUG);
drop_tcp(false);
}
return;
}
if (m_state != STATE_CLOSED) {
log("close timed out",LOG_DEBUG);
drop_tcp(false);
}
}
void session::process_ping() {
access_log("Ping",ALOG_MISC_CONTROL);
// TODO: on_ping
// send pong
m_write_frame.set_fin(true);
m_write_frame.set_opcode(frame::PONG);
m_write_frame.set_payload(m_read_frame.get_payload());
write_frame();
}
void session::process_pong() {
access_log("Pong",ALOG_MISC_CONTROL);
// TODO: on_pong
}
void session::process_text() {
// this will throw an exception if validation fails at any point
m_read_frame.validate_utf8(&m_utf8_state,&m_utf8_codepoint);
// otherwise, treat as binary
process_binary();
}
void session::process_binary() {
if (m_fragmented) {
throw frame_error("Got a new message before the previous was finished.",
frame::FERR_PROTOCOL_VIOLATION);
}
m_current_opcode = m_read_frame.get_opcode();
if (m_read_frame.get_fin()) {
deliver_message();
reset_message();
} else {
m_fragmented = true;
extract_payload();
}
}
void session::process_continuation() {
if (!m_fragmented) {
throw frame_error("Got a continuation frame without an outstanding message.",
frame::FERR_PROTOCOL_VIOLATION);
}
if (m_current_opcode == frame::TEXT_FRAME) {
// this will throw an exception if validation fails at any point
m_read_frame.validate_utf8(&m_utf8_state,&m_utf8_codepoint);
}
extract_payload();
// check if we are done
if (m_read_frame.get_fin()) {
deliver_message();
reset_message();
}
}
void session::process_close() {
m_remote_close_code = m_read_frame.get_close_status();
m_remote_close_msg = m_read_frame.get_close_msg();
if (m_state == STATE_OPEN) {
log("process_close sending ack",LOG_DEBUG);
// This is the case where the remote initiated the close.
m_closed_by_me = false;
// send acknowledgement
// check if the remote close code
if (m_remote_close_code >= close::status::RSV_START) {
}
send_close(m_remote_close_code,m_remote_close_msg);
} else if (m_state == STATE_CLOSING) {
log("process_close got ack",LOG_DEBUG);
// this is an ack of our close message
m_closed_by_me = true;
} else {
throw frame_error("process_closed called from wrong state");
}
m_was_clean = true;
m_state = STATE_CLOSED;
}
void session::deliver_message() {
if (!m_local_interface) {
return;
}
if (m_current_opcode == frame::BINARY_FRAME) {
//log("Dispatching Binary Message",LOG_DEBUG);
if (m_fragmented) {
m_local_interface->on_message(shared_from_this(),m_current_message);
} else {
m_local_interface->on_message(shared_from_this(),
m_read_frame.get_payload());
}
} else if (m_current_opcode == frame::TEXT_FRAME) {
std::string msg;
// make sure the finished frame is valid utf8
// the streaming validator checks for bad codepoints as it goes. It
// doesn't know where the end of the message is though, so we need to
// check here to make sure the final message ends on a valid codepoint.
if (m_utf8_state != utf8_validator::UTF8_ACCEPT) {
throw frame_error("Invalid UTF-8 Data",
frame::FERR_PAYLOAD_VIOLATION);
}
if (m_fragmented) {
msg.append(m_current_message.begin(),m_current_message.end());
} else {
msg.append(
m_read_frame.get_payload().begin(),
m_read_frame.get_payload().end()
);
}
//log("Dispatching Text Message",LOG_DEBUG);
m_local_interface->on_message(shared_from_this(),msg);
} else {
// Not sure if this should be a fatal error or not
std::stringstream err;
err << "Attempted to deliver a message of unsupported opcode " << m_current_opcode;
throw frame_error(err.str(),frame::FERR_SOFT_SESSION_ERROR);
}
}
void session::extract_payload() {
std::vector<unsigned char> &msg = m_read_frame.get_payload();
m_current_message.resize(m_current_message.size()+msg.size());
std::copy(msg.begin(),msg.end(),m_current_message.end()-msg.size());
}
void session::write_frame() {
if (!is_server()) {
m_write_frame.set_masked(true); // client must mask frames
}
m_write_frame.process_payload();
std::vector<boost::asio::mutable_buffer> data;
data.push_back(
boost::asio::buffer(
m_write_frame.get_header(),
m_write_frame.get_header_len()
)
);
data.push_back(
boost::asio::buffer(m_write_frame.get_payload())
);
log("Write Frame: "+m_write_frame.print_frame(),LOG_DEBUG);
m_writing = true;
boost::asio::async_write(
m_socket,
data,
boost::bind(
&session::handle_write_frame,
shared_from_this(),
boost::asio::placeholders::error
)
);
}
void session::reset_message() {
m_error = false;
m_fragmented = false;
m_current_message.clear();
m_utf8_state = utf8_validator::UTF8_ACCEPT;
m_utf8_codepoint = 0;
}
void session::log_close_result() {
std::stringstream msg;
msg << "[Connection " << this << "] "
<< (m_was_clean ? "Clean " : "Unclean ")
<< "close local:[" << m_local_close_code
<< (m_local_close_msg == "" ? "" : ","+m_local_close_msg)
<< "] remote:[" << m_remote_close_code
<< (m_remote_close_msg == "" ? "" : ","+m_remote_close_msg) << "]";
access_log(msg.str(),ALOG_DISCONNECT);
}
void session::log_open_result() {
std::stringstream msg;
msg << "[Connection " << this << "] "
<< m_socket.remote_endpoint()
<< " v" << m_version << " "
<< (get_client_header("User-Agent") == "" ? "NULL" : get_client_header("User-Agent"))
<< " " << m_resource << " " << m_server_http_code;
access_log(msg.str(),ALOG_HANDSHAKE);
}
// this is called when an async asio call encounters an error
void session::log_error(std::string msg,const boost::system::error_code& e) {
std::stringstream err;
err << "[Connection " << this << "] " << msg << " (" << e << ")";
log(err.str(),LOG_ERROR);
}
// validates status codes that the end application is allowed to use
bool session::validate_app_close_status(uint16_t status) {
if (status == CLOSE_STATUS_NORMAL) {
return true;
}
if (status >= 4000 && status < 5000) {
return true;
}
return false;
}
void session::drop_tcp(bool dropped_by_me) {
m_timer.cancel();
try {
if (m_socket.is_open()) {
m_socket.shutdown(tcp::socket::shutdown_both);
m_socket.close();
}
} catch (boost::system::system_error& e) {
if (e.code() == boost::asio::error::not_connected) {
// this means the socket was disconnected by the other side before
// we had a chance to. Ignore and continue.
} else {
throw e;
}
}
m_dropped_by_me = dropped_by_me;
m_state = STATE_CLOSED;
}

View File

@@ -77,6 +77,7 @@
#include <boost/asio.hpp>
#include <boost/bind.hpp>
#include <boost/algorithm/string.hpp>
#include <boost/shared_ptr.hpp>
#include <boost/enable_shared_from_this.hpp>
@@ -87,6 +88,7 @@
#include <arpa/inet.h>
#endif
#include <cstdlib>
#include <algorithm>
#include <exception>
#include <iostream>
@@ -109,6 +111,7 @@ namespace websocketpp {
#include "base64/base64.h"
#include "sha1/sha1.h"
#include "utf8_validator/utf8_validator.hpp"
using boost::asio::ip::tcp;
@@ -116,33 +119,54 @@ namespace websocketpp {
typedef std::map<std::string,std::string> header_list;
class session : public boost::enable_shared_from_this<session> {
template <template <class rng_policy> endpoint_policy>
class session : public boost::enable_shared_from_this<session<endpoint_policy<rng_policy> > > {
public:
typedef endpoint_policy<rng_policy> endpoint_t;
typedef boost::shared_ptr<session<endpoint_t> > ptr;
friend class handshake_error;
static const uint8_t STATE_CONNECTING = 0;
static const uint8_t STATE_OPEN = 1;
static const uint8_t STATE_CLOSING = 2;
static const uint8_t STATE_CLOSED = 3;
static const uint16_t CLOSE_STATUS_NORMAL = 1000;
static const uint16_t CLOSE_STATUS_GOING_AWAY = 1001;
static const uint16_t CLOSE_STATUS_PROTOCOL_ERROR = 1002;
static const uint16_t CLOSE_STATUS_UNSUPPORTED_DATA = 1003;
static const uint16_t CLOSE_STATUS_NO_STATUS = 1005;
static const uint16_t CLOSE_STATUS_ABNORMAL_CLOSE = 1006;
static const uint16_t CLOSE_STATUS_INVALID_PAYLOAD = 1007;
static const uint16_t CLOSE_STATUS_POLICY_VIOLATION = 1008;
static const uint16_t CLOSE_STATUS_MESSAGE_TOO_BIG = 1009;
static const uint16_t CLOSE_STATUS_EXTENSION_REQUIRE = 1010;
static const uint16_t CLOSE_STATUS_MAXIMUM = 1011;
namespace state {
enum value {
CONNECTING = 0,
OPEN = 1,
CLOSING = 2,
CLOSED = 3
};
}
session (boost::asio::io_service& io_service,
session (endpoint_t::ptr e,
boost::asio::io_service& io_service,
connection_handler_ptr defc,
uint64_t buf_size);
uint64_t buf_size)
: m_state(STATE_CONNECTING),
m_writing(false),
m_local_close_code(close::status::NO_STATUS),
m_remote_close_code(close::status::NO_STATUS),
m_was_clean(false),
m_closed_by_me(false),
m_dropped_by_me(false),
m_socket(io_service),
m_io_service(io_service),
m_local_interface(defc),
m_timer(io_service,boost::posix_time::seconds(0)),
m_buf(buf_size), // maximum buffered (unconsumed) bytes from network
m_utf8_state(utf8_validator::UTF8_ACCEPT),
m_utf8_codepoint(0),
m_read_frame(e->get_rng()),
m_write_frame(e->get_rng())
{
}
tcp::socket& socket();
boost::asio::io_service& io_service();
tcp::socket& socket() {
return m_socket;
}
boost::asio::io_service& io_service() {
return m_io_service;
}
/*** SERVER INTERFACE ***/
@@ -157,7 +181,14 @@ public:
// Example: a generic lobby handler could validate the handshake negotiate a
// sub protocol to talk to and then pass the connection off to a handler for
// that sub protocol.
void set_handler(connection_handler_ptr new_con);
void set_handler(connection_handler_ptr new_con) {
if (m_local_interface) {
// TODO: this should be another method and not reusing onclose
//m_local_interface->disconnect(shared_from_this(),4000,"Setting new connection handler");
}
m_local_interface = new_con;
m_local_interface->on_open(shared_from_this());
}
/*** HANDSHAKE INTERFACE ***/
@@ -168,24 +199,87 @@ public:
// returns the subprotocol that was negotiated during the opening handshake
// or the empty string if no subprotocol was requested.
const std::string& get_subprotocol() const;
const std::string& get_resource() const;
const std::string& get_origin() const;
std::string get_client_header(const std::string& key) const;
std::string get_server_header(const std::string& key) const;
const std::vector<std::string>& get_extensions() const;
unsigned int get_version() const;
const std::string& get_subprotocol() const {
if (m_state == state::CONNECTING) {
log("Subprotocol is not avaliable before the handshake has completed.",LOG_WARN);
throw server_error("Subprotocol is not avaliable before the handshake has completed.");
}
return m_server_subprotocol;
}
const std::string& get_resource() const {
return m_resource;
}
const std::string& get_origin() const {
return m_client_origin;
}
std::string get_client_header(const std::string& key) const {
return get_header(key,m_client_headers);
}
std::string get_server_header(const std::string& key) const {
return get_header(key,m_server_headers);
}
const std::vector<std::string>& get_extensions() const {
return m_server_extensions;
}
unsigned int get_version() const {
return m_version;
}
/*** SESSION INTERFACE ***/
// send basic frame types
void send(const std::string &msg); // text
void send(const std::vector<unsigned char> &data); // binary
void ping(const std::string &msg);
void pong(const std::string &msg);
void send(const std::string &msg) {
if (m_state != state::OPEN) {
log("Tried to send a message from a session that wasn't open",LOG_WARN);
return;
}
m_write_frame.set_fin(true);
m_write_frame.set_opcode(frame::opcode::TEXT);
m_write_frame.set_payload(msg);
write_frame();
}
void send(const std::vector<unsigned char> &data) {
if (m_state != state::OPEN) {
log("Tried to send a message from a session that wasn't open",LOG_WARN);
return;
}
m_write_frame.set_fin(true);
m_write_frame.set_opcode(frame::opcode::BINARY);
m_write_frame.set_payload(data);
write_frame();
}
void ping(const std::string &msg) {
if (m_state != state::OPEN) {
log("Tried to send a ping from a session that wasn't open",LOG_WARN);
return;
}
m_write_frame.set_fin(true);
m_write_frame.set_opcode(frame::opcode::PING);
m_write_frame.set_payload(msg);
write_frame();
}
void pong(const std::string &msg) {
if (m_state != state::OPEN) {
log("Tried to send a pong from a session that wasn't open",LOG_WARN);
return;
}
m_write_frame.set_fin(true);
m_write_frame.set_opcode(frame::opcode::PONG);
m_write_frame.set_payload(msg);
write_frame();
}
// initiate a connection close
void close(uint16_t status,const std::string &reason);
void close(uint16_t status,const std::string &reason) {
validate_app_close_status(status);
send_close(status,msg);
}
virtual bool is_server() const = 0;
@@ -198,53 +292,561 @@ public: //protected:
virtual void write_handshake() = 0;
virtual void read_handshake() = 0;
void read_frame();
void handle_read_frame (const boost::system::error_code& error);
void read_frame() {
// the initial read in the handshake may have read in the first frame.
// handle it (if it exists) before we read anything else.
handle_read_frame(boost::system::error_code());
}
// handle_read_frame reads and processes all socket read commands for the
// session by consuming the read buffer and then starting an async read with
// itself as the callback. The connection is over when this method returns.
void handle_read_frame (const boost::system::error_code& error) {
if (m_state != state::OPEN && m_state != state::CLOSING) {
log("handle_read_frame called in invalid state",LOG_ERROR);
return;
}
if (error) {
if (error == boost::asio::error::eof) {
// if this is a case where we are expecting eof, return, else log & drop
log_error("Recieved EOF",error);
//drop_tcp(false);
//m_state = STATE_CLOSED;
} else if (error == boost::asio::error::operation_aborted) {
// some other part of our client called shutdown on our socket.
// This is usually due to a write error. Everything should have
// already been logged and dropped so we just return here
return;
} else {
log_error("Error reading frame",error);
//drop_tcp(false);
m_state = state::CLOSED;
}
}
std::istream s(&m_buf);
while (m_buf.size() > 0 && m_state != state::CLOSED) {
try {
if (m_read_frame.get_bytes_needed() == 0) {
throw frame_error("have bytes that no frame needs",frame::error::FATAL_SESSION_ERROR);
}
// Consume will read bytes from s
// will throw a frame_error on error.
std::stringstream err;
err << "consuming. have: " << m_buf.size() << " bytes. Need: " << m_read_frame.get_bytes_needed() << " state: " << (int)m_read_frame.get_state();
log(err.str(),LOG_DEBUG);
m_read_frame.consume(s);
err.str("");
err << "consume complete, " << m_buf.size() << " bytes left, " << m_read_frame.get_bytes_needed() << " still needed, state: " << (int)m_read_frame.get_state();
log(err.str(),LOG_DEBUG);
if (m_read_frame.ready()) {
// process frame and reset frame state for the next frame.
// will throw a frame_error on error. May set m_state to CLOSED,
// if so no more frames should be processed.
err.str("");
err << "processing frame " << m_buf.size();
log(err.str(),LOG_DEBUG);
m_timer.cancel();
process_frame();
}
} catch (const frame_error& e) {
std::stringstream err;
err << "Caught frame exception: " << e.what();
access_log(e.what(),ALOG_FRAME);
log(err.str(),LOG_ERROR);
// if the exception happened while processing.
// TODO: this is not elegant, perhaps separate frame read vs process
// exceptions need to be used.
if (m_read_frame.ready()) {
m_read_frame.reset();
}
// process different types of frame errors
//
if (e.code() == frame::FERR_PROTOCOL_VIOLATION) {
send_close(close::status::PROTOCOL_ERROR, e.what());
} else if (e.code() == frame::FERR_PAYLOAD_VIOLATION) {
send_close(close::status::INVALID_PAYLOAD, e.what());
} else if (e.code() == frame::FERR_INTERNAL_SERVER_ERROR) {
send_close(close::status::ABNORMAL_CLOSE, e.what());
} else if (e.code() == frame::FERR_SOFT_SESSION_ERROR) {
// ignore and continue processing frames
continue;
} else {
// Fatal error, forcibly end connection immediately.
log("Dropping TCP due to unrecoverable exception",LOG_DEBUG);
drop_tcp(true);
}
break;
}
}
if (error == boost::asio::error::eof) {
m_state = state::CLOSED;
}
// we have read everything, check if we should read more
if ((m_state == state::OPEN || m_state == state::CLOSING) && m_read_frame.get_bytes_needed() > 0) {
std::stringstream msg;
msg << "starting async read for " << m_read_frame.get_bytes_needed() << " bytes.";
log(msg.str(),LOG_DEBUG);
// TODO: set a timer here in case we don't want to read forever.
// Ex: when the frame is in a degraded state.
boost::asio::async_read(
m_socket,
m_buf,
boost::asio::transfer_at_least(m_read_frame.get_bytes_needed()),
boost::bind(
&session::handle_read_frame,
shared_from_this(),
boost::asio::placeholders::error
)
);
} else if (m_state == state::CLOSED) {
log_close_result();
if (m_local_interface) {
// TODO: make sure close code/msg are properly set.
m_local_interface->on_close(shared_from_this());
}
m_timer.cancel();
} else {
log("handle_read_frame called in invalid state",LOG_ERROR);
}
}
// write m_write_frame out to the socket.
void write_frame();
void handle_write_frame (const boost::system::error_code& error);
void handle_write_frame (const boost::system::error_code& error) {
if (error) {
log_error("Error writing frame data",error);
drop_tcp(false);
}
access_log("handle_write_frame complete",ALOG_FRAME);
m_writing = false;
}
void handle_timer_expired(const boost::system::error_code& error);
void handle_handshake_expired(const boost::system::error_code& error);
void handle_close_expired(const boost::system::error_code& error);
void handle_error_timer_expired (const boost::system::error_code& error);
void handle_timer_expired(const boost::system::error_code& error) {
if (error) {
if (error == boost::asio::error::operation_aborted) {
log("timer was aborted",LOG_DEBUG);
//drop_tcp(false);
} else {
log("timer ended with error",LOG_DEBUG);
}
return;
}
log("timer ended without error",LOG_DEBUG);
}
void handle_handshake_expired(const boost::system::error_code& error) {
if (error) {
if (error != boost::asio::error::operation_aborted) {
log("Unexpected handshake timer error.",LOG_DEBUG);
drop_tcp(true);
}
return;
}
log("Handshake timed out",LOG_DEBUG);
drop_tcp(true);
}
void handle_close_expired(const boost::system::error_code& error) {
if (error) {
if (error == boost::asio::error::operation_aborted) {
log("timer was aborted",LOG_DEBUG);
//drop_tcp(false);
} else {
log("Unexpected close timer error.",LOG_DEBUG);
drop_tcp(false);
}
return;
}
if (m_state != STATE_CLOSED) {
log("close timed out",LOG_DEBUG);
drop_tcp(false);
}
}
// The error timer is set when we want to give the other endpoint some time to
// do something but don't want to wait forever. There is a special error code
// that represents the timer being canceled by us (because the other endpoint
// responded in time. All other cases should assume that the other endpoint is
// irrepairibly broken and drop the TCP connection.
void handle_error_timer_expired (const boost::system::error_code& error) {
if (error) {
if (error == boost::asio::error::operation_aborted) {
log("error timer was aborted",LOG_DEBUG);
//drop_tcp(false);
} else {
log("error timer ended with error",LOG_DEBUG);
drop_tcp(true);
}
return;
}
log("error timer ended without error",LOG_DEBUG);
drop_tcp(true);
}
// helper functions for processing each opcode
void process_frame();
void process_ping();
void process_pong();
void process_text();
void process_binary();
void process_continuation();
void process_close();
void process_frame() {
log("process_frame",LOG_DEBUG);
if (m_state == state::OPEN) {
switch (m_read_frame.get_opcode()) {
case frame::opcode::CONTINUATION:
process_continuation();
break;
case frame::opcode::TEXT:
process_text();
break;
case frame::opcode::BINARY:
process_binary();
break;
case frame::opcode::CLOSE:
log("process_close",LOG_DEBUG);
process_close();
break;
case frame::opcode::PING:
process_ping();
break;
case frame::opcode::PONG:
process_pong();
break;
default:
throw frame::exception("Invalid Opcode",
frame::FERR_PROTOCOL_VIOLATION);
break;
}
} else if (m_state == state::CLOSING) {
if (m_read_frame.get_opcode() == frame::opcode::CLOSE) {
process_close();
} else {
// Ignore all other frames in closing state
log("ignoring this frame",LOG_DEBUG);
}
} else {
// Recieved message before or after connection was opened/closed
throw frame::exception("process_frame called from invalid state");
}
m_read_frame.reset();
}
void process_ping() {
access_log("Ping",ALOG_MISC_CONTROL);
// TODO: on_ping
// send pong
m_write_frame.set_fin(true);
m_write_frame.set_opcode(frame::PONG);
m_write_frame.set_payload(m_read_frame.get_payload());
write_frame();
}
void process_pong() {
access_log("Pong",ALOG_MISC_CONTROL);
// TODO: on_pong
}
void process_text() {
// this will throw an exception if validation fails at any point
m_read_frame.validate_utf8(&m_utf8_state,&m_utf8_codepoint);
// otherwise, treat as binary
process_binary();
}
void process_binary() {
if (m_fragmented) {
throw frame::exception("Got a new message before the previous was finished.",frame::FERR_PROTOCOL_VIOLATION);
}
m_current_opcode = m_read_frame.get_opcode();
if (m_read_frame.get_fin()) {
deliver_message();
reset_message();
} else {
m_fragmented = true;
extract_payload();
}
}
void process_continuation() {
if (!m_fragmented) {
throw frame::exception("Got a continuation frame without an outstanding message.",frame::FERR_PROTOCOL_VIOLATION);
}
if (m_current_opcode == frame::opcode::TEXT) {
// this will throw an exception if validation fails at any point
m_read_frame.validate_utf8(&m_utf8_state,&m_utf8_codepoint);
}
extract_payload();
// check if we are done
if (m_read_frame.get_fin()) {
deliver_message();
reset_message();
}
}
void process_close() {
m_remote_close_code = m_read_frame.get_close_status();
m_remote_close_msg = m_read_frame.get_close_msg();
if (m_state == state::OPEN) {
log("process_close sending ack",LOG_DEBUG);
// This is the case where the remote initiated the close.
m_closed_by_me = false;
// send acknowledgement
// TODO: check if the remote close code
if (m_remote_close_code >= close::status::RSV_START) {
}
send_close(m_remote_close_code,m_remote_close_msg);
} else if (m_state == state::CLOSING) {
log("process_close got ack",LOG_DEBUG);
// this is an ack of our close message
m_closed_by_me = true;
} else {
throw frame::exception("process_closed called from wrong state");
}
m_was_clean = true;
m_state = state::CLOSED;
}
// deliver message if we have a local interface attached
void deliver_message();
void deliver_message() {
if (!m_local_interface) {
return;
}
if (m_current_opcode == frame::opcode::BINARY) {
//log("Dispatching Binary Message",LOG_DEBUG);
if (m_fragmented) {
m_local_interface->on_message(shared_from_this(),m_current_message);
} else {
m_local_interface->on_message(shared_from_this(),
m_read_frame.get_payload());
}
} else if (m_current_opcode == frame::opcode::TEXT) {
std::string msg;
// make sure the finished frame is valid utf8
// the streaming validator checks for bad codepoints as it goes. It
// doesn't know where the end of the message is though, so we need to
// check here to make sure the final message ends on a valid codepoint.
if (m_utf8_state != utf8_validator::UTF8_ACCEPT) {
throw frame::exception("Invalid UTF-8 Data",
frame::FERR_PAYLOAD_VIOLATION);
}
if (m_fragmented) {
msg.append(m_current_message.begin(),m_current_message.end());
} else {
msg.append(
m_read_frame.get_payload().begin(),
m_read_frame.get_payload().end()
);
}
//log("Dispatching Text Message",LOG_DEBUG);
m_local_interface->on_message(shared_from_this(),msg);
} else {
// Not sure if this should be a fatal error or not
std::stringstream err;
err << "Attempted to deliver a message of unsupported opcode " << m_current_opcode;
throw frame::exception(err.str(),frame::error::SOFT_SESSION_ERROR);
}
}
// copies the current read frame payload into the session so that the read
// frame can be cleared for the next read. This is done when fragmented
// messages are recieved.
void extract_payload();
void extract_payload() {
std::vector<unsigned char> &msg = m_read_frame.get_payload();
m_current_message.resize(m_current_message.size()+msg.size());
std::copy(msg.begin(),msg.end(),m_current_message.end()-msg.size());
}
// reset session for a new message
void reset_message();
void reset_message() {
m_error = false;
m_fragmented = false;
m_current_message.clear();
m_utf8_state = utf8_validator::UTF8_ACCEPT;
m_utf8_codepoint = 0;
}
// logging
virtual void log(const std::string& msg, uint16_t level) const = 0;
virtual void access_log(const std::string& msg, uint16_t level) const = 0;
void log_close_result();
void log_open_result();
void log_error(std::string msg,const boost::system::error_code& e);
void log_close_result() {
std::stringstream msg;
msg << "[Connection " << this << "] "
<< (m_was_clean ? "Clean " : "Unclean ")
<< "close local:[" << m_local_close_code
<< (m_local_close_msg == "" ? "" : ","+m_local_close_msg)
<< "] remote:[" << m_remote_close_code
<< (m_remote_close_msg == "" ? "" : ","+m_remote_close_msg) << "]";
access_log(msg.str(),ALOG_DISCONNECT);
}
void log_open_result() {
std::stringstream msg;
msg << "[Connection " << this << "] "
<< m_socket.remote_endpoint()
<< " v" << m_version << " "
<< (get_client_header("User-Agent") == "" ? "NULL" : get_client_header("User-Agent"))
<< " " << m_resource << " " << m_server_http_code;
access_log(msg.str(),ALOG_HANDSHAKE);
}
// this is called when an async asio call encounters an error
void log_error(std::string msg,const boost::system::error_code& e) {
std::stringstream err;
err << "[Connection " << this << "] " << msg << " (" << e << ")";
log(err.str(),LOG_ERROR);
}
// misc helpers
bool validate_app_close_status(uint16_t status);
void send_close(uint16_t status,const std::string& reason);
void drop_tcp(bool dropped_by_me = true);
// validates status codes that the end application is allowed to use
bool validate_app_close_status(close::status::value status) {
if (status == close::status::NORMAL) {
return true;
}
if (status >= 4000 && status < 5000) {
return true;
}
return false;
}
void send_close(close::status::value status,const std::string& reason) {
if (m_state != state::OPEN) {
log("Tried to disconnect a session that wasn't open",LOG_WARN);
return;
}
m_state = state::CLOSING;
m_timer.expires_from_now(boost::posix_time::milliseconds(1000));
m_timer.async_wait(
boost::bind(
&session::handle_close_expired,
shared_from_this(),
boost::asio::placeholders::error
)
);
m_local_close_code = status;
m_local_close_msg = message;
m_write_frame.set_fin(true);
m_write_frame.set_opcode(frame::opcode::CLOSE);
// echo close value unless there is a good reason not to.
if (status == close::status::NO_STATUS) {
m_write_frame.set_status(close::status::NORMAL,"");
} else if (status == close::status::ABNORMAL_CLOSE) {
// Internal implimentation error. There is no good close code for this.
m_write_frame.set_status(close::status::POLICY_VIOLATION,message);
} else if (close::status::invalid(status)) {
m_write_frame.set_status(close::status::PROTOCOL_ERROR,"Status code is invalid");
} else if (close::status::reserved(status)) {
m_write_frame.set_status(close::status::PROTOCOL_ERROR,"Status code is reserved");
} else {
m_write_frame.set_status(status,message);
}
write_frame() {
if (!is_server()) {
m_write_frame.set_masked(true); // client must mask frames
}
m_write_frame.process_payload();
std::vector<boost::asio::mutable_buffer> data;
data.push_back(
boost::asio::buffer(
m_write_frame.get_header(),
m_write_frame.get_header_len()
)
);
data.push_back(
boost::asio::buffer(m_write_frame.get_payload())
);
log("Write Frame: "+m_write_frame.print_frame(),LOG_DEBUG);
m_writing = true;
boost::asio::async_write(
m_socket,
data,
boost::bind(
&session::handle_write_frame,
shared_from_this(),
boost::asio::placeholders::error
)
);
}
}
void drop_tcp(bool dropped_by_me = true) {
m_timer.cancel();
try {
if (m_socket.is_open()) {
m_socket.shutdown(tcp::socket::shutdown_both);
m_socket.close();
}
} catch (boost::system::system_error& e) {
if (e.code() == boost::asio::error::not_connected) {
// this means the socket was disconnected by the other side before
// we had a chance to. Ignore and continue.
} else {
throw e;
}
}
m_dropped_by_me = dropped_by_me;
m_state = state::CLOSED;
}
private:
std::string get_header(const std::string& key,
const header_list& list) const;
const header_list& list) const {
header_list::const_iterator h = list.find(key);
if (h == list.end()) {
return "";
} else {
return h->second;
}
}
protected:
// Immutable state about the current connection from the handshake
@@ -272,9 +874,9 @@ protected:
bool m_writing;
// Close state
uint16_t m_local_close_code;
close::status::value m_local_close_code;
std::string m_local_close_msg;
uint16_t m_remote_close_code;
close::status::value m_remote_close_code;
std::string m_remote_close_msg;
bool m_was_clean;
bool m_closed_by_me;
@@ -294,13 +896,13 @@ protected:
uint32_t m_utf8_codepoint;
std::vector<unsigned char> m_current_message;
bool m_fragmented;
frame::opcode m_current_opcode;
frame::opcode::value m_current_opcode;
// current frame state
frame m_read_frame;
// unorganized
frame m_write_frame;
// frame parsers
frame::parser<rng_policy> m_read_frame;
frame::parser<rng_policy> m_write_frame;
// unknown
bool m_error;
};

View File

@@ -32,70 +32,7 @@
#include <stdint.h>
// Defaults
namespace websocketpp {
const uint64_t DEFAULT_MAX_MESSAGE_SIZE = 0xFFFFFF; // ~16MB
// System logging levels
static const uint16_t LOG_ALL = 0;
static const uint16_t LOG_DEBUG = 1;
static const uint16_t LOG_INFO = 2;
static const uint16_t LOG_WARN = 3;
static const uint16_t LOG_ERROR = 4;
static const uint16_t LOG_FATAL = 5;
static const uint16_t LOG_OFF = 6;
// Access logging controls
// Individual bits
static const uint16_t ALOG_CONNECT = 0x1;
static const uint16_t ALOG_DISCONNECT = 0x2;
static const uint16_t ALOG_MISC_CONTROL = 0x4;
static const uint16_t ALOG_FRAME = 0x8;
static const uint16_t ALOG_MESSAGE = 0x10;
static const uint16_t ALOG_INFO = 0x20;
static const uint16_t ALOG_HANDSHAKE = 0x40;
// Useful groups
static const uint16_t ALOG_OFF = 0x0;
static const uint16_t ALOG_CONTROL = ALOG_CONNECT
& ALOG_DISCONNECT
& ALOG_MISC_CONTROL;
static const uint16_t ALOG_ALL = 0xFFFF;
namespace close {
namespace status {
enum value {
INVALID_END = 999,
NORMAL = 1000,
GOING_AWAY = 1001,
PROTOCOL_ERROR = 1002,
UNSUPPORTED_DATA = 1003,
RSV_ADHOC_1 = 1004,
NO_STATUS = 1005,
ABNORMAL_CLOSE = 1006,
INVALID_PAYLOAD = 1007,
POLICY_VIOLATION = 1008,
MESSAGE_TOO_BIG = 1009,
EXTENSION_REQUIRE = 1010,
RSV_START = 1011,
RSV_END = 2999,
INVALID_START = 5000
};
inline bool reserved(uint16_t s) {
return ((s >= RSV_START && s <= RSV_END) ||
s == RSV_ADHOC_1);
}
inline bool invalid(uint16_t s) {
return ((s <= INVALID_END || s >= INVALID_START) ||
s == NO_STATUS ||
s == ABNORMAL_CLOSE);
}
}
}
}
#include "websocket_constants.hpp"
#include "websocket_session.hpp"
#include "websocket_server_session.hpp"

View File

@@ -70,6 +70,8 @@
B6DF1CE21435F1860029A1B1 /* libboost_system.dylib in Frameworks */ = {isa = PBXBuildFile; fileRef = B6DF1CE11435F1860029A1B1 /* libboost_system.dylib */; };
B6DF1CE41435F8250029A1B1 /* Foundation.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = B6DF1CE31435F8250029A1B1 /* Foundation.framework */; };
B6FE8CEC145A0F1900B32547 /* libboost_program_options.dylib in Frameworks */ = {isa = PBXBuildFile; fileRef = B6FE8CEB145A0F1900B32547 /* libboost_program_options.dylib */; };
B6FE8D06145AFF5F00B32547 /* websocket_constants.hpp in Headers */ = {isa = PBXBuildFile; fileRef = B6FE8D05145AFF5F00B32547 /* websocket_constants.hpp */; };
B6FE8D07145AFF5F00B32547 /* websocket_constants.hpp in Headers */ = {isa = PBXBuildFile; fileRef = B6FE8D05145AFF5F00B32547 /* websocket_constants.hpp */; };
/* End PBXBuildFile section */
/* Begin PBXContainerItemProxy section */
@@ -176,6 +178,11 @@
B6DF1CE31435F8250029A1B1 /* Foundation.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = Foundation.framework; path = System/Library/Frameworks/Foundation.framework; sourceTree = SDKROOT; };
B6FE8CE2144DE17F00B32547 /* readme.txt */ = {isa = PBXFileReference; lastKnownFileType = text; path = readme.txt; sourceTree = "<group>"; };
B6FE8CEB145A0F1900B32547 /* libboost_program_options.dylib */ = {isa = PBXFileReference; lastKnownFileType = "compiled.mach-o.dylib"; name = libboost_program_options.dylib; path = usr/local/lib/libboost_program_options.dylib; sourceTree = SDKROOT; };
B6FE8D05145AFF5F00B32547 /* websocket_constants.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; name = websocket_constants.hpp; path = src/websocket_constants.hpp; sourceTree = "<group>"; };
B6FE8D0A145B0FA400B32547 /* blank_rng.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = blank_rng.cpp; path = src/policy/rng/blank_rng.cpp; sourceTree = "<group>"; };
B6FE8D0B145B0FA400B32547 /* blank_rng.hpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.h; name = blank_rng.hpp; path = src/policy/rng/blank_rng.hpp; sourceTree = "<group>"; };
B6FE8D0C145B0FA400B32547 /* boost_rng.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = boost_rng.cpp; path = src/policy/rng/boost_rng.cpp; sourceTree = "<group>"; };
B6FE8D0D145B0FA400B32547 /* boost_rng.hpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.h; name = boost_rng.hpp; path = src/policy/rng/boost_rng.hpp; sourceTree = "<group>"; };
/* End PBXFileReference section */
/* Begin PBXFrameworksBuildPhase section */
@@ -274,6 +281,8 @@
B6DF1C7F1434ABB70029A1B1 /* src */ = {
isa = PBXGroup;
children = (
B6FE8D08145B0F6600B32547 /* policy */,
B6FE8D05145AFF5F00B32547 /* websocket_constants.hpp */,
B6DF1C921434AC470029A1B1 /* websocket_client_session.cpp */,
B6DF1C931434AC470029A1B1 /* websocket_client_session.hpp */,
B6DF1C941434AC470029A1B1 /* websocket_client.cpp */,
@@ -391,6 +400,25 @@
name = documentation;
sourceTree = "<group>";
};
B6FE8D08145B0F6600B32547 /* policy */ = {
isa = PBXGroup;
children = (
B6FE8D09145B0F7400B32547 /* rng */,
);
name = policy;
sourceTree = "<group>";
};
B6FE8D09145B0F7400B32547 /* rng */ = {
isa = PBXGroup;
children = (
B6FE8D0A145B0FA400B32547 /* blank_rng.cpp */,
B6FE8D0B145B0FA400B32547 /* blank_rng.hpp */,
B6FE8D0C145B0FA400B32547 /* boost_rng.cpp */,
B6FE8D0D145B0FA400B32547 /* boost_rng.hpp */,
);
name = rng;
sourceTree = "<group>";
};
/* End PBXGroup section */
/* Begin PBXHeadersBuildPhase section */
@@ -411,6 +439,7 @@
B6DF1CB41434AC470029A1B1 /* websocket_server.hpp in Headers */,
B6DF1CB81434AC470029A1B1 /* websocket_session.hpp in Headers */,
B6BE76EA144EF53000716A77 /* websocket_endpoint.hpp in Headers */,
B6FE8D06145AFF5F00B32547 /* websocket_constants.hpp in Headers */,
);
runOnlyForDeploymentPostprocessing = 0;
};
@@ -431,6 +460,7 @@
B6DF1CB51434AC470029A1B1 /* websocket_server.hpp in Headers */,
B6DF1CB91434AC470029A1B1 /* websocket_session.hpp in Headers */,
B6BE76EB144EF53000716A77 /* websocket_endpoint.hpp in Headers */,
B6FE8D07145AFF5F00B32547 /* websocket_constants.hpp in Headers */,
);
runOnlyForDeploymentPostprocessing = 0;
};