policy refactor echo client passes all autobahn tests

This commit is contained in:
Peter Thorson
2011-12-06 22:14:30 -06:00
parent 0f5302d400
commit 689c136298
9 changed files with 95 additions and 46 deletions

View File

@@ -39,24 +39,7 @@ public:
typedef echo_client_handler type;
typedef plain_endpoint_type::connection_ptr connection_ptr;
void validate(connection_ptr connection) {
//std::cout << "state: " << connection->get_state() << std::endl;
}
void on_open(connection_ptr connection) {
//std::cout << "connection opened" << std::endl;
}
void on_close(connection_ptr connection) {
//std::cout << "connection closed" << std::endl;
}
void on_message(connection_ptr connection,websocketpp::message::data_ptr msg) {
//std::cout << "got message: " << *msg << std::endl;
connection->send(msg->get_payload(),(msg->get_opcode() == websocketpp::frame::opcode::BINARY));
void on_message(connection_ptr connection,websocketpp::message::data_ptr msg) {
if (connection->get_resource() == "/getCaseCount") {
std::cout << "Detected " << msg->get_payload() << " test cases." << std::endl;
m_case_count = atoi(msg->get_payload().c_str());
@@ -64,7 +47,6 @@ public:
connection->send(msg->get_payload(),(msg->get_opcode() == websocketpp::frame::opcode::BINARY));
}
connection->recycle(msg);
}
@@ -94,9 +76,13 @@ int main(int argc, char* argv[]) {
connection_ptr connection;
plain_endpoint_type endpoint(handler);
endpoint.alog().unset_level(websocketpp::log::alevel::ALL);
endpoint.elog().unset_level(websocketpp::log::elevel::ALL);
connection = endpoint.connect("ws://localhost:9001/getCaseCount");
endpoint.run();
/*boost::asio::io_service io_service;
@@ -112,11 +98,15 @@ int main(int argc, char* argv[]) {
std::cout << "case count: " << boost::dynamic_pointer_cast<echo_client_handler>(handler)->m_case_count << std::endl;
for (int i = 1; i <= boost::dynamic_pointer_cast<echo_client_handler>(handler)->m_case_count; i++) {
endpoint.reset();
std::stringstream url;
url << "ws://localhost:9001/runCase?case=" << i << "&agent=\"WebSocket++Snapshot/2011-10-27\"";
url << "ws://localhost:9001/runCase?case=" << i << "&agent=\"WebSocket++Snapshot/2011-12-06\"";
connection = endpoint.connect(url.str());
endpoint.run();
}
std::cout << "done" << std::endl;

View File

@@ -139,9 +139,9 @@ public:
void send(const utf8_string& payload,bool binary = false) {
binary_string_ptr msg;
if (binary) {
msg = m_processor->prepare_frame(frame::opcode::BINARY,false,payload);
msg = m_processor->prepare_frame(frame::opcode::BINARY,!m_endpoint.is_server(),payload);
} else {
msg = m_processor->prepare_frame(frame::opcode::TEXT,false,payload);
msg = m_processor->prepare_frame(frame::opcode::TEXT,!m_endpoint.is_server(),payload);
}
// TODO: decide which of these to use. Direct function call better
@@ -159,7 +159,7 @@ public:
}
void send(const binary_string& data) {
binary_string_ptr msg(m_processor->prepare_frame(frame::opcode::BINARY,
false,data));
!m_endpoint.is_server(),data));
m_endpoint.endpoint_base::m_io_service.post(
boost::bind(
&type::write_message,
@@ -171,7 +171,8 @@ public:
}
void ping(const binary_string& payload) {
binary_string_ptr msg(m_processor->prepare_frame(frame::opcode::PING,
false,payload));
!m_endpoint.is_server(),
payload));
m_endpoint.endpoint_base::m_io_service.post(
boost::bind(
@@ -181,7 +182,7 @@ public:
}
void pong(const binary_string& payload) {
binary_string_ptr msg(m_processor->prepare_frame(frame::opcode::PONG,
false,payload));
!m_endpoint.is_server(),payload));
m_endpoint.endpoint_base::m_io_service.post(
boost::bind(
&type::write_message,
@@ -380,7 +381,7 @@ protected:
if (response) {
// send response ping
write_message(m_processor->prepare_frame(frame::opcode::PONG,false,msg->get_payload()));
write_message(m_processor->prepare_frame(frame::opcode::PONG,!m_endpoint.is_server(),msg->get_payload()));
}
break;
case frame::opcode::PONG:
@@ -454,7 +455,7 @@ protected:
write_message(m_processor->prepare_close_frame(m_local_close_code,
false,
!m_endpoint.is_server(),
m_local_close_reason));
m_write_state = INTURRUPT;
}
@@ -491,7 +492,7 @@ protected:
write_message(m_processor->prepare_close_frame(m_local_close_code,
false,
!m_endpoint.is_server(),
m_local_close_reason));
m_write_state = INTURRUPT;
}

View File

@@ -213,12 +213,14 @@ public:
std::stringstream ss(response);
std::string str_val;
int int_val;
char char_val[256];
ss >> str_val;
set_version(str_val);
ss >> int_val;
set_status(status_code::value(int_val),str_val);
ss.getline(char_val,256);
set_status(status_code::value(int_val),std::string(char_val));
} else {
return false;
}

View File

@@ -356,14 +356,21 @@ public:
// TODO: utf8 validation on payload.
binary_string_ptr response(new binary_string(0));
m_write_frame.reset();
m_write_frame.set_opcode(opcode);
m_write_frame.set_masked(mask);
m_write_frame.set_fin(true);
m_write_frame.set_payload(payload);
m_write_frame.process_payload();
// TODO
response->resize(m_write_frame.get_header_len()+m_write_frame.get_payload().size());
@@ -395,6 +402,8 @@ public:
m_write_frame.set_fin(true);
m_write_frame.set_payload(payload);
m_write_frame.process_payload();
// TODO
response->resize(m_write_frame.get_header_len()+m_write_frame.get_payload().size());
@@ -418,6 +427,8 @@ public:
m_write_frame.set_fin(true);
m_write_frame.set_status(code,reason);
m_write_frame.process_payload();
// TODO
response->resize(m_write_frame.get_header_len()+m_write_frame.get_payload().size());

View File

@@ -83,30 +83,36 @@ public:
uint16_t get_port() const {
return m_uri->get_port();
}
int32_t rand() {
return m_endpoint.rand();
}
protected:
connection(endpoint& e)
: m_endpoint(e),
m_connection(static_cast< connection_type& >(*this)),
// TODO: version shouldn't be hardcoded
m_version(17) {}
m_version(13) {}
void set_uri(uri_ptr u) {
m_uri = u;
}
void async_init() {
m_connection.m_processor = processor::ptr(new processor::hybi<connection_type>(m_connection));
write_request();
}
int32_t rand() {
return m_endpoint.rand();
}
void write_request();
void handle_write_request(const boost::system::error_code& error);
void read_response();
void handle_read_response(const boost::system::error_code& error,
std::size_t bytes_transferred);
void log_open_result();
private:
endpoint& m_endpoint;
connection_type& m_connection;
@@ -265,7 +271,7 @@ void client<endpoint>::connection<connection_type>::write_request() {
m_request.add_header("Connection","Upgrade");
m_request.replace_header("Sec-WebSocket-Version","13");
m_request.replace_header("Host",m_uri->get_host());
m_request.replace_header("Host",m_uri->get_host_port());
if (m_origin != "") {
m_request.replace_header("Origin",m_origin);
@@ -418,7 +424,7 @@ void client<endpoint>::connection<connection_type>::handle_read_response (
} catch (const http::exception& e) {
m_endpoint.elog().at(log::elevel::ERROR)
<< "Error processing server handshake. Server HTTP response: "
<< e.m_error_msg << "(" << e.m_error_code
<< e.m_error_msg << " (" << e.m_error_code
<< ") Local error: " << e.what() << log::endl;
return;
}
@@ -427,6 +433,20 @@ void client<endpoint>::connection<connection_type>::handle_read_response (
// start session loop
}
template <class endpoint>
template <class connection_type>
void client<endpoint>::connection<connection_type>::log_open_result() {
/*std::stringstream version;
version << "v" << m_version << " ";
m_endpoint.alog().at(log::alevel::CONNECT) << (m_version == -1 ? "HTTP" : "WebSocket") << " Connection "
<< m_connection.get_raw_socket().remote_endpoint() << " "
<< (m_version == -1 ? "" : version.str())
<< (get_request_header("User-Agent") == "" ? "NULL" : get_request_header("User-Agent"))
<< " " << m_uri->get_resource() << " " << m_response.get_status_code()
<< log::endl;*/
}
} // namespace role
} // namespace websocketpp

View File

@@ -419,7 +419,7 @@ template <class connection_type>
void server<endpoint>::connection<connection_type>::write_response() {
m_response.set_version("HTTP/1.1");
if (m_response.status_code() == http::status_code::SWITCHING_PROTOCOLS) {
if (m_response.get_status_code() == http::status_code::SWITCHING_PROTOCOLS) {
// websocket response
m_connection.m_processor->handshake_response(m_request,m_response);
@@ -470,7 +470,7 @@ void server<endpoint>::connection<connection_type>::handle_write_response(
log_open_result();
if (m_response.status_code() != http::status_code::SWITCHING_PROTOCOLS) {
if (m_response.get_status_code() != http::status_code::SWITCHING_PROTOCOLS) {
if (m_version == -1) {
// if this was not a websocket connection, we have written
// the expected response and the connection can be closed.
@@ -478,8 +478,8 @@ void server<endpoint>::connection<connection_type>::handle_write_response(
// this was a websocket connection that ended in an error
m_endpoint.elog().at(log::elevel::ERROR)
<< "Handshake ended with HTTP error: "
<< m_response.status_code() << " "
<< m_response.status_msg() << log::endl;
<< m_response.get_status_code() << " "
<< m_response.get_status_msg() << log::endl;
}
m_connection.terminate(true);
return;
@@ -502,7 +502,7 @@ void server<endpoint>::connection<connection_type>::log_open_result() {
<< m_connection.get_raw_socket().remote_endpoint() << " "
<< (m_version == -1 ? "" : version.str())
<< (get_request_header("User-Agent") == "" ? "NULL" : get_request_header("User-Agent"))
<< " " << m_uri->get_resource() << " " << m_response.status_code()
<< " " << m_uri->get_resource() << " " << m_response.get_status_code()
<< log::endl;
}

View File

@@ -118,6 +118,17 @@ std::string uri::get_host() const {
return m_host;
}
std::string uri::get_host_port() const {
if (m_port == (m_secure ? DEFAULT_SECURE_PORT : DEFAULT_PORT)) {
return m_host;
} else {
std::stringstream p;
p << m_host << ":" << m_port;
return p.str();
}
}
uint16_t uri::get_port() const {
return m_port;
}

View File

@@ -64,6 +64,7 @@ public:
bool get_secure() const;
std::string get_host() const;
std::string get_host_port() const;
uint16_t get_port() const;
std::string get_port_str() const;
std::string get_resource() const;

View File

@@ -277,10 +277,11 @@ public:
}
char* get_masking_key() {
if (m_state != STATE_READY) {
/*if (m_state != STATE_READY) {
std::cout << "throw" << std::endl;
throw processor::exception("attempted to get masking_key before reading full header");
}
return m_header[get_header_len()-4];
}*/
return &m_header[get_header_len()-4];
}
// get and set header bits
@@ -427,6 +428,8 @@ public:
throw processor::exception("control frames can't have large payloads",processor::error::PROTOCOL_VIOLATION);
}
bool masked = get_masked();
if (s <= limits::PAYLOAD_SIZE_BASIC) {
m_header[1] = s;
} else if (s <= limits::PAYLOAD_SIZE_EXTENDED) {
@@ -443,6 +446,10 @@ public:
throw processor::exception("payload size limit is 63 bits",processor::error::PROTOCOL_VIOLATION);
}
if (masked) {
m_header[1] |= BPB1_MASK;
}
m_payload.resize(s);
}
@@ -466,8 +473,14 @@ public:
*reinterpret_cast<uint16_t*>(&val[0]) = htons(status);
bool masked = get_masked();
m_header[1] = message.size()+2;
if (masked) {
m_header[1] |= BPB1_MASK;
}
m_payload[0] = val[0];
m_payload[1] = val[1];
@@ -570,7 +583,7 @@ public:
void process_payload() {
if (get_masked()) {
char *masking_key = &m_header[get_header_len()-4];
char *masking_key = get_masking_key();
for (uint64_t i = 0; i < m_payload.size(); i++) {
m_payload[i] = (m_payload[i] ^ masking_key[i%4]);
@@ -638,7 +651,7 @@ public:
}
void generate_masking_key() {
*(reinterpret_cast<int32_t *>(&m_header[get_header_len()-4])) = m_rng.gen();
*(reinterpret_cast<int32_t *>(&m_header[get_header_len()-4])) = m_rng.rand();
}
void clear_masking_key() {
// this is a no-op as clearing the mask bit also changes the get_header_len