From 17fcb3f8be0a576f0d0060ea93b0e97340b1477f Mon Sep 17 00:00:00 2001 From: Peter Thorson Date: Sat, 6 Apr 2013 11:09:41 -0500 Subject: [PATCH] adds subprotocol selection to connection --- test/roles/server.cpp | 84 ++++++++++++++++++++++++---- websocketpp/connection.hpp | 25 +++++++++ websocketpp/error.hpp | 7 ++- websocketpp/impl/connection_impl.hpp | 38 +++++++++++-- 4 files changed, 138 insertions(+), 16 deletions(-) diff --git a/test/roles/server.cpp b/test/roles/server.cpp index 522be9bdcb..bf47bfba58 100644 --- a/test/roles/server.cpp +++ b/test/roles/server.cpp @@ -85,18 +85,26 @@ void echo_func(server* s, websocketpp::connection_hdl hdl, message_ptr msg) { s->send(hdl, msg->get_payload(), msg->get_opcode()); } -bool validate_func_subprotocol(server* s, std::string* out, websocketpp::connection_hdl hdl) { +bool validate_func_subprotocol(server* s, std::string* out, std::string accept, + websocketpp::connection_hdl hdl) +{ server::connection_ptr con = s->get_con_from_hdl(hdl); std::stringstream o; - o << con->get_requested_subprotocols().size(); - if (con->get_requested_subprotocols().size() == 1) { - o << "," << con->get_requested_subprotocols()[0]; + const std::vector & protocols = con->get_requested_subprotocols(); + std::vector::const_iterator it; + + for (it = protocols.begin(); it != protocols.end(); ++it) { + o << *it << ","; } *out = o.str(); + if (accept != "") { + con->select_subprotocol(accept); + } + return true; } @@ -114,7 +122,7 @@ BOOST_AUTO_TEST_CASE( basic_websocket_request ) { server s; s.set_user_agent("test"); - BOOST_CHECK(run_server_test(s,input) == output); + BOOST_CHECK_EQUAL(run_server_test(s,input), output); } BOOST_AUTO_TEST_CASE( invalid_websocket_version ) { @@ -125,7 +133,7 @@ BOOST_AUTO_TEST_CASE( invalid_websocket_version ) { s.set_user_agent("test"); //s.set_message_handler(bind(&echo_func,&s,::_1,::_2)); - BOOST_CHECK(run_server_test(s,input) == output); + BOOST_CHECK_EQUAL(run_server_test(s,input), output); } BOOST_AUTO_TEST_CASE( unimplimented_websocket_version ) { @@ -136,10 +144,10 @@ BOOST_AUTO_TEST_CASE( unimplimented_websocket_version ) { server s; s.set_user_agent("test"); - BOOST_CHECK(run_server_test(s,input) == output); + BOOST_CHECK_EQUAL(run_server_test(s,input), output); } -BOOST_AUTO_TEST_CASE( accept_subprotocol_empty ) { +BOOST_AUTO_TEST_CASE( list_subprotocol_empty ) { std::string input = "GET / HTTP/1.1\r\nHost: www.example.com\r\nConnection: upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Version: 13\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nOrigin: http://www.example.com\r\nSec-WebSocket-Protocol: foo\r\n\r\n"; std::string output = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\nServer: test\r\nUpgrade: websocket\r\n\r\n"; @@ -154,7 +162,7 @@ BOOST_AUTO_TEST_CASE( accept_subprotocol_empty ) { BOOST_CHECK_EQUAL(subprotocol, ""); } -BOOST_AUTO_TEST_CASE( accept_subprotocol_one ) { +BOOST_AUTO_TEST_CASE( list_subprotocol_one ) { std::string input = "GET / HTTP/1.1\r\nHost: www.example.com\r\nConnection: upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Version: 13\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nOrigin: http://www.example.com\r\nSec-WebSocket-Protocol: foo\r\n\r\n"; std::string output = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\nServer: test\r\nUpgrade: websocket\r\n\r\n"; @@ -164,14 +172,68 @@ BOOST_AUTO_TEST_CASE( accept_subprotocol_one ) { server s; s.set_user_agent("test"); - s.set_validate_handler(bind(&validate_func_subprotocol,&s,&validate,::_1)); + s.set_validate_handler(bind(&validate_func_subprotocol,&s,&validate,"",::_1)); s.set_open_handler(bind(&open_func_subprotocol,&s,&open,::_1)); BOOST_CHECK_EQUAL(run_server_test(s,input), output); - BOOST_CHECK_EQUAL(validate, "1,foo"); + BOOST_CHECK_EQUAL(validate, "foo,"); BOOST_CHECK_EQUAL(open, ""); } +BOOST_AUTO_TEST_CASE( accept_subprotocol_one ) { + std::string input = "GET / HTTP/1.1\r\nHost: www.example.com\r\nConnection: upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Version: 13\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nOrigin: http://www.example.com\r\nSec-WebSocket-Protocol: foo\r\n\r\n"; + + std::string output = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\nSec-WebSocket-Protocol: foo\r\nServer: test\r\nUpgrade: websocket\r\n\r\n"; + + std::string validate; + std::string open; + + server s; + s.set_user_agent("test"); + s.set_validate_handler(bind(&validate_func_subprotocol,&s,&validate,"foo",::_1)); + s.set_open_handler(bind(&open_func_subprotocol,&s,&open,::_1)); + + BOOST_CHECK_EQUAL(run_server_test(s,input), output); + BOOST_CHECK_EQUAL(validate, "foo,"); + BOOST_CHECK_EQUAL(open, "foo"); +} + +BOOST_AUTO_TEST_CASE( accept_subprotocol_invalid ) { + std::string input = "GET / HTTP/1.1\r\nHost: www.example.com\r\nConnection: upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Version: 13\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nOrigin: http://www.example.com\r\nSec-WebSocket-Protocol: foo\r\n\r\n"; + + std::string output = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\nSec-WebSocket-Protocol: foo\r\nServer: test\r\nUpgrade: websocket\r\n\r\n"; + + std::string validate; + std::string open; + + server s; + s.set_user_agent("test"); + s.set_validate_handler(bind(&validate_func_subprotocol,&s,&validate,"foo2",::_1)); + s.set_open_handler(bind(&open_func_subprotocol,&s,&open,::_1)); + + std::string o; + + BOOST_CHECK_THROW(o = run_server_test(s,input), websocketpp::lib::error_code); +} + +BOOST_AUTO_TEST_CASE( accept_subprotocol_two ) { + std::string input = "GET / HTTP/1.1\r\nHost: www.example.com\r\nConnection: upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Version: 13\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nOrigin: http://www.example.com\r\nSec-WebSocket-Protocol: foo, bar\r\n\r\n"; + + std::string output = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\nSec-WebSocket-Protocol: bar\r\nServer: test\r\nUpgrade: websocket\r\n\r\n"; + + std::string validate; + std::string open; + + server s; + s.set_user_agent("test"); + s.set_validate_handler(bind(&validate_func_subprotocol,&s,&validate,"bar",::_1)); + s.set_open_handler(bind(&open_func_subprotocol,&s,&open,::_1)); + + BOOST_CHECK_EQUAL(run_server_test(s,input), output); + BOOST_CHECK_EQUAL(validate, "foo,bar,"); + BOOST_CHECK_EQUAL(open, "bar"); +} + /*BOOST_AUTO_TEST_CASE( user_reject_origin ) { std::string input = "GET / HTTP/1.1\r\nHost: www.example.com\r\nConnection: upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Version: 13\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nOrigin: http://www.example2.com\r\n\r\n"; std::string output = "HTTP/1.1 403 Forbidden\r\nServer: test\r\n\r\n"; diff --git a/websocketpp/connection.hpp b/websocketpp/connection.hpp index 1e31e85103..a4dcd02ba1 100644 --- a/websocketpp/connection.hpp +++ b/websocketpp/connection.hpp @@ -542,6 +542,31 @@ public: */ const std::vector & get_requested_subprotocols() const; + /// Select a subprotocol to use (exception free) + /** + * Indicates which subprotocol should be used for this connection. Valid + * only during the validate handler callback. Subprotocol selected must have + * been requested by the client. Consult get_requested_subprotocols() for a + * list of valid subprotocols. + * + * @param value The subprotocol to select + * + * @param ec A reference to an error code that will be filled in the case of + * errors + */ + void select_subprotocol(const std::string & value, lib::error_code & ec); + + /// Select a subprotocol to use + /** + * Indicates which subprotocol should be used for this connection. Valid + * only during the validate handler callback. Subprotocol selected must have + * been requested by the client. Consult get_requested_subprotocols() for a + * list of valid subprotocols. + * + * @param value The subprotocol to select + */ + void select_subprotocol(const std::string & value); + ///////////////////////////////////////////////////////////// // Pass-through access to the request and response objects // ///////////////////////////////////////////////////////////// diff --git a/websocketpp/error.hpp b/websocketpp/error.hpp index 27b020fc4b..8277ac4a95 100644 --- a/websocketpp/error.hpp +++ b/websocketpp/error.hpp @@ -84,7 +84,10 @@ enum value { test, /// Connection creation attempted failed - con_creation_failed + con_creation_failed, + + /// Selected subprotocol was not requested by the client + unrequested_subprotocol }; // enum value @@ -128,6 +131,8 @@ public: return "Test Error"; case error::con_creation_failed: return "Connection creation attempt failed"; + case error::unrequested_subprotocol: + return "Selected subprotocol was not requested by the client"; default: return "Unknown"; } diff --git a/websocketpp/impl/connection_impl.hpp b/websocketpp/impl/connection_impl.hpp index 6914de4636..ca2c89512c 100644 --- a/websocketpp/impl/connection_impl.hpp +++ b/websocketpp/impl/connection_impl.hpp @@ -327,9 +327,37 @@ connection::get_requested_subprotocols() const { return m_requested_subprotocols; } +template +void connection::select_subprotocol(const std::string & value, + lib::error_code & ec) +{ + if (value.empty()) { + ec = lib::error_code(); + return; + } + + std::vector::iterator it; + + it = std::find(m_requested_subprotocols.begin(), + m_requested_subprotocols.end(), + value); + + if (it == m_requested_subprotocols.end()) { + ec = error::make_error_code(error::unrequested_subprotocol); + return; + } + + m_subprotocol = value; +} - - +template +void connection::select_subprotocol(const std::string & value) { + lib::error_code ec; + this->select_subprotocol(value,ec); + if (ec) { + throw ec; + } +} template void connection::set_status( @@ -868,7 +896,7 @@ bool connection::process_handshake_request() { // Write the appropriate response headers based on request and // processor version - ec = m_processor->process_handshake(m_request,m_response); + ec = m_processor->process_handshake(m_request,m_subprotocol,m_response); if (ec) { std::stringstream s; @@ -962,7 +990,9 @@ void connection::send_http_response() { // Set some common headers m_response.replace_header("Server",m_user_agent); - + + + // have the processor generate the raw bytes for the wire (if it exists) if (m_processor) { m_handshake_buffer = m_processor->get_raw(m_response);