api tweaks, origin and uri detection behaviors. chat server updated for new apis

This commit is contained in:
Peter Thorson
2011-11-13 07:12:37 -06:00
parent 74aa325591
commit 4ed86a7c30
11 changed files with 146 additions and 94 deletions

View File

@@ -32,35 +32,36 @@
using websocketchat::chat_server_handler;
using websocketpp::session::server_session_ptr;
void chat_server_handler::validate(server_session_ptr session) {
std::stringstream err;
// We only know about the chat resource
if (session->get_uri().resource != "/chat") {
err << "Request for unknown resource " << client->get_resource();
throw(websocketpp::handshake_error(err.str(),404));
if (session->get_resource() != "/chat") {
err << "Request for unknown resource " << session->get_resource();
throw(websocketpp::http::exception(err.str(),websocketpp::http::status_code::NOT_FOUND));
}
// Require specific origin example
if (client->get_origin() != "http://zaphoyd.com") {
err << "Request from unrecognized origin: " << client->get_origin();
throw(websocketpp::handshake_error(err.str(),403));
if (session->get_origin() != "http://zaphoyd.com") {
err << "Request from unrecognized origin: " << session->get_origin();
throw(websocketpp::http::exception(err.str(),websocketpp::http::status_code::FORBIDDEN));
}
}
void chat_server_handler::on_open(session_ptr client) {
std::cout << "client " << client << " joined the lobby." << std::endl;
m_connections.insert(std::pair<session_ptr,std::string>(client,get_con_id(client)));
void chat_server_handler::on_open(server_session_ptr session) {
std::cout << "client " << session << " joined the lobby." << std::endl;
m_connections.insert(std::pair<server_session_ptr,std::string>(session,get_con_id(session)));
// send user list and signon message to all clients
send_to_all(serialize_state());
client->send(encode_message("server","Welcome, use the /alias command to set a name, /help for a list of other commands."));
send_to_all(encode_message("server",m_connections[client]+" has joined the chat."));
session->send(encode_message("server","Welcome, use the /alias command to set a name, /help for a list of other commands."));
send_to_all(encode_message("server",m_connections[session]+" has joined the chat."));
}
void chat_server_handler::on_close(session_ptr client) {
std::map<session_ptr,std::string>::iterator it = m_connections.find(client);
void chat_server_handler::on_close(server_session_ptr session) {
std::map<server_session_ptr,std::string>::iterator it = m_connections.find(session);
if (it == m_connections.end()) {
// this client has already disconnected, we can ignore this.
@@ -70,7 +71,7 @@ void chat_server_handler::on_close(session_ptr client) {
return;
}
std::cout << "client " << client << " left the lobby." << std::endl;
std::cout << "client " << session << " left the lobby." << std::endl;
const std::string alias = it->second;
m_connections.erase(it);
@@ -80,31 +81,31 @@ void chat_server_handler::on_close(session_ptr client) {
send_to_all(encode_message("server",alias+" has left the chat."));
}
void chat_server_handler::on_message(session_ptr client,const std::string &msg) {
std::cout << "message from client " << client << ": " << msg << std::endl;
void chat_server_handler::on_message(server_session_ptr session,websocketpp::utf8_string_ptr msg) {
std::cout << "message from client " << session << ": " << *msg << std::endl;
// check for special command messages
if (msg == "/help") {
if (*msg == "/help") {
// print command list
client->send(encode_message("server","avaliable commands:<br />&nbsp;&nbsp;&nbsp;&nbsp;/help - show this help<br />&nbsp;&nbsp;&nbsp;&nbsp;/alias foo - set alias to foo",false));
session->send(encode_message("server","avaliable commands:<br />&nbsp;&nbsp;&nbsp;&nbsp;/help - show this help<br />&nbsp;&nbsp;&nbsp;&nbsp;/alias foo - set alias to foo",false));
return;
}
if (msg.substr(0,7) == "/alias ") {
if (msg->substr(0,7) == "/alias ") {
std::string response;
std::string alias;
if (msg.size() == 7) {
if (msg->size() == 7) {
response = "You must enter an alias.";
client->send(encode_message("server",response));
session->send(encode_message("server",response));
return;
} else {
alias = msg.substr(7);
alias = msg->substr(7);
}
response = m_connections[client] + " is now known as "+alias;
response = m_connections[session] + " is now known as "+alias;
// store alias pre-escaped so we don't have to do this replacing every time this
// user sends a message
@@ -118,7 +119,7 @@ void chat_server_handler::on_message(session_ptr client,const std::string &msg)
boost::algorithm::replace_all(alias,"<","&lt;");
boost::algorithm::replace_all(alias,">","&gt;");
m_connections[client] = alias;
m_connections[session] = alias;
// set alias
send_to_all(serialize_state());
@@ -127,13 +128,13 @@ void chat_server_handler::on_message(session_ptr client,const std::string &msg)
}
// catch other slash commands
if (msg[0] == '/') {
client->send(encode_message("server","unrecognized command"));
if ((*msg)[0] == '/') {
session->send(encode_message("server","unrecognized command"));
return;
}
// create JSON message to send based on msg
send_to_all(encode_message(m_connections[client],msg));
send_to_all(encode_message(m_connections[session],*msg));
}
// {"type":"participants","value":[<participant>,...]}
@@ -142,7 +143,7 @@ std::string chat_server_handler::serialize_state() {
s << "{\"type\":\"participants\",\"value\":[";
std::map<session_ptr,std::string>::iterator it;
std::map<server_session_ptr,std::string>::iterator it;
for (it = m_connections.begin(); it != m_connections.end(); it++) {
s << "\"" << (*it).second << "\"";
@@ -178,14 +179,14 @@ std::string chat_server_handler::encode_message(std::string sender,std::string m
return s.str();
}
std::string chat_server_handler::get_con_id(session_ptr s) {
std::string chat_server_handler::get_con_id(server_session_ptr s) {
std::stringstream endpoint;
endpoint << s->socket().remote_endpoint();
endpoint << s->get_endpoint();
return endpoint.str();
}
void chat_server_handler::send_to_all(std::string data) {
std::map<session_ptr,std::string>::iterator it;
std::map<server_session_ptr,std::string>::iterator it;
for (it = m_connections.begin(); it != m_connections.end(); it++) {
(*it).first->send(data);
}

View File

@@ -36,46 +36,33 @@ using boost::asio::ip::tcp;
using namespace websocketchat;
int main(int argc, char* argv[]) {
std::string host = "localhost";
short port = 9003;
std::string full_host;
if (argc == 3) {
if (argc == 2) {
// TODO: input validation?
host = argv[1];
port = atoi(argv[2]);
}
std::stringstream temp;
temp << host << ":" << port;
full_host = temp.str();
chat_server_handler_ptr chat_handler(new chat_server_handler());
try {
boost::asio::io_service io_service;
tcp::endpoint endpoint(tcp::v6(), port);
// create an instance of our handler
server_handler_ptr default_handler(new chat_server_handler());
websocketpp::server_ptr server(
new websocketpp::server(io_service,endpoint,chat_handler)
);
// create a server that listens on port `port` and uses our handler
websocketpp::basic_server_ptr server(new websocketpp::basic_server(port,default_handler));
server->elog().set_levels(websocketpp::log::elevel::DEVEL,websocketpp::log::elevel::FATAL);
server->alog().set_level(websocketpp::log::alevel::ALL);
// setup server settings
server->add_host(host);
server->add_host(full_host);
// Chat server should only be receiving small text messages, reduce max
// message size limit slightly to save memory, improve performance, and
// guard against DoS attacks.
server->set_max_message_size(0xFFFF); // 64KiB
//server->set_max_message_size(0xFFFF); // 64KiB
// start the server
server->start_accept();
std::cout << "Starting chat server on port " << port << std::endl;
std::cout << "Starting chat server on " << full_host << std::endl;
io_service.run();
server->run();
} catch (std::exception& e) {
std::cerr << "Exception: " << e.what() << std::endl;
}

View File

@@ -35,31 +35,19 @@
using boost::asio::ip::tcp;
int main(int argc, char* argv[]) {
std::string host = "localhost";
short port = 9002;
std::string full_host;
if (argc == 3) {
if (argc == 2) {
// TODO: input validation?
host = argv[1];
port = atoi(argv[2]);
}
std::stringstream temp;
temp << host << ":" << port;
full_host = temp.str();
try {
// create an instance of our handler
server_handler_ptr default_handler(new websocketecho::echo_server_handler());
// create a server that listens on port `port` and uses our handler
typedef boost::shared_ptr< websocketpp::server::server<> > server_ptr;
server_ptr server(new websocketpp::server::server<>(port,default_handler));
websocketpp::basic_server_ptr server(new websocketpp::basic_server(port,default_handler));
server->elog().set_levels(websocketpp::log::elevel::DEVEL,websocketpp::log::elevel::FATAL);
@@ -75,10 +63,10 @@ int main(int argc, char* argv[]) {
// server to test performance and protocol extremes.
//server->set_max_message_size(websocketpp::frame::limits::PAYLOAD_SIZE_JUMBO);
std::cout << "Starting echo server on port " << port << std::endl;
// start the server
server->run();
std::cout << "Starting echo server on " << full_host << std::endl;
} catch (std::exception& e) {
std::cerr << "Exception: " << e.what() << std::endl;
}

View File

@@ -85,7 +85,8 @@ namespace http {
NETWORK_AUTHENTICATION_REQUIRED = 511
};
std::string get_string(value c) {
// TODO: should this be inline?
inline std::string get_string(value c) {
switch (c) {
case CONTINUE:
return "Continue";

View File

@@ -45,7 +45,7 @@ namespace hybi_legacy_state {
class hybi_legacy_processor : public processor {
public:
hybi_legacy_processor() : m_state(hybi_legacy_state::INIT) {
hybi_legacy_processor(bool secure) : m_secure(secure),m_state(hybi_legacy_state::INIT) {
reset();
}
@@ -86,14 +86,26 @@ public:
// set a different one.
if (response.header("Sec-WebSocket-Location") == "") {
// TODO: extract from host header rather than hard code
ws_uri uri;
uri.secure = false;
uri.host = "localhost";
uri.port = 9002;
ws_uri uri = get_uri(request);
response.add_header("Sec-WebSocket-Location",uri.base());
}
}
std::string get_origin(const http::parser::request& request) const {
return request.header("Origin");
}
ws_uri get_uri(const http::parser::request& request) const {
ws_uri uri;
uri.secure = m_secure;
// TODO: check if get_uri is a full uri
// TODO: host and port
uri.resource = request.uri();
return uri;
}
void consume(std::istream& s) {
unsigned char c;
while (s.good() && m_state != hybi_legacy_state::DONE) {
@@ -251,6 +263,8 @@ private:
}
}
bool m_secure;
hybi_legacy_state::value m_state;
frame::opcode::value m_opcode;

View File

@@ -52,7 +52,7 @@ namespace hybi_state {
template <class rng_policy>
class hybi_processor : public processor {
public:
hybi_processor(rng_policy &rng) : m_fragmented_opcode(frame::opcode::CONTINUATION),m_utf8_payload(new utf8_string()),m_binary_payload(new binary_string()),m_read_frame(rng),m_write_frame(rng) {
hybi_processor(bool secure,rng_policy &rng) : m_secure(secure),m_fragmented_opcode(frame::opcode::CONTINUATION),m_utf8_payload(new utf8_string()),m_binary_payload(new binary_string()),m_read_frame(rng),m_write_frame(rng) {
reset();
}
@@ -115,6 +115,30 @@ public:
}
}
std::string get_origin(const http::parser::request& request) const {
std::string h = request.header("Sec-WebSocket-Version");
int version = atoi(h.c_str());
if (version == 13) {
return request.header("Origin");
} else if (version == 7 || version == 8) {
return request.header("Sec-WebSocket-Origin");
} else {
throw(http::exception("Could not determine origin header.",http::status_code::BAD_REQUEST));
}
}
ws_uri get_uri(const http::parser::request& request) const {
ws_uri uri;
uri.secure = m_secure;
// TODO: check if get_uri is a full uri
// TODO: host and port
uri.resource = request.uri();
return uri;
}
void handshake_response(const http::parser::request& request,http::parser::response& response) {
// TODO:
@@ -416,6 +440,7 @@ public:
}
private:
bool m_secure;
int m_state;
frame::opcode::value m_opcode;
frame::opcode::value m_fragmented_opcode;

View File

@@ -49,7 +49,12 @@ public:
virtual void validate_handshake(const http::parser::request& headers) const = 0;
virtual void handshake_response(const http::parser::request& request,http::parser::response& response) = 0;
// Extracts client origin from a handshake request
virtual std::string get_origin(const http::parser::request& request) const = 0;
// Extracts client uri from a handshake request
virtual ws_uri get_uri(const http::parser::request& request) const = 0;
// consume bytes, throw on exception
virtual void consume(std::istream& s) = 0;

View File

@@ -87,10 +87,17 @@ public:
virtual session::state::value get_state() const = 0;
virtual unsigned int get_version() const = 0;
virtual std::string get_origin() const = 0;
virtual std::string get_request_header(const std::string& key) const = 0;
virtual const ws_uri& get_uri() const = 0;
virtual std::string get_origin() const = 0;
// Information about the requested URI
virtual bool get_secure() const = 0;
virtual std::string get_host() const = 0;
virtual std::string get_resource() const = 0;
virtual uint16_t get_port() const = 0;
// Information about the connected endpoint
/* Tentative API member function */ virtual boost::asio::ip::tcp::endpoint get_endpoint() const = 0;
// Valid for CONNECTING state
virtual void add_response_header(const std::string& key, const std::string& value) = 0;

View File

@@ -78,7 +78,8 @@ private:
// convenience function that creates a validator, validates a complete string
// and returns the result.
bool validate(const std::string& s) {
// TODO: should this be inline?
inline bool validate(const std::string& s) {
validator v;
if (!v.decode(s.begin(),s.end())) {
return false;

View File

@@ -58,7 +58,7 @@ namespace po = boost::program_options;
#include "http/parser.hpp"
#include "logger/logger.hpp"
using boost::asio ::ip::tcp;
using boost::asio::ip::tcp;
using websocketpp::session::server_handler_ptr;
using websocketpp::protocol::processor_ptr;
@@ -114,15 +114,27 @@ public:
return m_request.header(key);
}
const ws_uri& get_uri() const {
return m_uri;
}
bool get_secure() const {
// TODO
return false;
}
std::string get_host() const {
return m_uri.host;
}
uint16_t get_port() const {
return m_uri.port;
}
std::string get_resource() const {
return m_uri.resource;
}
tcp::endpoint get_endpoint() const {
return m_socket.remote_endpoint();
}
// Valid for CONNECTING state
void add_response_header(const std::string& key, const std::string& value) {
m_response.add_header(key,value);
@@ -325,16 +337,18 @@ public:
}
m_request.add_header("Sec-WebSocket-Key3",std::string(foo));
m_processor = processor_ptr(new protocol::hybi_legacy_processor());
m_processor = processor_ptr(new protocol::hybi_legacy_processor(false));
} else if (m_version == 7 || m_version == 8 || m_version == 13) {
// create hybi 17 processor
m_processor = processor_ptr(new protocol::hybi_processor<blank_rng>(m_rng));
m_processor = processor_ptr(new protocol::hybi_processor<blank_rng>(false,m_rng));
} else {
// TODO: respond with unknown version message per spec
}
// ask new protocol whether this set of headers is valid
m_processor->validate_handshake(m_request);
m_origin = m_processor->get_origin(m_request);
m_uri = m_processor->get_uri(m_request);
// ask local application to confirm that it wants to accept
m_handler->validate(boost::static_pointer_cast<websocketpp::session::server>(connection_type::shared_from_this()));
@@ -1029,8 +1043,20 @@ private:
po::options_description m_desc;
po::variables_map m_vm;
};
}
// convenience type interface
// websocketpp::server_ptr represents a basic non-secure websocket server
typedef server::server<> basic_server;
typedef basic_server::ptr basic_server_ptr;
// websocketpp::secure_server_ptr represents a basic secure websocket server
// TODO:
typedef server::server<> secure_server;
typedef secure_server::ptr secure_server_ptr;
}
#endif // WEBSOCKET_SERVER_HPP

View File

@@ -32,9 +32,6 @@
#include "websocket_constants.hpp"
//#include "websocket_session.hpp"
//#include "websocket_server_session.hpp"
//#include "websocket_client_session.hpp"
#include "websocket_server.hpp"
//#include "websocket_client.hpp"