completes frame reading changes.

fixes 37, fixes 32
This commit is contained in:
Peter Thorson
2011-10-10 18:35:38 -05:00
parent eec0390618
commit 1ff8d333a5
10 changed files with 375 additions and 188 deletions

View File

@@ -46,47 +46,114 @@ uint8_t frame::get_state() const {
return m_state;
}
uint64_t frame::get_bytes_needed() const {
return m_bytes_needed;
}
void frame::reset() {
m_state = STATE_BASIC_HEADER;
m_bytes_needed = BASIC_HEADER_LENGTH;
m_degraded = false;
m_payload.empty();
memset(m_header,0,MAX_HEADER_LENGTH);
}
// Method invariant: One of the following must always be true even in the case
// of exceptions.
// - m_bytes_needed > 0
// - m-state = STATE_READY
void frame::consume(std::istream &s) {
if (m_state == STATE_BASIC_HEADER) {
s.read(&m_header[BASIC_HEADER_LENGTH-m_bytes_needed],m_bytes_needed);
try {
switch (m_state) {
case STATE_BASIC_HEADER:
s.read(&m_header[BASIC_HEADER_LENGTH-m_bytes_needed],m_bytes_needed);
m_bytes_needed -= s.gcount();
if (m_bytes_needed == 0) {
process_basic_header();
// basic header validation
validate_basic_header();
if (m_bytes_needed > 0) {
m_state = STATE_EXTENDED_HEADER;
} else {
m_state = STATE_PAYLOAD;
}
m_bytes_needed -= s.gcount();
if (m_bytes_needed == 0) {
process_basic_header();
validate_basic_header();
if (m_bytes_needed > 0) {
m_state = STATE_EXTENDED_HEADER;
} else {
process_extended_header();
if (m_bytes_needed == 0) {
m_state = STATE_READY;
process_payload();
} else {
m_state = STATE_PAYLOAD;
}
}
}
break;
case STATE_EXTENDED_HEADER:
s.read(&m_header[get_header_len()-m_bytes_needed],m_bytes_needed);
m_bytes_needed -= s.gcount();
if (m_bytes_needed == 0) {
process_extended_header();
if (m_bytes_needed == 0) {
m_state = STATE_READY;
process_payload();
} else {
m_state = STATE_PAYLOAD;
}
}
break;
case STATE_PAYLOAD:
s.read(reinterpret_cast<char *>(&m_payload[m_payload.size()-m_bytes_needed]),
m_bytes_needed);
m_bytes_needed -= s.gcount();
if (m_bytes_needed == 0) {
m_state = STATE_READY;
process_payload();
}
break;
case STATE_RECOVERY:
// Recovery state discards all bytes that are not the first byte
// of a close frame.
do {
s.read(reinterpret_cast<char *>(&m_header[0]),1);
//std::cout << std::hex << int(static_cast<unsigned char>(m_header[0])) << " ";
if (int(static_cast<unsigned char>(m_header[0])) == 0x88) {
//(BPB0_FIN && CONNECTION_CLOSE)
m_bytes_needed--;
m_state = STATE_BASIC_HEADER;
break;
}
} while (s.gcount() > 0);
//std::cout << std::endl;
break;
default:
break;
}
} else if (m_state == STATE_EXTENDED_HEADER) {
s.read(&m_header[BASIC_HEADER_LENGTH+get_header_len()-m_bytes_needed],m_bytes_needed);
m_bytes_needed -= s.gcount();
if (m_bytes_needed == 0) {
process_extended_header();
m_state = STATE_PAYLOAD;
}
} else if (m_state == STATE_PAYLOAD) {
s.read(reinterpret_cast<char *>(&m_payload[0]),m_bytes_needed);
m_bytes_needed -= s.gcount();
if (m_bytes_needed == 0) {
process_payload();
m_state = STATE_READY;
/*if (s.gcount() == 0) {
throw frame_error("consume read zero bytes",FERR_FATAL_SESSION_ERROR);
}*/
} catch (const frame_error& e) {
// After this point all non-close frames must be considered garbage,
// including the current one. Reset it and put the reading frame into
// a recovery state.
if (m_degraded == true) {
throw frame_error("An error occurred while trying to gracefully recover from a less serious frame error.",FERR_FATAL_SESSION_ERROR);
} else {
reset();
m_state = STATE_RECOVERY;
m_degraded = true;
throw e;
}
}
}
@@ -116,8 +183,8 @@ unsigned int frame::get_header_len() const {
}
char* frame::get_masking_key() {
if (m_extended_header_bytes_needed > 0) {
throw "attempted to get masking_key before reading full header";
if (m_state != STATE_READY) {
throw frame_error("attempted to get masking_key before reading full header");
}
return m_masking_key;
}
@@ -178,12 +245,12 @@ frame::opcode frame::get_opcode() const {
void frame::set_opcode(frame::opcode op) {
if (op > 0x0F) {
throw "invalid opcode";
throw frame_error("invalid opcode",FERR_PROTOCOL_VIOLATION);
}
if (get_basic_size() > BASIC_PAYLOAD_LIMIT &&
is_control()) {
throw "control frames can't have large payloads";
throw frame_error("control frames can't have large payloads",FERR_PROTOCOL_VIOLATION);
}
m_header[0] &= (0xFF ^ BPB0_OPCODE); // clear op bits
@@ -209,9 +276,9 @@ uint8_t frame::get_basic_size() const {
}
size_t frame::get_payload_size() const {
if (m_extended_header_bytes_needed > 0) {
if (m_state != STATE_READY && m_state != STATE_PAYLOAD) {
// problem
throw "attempted to get payload size before reading full header";
throw frame_error("attempted to get payload size before reading full header");
}
return m_payload.size();
@@ -263,13 +330,13 @@ bool frame::is_control() const {
void frame::set_payload_helper(size_t s) {
if (s > max_payload_size) {
throw "requested payload is over implimentation defined limit";
throw frame_error("requested payload is over implimentation defined limit",FERR_MSG_TOO_BIG);
}
// limits imposed by the websocket spec
if (s > BASIC_PAYLOAD_LIMIT &&
get_opcode() > MAX_FRAME_OPCODE) {
throw "control frames can't have large payloads";
throw frame_error("control frames can't have large payloads",FERR_PROTOCOL_VIOLATION);
}
if (s <= BASIC_PAYLOAD_LIMIT) {
@@ -285,7 +352,7 @@ void frame::set_payload_helper(size_t s) {
m_header[1] = BASIC_PAYLOAD_64BIT_CODE;
*reinterpret_cast<uint64_t*>(&m_header[BASIC_HEADER_LENGTH]) = htonll(s);
} else {
throw "payload size limit is 63 bits";
throw frame_error("payload size limit is 63 bits",FERR_PROTOCOL_VIOLATION);
}
m_payload.resize(s);
@@ -294,11 +361,11 @@ void frame::set_payload_helper(size_t s) {
void frame::set_status(uint16_t status,const std::string message) {
// check for valid statuses
if (status < 1000 || status > 4999) {
throw server_error("Status codes must be in the range 1000-4999");
throw frame_error("Status codes must be in the range 1000-4999");
}
if (status == 1005 || status == 1006) {
throw server_error("Status codes 1005 and 1006 are reserved for internal use and cannot be written to a frame.");
throw frame_error("Status codes 1005 and 1006 are reserved for internal use and cannot be written to a frame.");
}
m_payload.resize(2+message.size());
@@ -326,9 +393,13 @@ std::string frame::print_frame() const {
f << std::hex << (unsigned short)m_header[i] << " ";
}
// print message
std::vector<unsigned char>::const_iterator it;
for (it = m_payload.begin(); it != m_payload.end(); it++) {
f << *it;
if (m_payload.size() > 50) {
f << "[payload of " << m_payload.size() << " bytes]";
} else {
std::vector<unsigned char>::const_iterator it;
for (it = m_payload.begin(); it != m_payload.end(); it++) {
f << *it;
}
}
return f.str();
}
@@ -352,7 +423,10 @@ void frame::process_extended_header() {
));
if (payload_size < s) {
throw frame_error("payload length not minimally encoded",
std::stringstream err;
err << "payload length not minimally encoded. Using 16 bit form for payload size: " << payload_size;
m_bytes_needed = payload_size;
throw frame_error(err.str(),
FERR_PROTOCOL_VIOLATION);
}
@@ -365,6 +439,7 @@ void frame::process_extended_header() {
));
if (payload_size <= PAYLOAD_16BIT_LIMIT) {
m_bytes_needed = payload_size;
throw frame_error("payload length not minimally encoded",
FERR_PROTOCOL_VIOLATION);
}
@@ -471,9 +546,7 @@ void frame::generate_masking_key() {
//throw "masking key generation not implimented";
int32_t key = m_gen();
std::cout << "genkey: " << key << std::endl;
//m_masking_key[0] = reinterpret_cast<char*>(&key)[0];
//m_masking_key[1] = reinterpret_cast<char*>(&key)[1];
//m_masking_key[2] = reinterpret_cast<char*>(&key)[2];