continues frame reading changes fixes a lot of underspecified close behavior. NOTE: this revision won't compile

This commit is contained in:
Peter Thorson
2011-10-07 08:26:41 -05:00
parent 831da9e12b
commit 9180e52103
4 changed files with 122 additions and 82 deletions

View File

@@ -49,6 +49,7 @@ uint8_t frame::get_state() const {
void frame::reset() {
m_state = STATE_BASIC_HEADER;
m_bytes_needed = BASIC_HEADER_LENGTH;
m_payload.empty();
}
void frame::consume(std::istream &s) {
@@ -78,6 +79,15 @@ void frame::consume(std::istream &s) {
process_extended_header();
m_state = STATE_PAYLOAD;
}
} else if (m_state == STATE_PAYLOAD) {
s.read(reinterpret_cast<char *>(&m_payload[0]),m_bytes_needed);
m_bytes_needed -= s.gcount();
if (m_bytes_needed == 0) {
process_payload();
m_state = STATE_READY;
}
}
}
@@ -324,13 +334,10 @@ std::string frame::print_frame() const {
}
void frame::process_basic_header() {
m_payload.empty();
m_bytes_needed = get_header_len() - BASIC_HEADER_LENGTH;
}
void frame::process_extended_header() {
m_extended_header_bytes_needed = 0;
uint8_t s = get_basic_size();
uint64_t payload_size;
int mask_index = BASIC_HEADER_LENGTH;
@@ -345,7 +352,8 @@ void frame::process_extended_header() {
));
if (payload_size < s) {
throw frame_error("payload length not minimally encoded");
throw frame_error("payload length not minimally encoded",
FERR_PROTOCOL_VIOLATION);
}
mask_index += 2;
@@ -357,7 +365,8 @@ void frame::process_extended_header() {
));
if (payload_size <= PAYLOAD_16BIT_LIMIT) {
throw frame_error("payload length not minimally encoded");
throw frame_error("payload length not minimally encoded",
FERR_PROTOCOL_VIOLATION);
}
mask_index += 8;
@@ -378,9 +387,11 @@ void frame::process_extended_header() {
}
if (payload_size > max_payload_size) {
// TODO: frame/message size limits
throw server_error("got frame with payload greater than maximum frame buffer size.");
}
m_payload.resize(payload_size);
m_bytes_needed = payload_size;
}
void frame::process_payload() {
@@ -420,43 +431,39 @@ void frame::process_payload2() {
}
}
bool frame::validate_utf8(uint32_t* state,uint32_t* codep) const {
void frame::validate_utf8(uint32_t* state,uint32_t* codep) const {
for (size_t i = 0; i < m_payload.size(); i++) {
using utf8_validator::decode;
//std::cout << "decoding: " << std::hex << m_payload[i] << std::endl;
if (decode(state,codep,m_payload[i]) == utf8_validator::UTF8_REJECT) {
// std::cout << "bad byte" << std::endl;
return false;
throw frame_error("Invalid UTF-8 Data",FERR_PAYLOAD_VIOLATION);
}
}
return true;
}
void frame::validate_basic_header() const {
// check for control frame size
if (get_basic_size() > BASIC_PAYLOAD_LIMIT && is_control()) {
throw frame_error("Control Frame is too large");
throw frame_error("Control Frame is too large",FERR_PROTOCOL_VIOLATION);
}
// check for reserved opcodes
if (get_rsv1() || get_rsv2() || get_rsv3()) {
throw frame_error("Reserved bit used");
throw frame_error("Reserved bit used",FERR_PROTOCOL_VIOLATION);
}
// check for reserved opcodes
opcode op = get_opcode();
if (op > 0x02 && op < 0x08) {
throw frame_error("Reserved opcode used");
throw frame_error("Reserved opcode used",FERR_PROTOCOL_VIOLATION);
}
if (op > 0x0A) {
throw frame_error("Reserved opcode used");
throw frame_error("Reserved opcode used",FERR_PROTOCOL_VIOLATION);
}
// check for fragmented control message
if (is_control() && !get_fin()) {
throw frame_error("Fragmented control message");
throw frame_error("Fragmented control message",FERR_PROTOCOL_VIOLATION);
}
}

View File

@@ -59,6 +59,11 @@ public:
static const uint8_t STATE_PAYLOAD = 3;
static const uint8_t STATE_READY = 4;
static const uint16_t FERR_FATAL_SESSION_ERROR = 0; // must end session
static const uint16_t FERR_SOFT_SESSION_ERROR = 1; // should log and ignore
static const uint16_t FERR_PROTOCOL_VIOLATION = 2; // must end session
static const uint16_t FERR_PAYLOAD_VIOLATION = 3; // should end session
// basic payload byte flags
static const uint8_t BPB0_OPCODE = 0x0F;
static const uint8_t BPB0_RSV3 = 0x10;
@@ -87,6 +92,7 @@ public:
}
uint8_t get_state() const;
uint64_t get_bytes_needed() const;
void reset();
void consume(std::istream &s);
@@ -141,7 +147,7 @@ public:
void process_payload();
void process_payload2(); // experiment with more efficient masking code.
bool validate_utf8(uint32_t* state,uint32_t* codep) const;
void validate_utf8(uint32_t* state,uint32_t* codep) const;
void validate_basic_header() const;
void generate_masking_key();
@@ -166,7 +172,9 @@ private:
// Exception classes
class frame_error : public std::exception {
public:
frame_error(const std::string& msg) : m_msg(msg) {}
frame_error(const std::string& msg,
uint16_t code = frame::FERR_FATAL_SESSION_ERROR)
: m_msg(msg),m_code(code) {}
~frame_error() throw() {}
virtual const char* what() const throw() {
@@ -174,6 +182,7 @@ public:
}
std::string m_msg;
uint16_t m_code;
};
}

View File

@@ -120,7 +120,7 @@ unsigned int session::get_version() const {
}
void session::send(const std::string &msg) {
if (m_status != OPEN) {
if (m_state != STATE_OPEN) {
log("Tried to send a message from a session that wasn't open",LOG_WARN);
return;
}
@@ -132,7 +132,7 @@ void session::send(const std::string &msg) {
}
void session::send(const std::vector<unsigned char> &data) {
if (m_status != OPEN) {
if (m_state != STATE_OPEN) {
log("Tried to send a message from a session that wasn't open",LOG_WARN);
return;
}
@@ -143,19 +143,28 @@ void session::send(const std::vector<unsigned char> &data) {
write_frame();
}
// end user interface to close the connection
void session::close(uint16_t status,const std::string& msg) {
validate_app_close_status(status);
disconnect(status,msg);
// TODO: close behavior
}
// TODO: clean this up, needs to be broken out into more specific methods
// This method initiates a clean disconnect with the given status code and reason
// it logs an error and is ignored if it is called from a state other than OPEN
// called by process_close when an initiate close method is received.
void session::disconnect(uint16_t status,const std::string &message) {
if (m_status != OPEN) {
if (m_state != STATE_OPEN) {
log("Tried to disconnect a session that wasn't open",LOG_WARN);
return;
}
m_status = CLOSING;
m_state = STATE_CLOSING;
m_close_code = status;
m_close_message = message;
@@ -179,7 +188,7 @@ void session::disconnect(uint16_t status,const std::string &message) {
}
void session::ping(const std::string &msg) {
if (m_status != OPEN) {
if (m_state != STATE_OPEN) {
log("Tried to send a ping from a session that wasn't open",LOG_WARN);
return;
}
@@ -191,7 +200,7 @@ void session::ping(const std::string &msg) {
}
void session::pong(const std::string &msg) {
if (m_status != OPEN) {
if (m_state != STATE_OPEN) {
log("Tried to send a pong from a session that wasn't open",LOG_WARN);
return;
}
@@ -203,16 +212,28 @@ void session::pong(const std::string &msg) {
}
void session::handle_read_frame(const boost::system::error_code& error) {
// while
// if there are enough bytes to do something:
// do something
// read more
if (error) {
handle_error("Error reading extended frame header",error);
// TODO: close behavior
return;
}
if (m_state != STATE_OPEN && m_state != STATE_CLOSING) {
// stop processing frames.
log("handle_read_frame called in invalid state",LOG_ERROR);
return;
}
std::istream s(&m_buf);
try {
while (m_buf.size() > 0) {
m_read_frame.consume(s);
if (m_read_frame.get_state() == frame::STATE_READY) {
process_frame();
// should we break?
}
}
} catch (const frame_error& e) {
std::stringstream err;
@@ -221,15 +242,24 @@ void session::handle_read_frame(const boost::system::error_code& error) {
access_log(e.what(),ALOG_FRAME);
log(err.str(),LOG_ERROR);
disconnect(CLOSE_STATUS_PROTOCOL_ERROR,"");
// TODO: close behavior
return;
}
// we have read everything, check if we should read more
if (m_state != STATE_OPEN && m_state != STATE_CLOSING) {
// stop processing frames.
log("handle_read_frame called in invalid state");
return;
}
// read more
boost::asio::async_read(
m_socket,
m_buf,
boost::asio::transfer_at_least(1),
boost::asio::transfer_at_least(m_read_frame.get_bytes_needed()),
boost::bind(
&session::handle_read_frame,
shared_from_this(),
@@ -313,16 +343,8 @@ void session::read_payload() {
);
}
void session::handle_read_payload (const boost::system::error_code& error) {
if (error) {
handle_error("Error reading payload data frame header",error);
// TODO: close behavior
return;
}
m_read_frame.process_payload();
if (m_status == OPEN) {
void session::process_frame () {
if (m_state == STATE_OPEN) {
switch (m_read_frame.get_opcode()) {
case frame::CONTINUATION_FRAME:
process_continuation();
@@ -347,7 +369,7 @@ void session::handle_read_payload (const boost::system::error_code& error) {
// TODO: close behavior
break;
}
} else if (m_status == CLOSING) {
} else if (m_state == STATE_CLOSING) {
if (m_read_frame.get_opcode() == frame::CONNECTION_CLOSE) {
process_close();
} else {
@@ -377,8 +399,8 @@ void session::handle_read_payload (const boost::system::error_code& error) {
// TODO: close behavior
return;
}
this->read_frame();
m_read_frame.reset();
}
void session::handle_write_frame (const boost::system::error_code& error) {
@@ -408,22 +430,17 @@ void session::process_pong() {
}
void session::process_text() {
if (!m_read_frame.validate_utf8(&m_utf8_state,&m_utf8_codepoint)) {
disconnect(CLOSE_STATUS_INVALID_PAYLOAD,"Invalid UTF8 Data");
// TODO: close behavior
return;
}
// this will throw an exception if validation fails at any point
m_read_frame.validate_utf8(&m_utf8_state,&m_utf8_codepoint);
// otherwise, treat as binary
process_binary();
}
void session::process_binary() {
if (m_fragmented) {
handle_error("Got a new message before the previous was finished.",
boost::system::error_code());
disconnect(CLOSE_STATUS_PROTOCOL_ERROR,"");
// TODO: close behavior
return;
throw frame_error("Got a new message before the previous was finished.",
frame::FERR_PROTOCOL_VIOLATION);
}
m_current_opcode = m_read_frame.get_opcode();
@@ -439,19 +456,13 @@ void session::process_binary() {
void session::process_continuation() {
if (!m_fragmented) {
handle_error("Got a continuation frame without an outstanding message.",
boost::system::error_code());
disconnect(CLOSE_STATUS_PROTOCOL_ERROR,"");
// TODO: close behavior
return;
throw frame_error("Got a continuation frame without an outstanding message.",
frame::FERR_PROTOCOL_VIOLATION);
}
if (m_current_opcode == frame::TEXT_FRAME) {
if (!m_read_frame.validate_utf8(&m_utf8_state,&m_utf8_codepoint)) {
disconnect(CLOSE_STATUS_INVALID_PAYLOAD,"Invalid UTF8 Data");
// TODO: close behavior
return;
}
// this will throw an exception if validation fails at any point
m_read_frame.validate_utf8(&m_utf8_state,&m_utf8_codepoint);
}
extract_payload();
@@ -470,17 +481,16 @@ void session::process_close() {
m_remote_close_code = status;
m_remote_close_msg = message;
if (m_status == OPEN) {
if (m_state == STATE_OPEN) {
// This is the case where the remote initiated the close.
m_closed_by_me = false;
// TODO: close behavior
disconnect(status,message);
} else if (m_status == CLOSING) {
} else if (m_state == STATE_CLOSING) {
// this is an ack of our close message
m_closed_by_me = true;
} else {
throw "fixme";
throw frame_error("process_closed called from wrong state");
}
m_was_clean = true;
@@ -503,10 +513,11 @@ void session::deliver_message() {
std::string msg;
// make sure the finished frame is valid utf8
// the streaming validator checks for bad codepoints as it goes. It
// doesn't know where the end of the message is though, so we need to
// check here to make sure the final message ends on a valid codepoint.
if (m_utf8_state != utf8_validator::UTF8_ACCEPT) {
disconnect(CLOSE_STATUS_INVALID_PAYLOAD,"Invalid UTF8 Data");
// TODO: close behavior
return;
throw frame_error("Invalid UTF-8 Data",FERR_PAYLOAD_VIOLATION);
}
if (m_fragmented) {
@@ -523,7 +534,7 @@ void session::deliver_message() {
// Not sure if this should be a fatal error or not
std::stringstream err;
err << "Attempted to deliver a message of unsupported opcode " << m_current_opcode;
log(err.str(),LOG_ERROR);
throw frame_error(err.str(),frame::FERR_SOFT_SESSION_ERROR);
}
}
@@ -614,3 +625,16 @@ void session::handle_error(std::string msg,
m_error = true;
}
// validates status codes that the end application is allowed to use
bool session::validate_app_close_status(uint16_t status) {
if (status == CLOSE_STATUS_NORMAL) {
return true;
}
if (status >= 4000 && status < 5000) {
return true;
}
return false;
}

View File

@@ -73,14 +73,10 @@ class session : public boost::enable_shared_from_this<session> {
public:
friend class handshake_error;
enum ws_status {
CONNECTING,
OPEN,
CLOSING,
CLOSED
};
typedef enum ws_status status_code;
static const uint8_t STATE_CONNECTING = 0;
static const uint8_t STATE_OPEN = 1;
static const uint8_t STATE_CLOSING = 2;
static const uint8_t STATE_CLOSED = 3;
static const uint16_t CLOSE_STATUS_NORMAL = 1000;
static const uint16_t CLOSE_STATUS_GOING_AWAY = 1001;
@@ -160,13 +156,14 @@ protected:
void handle_frame_header(const boost::system::error_code& error);
void handle_extended_frame_header(const boost::system::error_code& error);
void read_payload();
void handle_read_payload (const boost::system::error_code& error);
void handle_read_frame (const boost::system::error_code& error);
// write m_write_frame out to the socket.
void write_frame();
void handle_write_frame (const boost::system::error_code& error);
// helper functions for processing each opcode
void process_frame();
void process_ping();
void process_pong();
void process_text();
@@ -194,6 +191,9 @@ protected:
// prints a diagnostic message and disconnects the local interface
void handle_error(std::string msg,const boost::system::error_code& error);
// misc helpers
bool validate_app_close_status(uint16_t status);
private:
std::string get_header(const std::string& key,
const header_list& list) const;
@@ -220,7 +220,7 @@ protected:
std::string m_server_http_string;
// Mutable connection state;
status_code m_status;
uint8_t m_state;
uint16_t m_close_code;
std::string m_close_message;