From 689c1362984e47e04f88853e2c9f8b23b2d796f1 Mon Sep 17 00:00:00 2001 From: Peter Thorson Date: Tue, 6 Dec 2011 22:14:30 -0600 Subject: [PATCH] policy refactor echo client passes all autobahn tests --- examples/echo_client/echo_client.cpp | 32 ++++++++++------------------ src/connection.hpp | 17 ++++++++------- src/http/parser.hpp | 4 +++- src/processors/hybi.hpp | 11 ++++++++++ src/roles/client.hpp | 32 ++++++++++++++++++++++------ src/roles/server.hpp | 10 ++++----- src/uri.cpp | 11 ++++++++++ src/uri.hpp | 1 + src/websocket_frame.hpp | 23 +++++++++++++++----- 9 files changed, 95 insertions(+), 46 deletions(-) diff --git a/examples/echo_client/echo_client.cpp b/examples/echo_client/echo_client.cpp index 3881127f6b..0839d5b2b2 100644 --- a/examples/echo_client/echo_client.cpp +++ b/examples/echo_client/echo_client.cpp @@ -39,24 +39,7 @@ public: typedef echo_client_handler type; typedef plain_endpoint_type::connection_ptr connection_ptr; - void validate(connection_ptr connection) { - //std::cout << "state: " << connection->get_state() << std::endl; - } - - void on_open(connection_ptr connection) { - //std::cout << "connection opened" << std::endl; - } - - void on_close(connection_ptr connection) { - //std::cout << "connection closed" << std::endl; - } - - void on_message(connection_ptr connection,websocketpp::message::data_ptr msg) { - //std::cout << "got message: " << *msg << std::endl; - connection->send(msg->get_payload(),(msg->get_opcode() == websocketpp::frame::opcode::BINARY)); - - - + void on_message(connection_ptr connection,websocketpp::message::data_ptr msg) { if (connection->get_resource() == "/getCaseCount") { std::cout << "Detected " << msg->get_payload() << " test cases." << std::endl; m_case_count = atoi(msg->get_payload().c_str()); @@ -64,7 +47,6 @@ public: connection->send(msg->get_payload(),(msg->get_opcode() == websocketpp::frame::opcode::BINARY)); } - connection->recycle(msg); } @@ -94,9 +76,13 @@ int main(int argc, char* argv[]) { connection_ptr connection; plain_endpoint_type endpoint(handler); + endpoint.alog().unset_level(websocketpp::log::alevel::ALL); + + endpoint.elog().unset_level(websocketpp::log::elevel::ALL); + connection = endpoint.connect("ws://localhost:9001/getCaseCount"); - + endpoint.run(); /*boost::asio::io_service io_service; @@ -112,11 +98,15 @@ int main(int argc, char* argv[]) { std::cout << "case count: " << boost::dynamic_pointer_cast(handler)->m_case_count << std::endl; for (int i = 1; i <= boost::dynamic_pointer_cast(handler)->m_case_count; i++) { + endpoint.reset(); + std::stringstream url; - url << "ws://localhost:9001/runCase?case=" << i << "&agent=\"WebSocket++Snapshot/2011-10-27\""; + url << "ws://localhost:9001/runCase?case=" << i << "&agent=\"WebSocket++Snapshot/2011-12-06\""; connection = endpoint.connect(url.str()); + + endpoint.run(); } std::cout << "done" << std::endl; diff --git a/src/connection.hpp b/src/connection.hpp index e894e7ada6..7740da9754 100644 --- a/src/connection.hpp +++ b/src/connection.hpp @@ -139,9 +139,9 @@ public: void send(const utf8_string& payload,bool binary = false) { binary_string_ptr msg; if (binary) { - msg = m_processor->prepare_frame(frame::opcode::BINARY,false,payload); + msg = m_processor->prepare_frame(frame::opcode::BINARY,!m_endpoint.is_server(),payload); } else { - msg = m_processor->prepare_frame(frame::opcode::TEXT,false,payload); + msg = m_processor->prepare_frame(frame::opcode::TEXT,!m_endpoint.is_server(),payload); } // TODO: decide which of these to use. Direct function call better @@ -159,7 +159,7 @@ public: } void send(const binary_string& data) { binary_string_ptr msg(m_processor->prepare_frame(frame::opcode::BINARY, - false,data)); + !m_endpoint.is_server(),data)); m_endpoint.endpoint_base::m_io_service.post( boost::bind( &type::write_message, @@ -171,7 +171,8 @@ public: } void ping(const binary_string& payload) { binary_string_ptr msg(m_processor->prepare_frame(frame::opcode::PING, - false,payload)); + !m_endpoint.is_server(), + payload)); m_endpoint.endpoint_base::m_io_service.post( boost::bind( @@ -181,7 +182,7 @@ public: } void pong(const binary_string& payload) { binary_string_ptr msg(m_processor->prepare_frame(frame::opcode::PONG, - false,payload)); + !m_endpoint.is_server(),payload)); m_endpoint.endpoint_base::m_io_service.post( boost::bind( &type::write_message, @@ -380,7 +381,7 @@ protected: if (response) { // send response ping - write_message(m_processor->prepare_frame(frame::opcode::PONG,false,msg->get_payload())); + write_message(m_processor->prepare_frame(frame::opcode::PONG,!m_endpoint.is_server(),msg->get_payload())); } break; case frame::opcode::PONG: @@ -454,7 +455,7 @@ protected: write_message(m_processor->prepare_close_frame(m_local_close_code, - false, + !m_endpoint.is_server(), m_local_close_reason)); m_write_state = INTURRUPT; } @@ -491,7 +492,7 @@ protected: write_message(m_processor->prepare_close_frame(m_local_close_code, - false, + !m_endpoint.is_server(), m_local_close_reason)); m_write_state = INTURRUPT; } diff --git a/src/http/parser.hpp b/src/http/parser.hpp index 1fb26de2a4..6ab143e00a 100644 --- a/src/http/parser.hpp +++ b/src/http/parser.hpp @@ -213,12 +213,14 @@ public: std::stringstream ss(response); std::string str_val; int int_val; + char char_val[256]; ss >> str_val; set_version(str_val); ss >> int_val; - set_status(status_code::value(int_val),str_val); + ss.getline(char_val,256); + set_status(status_code::value(int_val),std::string(char_val)); } else { return false; } diff --git a/src/processors/hybi.hpp b/src/processors/hybi.hpp index fcf5cfe0e3..1ef8cec54b 100644 --- a/src/processors/hybi.hpp +++ b/src/processors/hybi.hpp @@ -356,14 +356,21 @@ public: // TODO: utf8 validation on payload. + + binary_string_ptr response(new binary_string(0)); + m_write_frame.reset(); m_write_frame.set_opcode(opcode); m_write_frame.set_masked(mask); + m_write_frame.set_fin(true); m_write_frame.set_payload(payload); + + m_write_frame.process_payload(); + // TODO response->resize(m_write_frame.get_header_len()+m_write_frame.get_payload().size()); @@ -395,6 +402,8 @@ public: m_write_frame.set_fin(true); m_write_frame.set_payload(payload); + m_write_frame.process_payload(); + // TODO response->resize(m_write_frame.get_header_len()+m_write_frame.get_payload().size()); @@ -418,6 +427,8 @@ public: m_write_frame.set_fin(true); m_write_frame.set_status(code,reason); + m_write_frame.process_payload(); + // TODO response->resize(m_write_frame.get_header_len()+m_write_frame.get_payload().size()); diff --git a/src/roles/client.hpp b/src/roles/client.hpp index e116b7e301..135ab060e5 100644 --- a/src/roles/client.hpp +++ b/src/roles/client.hpp @@ -83,30 +83,36 @@ public: uint16_t get_port() const { return m_uri->get_port(); } + + int32_t rand() { + return m_endpoint.rand(); + } protected: connection(endpoint& e) : m_endpoint(e), m_connection(static_cast< connection_type& >(*this)), // TODO: version shouldn't be hardcoded - m_version(17) {} + m_version(13) {} void set_uri(uri_ptr u) { m_uri = u; } void async_init() { + m_connection.m_processor = processor::ptr(new processor::hybi(m_connection)); + write_request(); } - int32_t rand() { - return m_endpoint.rand(); - } + void write_request(); void handle_write_request(const boost::system::error_code& error); void read_response(); void handle_read_response(const boost::system::error_code& error, std::size_t bytes_transferred); + + void log_open_result(); private: endpoint& m_endpoint; connection_type& m_connection; @@ -265,7 +271,7 @@ void client::connection::write_request() { m_request.add_header("Connection","Upgrade"); m_request.replace_header("Sec-WebSocket-Version","13"); - m_request.replace_header("Host",m_uri->get_host()); + m_request.replace_header("Host",m_uri->get_host_port()); if (m_origin != "") { m_request.replace_header("Origin",m_origin); @@ -418,7 +424,7 @@ void client::connection::handle_read_response ( } catch (const http::exception& e) { m_endpoint.elog().at(log::elevel::ERROR) << "Error processing server handshake. Server HTTP response: " - << e.m_error_msg << "(" << e.m_error_code + << e.m_error_msg << " (" << e.m_error_code << ") Local error: " << e.what() << log::endl; return; } @@ -427,6 +433,20 @@ void client::connection::handle_read_response ( // start session loop } +template +template +void client::connection::log_open_result() { + /*std::stringstream version; + version << "v" << m_version << " "; + + m_endpoint.alog().at(log::alevel::CONNECT) << (m_version == -1 ? "HTTP" : "WebSocket") << " Connection " + << m_connection.get_raw_socket().remote_endpoint() << " " + << (m_version == -1 ? "" : version.str()) + << (get_request_header("User-Agent") == "" ? "NULL" : get_request_header("User-Agent")) + << " " << m_uri->get_resource() << " " << m_response.get_status_code() + << log::endl;*/ +} + } // namespace role } // namespace websocketpp diff --git a/src/roles/server.hpp b/src/roles/server.hpp index f41c1dfa3e..c69ecb487d 100644 --- a/src/roles/server.hpp +++ b/src/roles/server.hpp @@ -419,7 +419,7 @@ template void server::connection::write_response() { m_response.set_version("HTTP/1.1"); - if (m_response.status_code() == http::status_code::SWITCHING_PROTOCOLS) { + if (m_response.get_status_code() == http::status_code::SWITCHING_PROTOCOLS) { // websocket response m_connection.m_processor->handshake_response(m_request,m_response); @@ -470,7 +470,7 @@ void server::connection::handle_write_response( log_open_result(); - if (m_response.status_code() != http::status_code::SWITCHING_PROTOCOLS) { + if (m_response.get_status_code() != http::status_code::SWITCHING_PROTOCOLS) { if (m_version == -1) { // if this was not a websocket connection, we have written // the expected response and the connection can be closed. @@ -478,8 +478,8 @@ void server::connection::handle_write_response( // this was a websocket connection that ended in an error m_endpoint.elog().at(log::elevel::ERROR) << "Handshake ended with HTTP error: " - << m_response.status_code() << " " - << m_response.status_msg() << log::endl; + << m_response.get_status_code() << " " + << m_response.get_status_msg() << log::endl; } m_connection.terminate(true); return; @@ -502,7 +502,7 @@ void server::connection::log_open_result() { << m_connection.get_raw_socket().remote_endpoint() << " " << (m_version == -1 ? "" : version.str()) << (get_request_header("User-Agent") == "" ? "NULL" : get_request_header("User-Agent")) - << " " << m_uri->get_resource() << " " << m_response.status_code() + << " " << m_uri->get_resource() << " " << m_response.get_status_code() << log::endl; } diff --git a/src/uri.cpp b/src/uri.cpp index e3c8a0ec32..493b2bafb3 100644 --- a/src/uri.cpp +++ b/src/uri.cpp @@ -118,6 +118,17 @@ std::string uri::get_host() const { return m_host; } +std::string uri::get_host_port() const { + if (m_port == (m_secure ? DEFAULT_SECURE_PORT : DEFAULT_PORT)) { + return m_host; + } else { + std::stringstream p; + p << m_host << ":" << m_port; + return p.str(); + } + +} + uint16_t uri::get_port() const { return m_port; } diff --git a/src/uri.hpp b/src/uri.hpp index 935c65cc99..d68916ee5f 100644 --- a/src/uri.hpp +++ b/src/uri.hpp @@ -64,6 +64,7 @@ public: bool get_secure() const; std::string get_host() const; + std::string get_host_port() const; uint16_t get_port() const; std::string get_port_str() const; std::string get_resource() const; diff --git a/src/websocket_frame.hpp b/src/websocket_frame.hpp index 0209514f39..a924324d09 100644 --- a/src/websocket_frame.hpp +++ b/src/websocket_frame.hpp @@ -277,10 +277,11 @@ public: } char* get_masking_key() { - if (m_state != STATE_READY) { + /*if (m_state != STATE_READY) { + std::cout << "throw" << std::endl; throw processor::exception("attempted to get masking_key before reading full header"); - } - return m_header[get_header_len()-4]; + }*/ + return &m_header[get_header_len()-4]; } // get and set header bits @@ -427,6 +428,8 @@ public: throw processor::exception("control frames can't have large payloads",processor::error::PROTOCOL_VIOLATION); } + bool masked = get_masked(); + if (s <= limits::PAYLOAD_SIZE_BASIC) { m_header[1] = s; } else if (s <= limits::PAYLOAD_SIZE_EXTENDED) { @@ -443,6 +446,10 @@ public: throw processor::exception("payload size limit is 63 bits",processor::error::PROTOCOL_VIOLATION); } + if (masked) { + m_header[1] |= BPB1_MASK; + } + m_payload.resize(s); } @@ -466,8 +473,14 @@ public: *reinterpret_cast(&val[0]) = htons(status); + bool masked = get_masked(); + m_header[1] = message.size()+2; + if (masked) { + m_header[1] |= BPB1_MASK; + } + m_payload[0] = val[0]; m_payload[1] = val[1]; @@ -570,7 +583,7 @@ public: void process_payload() { if (get_masked()) { - char *masking_key = &m_header[get_header_len()-4]; + char *masking_key = get_masking_key(); for (uint64_t i = 0; i < m_payload.size(); i++) { m_payload[i] = (m_payload[i] ^ masking_key[i%4]); @@ -638,7 +651,7 @@ public: } void generate_masking_key() { - *(reinterpret_cast(&m_header[get_header_len()-4])) = m_rng.gen(); + *(reinterpret_cast(&m_header[get_header_len()-4])) = m_rng.rand(); } void clear_masking_key() { // this is a no-op as clearing the mask bit also changes the get_header_len