adds subprotocol selection to connection

This commit is contained in:
Peter Thorson
2013-04-06 11:09:41 -05:00
parent 8f812aafd5
commit 17fcb3f8be
4 changed files with 138 additions and 16 deletions

View File

@@ -85,18 +85,26 @@ void echo_func(server* s, websocketpp::connection_hdl hdl, message_ptr msg) {
s->send(hdl, msg->get_payload(), msg->get_opcode());
}
bool validate_func_subprotocol(server* s, std::string* out, websocketpp::connection_hdl hdl) {
bool validate_func_subprotocol(server* s, std::string* out, std::string accept,
websocketpp::connection_hdl hdl)
{
server::connection_ptr con = s->get_con_from_hdl(hdl);
std::stringstream o;
o << con->get_requested_subprotocols().size();
if (con->get_requested_subprotocols().size() == 1) {
o << "," << con->get_requested_subprotocols()[0];
const std::vector<std::string> & protocols = con->get_requested_subprotocols();
std::vector<std::string>::const_iterator it;
for (it = protocols.begin(); it != protocols.end(); ++it) {
o << *it << ",";
}
*out = o.str();
if (accept != "") {
con->select_subprotocol(accept);
}
return true;
}
@@ -114,7 +122,7 @@ BOOST_AUTO_TEST_CASE( basic_websocket_request ) {
server s;
s.set_user_agent("test");
BOOST_CHECK(run_server_test(s,input) == output);
BOOST_CHECK_EQUAL(run_server_test(s,input), output);
}
BOOST_AUTO_TEST_CASE( invalid_websocket_version ) {
@@ -125,7 +133,7 @@ BOOST_AUTO_TEST_CASE( invalid_websocket_version ) {
s.set_user_agent("test");
//s.set_message_handler(bind(&echo_func,&s,::_1,::_2));
BOOST_CHECK(run_server_test(s,input) == output);
BOOST_CHECK_EQUAL(run_server_test(s,input), output);
}
BOOST_AUTO_TEST_CASE( unimplimented_websocket_version ) {
@@ -136,10 +144,10 @@ BOOST_AUTO_TEST_CASE( unimplimented_websocket_version ) {
server s;
s.set_user_agent("test");
BOOST_CHECK(run_server_test(s,input) == output);
BOOST_CHECK_EQUAL(run_server_test(s,input), output);
}
BOOST_AUTO_TEST_CASE( accept_subprotocol_empty ) {
BOOST_AUTO_TEST_CASE( list_subprotocol_empty ) {
std::string input = "GET / HTTP/1.1\r\nHost: www.example.com\r\nConnection: upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Version: 13\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nOrigin: http://www.example.com\r\nSec-WebSocket-Protocol: foo\r\n\r\n";
std::string output = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\nServer: test\r\nUpgrade: websocket\r\n\r\n";
@@ -154,7 +162,7 @@ BOOST_AUTO_TEST_CASE( accept_subprotocol_empty ) {
BOOST_CHECK_EQUAL(subprotocol, "");
}
BOOST_AUTO_TEST_CASE( accept_subprotocol_one ) {
BOOST_AUTO_TEST_CASE( list_subprotocol_one ) {
std::string input = "GET / HTTP/1.1\r\nHost: www.example.com\r\nConnection: upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Version: 13\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nOrigin: http://www.example.com\r\nSec-WebSocket-Protocol: foo\r\n\r\n";
std::string output = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\nServer: test\r\nUpgrade: websocket\r\n\r\n";
@@ -164,14 +172,68 @@ BOOST_AUTO_TEST_CASE( accept_subprotocol_one ) {
server s;
s.set_user_agent("test");
s.set_validate_handler(bind(&validate_func_subprotocol,&s,&validate,::_1));
s.set_validate_handler(bind(&validate_func_subprotocol,&s,&validate,"",::_1));
s.set_open_handler(bind(&open_func_subprotocol,&s,&open,::_1));
BOOST_CHECK_EQUAL(run_server_test(s,input), output);
BOOST_CHECK_EQUAL(validate, "1,foo");
BOOST_CHECK_EQUAL(validate, "foo,");
BOOST_CHECK_EQUAL(open, "");
}
BOOST_AUTO_TEST_CASE( accept_subprotocol_one ) {
std::string input = "GET / HTTP/1.1\r\nHost: www.example.com\r\nConnection: upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Version: 13\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nOrigin: http://www.example.com\r\nSec-WebSocket-Protocol: foo\r\n\r\n";
std::string output = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\nSec-WebSocket-Protocol: foo\r\nServer: test\r\nUpgrade: websocket\r\n\r\n";
std::string validate;
std::string open;
server s;
s.set_user_agent("test");
s.set_validate_handler(bind(&validate_func_subprotocol,&s,&validate,"foo",::_1));
s.set_open_handler(bind(&open_func_subprotocol,&s,&open,::_1));
BOOST_CHECK_EQUAL(run_server_test(s,input), output);
BOOST_CHECK_EQUAL(validate, "foo,");
BOOST_CHECK_EQUAL(open, "foo");
}
BOOST_AUTO_TEST_CASE( accept_subprotocol_invalid ) {
std::string input = "GET / HTTP/1.1\r\nHost: www.example.com\r\nConnection: upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Version: 13\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nOrigin: http://www.example.com\r\nSec-WebSocket-Protocol: foo\r\n\r\n";
std::string output = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\nSec-WebSocket-Protocol: foo\r\nServer: test\r\nUpgrade: websocket\r\n\r\n";
std::string validate;
std::string open;
server s;
s.set_user_agent("test");
s.set_validate_handler(bind(&validate_func_subprotocol,&s,&validate,"foo2",::_1));
s.set_open_handler(bind(&open_func_subprotocol,&s,&open,::_1));
std::string o;
BOOST_CHECK_THROW(o = run_server_test(s,input), websocketpp::lib::error_code);
}
BOOST_AUTO_TEST_CASE( accept_subprotocol_two ) {
std::string input = "GET / HTTP/1.1\r\nHost: www.example.com\r\nConnection: upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Version: 13\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nOrigin: http://www.example.com\r\nSec-WebSocket-Protocol: foo, bar\r\n\r\n";
std::string output = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\nSec-WebSocket-Protocol: bar\r\nServer: test\r\nUpgrade: websocket\r\n\r\n";
std::string validate;
std::string open;
server s;
s.set_user_agent("test");
s.set_validate_handler(bind(&validate_func_subprotocol,&s,&validate,"bar",::_1));
s.set_open_handler(bind(&open_func_subprotocol,&s,&open,::_1));
BOOST_CHECK_EQUAL(run_server_test(s,input), output);
BOOST_CHECK_EQUAL(validate, "foo,bar,");
BOOST_CHECK_EQUAL(open, "bar");
}
/*BOOST_AUTO_TEST_CASE( user_reject_origin ) {
std::string input = "GET / HTTP/1.1\r\nHost: www.example.com\r\nConnection: upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Version: 13\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nOrigin: http://www.example2.com\r\n\r\n";
std::string output = "HTTP/1.1 403 Forbidden\r\nServer: test\r\n\r\n";

View File

@@ -542,6 +542,31 @@ public:
*/
const std::vector<std::string> & get_requested_subprotocols() const;
/// Select a subprotocol to use (exception free)
/**
* Indicates which subprotocol should be used for this connection. Valid
* only during the validate handler callback. Subprotocol selected must have
* been requested by the client. Consult get_requested_subprotocols() for a
* list of valid subprotocols.
*
* @param value The subprotocol to select
*
* @param ec A reference to an error code that will be filled in the case of
* errors
*/
void select_subprotocol(const std::string & value, lib::error_code & ec);
/// Select a subprotocol to use
/**
* Indicates which subprotocol should be used for this connection. Valid
* only during the validate handler callback. Subprotocol selected must have
* been requested by the client. Consult get_requested_subprotocols() for a
* list of valid subprotocols.
*
* @param value The subprotocol to select
*/
void select_subprotocol(const std::string & value);
/////////////////////////////////////////////////////////////
// Pass-through access to the request and response objects //
/////////////////////////////////////////////////////////////

View File

@@ -84,7 +84,10 @@ enum value {
test,
/// Connection creation attempted failed
con_creation_failed
con_creation_failed,
/// Selected subprotocol was not requested by the client
unrequested_subprotocol
}; // enum value
@@ -128,6 +131,8 @@ public:
return "Test Error";
case error::con_creation_failed:
return "Connection creation attempt failed";
case error::unrequested_subprotocol:
return "Selected subprotocol was not requested by the client";
default:
return "Unknown";
}

View File

@@ -327,9 +327,37 @@ connection<config>::get_requested_subprotocols() const {
return m_requested_subprotocols;
}
template <typename config>
void connection<config>::select_subprotocol(const std::string & value,
lib::error_code & ec)
{
if (value.empty()) {
ec = lib::error_code();
return;
}
std::vector<std::string>::iterator it;
it = std::find(m_requested_subprotocols.begin(),
m_requested_subprotocols.end(),
value);
if (it == m_requested_subprotocols.end()) {
ec = error::make_error_code(error::unrequested_subprotocol);
return;
}
m_subprotocol = value;
}
template <typename config>
void connection<config>::select_subprotocol(const std::string & value) {
lib::error_code ec;
this->select_subprotocol(value,ec);
if (ec) {
throw ec;
}
}
template <typename config>
void connection<config>::set_status(
@@ -868,7 +896,7 @@ bool connection<config>::process_handshake_request() {
// Write the appropriate response headers based on request and
// processor version
ec = m_processor->process_handshake(m_request,m_response);
ec = m_processor->process_handshake(m_request,m_subprotocol,m_response);
if (ec) {
std::stringstream s;
@@ -962,7 +990,9 @@ void connection<config>::send_http_response() {
// Set some common headers
m_response.replace_header("Server",m_user_agent);
// have the processor generate the raw bytes for the wire (if it exists)
if (m_processor) {
m_handshake_buffer = m_processor->get_raw(m_response);