re-adds utf8 validation, passes all autobahn tests except edge close cases (reading invalid wire codes) starts working on chat server example update

This commit is contained in:
Peter Thorson
2011-11-11 16:21:38 -06:00
parent 0767a6ef18
commit 74aa325591
8 changed files with 165 additions and 53 deletions

View File

@@ -30,13 +30,13 @@
#include <boost/algorithm/string/replace.hpp>
using websocketchat::chat_server_handler;
using websocketpp::session_ptr;
using websocketpp::session::server_session_ptr;
void chat_server_handler::validate(session_ptr client) {
void chat_server_handler::validate(server_session_ptr session) {
std::stringstream err;
// We only know about the chat resource
if (client->get_resource() != "/chat") {
if (session->get_uri().resource != "/chat") {
err << "Request for unknown resource " << client->get_resource();
throw(websocketpp::handshake_error(err.str(),404));
}

View File

@@ -38,43 +38,42 @@
// {"type":"msg","sender":"<sender>","value":"<msg>" }
// {"type":"participants","value":[<participant>,...]}
#include <boost/shared_ptr.hpp>
#include "../../src/websocketpp.hpp"
#include "../../src/websocket_connection_handler.hpp"
#include "../../src/interfaces/session.hpp"
#include <boost/shared_ptr.hpp>
#include <map>
#include <string>
#include <vector>
using websocketpp::session::server_session_ptr;
using websocketpp::session::server_handler;
namespace websocketchat {
class chat_server_handler : public websocketpp::connection_handler {
class chat_server_handler : public server_handler {
public:
chat_server_handler() {}
virtual ~chat_server_handler() {}
void validate(websocketpp::session_ptr client);
void validate(server_session_ptr session);
// add new connection to the lobby
void on_open(websocketpp::session_ptr client);
void on_open(server_session_ptr session);
// someone disconnected from the lobby, remove them
void on_close(websocketpp::session_ptr client);
void on_close(server_session_ptr session);
void on_message(websocketpp::session_ptr client,const std::string &msg);
void on_message(server_session_ptr session,websocketpp::utf8_string_ptr msg);
// lobby will ignore binary messages
void on_message(websocketpp::session_ptr client,
const std::vector<unsigned char> &data) {}
void on_message(server_session_ptr session,websocketpp::binary_string_ptr data) {}
private:
std::string serialize_state();
std::string encode_message(std::string sender,std::string msg,bool escape = true);
std::string get_con_id(websocketpp::session_ptr s);
std::string get_con_id(server_session_ptr s);
void send_to_all(std::string data);
// list of outstanding connections
std::map<websocketpp::session_ptr,std::string> m_connections;
std::map<server_session_ptr,std::string> m_connections;
};
typedef boost::shared_ptr<chat_server_handler> chat_server_handler_ptr;

View File

@@ -40,7 +40,7 @@ using websocketpp::session::server_handler;
namespace websocketecho {
class echo_server_handler : public server_handler {
class echo_server_handler : public server_handler {
public:
// The echo server allows all domains is protocol free.
void validate(server_ptr session) {}

View File

@@ -219,7 +219,7 @@ public:
binary_string_ptr prepare_close_frame(close::status::value code,
bool mask,
const std::string& reason) const {
const std::string& reason) {
binary_string_ptr response(new binary_string(2));
(*response)[0] = 0xFF;

View File

@@ -52,9 +52,9 @@ namespace hybi_state {
template <class rng_policy>
class hybi_processor : public processor {
public:
hybi_processor(rng_policy &rng) : m_read_frame(rng),m_write_frame(rng) {
hybi_processor(rng_policy &rng) : m_fragmented_opcode(frame::opcode::CONTINUATION),m_utf8_payload(new utf8_string()),m_binary_payload(new binary_string()),m_read_frame(rng),m_write_frame(rng) {
reset();
}
}
void validate_handshake(const http::parser::request& request) const {
std::stringstream err;
@@ -159,31 +159,28 @@ public:
if (m_read_frame.ready()) {
switch (m_read_frame.get_opcode()) {
case frame::opcode::CONTINUATION:
if (m_opcode == frame::opcode::BINARY) {
extract_binary();
} else if (m_opcode == frame::opcode::TEXT) {
extract_utf8();
} else {
// can't be here
}
process_continuation();
break;
case frame::opcode::TEXT:
m_opcode = frame::opcode::TEXT;
extract_utf8();
process_text();
break;
case frame::opcode::BINARY:
m_opcode = frame::opcode::BINARY;
extract_binary();
process_binary();
break;
case frame::opcode::CLOSE:
if (!utf8_validator::validate(m_read_frame.get_close_msg())) {
throw session::exception("Invalid UTF8",session::error::PAYLOAD_VIOLATION);
}
m_opcode = frame::opcode::CLOSE;
m_close_code = m_read_frame.get_close_status();
m_close_reason = m_read_frame.get_close_msg();
break;
case frame::opcode::PING:
case frame::opcode::PONG:
m_opcode = m_read_frame.get_opcode();
extract_binary();
extract_binary(m_control_payload);
break;
default:
throw session::exception("Invalid Opcode",session::error::PROTOCOL_VIOLATION);
@@ -191,6 +188,13 @@ public:
}
if (m_read_frame.get_fin()) {
m_state = hybi_state::DONE;
if (m_opcode == frame::opcode::TEXT) {
if (!m_validator.complete()) {
m_validator.reset();
throw session::exception("Invalid UTF8",session::error::PAYLOAD_VIOLATION);
}
m_validator.reset();
}
}
m_read_frame.reset();
}
@@ -204,17 +208,57 @@ public:
}
}
void extract_binary() {
binary_string &msg = m_read_frame.get_payload();
m_binary_payload->resize(m_binary_payload->size()+msg.size());
std::copy(msg.begin(),msg.end(),m_binary_payload->end()-msg.size());
// frame type handlers:
void process_continuation() {
if (m_fragmented_opcode == frame::opcode::BINARY) {
extract_binary(m_binary_payload);
} else if (m_fragmented_opcode == frame::opcode::TEXT) {
extract_utf8(m_utf8_payload);
} else if (m_fragmented_opcode == frame::opcode::CONTINUATION) {
// got continuation frame without a message to continue.
throw session::exception("No message to continue.",session::error::PROTOCOL_VIOLATION);
} else {
// can't be here
}
if (m_read_frame.get_fin()) {
m_opcode = m_fragmented_opcode;
}
}
void extract_utf8() {
void process_text() {
if (m_fragmented_opcode != frame::opcode::CONTINUATION) {
throw session::exception("New message started without closing previous.",session::error::PROTOCOL_VIOLATION);
}
extract_utf8(m_utf8_payload);
m_opcode = frame::opcode::TEXT;
m_fragmented_opcode = frame::opcode::TEXT;
}
void process_binary() {
if (m_fragmented_opcode != frame::opcode::CONTINUATION) {
throw session::exception("New message started without closing previous.",session::error::PROTOCOL_VIOLATION);
}
m_opcode = frame::opcode::BINARY;
m_fragmented_opcode = frame::opcode::BINARY;
extract_binary(m_binary_payload);
}
void extract_binary(binary_string_ptr dest) {
binary_string &msg = m_read_frame.get_payload();
dest->resize(dest->size() + msg.size());
std::copy(msg.begin(),msg.end(),dest->end() - msg.size());
}
void extract_utf8(utf8_string_ptr dest) {
binary_string &msg = m_read_frame.get_payload();
m_utf8_payload->reserve(m_utf8_payload->size()+msg.size());
m_utf8_payload->append(msg.begin(),msg.end());
if (!m_validator.decode(msg.begin(),msg.end())) {
throw session::exception("Invalid UTF8",session::error::PAYLOAD_VIOLATION);
}
dest->reserve(dest->size() + msg.size());
dest->append(msg.begin(),msg.end());
}
bool ready() const {
@@ -223,8 +267,13 @@ public:
void reset() {
m_state = m_state = hybi_state::INIT;
m_utf8_payload = utf8_string_ptr(new utf8_string());
m_binary_payload = binary_string_ptr(new binary_string());
m_control_payload = binary_string_ptr(new binary_string());
if (m_fragmented_opcode == m_opcode) {
m_utf8_payload = utf8_string_ptr(new utf8_string());
m_binary_payload = binary_string_ptr(new binary_string());
m_fragmented_opcode = frame::opcode::CONTINUATION;
}
}
uint64_t get_bytes_needed() const {
@@ -251,18 +300,18 @@ public:
}
binary_string_ptr get_binary_payload() const {
// TODO: opcode doesn't match
if (get_opcode() != frame::opcode::BINARY &&
get_opcode() != frame::opcode::PING &&
get_opcode() != frame::opcode::PONG) {
throw "opcode doesn't have a binary payload";
}
if (!ready()) {
throw "not ready";
}
return m_binary_payload;
if (get_opcode() == frame::opcode::BINARY) {
return m_binary_payload;
} else if (get_opcode() != frame::opcode::PING ||
get_opcode() != frame::opcode::PONG) {
return m_control_payload;
} else {
throw "opcode doesn't have a binary payload";
}
}
// legacy hybi doesn't have close codes
@@ -345,10 +394,23 @@ public:
binary_string_ptr prepare_close_frame(close::status::value code,
bool mask,
const std::string& reason) const {
const std::string& reason) {
binary_string_ptr response(new binary_string(0));
m_write_frame.reset();
m_write_frame.set_opcode(frame::opcode::CLOSE);
m_write_frame.set_masked(mask);
m_write_frame.set_fin(true);
m_write_frame.set_status(code,reason);
// TODO
response->resize(m_write_frame.get_header_len()+m_write_frame.get_payload().size());
// copy header
std::copy(m_write_frame.get_header(),m_write_frame.get_header()+m_write_frame.get_header_len(),response->begin());
// copy payload
std::copy(m_write_frame.get_payload().begin(),m_write_frame.get_payload().end(),response->begin()+m_write_frame.get_header_len());
return response;
}
@@ -356,13 +418,17 @@ public:
private:
int m_state;
frame::opcode::value m_opcode;
frame::opcode::value m_fragmented_opcode;
utf8_string_ptr m_utf8_payload;
binary_string_ptr m_binary_payload;
binary_string_ptr m_control_payload;
close::status::value m_close_code;
std::string m_close_reason;
utf8_validator::validator m_validator;
frame::parser<rng_policy> m_read_frame;
frame::parser<rng_policy> m_write_frame;
};

View File

@@ -77,7 +77,7 @@ public:
virtual binary_string_ptr prepare_close_frame(close::status::value code,
bool mask,
const std::string& reason) const = 0;
const std::string& reason) = 0;

View File

@@ -117,7 +117,8 @@ public:
virtual bool get_closed_by_me() const = 0;
};
typedef boost::shared_ptr<server> server_ptr;
typedef server_ptr server_session_ptr;
/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
* Server Handler API *
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */

View File

@@ -39,7 +39,53 @@ decode(uint32_t* state, uint32_t* codep, uint32_t byte) {
*state = utf8d[256 + *state*16 + type];
return *state;
}
class validator {
public:
validator() : m_state(UTF8_ACCEPT),m_codepoint(0) {}
bool consume (uint32_t byte) {
if (utf8_validator::decode(&m_state,&m_codepoint,byte) == UTF8_REJECT) {
return false;
}
return true;
}
template <typename iterator_type>
bool decode (iterator_type b, iterator_type e) {
for (iterator_type i = b; i != e; i++) {
if (utf8_validator::decode(&m_state,&m_codepoint,*i) == UTF8_REJECT) {
return false;
}
}
return true;
}
bool complete() {
return m_state == UTF8_ACCEPT;
}
void reset() {
m_state = UTF8_ACCEPT;
m_codepoint = 0;
}
private:
uint32_t m_state;
uint32_t m_codepoint;
};
// convenience function that creates a validator, validates a complete string
// and returns the result.
bool validate(const std::string& s) {
validator v;
if (!v.decode(s.begin(),s.end())) {
return false;
}
return v.complete();
}
} // namespace utf8_validator
#endif // UTF8_VALIDATOR_HPP