diff --git a/examples/echo_server/echo.hpp b/examples/echo_server/echo.hpp index 5551f231f5..1e3c00e706 100644 --- a/examples/echo_server/echo.hpp +++ b/examples/echo_server/echo.hpp @@ -29,40 +29,35 @@ #define ECHO_SERVER_HANDLER_HPP #include "../../src/websocketpp.hpp" -#include "../../src/websocket_connection_handler.hpp" +#include "../../src/interfaces/session.hpp" #include #include #include +using websocketpp::session::server_ptr; +using websocketpp::session::server_handler; + namespace websocketecho { -template -class echo_server_handler : public websocketpp::connection_handler { + class echo_server_handler : public server_handler { public: - //typedef boost::shared_ptr > - //typedef typename websocketpp::connection_handler::ptr session_ptr; - typedef typename websocketpp::connection_handler::session_ptr session_ptr; - - echo_server_handler() {} - virtual ~echo_server_handler() {} - // The echo server allows all domains is protocol free. - void validate(session_ptr client) {} + void validate(server_ptr session) {} // an echo server is stateless. // The handler has no need to keep track of connected clients. - void on_fail(session_ptr client) {} - void on_open(session_ptr client) {} - void on_close(session_ptr client) {} + void on_fail(server_ptr session) {} + void on_open(server_ptr session) {} + void on_close(server_ptr session) {} // both text and binary messages are echoed back to the sending client. - void on_message(session_ptr client,const std::string &msg) { - client->send(msg); + void on_message(server_ptr session,websocketpp::utf8_string_ptr msg) { + std::cout << *msg << std::endl; + session->send(*msg); } - void on_message(session_ptr client, - const std::vector &data) { - client->send(data); + void on_message(server_ptr session,websocketpp::binary_string_ptr data) { + session->send(*data); } }; diff --git a/examples/echo_server/echo_server.cpp b/examples/echo_server/echo_server.cpp index 2cda611905..97e3431136 100644 --- a/examples/echo_server/echo_server.cpp +++ b/examples/echo_server/echo_server.cpp @@ -53,24 +53,17 @@ int main(int argc, char* argv[]) { try { - boost::asio::io_service io_service; - tcp::endpoint endpoint(tcp::v6(), port); + // create an instance of our handler + server_handler_ptr default_handler(new websocketecho::echo_server_handler()); + + // create a server that listens on port `port` and uses our handler + typedef boost::shared_ptr< websocketpp::server::server<> > server_ptr; - using websocketpp::server; + server_ptr server(new websocketpp::server::server<>(port,default_handler)); - typedef boost::shared_ptr< server<> > server_ptr; - //typedef server<>::ptr server_ptr; + server->elog().set_levels(websocketpp::log::elevel::DEVEL,websocketpp::log::elevel::FATAL); - server_ptr s(new server<>(io_service,endpoint)); - - server<>::connection_handler_ptr handler = s->make_handler(); - - s->set_default_connection_handler(handler); - - s->elog().set_levels(websocketpp::log::elevel::DEVEL,websocketpp::log::elevel::FATAL); - - s->alog().set_level(websocketpp::log::alevel::CONNECT); - s->alog().set_level(websocketpp::log::alevel::DEBUG_HANDSHAKE); + server->alog().set_level(websocketpp::log::alevel::ALL); //server->parse_command_line(argc, argv); @@ -80,15 +73,12 @@ int main(int argc, char* argv[]) { // bump up max message size to maximum since we may be using the echo // server to test performance and protocol extremes. - s->set_max_message_size(websocketpp::frame::limits::PAYLOAD_SIZE_JUMBO); + //server->set_max_message_size(websocketpp::frame::limits::PAYLOAD_SIZE_JUMBO); // start the server - s->start_accept(); + server->run(); std::cout << "Starting echo server on " << full_host << std::endl; - - // start asio - io_service.run(); } catch (std::exception& e) { std::cerr << "Exception: " << e.what() << std::endl; } diff --git a/src/http/constants.hpp b/src/http/constants.hpp index df49b0238a..283e57e1b1 100644 --- a/src/http/constants.hpp +++ b/src/http/constants.hpp @@ -185,8 +185,27 @@ namespace http { } } - - + class exception : public std::exception { + public: + exception(const std::string& log_msg, + status_code::value error_code, + const std::string& error_msg = "", + const std::string& body = "") + : m_msg(log_msg), + m_error_code(error_code), + m_error_msg(error_msg), + m_body(body) {} + ~exception() throw() {} + + virtual const char* what() const throw() { + return m_msg.c_str(); + } + + std::string m_msg; + status_code::value m_error_code; + std::string m_error_msg; + std::string m_body; + }; } } diff --git a/src/http/parser.hpp b/src/http/parser.hpp index 7b1b5e9a0d..d6069eb0e9 100644 --- a/src/http/parser.hpp +++ b/src/http/parser.hpp @@ -76,9 +76,9 @@ public: } } - // multiple calls to set header will result in values aggregating. + // multiple calls to add header will result in values aggregating. // use replace_header if you do not want this behavior. - void set_header(const std::string &key,const std::string &val) { + void add_header(const std::string &key,const std::string &val) { // TODO: prevent use of reserved headers? if (this->header(key) == "") { m_headers[key] = val; @@ -106,7 +106,7 @@ protected: end = header.find(": ",0); if (end != std::string::npos) { - set_header(header.substr(0,end),header.substr(end+2)); + add_header(header.substr(0,end),header.substr(end+2)); } } diff --git a/src/hybi_00_processor.hpp b/src/hybi_00_processor.hpp deleted file mode 100644 index 7313df50cc..0000000000 --- a/src/hybi_00_processor.hpp +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Copyright (c) 2011, Peter Thorson. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of the WebSocket++ Project nor the - * names of its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - * ARE DISCLAIMED. IN NO EVENT SHALL PETER THORSON BE LIABLE FOR ANY - * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - */ - -#ifndef WEBSOCKET_HYBI_00_PROCESSOR_HPP -#define WEBSOCKET_HYBI_00_PROCESSOR_HPP - -#include "interfaces/protocol.hpp" - -namespace websocketpp { -namespace protocol { - -class hybi_00_processor : public processor { -public: - void validate_handshake(const http::parser::request& headers) const { - - } - - void handshake_response(const http::parser::request& request,http::parser::response& response) { - - } -}; - -} -} -#endif // WEBSOCKET_HYBI_00_PROCESSOR_HPP diff --git a/src/hybi_legacy_processor.hpp b/src/hybi_legacy_processor.hpp new file mode 100644 index 0000000000..398039db25 --- /dev/null +++ b/src/hybi_legacy_processor.hpp @@ -0,0 +1,266 @@ +/* + * Copyright (c) 2011, Peter Thorson. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the WebSocket++ Project nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL PETER THORSON BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +#ifndef WEBSOCKET_HYBI_LEGACY_PROCESSOR_HPP +#define WEBSOCKET_HYBI_LEGACY_PROCESSOR_HPP + +#include "interfaces/protocol.hpp" + +#include "network_utilities.hpp" + +namespace websocketpp { +namespace protocol { + +namespace hybi_legacy_state { + enum value { + INIT = 0, + READ = 1, + DONE = 2 + }; +} + +class hybi_legacy_processor : public processor { +public: + hybi_legacy_processor() : m_state(hybi_legacy_state::INIT) { + reset(); + } + + void validate_handshake(const http::parser::request& headers) const { + + } + + void handshake_response(const http::parser::request& request,http::parser::response& response) { + char key_final[16]; + + // key1 + *reinterpret_cast(&key_final[0]) = + decode_client_key(request.header("Sec-WebSocket-Key1")); + + // key2 + *reinterpret_cast(&key_final[4]) = + decode_client_key(request.header("Sec-WebSocket-Key2")); + + // key3 + memcpy(&key_final[8],request.header("Sec-WebSocket-Key3").c_str(),8); + + // md5 + md5_hash_string(key_final,m_digest); + m_digest[16] = 0; + + response.add_header("Upgrade","websocket"); + response.add_header("Connection","Upgrade"); + + // TODO: require headers that need application specific information? + + // Echo back client's origin unless our local application set a + // more restrictive one. + if (response.header("Sec-WebSocket-Origin") == "") { + response.add_header("Sec-WebSocket-Origin",request.header("Origin")); + } + + // Echo back the client's request host unless our local application + // set a different one. + if (response.header("Sec-WebSocket-Location") == "") { + // TODO: extract from host header rather than hard code + ws_uri uri; + uri.secure = false; + uri.host = "localhost"; + uri.port = 9002; + response.add_header("Sec-WebSocket-Location",uri.base()); + } + } + + void consume(std::istream& s) { + unsigned char c; + while (s.good() && m_state != hybi_legacy_state::DONE) { + c = s.get(); + if (s.good()) { + process(c); + } + } + } + + void process(unsigned char c) { + if (m_state == hybi_legacy_state::INIT) { + // we are looking for a 0x00 + if (c == 0x00) { + // start a message + + m_state = hybi_legacy_state::READ; + } else { + // TODO: ignore or error + throw session::exception("invalid character read",session::error::PROTOCOL_VIOLATION); + } + } else if (m_state == hybi_legacy_state::READ) { + // TODO: utf8 validation + + if (c == 0xFF) { + // end + m_state = hybi_legacy_state::DONE; + } else { + m_utf8_payload->push_back(c); + } + } + } + + bool ready() const { + return m_state == hybi_legacy_state::DONE; + } + + void reset() { + m_state = hybi_legacy_state::INIT; + m_utf8_payload = utf8_string_ptr(new utf8_string()); + } + + uint64_t get_bytes_needed() const { + return 1; + } + + frame::opcode::value get_opcode() const { + return frame::opcode::TEXT; + } + + std::string get_key3() const { + return std::string(m_digest); + } + + utf8_string_ptr get_utf8_payload() const { + if (get_opcode() != frame::opcode::TEXT) { + throw "opcode doesn't match"; + } + + if (!ready()) { + throw "not ready"; + } + + return m_utf8_payload; + } + + binary_string_ptr get_binary_payload() const { + // TODO: opcode doesn't match + throw "opcode doesn't match"; + return binary_string_ptr(); + } + + // legacy hybi doesn't have close codes + close::status::value get_close_code() const { + return close::status::NO_STATUS; + } + + utf8_string get_close_reason() const { + return ""; + } + + binary_string_ptr prepare_frame(frame::opcode::value opcode, + bool mask, + const binary_string& payload) { + if (opcode != frame::opcode::TEXT) { + // TODO: hybi_legacy doesn't allow non-text frames. + throw; + } + + // TODO: mask = ignore? + + // TODO: utf8 validation on payload. + + binary_string_ptr response(new binary_string(payload.size()+2)); + + (*response)[0] = 0x00; + std::copy(payload.begin(),payload.end(),response->begin()+1); + (*response)[response->size()-1] = 0xFF; + + return response; + } + + binary_string_ptr prepare_frame(frame::opcode::value opcode, + bool mask, + const utf8_string& payload) { + if (opcode != frame::opcode::TEXT) { + // TODO: hybi_legacy doesn't allow non-text frames. + throw; + } + + // TODO: mask = ignore? + + // TODO: utf8 validation on payload. + + binary_string_ptr response(new binary_string(payload.size()+2)); + + (*response)[0] = 0x00; + std::copy(payload.begin(),payload.end(),response->begin()+1); + (*response)[response->size()-1] = 0xFF; + + return response; + } + + binary_string_ptr prepare_close_frame(close::status::value code, + bool mask, + const std::string& reason) const { + binary_string_ptr response(new binary_string(2)); + + (*response)[0] = 0xFF; + (*response)[1] = 0x00; + + return response; + } + +private: + uint32_t decode_client_key(const std::string& key) { + int spaces = 0; + std::string digits = ""; + uint32_t num; + + // key2 + for (size_t i = 0; i < key.size(); i++) { + if (key[i] == ' ') { + spaces++; + } else if (key[i] >= '0' && key[i] <= '9') { + digits += key[i]; + } + } + + num = atoi(digits.c_str()); + if (spaces > 0 && num > 0) { + return htonl(num/spaces); + } else { + return 0; + } + } + + hybi_legacy_state::value m_state; + frame::opcode::value m_opcode; + + utf8_string_ptr m_utf8_payload; + + char m_digest[17]; +}; + +typedef boost::shared_ptr hybi_legacy_processor_ptr; + +} +} +#endif // WEBSOCKET_HYBI_LEGACY_PROCESSOR_HPP diff --git a/src/hybi_processor.hpp b/src/hybi_processor.hpp new file mode 100644 index 0000000000..bff1f09073 --- /dev/null +++ b/src/hybi_processor.hpp @@ -0,0 +1,358 @@ +/* + * Copyright (c) 2011, Peter Thorson. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the WebSocket++ Project nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL PETER THORSON BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +#ifndef WEBSOCKET_HYBI_PROCESSOR_HPP +#define WEBSOCKET_HYBI_PROCESSOR_HPP + +#include "interfaces/protocol.hpp" + +#include "websocket_frame.hpp" + +#include "network_utilities.hpp" +#include "http/parser.hpp" + +#include "base64/base64.h" +#include "sha1/sha1.h" + +namespace websocketpp { +namespace protocol { + +namespace hybi_state { + enum value { + INIT = 0, + READ = 1, + DONE = 2 + }; +} + +template +class hybi_processor : public processor { +public: + hybi_processor(rng_policy &rng) : m_read_frame(rng),m_write_frame(rng) { + reset(); + } + + void validate_handshake(const http::parser::request& request) const { + std::stringstream err; + std::string h; + + if (request.method() != "GET") { + err << "Websocket handshake has invalid method: " + << request.method(); + + throw(http::exception(err.str(),http::status_code::BAD_REQUEST)); + } + + // TODO: allow versions greater than 1.1 + if (request.version() != "HTTP/1.1") { + err << "Websocket handshake has invalid HTTP version: " + << request.method(); + + throw(http::exception(err.str(),http::status_code::BAD_REQUEST)); + } + + // verify the presence of required headers + if (request.header("Host") == "") { + throw(http::exception("Required Host header is missing",http::status_code::BAD_REQUEST)); + } + + h = request.header("Upgrade"); + if (h == "") { + throw(http::exception("Required Upgrade header is missing",http::status_code::BAD_REQUEST)); + } else if (!boost::ifind_first(h,"websocket")) { + err << "Upgrade header \"" << h << "\", does not contain required token \"websocket\""; + throw(http::exception(err.str(),http::status_code::BAD_REQUEST)); + } + + h = request.header("Connection"); + if (h == "") { + throw(http::exception("Required Connection header is missing",http::status_code::BAD_REQUEST)); + } else if (!boost::ifind_first(h,"upgrade")) { + err << "Connection header, \"" << h + << "\", does not contain required token \"upgrade\""; + throw(http::exception(err.str(),http::status_code::BAD_REQUEST)); + } + + if (request.header("Sec-WebSocket-Key") == "") { + throw(http::exception("Required Sec-WebSocket-Key header is missing",http::status_code::BAD_REQUEST)); + } + + h = request.header("Sec-WebSocket-Version"); + if (h == "") { + throw(http::exception("Required Sec-WebSocket-Version header is missing",http::status_code::BAD_REQUEST)); + } else { + int version = atoi(h.c_str()); + + if (version != 7 && version != 8 && version != 13) { + err << "This processor doesn't support WebSocket protocol version " + << version; + throw(http::exception(err.str(),http::status_code::BAD_REQUEST)); + } + } + } + + void handshake_response(const http::parser::request& request,http::parser::response& response) { + // TODO: + + std::string server_key = request.header("Sec-WebSocket-Key"); + server_key += "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + + SHA1 sha; + uint32_t message_digest[5]; + + sha.Reset(); + sha << server_key.c_str(); + + if (sha.Result(message_digest)){ + // convert sha1 hash bytes to network byte order because this sha1 + // library works on ints rather than bytes + for (int i = 0; i < 5; i++) { + message_digest[i] = htonl(message_digest[i]); + } + + server_key = base64_encode( + reinterpret_cast + (message_digest),20 + ); + + // set handshake accept headers + response.replace_header("Sec-WebSocket-Accept",server_key); + response.add_header("Upgrade","websocket"); + response.add_header("Connection","Upgrade"); + } else { + //m_endpoint->elog().at(log::elevel::ERROR) + //<< "Error computing handshake sha1 hash" << log::endl; + // TODO: make sure this error path works + response.set_status(http::status_code::INTERNAL_SERVER_ERROR); + } + } + + void consume(std::istream& s) { + // TODO: + + while (s.good() && m_state != hybi_state::DONE) { + m_read_frame.consume(s); + + if (m_read_frame.ready()) { + switch (m_read_frame.get_opcode()) { + case frame::opcode::CONTINUATION: + if (m_opcode == frame::opcode::BINARY) { + extract_binary(); + } else if (m_opcode == frame::opcode::TEXT) { + extract_utf8(); + } else { + // can't be here + } + break; + case frame::opcode::TEXT: + m_opcode = frame::opcode::TEXT; + extract_utf8(); + break; + case frame::opcode::BINARY: + m_opcode = frame::opcode::BINARY; + extract_binary(); + break; + case frame::opcode::CLOSE: + // TODO: + break; + case frame::opcode::PING: + case frame::opcode::PONG: + m_opcode = m_read_frame.get_opcode(); + extract_binary(); + break; + default: + throw session::exception("Invalid Opcode",session::error::PROTOCOL_VIOLATION); + break; + } + if (m_read_frame.get_fin()) { + m_state = hybi_state::DONE; + } + } + } + + + } + + void extract_binary() { + binary_string &msg = m_read_frame.get_payload(); + m_binary_payload->resize(m_binary_payload->size()+msg.size()); + std::copy(msg.begin(),msg.end(),m_binary_payload->end()-msg.size()); + } + + void extract_utf8() { + binary_string &msg = m_read_frame.get_payload(); + + m_utf8_payload->reserve(m_utf8_payload->size()+msg.size()); + m_utf8_payload->append(msg.begin(),msg.end()); + } + + bool ready() const { + return m_state == hybi_state::DONE; + } + + void reset() { + m_state = m_state = hybi_state::INIT; + m_utf8_payload = utf8_string_ptr(new utf8_string()); + m_binary_payload = binary_string_ptr(new binary_string()); + } + + uint64_t get_bytes_needed() const { + return m_read_frame.get_bytes_needed(); + } + + frame::opcode::value get_opcode() const { + if (!ready()) { + throw "not ready"; + } + return m_opcode; + } + + utf8_string_ptr get_utf8_payload() const { + if (get_opcode() != frame::opcode::TEXT) { + throw "opcode doesn't have a utf8 payload"; + } + + if (!ready()) { + throw "not ready"; + } + + return m_utf8_payload; + } + + binary_string_ptr get_binary_payload() const { + // TODO: opcode doesn't match + if (get_opcode() != frame::opcode::BINARY && + get_opcode() != frame::opcode::PING && + get_opcode() != frame::opcode::PONG) { + throw "opcode doesn't have a binary payload"; + } + + if (!ready()) { + throw "not ready"; + } + + return m_binary_payload; + } + + // legacy hybi doesn't have close codes + close::status::value get_close_code() const { + // TODO + return close::status::NO_STATUS; + } + + utf8_string get_close_reason() const { + // TODO + return ""; + } + + binary_string_ptr prepare_frame(frame::opcode::value opcode, + bool mask, + const utf8_string& payload) { + if (opcode != frame::opcode::TEXT) { + // TODO: hybi_legacy doesn't allow non-text frames. + throw; + } + + // TODO: utf8 validation on payload. + + binary_string_ptr response(new binary_string(0)); + + m_write_frame.reset(); + m_write_frame.set_opcode(opcode); + m_write_frame.set_masked(mask); + m_write_frame.set_fin(true); + m_write_frame.set_payload(payload); + + // TODO + response->resize(m_write_frame.get_header_len()+m_write_frame.get_payload().size()); + + // copy header + std::copy(m_write_frame.get_header(),m_write_frame.get_header()+m_write_frame.get_header_len(),response->begin()); + + // copy payload + std::copy(m_write_frame.get_payload().begin(),m_write_frame.get_payload().end(),response->begin()+m_write_frame.get_header_len()); + + + return response; + } + + binary_string_ptr prepare_frame(frame::opcode::value opcode, + bool mask, + const binary_string& payload) { + /*if (opcode != frame::opcode::TEXT) { + // TODO: hybi_legacy doesn't allow non-text frames. + throw; + }*/ + + // TODO: utf8 validation on payload. + + binary_string_ptr response(new binary_string(0)); + + m_write_frame.reset(); + m_write_frame.set_opcode(opcode); + m_write_frame.set_masked(mask); + m_write_frame.set_fin(true); + m_write_frame.set_payload(payload); + + // TODO + response->resize(m_write_frame.get_header_len()+m_write_frame.get_payload().size()); + + // copy header + std::copy(m_write_frame.get_header(),m_write_frame.get_header()+m_write_frame.get_header_len(),response->begin()); + + // copy payload + std::copy(m_write_frame.get_payload().begin(),m_write_frame.get_payload().end(),response->begin()+m_write_frame.get_header_len()); + + return response; + } + + binary_string_ptr prepare_close_frame(close::status::value code, + bool mask, + const std::string& reason) const { + binary_string_ptr response(new binary_string(0)); + + // TODO + + return response; + } + +private: + int m_state; + frame::opcode::value m_opcode; + + utf8_string_ptr m_utf8_payload; + binary_string_ptr m_binary_payload; + + frame::parser m_read_frame; + frame::parser m_write_frame; +}; + +//typedef boost::shared_ptr hybi_processor_ptr; + +} +} +#endif // WEBSOCKET_HYBI_PROCESSOR_HPP diff --git a/src/interfaces/protocol.hpp b/src/interfaces/protocol.hpp index d61e6a8297..eca5ab8870 100644 --- a/src/interfaces/protocol.hpp +++ b/src/interfaces/protocol.hpp @@ -28,9 +28,11 @@ #ifndef WEBSOCKET_INTERFACE_FRAME_PARSER_HPP #define WEBSOCKET_INTERFACE_FRAME_PARSER_HPP +#include "../http/parser.hpp" + #include -#include "../http/parser.hpp" +#include namespace websocketpp { namespace protocol { @@ -47,23 +49,38 @@ public: virtual void validate_handshake(const http::parser::request& headers) const = 0; virtual void handshake_response(const http::parser::request& request,http::parser::response& response) = 0; - - // Given a list of HTTP headers determin if the values are a reasonable - // response to our handshake request. If so - - // construct - + // consume bytes, throw on exception virtual void consume(std::istream& s) = 0; // is there a message ready to be dispatched? - virtual bool ready() = 0; - virtual ? get_opcode() = 0; + virtual bool ready() const = 0; + virtual void reset() = 0; + + virtual uint64_t get_bytes_needed() const = 0; + + // Get information about the message that is ready + virtual frame::opcode::value get_opcode() const = 0; + + virtual utf8_string_ptr get_utf8_payload() const = 0; + virtual binary_string_ptr get_binary_payload() const = 0; + virtual close::status::value get_close_code() const = 0; + virtual utf8_string get_close_reason() const = 0; + + // TODO: prepare a frame + virtual binary_string_ptr prepare_frame(frame::opcode::value opcode, + bool mask, + const utf8_string& payload) = 0; + virtual binary_string_ptr prepare_frame(frame::opcode::value opcode, + bool mask, + const binary_string& payload) = 0; + + virtual binary_string_ptr prepare_close_frame(close::status::value code, + bool mask, + const std::string& reason) const = 0; + + - // consume - // is_message_complete - // deliver message (get_payload) - // some sort of message type? for onping onpong? }; typedef boost::shared_ptr processor_ptr; diff --git a/src/interfaces/session.hpp b/src/interfaces/session.hpp index 0688d9034f..d7d918e8e2 100644 --- a/src/interfaces/session.hpp +++ b/src/interfaces/session.hpp @@ -33,8 +33,8 @@ #include #include -#include "websocket_constants.hpp" -#include "network_utilities.hpp" +#include "../websocket_constants.hpp" +#include "../network_utilities.hpp" namespace websocketpp { namespace session { @@ -48,6 +48,36 @@ namespace state { }; } +namespace error { + enum value { + FATAL_ERROR = 0, // force session end + SOFT_ERROR = 1, // should log and ignore + PROTOCOL_VIOLATION = 2, // must end session + PAYLOAD_VIOLATION = 3, // should end session + INTERNAL_SERVER_ERROR = 4, // cleanly end session + MESSAGE_TOO_BIG = 5 // ??? + }; +} + +class exception : public std::exception { +public: + exception(const std::string& msg, + error::value code = error::FATAL_ERROR) + : m_msg(msg),m_code(code) {} + ~exception() throw() {} + + virtual const char* what() const throw() { + return m_msg.c_str(); + } + + error::value code() const throw() { + return m_code; + } + + std::string m_msg; + error::value m_code; +}; + /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * Server Session API * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ @@ -55,7 +85,7 @@ class server { public: // Valid always virtual session::state::value get_state() const = 0; - virtual int get_version() const = 0; + virtual unsigned int get_version() const = 0; virtual std::string get_origin() const = 0; virtual std::string get_request_header(const std::string& key) const = 0; @@ -71,20 +101,20 @@ public: virtual void select_extension(const std::string& value) = 0; // Valid for OPEN state - virtual void send(const std::string& msg) = 0; - virtual void send(const std::vector& data) = 0; - virtual void close(close::status::value code, const std::string& reason) = 0; - virtual void ping(const std::string& payload) = 0; - virtual void pong(const std::string& payload) = 0; + virtual void send(const utf8_string& payload) = 0; + virtual void send(const binary_string& data) = 0; + virtual void close(close::status::value code, const utf8_string& reason) = 0; + virtual void ping(const binary_string& payload) = 0; + virtual void pong(const binary_string& payload) = 0; // Valid for CLOSED state virtual close::status::value get_local_close_code() const = 0; - virtual std::string get_local_close_reason() const = 0; + virtual utf8_string get_local_close_reason() const = 0; virtual close::status::value get_remote_close_code() const = 0; - virtual std::string get_remote_close_reason() const = 0; - virtual bool failed_by_me() const = 0; - virtual bool dropped_by_me() const = 0; - virtual bool closed_by_me() const = 0; + virtual utf8_string get_remote_close_reason() const = 0; + virtual bool get_failed_by_me() const = 0; + virtual bool get_dropped_by_me() const = 0; + virtual bool get_closed_by_me() const = 0; }; typedef boost::shared_ptr server_ptr; @@ -92,6 +122,9 @@ typedef boost::shared_ptr server_ptr; * Server Handler API * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ class server_handler { +public: + virtual ~server_handler() {} + // validate will be called after a websocket handshake has been received and // before it is accepted. It provides a handler the ability to refuse a // connection based on application specific logic (ex: restrict domains or @@ -126,13 +159,12 @@ class server_handler { // data will not be avaliable after this callback ends so the handler must // either completely process the message or copy it somewhere else for // processing later. - virtual void on_message(server_ptr session, - const std::vector &data) = 0; + virtual void on_message(server_ptr session, binary_string_ptr data) = 0; // on_message (text version). Identical to on_message except the data // parameter is a string interpreted as UTF-8. WebSocket++ guarantees that // this string is valid UTF-8. - virtual void on_message(server_ptr session,const std::string &msg) = 0; + virtual void on_message(server_ptr session, utf8_string_ptr msg) = 0; // on_fail is called whenever a session is terminated or failed before it // was successfully established. This happens if there is an error during @@ -142,6 +174,19 @@ class server_handler { // application will need information from `session` after this function you // should either save the session_ptr somewhere or copy the data out. virtual void on_fail(server_ptr session) {}; + + // on_ping is called whenever a ping is recieved with the binary + // application data from the ping. If on_ping returns true the + // implimentation will send a response pong. + virtual bool on_ping(server_ptr session, binary_string_ptr data) { + return true; + } + + // on_pong is called whenever a pong is recieved with the binary + // application data from the pong. + virtual void on_pong(server_ptr session, binary_string_ptr data) {} + + // TODO: on_ping_timeout }; typedef boost::shared_ptr server_handler_ptr; @@ -152,7 +197,7 @@ typedef boost::shared_ptr server_handler_ptr; class client { public: - client(const std::string& uri) = 0; + client(const std::string& uri) {}; // Valid always virtual session::state::value get_state() const = 0; @@ -174,17 +219,17 @@ public: virtual std::string get_subprotocol() const; virtual const std::vector& get_extensions() const = 0; - virtual void send(const std::string& msg) = 0; - virtual void send(const std::vector& data) = 0; - virtual void close(close::status::value code, const std::string& reason) = 0; - virtual void ping(const std::string& payload) = 0; - virtual void pong(const std::string& payload) = 0; + virtual void send(const utf8_string& msg) = 0; + virtual void send(const binary_string& data) = 0; + virtual void close(close::status::value code, const binary_string& reason) = 0; + virtual void ping(const binary_string& payload) = 0; + virtual void pong(const binary_string& payload) = 0; // Valid for CLOSED state virtual close::status::value get_local_close_code() const = 0; - virtual std::string get_local_close_reason() const = 0; + virtual utf8_string get_local_close_reason() const = 0; virtual close::status::value get_remote_close_code() const = 0; - virtual std::string get_remote_close_reason() const = 0; + virtual utf8_string get_remote_close_reason() const = 0; virtual bool failed_by_me() const = 0; virtual bool dropped_by_me() const = 0; virtual bool closed_by_me() const = 0; @@ -197,6 +242,7 @@ typedef boost::shared_ptr client_ptr; * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ class client_handler { +public: // on_open is called after the websocket session has been successfully // established and is in the OPEN state. The session is now avaliable to // send messages and will begin reading frames and calling the on_message/ @@ -219,13 +265,12 @@ class client_handler { // data will not be avaliable after this callback ends so the handler must // either completely process the message or copy it somewhere else for // processing later. - virtual void on_message(client_ptr session, - const std::vector &data) = 0; + virtual void on_message(client_ptr session, binary_string_ptr data) = 0; // on_message (text version). Identical to on_message except the data // parameter is a string interpreted as UTF-8. WebSocket++ guarantees that // this string is valid UTF-8. - virtual void on_message(client_ptr session,const std::string &msg) = 0; + virtual void on_message(client_ptr session, utf8_string_ptr msg) = 0; // on_fail is called whenever a session is terminated or failed before it // was successfully established. This happens if there is an error during @@ -235,6 +280,19 @@ class client_handler { // application will need information from `session` after this function you // should either save the session_ptr somewhere or copy the data out. virtual void on_fail(client_ptr session) {}; + + // on_ping is called whenever a ping is recieved with the binary + // application data from the ping. If on_ping returns true the + // implimentation will send a response pong. + virtual bool on_ping(server_ptr session, binary_string_ptr data) { + return true; + } + + // on_pong is called whenever a pong is recieved with the binary + // application data from the pong. + virtual void on_pong(server_ptr session, binary_string_ptr data) {} + + // TODO: on_ping_timeout }; typedef boost::shared_ptr client_handler_ptr; diff --git a/src/logger/logger.hpp b/src/logger/logger.hpp index 0c936e0d5b..5327dbc705 100644 --- a/src/logger/logger.hpp +++ b/src/logger/logger.hpp @@ -62,6 +62,8 @@ namespace alevel { // DEBUG values static const value DEBUG_HANDSHAKE = 0x8000; + static const value DEBUG_CLOSE = 0x4000; + static const value DEVEL = 0x2000; static const value ALL = 0xFFFF; } diff --git a/src/rng/boost_rng.hpp b/src/rng/boost_rng.hpp index 87d368a48d..730fea05ab 100644 --- a/src/rng/boost_rng.hpp +++ b/src/rng/boost_rng.hpp @@ -46,7 +46,7 @@ namespace websocketpp { boost::random::random_device&, boost::random::uniform_int_distribution<> > m_gen; - } + }; } #endif // BOOST_RNG_HPP diff --git a/src/websocket_constants.hpp b/src/websocket_constants.hpp index 8841509d7e..4dc05cc73f 100644 --- a/src/websocket_constants.hpp +++ b/src/websocket_constants.hpp @@ -33,9 +33,18 @@ // for exceptions that should be somewhere else #include #include +#include + +#include // Defaults namespace websocketpp { + typedef std::vector binary_string; + typedef boost::shared_ptr binary_string_ptr; + + typedef std::string utf8_string; + typedef boost::shared_ptr utf8_string_ptr; + const uint64_t DEFAULT_MAX_MESSAGE_SIZE = 0xFFFFFF; // ~16MB // System logging levels @@ -154,6 +163,8 @@ namespace websocketpp { } } + + // TODO: these classes need a better place to live class server_error : public std::exception { public: diff --git a/src/websocket_frame.hpp b/src/websocket_frame.hpp index fe88027e6a..61d37c3194 100644 --- a/src/websocket_frame.hpp +++ b/src/websocket_frame.hpp @@ -382,12 +382,12 @@ public: return m_payload; } - void set_payload(const std::vector source) { + void set_payload(const std::vector& source) { set_payload_helper(source.size()); std::copy(source.begin(),source.end(),m_payload.begin()); } - void set_payload(const std::string source) { + void set_payload(const std::string& source) { set_payload_helper(source.size()); std::copy(source.begin(),source.end(),m_payload.begin()); diff --git a/src/websocket_server.hpp b/src/websocket_server.hpp index 7d6e7d6085..c812d3fae7 100644 --- a/src/websocket_server.hpp +++ b/src/websocket_server.hpp @@ -29,22 +29,29 @@ #define WEBSOCKET_SERVER_HPP #include +#include +#include #include #include #include namespace po = boost::program_options; #include +#include #include "websocketpp.hpp" #include "interfaces/session.hpp" -#include "interfaces/protocol.hpp" -#include "websocket_session.hpp" +// Session processors +#include "interfaces/protocol.hpp" +#include "hybi_legacy_processor.hpp" +#include "hybi_processor.hpp" + +//#include "websocket_session.hpp" #include "websocket_connection_handler.hpp" -#include +//#include #include "rng/blank_rng.hpp" @@ -58,15 +65,23 @@ using websocketpp::protocol::processor_ptr; namespace websocketpp { namespace server { +namespace write_state { + enum value { + IDLE = 0, + WRITING = 1, + INTURRUPT = 2 + }; +} + template -class session : public websocketpp::session::server, boost::enable_shared_from_this< session > { +class connection : public websocketpp::session::server, public boost::enable_shared_from_this< connection > { public: typedef server_policy server_type; - typedef session session_type; + typedef connection connection_type; typedef boost::shared_ptr server_ptr; - session(server_ptr s, + connection(server_ptr s, boost::asio::io_service& io_service, server_handler_ptr handler) : m_server(s), @@ -75,7 +90,139 @@ public: m_timer(io_service,boost::posix_time::seconds(0)), m_buf(/* TODO: needs a max here */), m_handler(handler), - m_state(session::state::CONNECTING) {} + m_state(session::state::CONNECTING), + m_write_buffer(0), + m_write_state(write_state::IDLE) {} + + // implimentation of the server session API + + // Valid always + session::state::value get_state() const { + // TODO: syncronize + return m_state; + } + + unsigned int get_version() const { + return m_version; + } + + std::string get_origin() const { + return m_origin; + } + + std::string get_request_header(const std::string& key) const { + return m_request.header(key); + } + + const ws_uri& get_uri() const { + return m_uri; + } + + bool get_secure() const { + // TODO + return false; + } + + // Valid for CONNECTING state + void add_response_header(const std::string& key, const std::string& value) { + m_response.add_header(key,value); + } + void replace_response_header(const std::string& key, const std::string& value) { + m_response.replace_header(key,value); + } + const std::vector& get_subprotocols() const { + return m_requested_subprotocols; + } + const std::vector& get_extensions() const { + return m_requested_extensions; + } + void select_subprotocol(const std::string& value) { + std::vector::iterator it; + + it = std::find(m_requested_subprotocols.begin(), + m_requested_subprotocols.end(), + value); + + if (value != "" && it == m_requested_subprotocols.end()) { + throw server_error("Attempted to choose a subprotocol not proposed by the client"); + } + + m_subprotocol = value; + } + void select_extension(const std::string& value) { + if (value == "") { + return; + } + + std::vector::iterator it; + + it = std::find(m_requested_extensions.begin(), + m_requested_extensions.end(), + value); + + if (it == m_requested_extensions.end()) { + throw server_error("Attempted to choose an extension not proposed by the client"); + } + + m_extensions.push_back(value); + } + + // Valid for OPEN state + + // These functions invoke write_message through the io_service to gain + // thread safety + void send(const utf8_string& payload) { + binary_string_ptr msg(m_processor->prepare_frame(frame::opcode::TEXT,false,payload)); + + m_io_service.post(boost::bind(&connection_type::write_message,connection_type::shared_from_this(),msg)); + + // TODO: return bytes in flight somehow? + } + + void send(const binary_string& data) { + binary_string_ptr msg(m_processor->prepare_frame(frame::opcode::BINARY,false,data)); + m_io_service.post(boost::bind(&connection_type::write_message,connection_type::shared_from_this(),msg)); + } + + void close(close::status::value code, const utf8_string& reason) { + // TODO + } + + void ping(const binary_string& payload) { + binary_string_ptr msg(m_processor->prepare_frame(frame::opcode::PING,false,payload)); + + m_io_service.post(boost::bind(&connection_type::write_message,connection_type::shared_from_this(),msg)); + } + + void pong(const binary_string& payload) { + binary_string_ptr msg(m_processor->prepare_frame(frame::opcode::PONG,false,payload)); + m_io_service.post(boost::bind(&connection_type::write_message,connection_type::shared_from_this(),msg)); + } + + // Valid for CLOSED state + close::status::value get_local_close_code() const { + return m_local_close_code; + } + utf8_string get_local_close_reason() const { + return m_local_close_reason; + } + close::status::value get_remote_close_code() const { + return m_remote_close_code; + } + utf8_string get_remote_close_reason() const { + return m_remote_close_reason; + } + bool get_failed_by_me() const { + return m_failed_by_me; + } + bool get_dropped_by_me() const { + return m_dropped_by_me; + } + bool get_closed_by_me() const { + return m_closed_by_me; + } + + //////// tcp::socket& get_socket() { return m_socket; @@ -92,8 +239,8 @@ public: m_timer.async_wait( boost::bind( - &session_type::fail_on_expire, - session_type::shared_from_this(), + &connection_type::fail_on_expire, + connection_type::shared_from_this(), boost::asio::placeholders::error ) ); @@ -103,8 +250,8 @@ public: m_buf, "\r\n\r\n", boost::bind( - &session_type::handle_read_request, - session_type::shared_from_this(), + &connection_type::handle_read_request, + connection_type::shared_from_this(), boost::asio::placeholders::error, boost::asio::placeholders::bytes_transferred ) @@ -133,29 +280,34 @@ public: //m_remote_handshake.consume(response_stream); if (!m_request.parse_complete(request)) { // not a valid HTTP request/response - throw handshake_error("Recieved invalid HTTP Request",http::status_code::BAD_REQUEST); + throw http::exception("Recieved invalid HTTP Request",http::status_code::BAD_REQUEST); } // Log the raw handshake. m_server->alog().at(log::alevel::DEBUG_HANDSHAKE) << m_request.raw() << log::endl; // Determine what sort of connection this is: - int m_version = -1; + m_version = -1; - if (boost::ifind_first(m_request.header("Upgrade","websocket"))) { - if (m_request.header("Sec-WebSocket-Version") == "") { + std::string h = m_request.header("Upgrade"); + if (boost::ifind_first(h,"websocket")) { + h = m_request.header("Sec-WebSocket-Version"); + if (h == "") { m_version = 0; } else { - m_version = atoi(m_request.header("Sec-WebSocket-Version").c_str()); + m_version = atoi(h.c_str()); if (m_version == 0) { - throw(handshake_error("Unable to determine connection version",http::status_code::BAD_REQUEST)); + throw(http::exception("Unable to determine connection version",http::status_code::BAD_REQUEST)); } } } + m_server->alog().at(log::alevel::DEBUG_HANDSHAKE) << "determined connection version: " << m_version << log::endl; + if (m_version == -1) { // Probably a plain HTTP request // TODO: forward to an http handler? + } else { // websocket connection // create a processor based on version. @@ -169,14 +321,14 @@ public: request.get(foo,9); if (request.gcount() != 8) { - throw handshake_error("Missing Key3",http::status_code::BAD_REQUEST); + throw http::exception("Missing Key3",http::status_code::BAD_REQUEST); } - m_request.set_header("Sec-WebSocket-Key3",std::string(foo)); + m_request.add_header("Sec-WebSocket-Key3",std::string(foo)); - m_processor = processor_ptr(new protocol::hybi_00_processor()); + m_processor = processor_ptr(new protocol::hybi_legacy_processor()); } else if (m_version == 7 || m_version == 8 || m_version == 13) { // create hybi 17 processor - m_processor = protocol::processor_ptr(new protocol::hybi_17_processor()); + m_processor = processor_ptr(new protocol::hybi_processor(m_rng)); } else { // TODO: respond with unknown version message per spec } @@ -185,18 +337,18 @@ public: m_processor->validate_handshake(m_request); // ask local application to confirm that it wants to accept - m_handler->validate(boost::static_pointer_cast(session_type::shared_from_this())); + m_handler->validate(boost::static_pointer_cast(connection_type::shared_from_this())); m_response.set_status(http::status_code::SWITCHING_PROTOCOLS); } - } catch (const handshake_error& e) { + } catch (const http::exception& e) { m_server->alog().at(log::alevel::DEBUG_HANDSHAKE) << e.what() << log::endl; m_server->elog().at(log::elevel::ERROR) << "Caught handshake exception: " << e.what() << log::endl; - m_response.set_status(e.m_http_error_code,e.m_http_error_msg); + m_response.set_status(e.m_error_code,e.m_error_msg); } write_response(); @@ -225,8 +377,9 @@ public: std::string raw = m_response.raw(); + // Hack for legacy HyBi if (m_version == 0) { - raw += digest; + raw += boost::dynamic_pointer_cast(m_processor)->get_key3(); } m_server->alog().at(log::alevel::DEBUG_HANDSHAKE) << raw << log::endl; @@ -236,8 +389,8 @@ public: m_socket, boost::asio::buffer(raw), boost::bind( - &session_type::handle_write_response, - session_type::shared_from_this(), + &connection_type::handle_write_response, + connection_type::shared_from_this(), boost::asio::placeholders::error ) ); @@ -250,7 +403,7 @@ public: if (error) { log_error("Network error writing handshake response",error); terminate_connection(false); - m_handler->on_fail(boost::static_pointer_cast(session_type::shared_from_this())); + m_handler->on_fail(boost::static_pointer_cast(connection_type::shared_from_this())); return; } @@ -264,15 +417,17 @@ public: << log::endl; terminate_connection(true); - m_handler->on_fail(boost::static_pointer_cast(session_type::shared_from_this())); + m_handler->on_fail(boost::static_pointer_cast(connection_type::shared_from_this())); return; } - m_state = state::OPEN; + m_state = session::state::OPEN; - m_handler->on_open(boost::static_pointer_cast(session_type::shared_from_this())); + m_handler->on_open(boost::static_pointer_cast(connection_type::shared_from_this())); // TODO: start read message loop. + m_server->alog().at(log::alevel::DEVEL) << "calling handle_read_frame" << log::endl; + handle_read_frame(boost::system::error_code()); } void handle_read_frame (const boost::system::error_code& error) { @@ -316,41 +471,105 @@ public: // process data from the buffer just read into std::istream s(&m_buf); + + while (m_buf.size() > 0) { try { - m_processor.consume(s); + m_server->alog().at(log::alevel::DEVEL) << "starting consume, buffer size: " << m_buf.size() << log::endl; + m_processor->consume(s); + m_server->alog().at(log::alevel::DEVEL) << "done consume, buffer size: " << m_buf.size() << log::endl; - if (m_processor.ready()) { - switch (m_processor.get_opcode()) { - case TEXT: - // TODO: on_message + if (m_processor->ready()) { + m_server->alog().at(log::alevel::DEVEL) << "new message ready" << m_buf.size() << log::endl; + + bool response; + switch (m_processor->get_opcode()) { + case frame::opcode::TEXT: + m_handler->on_message(boost::static_pointer_cast(connection_type::shared_from_this()),m_processor->get_utf8_payload()); break; - case BINARY: - // TODO: on_message + case frame::opcode::BINARY: + m_handler->on_message(boost::static_pointer_cast(connection_type::shared_from_this()),m_processor->get_binary_payload()); break; - case PING: - // TODO: on_ping - // TODO: auto-respond pong perhaps based on + case frame::opcode::PING: + response = m_handler->on_ping(boost::static_pointer_cast(connection_type::shared_from_this()),m_processor->get_binary_payload()); + + if (response) { + // send response ping + } break; - case PONG: - // TODO: on_pong + case frame::opcode::PONG: + m_handler->on_pong(boost::static_pointer_cast(connection_type::shared_from_this()),m_processor->get_binary_payload()); + + // TODO: disable ping response timer + break; - case CLOSE: - // TODO: send_close - // if we are server - // drop tcp - // m_state = CLOSED - // else - // set timer - // on timer, drop tcp, m_state = closed. + case frame::opcode::CLOSE: + if (m_state == session::state::OPEN) { + // other end is initiating + m_server->elog().at(log::elevel::DEVEL) + << "sending close ack" << log::endl; + + m_remote_close_code = m_processor->get_close_code(); + m_remote_close_reason = m_processor->get_close_reason(); + + send_close_ack(); + } else if (m_state == session::state::CLOSING) { + // ack of our close + m_server->elog().at(log::elevel::DEVEL) + << "got close ack" << log::endl; + + m_remote_close_code = m_processor->get_close_code(); + m_remote_close_reason = m_processor->get_close_reason(); + terminate_connection(false); + // TODO: start terminate timer (if client) + } break; default: // error break; } + m_processor->reset(); } - } catch (/* session exception */) { + } catch (const session::exception& e) { + m_server->elog().at(log::elevel::ERROR) + << "Caught session exception: " << e.what() << log::endl; + // if the exception happened while processing. + // TODO: this is not elegant, perhaps separate frame read vs process + // exceptions need to be used. + if (m_processor->ready()) { + m_processor->reset(); + } + + if (e.code() == session::error::PROTOCOL_VIOLATION) { + //send_close(close::status::PROTOCOL_ERROR, e.what()); + m_server->elog().at(log::elevel::DEVEL) + << "Dropping TCP due to unrecoverable protocol violation" + << log::endl; + terminate_connection(true); + } else if (e.code() == session::error::PAYLOAD_VIOLATION) { + //send_close(close::status::INVALID_PAYLOAD, e.what()); + m_server->elog().at(log::elevel::DEVEL) + << "Dropping TCP due to unrecoverable payload violation" + << log::endl; + terminate_connection(true); + } else if (e.code() == session::error::INTERNAL_SERVER_ERROR) { + //send_close(close::status::ABNORMAL_CLOSE, e.what()); + m_server->elog().at(log::elevel::DEVEL) + << "Dropping TCP due to unrecoverable internal server error" + << log::endl; + terminate_connection(true); + } else if (e.code() == session::error::SOFT_ERROR) { + // ignore and continue processing frames + continue; + } else { + // Fatal error, forcibly end connection immediately. + m_server->elog().at(log::elevel::DEVEL) + << "Dropping TCP due to unrecoverable exception" + << log::endl; + terminate_connection(true); + } + break; } // if the result of processing/consuming a frame closed the @@ -367,13 +586,142 @@ public: } // try and read more + if (m_processor->get_bytes_needed() > 0) { + // TODO: read timeout timer? + + boost::asio::async_read( + m_socket, + m_buf, + boost::asio::transfer_at_least(m_processor->get_bytes_needed()), + boost::bind( + &connection_type::handle_read_frame, + connection_type::shared_from_this(), + boost::asio::placeholders::error + ) + ); + } else { + // TODO: ???? + } + } + + void send_close() { + + } + + // send an acknowledgement close frame + void send_close_ack() { + // TODO: state should be OPEN + + // echo close value unless there is a good reason not to. + if (m_remote_close_code == close::status::NO_STATUS) { + m_local_close_code = close::status::NORMAL; + m_local_close_reason = ""; + } else if (m_remote_close_code == close::status::ABNORMAL_CLOSE) { + // TODO: can we possibly get here? This means send_close_ack was + // called after a connection ended without getting a close + // frame + throw "shouldn't be here"; + } else if (close::status::invalid(m_remote_close_code)) { + m_local_close_code = close::status::PROTOCOL_ERROR; + m_local_close_reason = "Status code is invalid"; + } else if (close::status::reserved(m_remote_close_code)) { + m_local_close_code = close::status::PROTOCOL_ERROR; + m_local_close_reason = "Status code is reserved"; + } else { + m_local_close_code = m_remote_close_code; + m_local_close_reason = m_remote_close_reason; + } + + binary_string_ptr msg = m_processor->prepare_close_frame(m_local_close_code,false,m_local_close_reason); + + // TODO: check whether we should cancel the current in flight write. + // if not canceled the close message will be sent as soon as the + // current write completes. + + m_write_state = write_state::INTURRUPT; + write_message(m_processor->prepare_close_frame(m_local_close_code, + false, + m_local_close_reason)); + } + + void write_message(binary_string_ptr msg) { + m_write_buffer += msg->size(); + m_write_queue.push(msg); + write(); + } + + void write() { + switch (m_write_state) { + case write_state::IDLE: + break; + case write_state::WRITING: + // already writing. write() will get called again by the write + // handler once it is ready. + return; + case write_state::INTURRUPT: + // clear the queue except for the last message + while (m_write_queue.size() > 1) { + m_write_buffer -= m_write_queue.front()->size(); + m_write_queue.pop(); + } + break; + default: + // TODO: assert shouldn't be here + break; + } + + if (m_write_queue.size() > 0) { + if (m_write_state == write_state::IDLE) { + m_write_state = write_state::WRITING; + } + + boost::asio::async_write( + m_socket, + boost::asio::buffer(*m_write_queue.front()), + boost::bind( + &connection_type::handle_write, + connection_type::shared_from_this(), + boost::asio::placeholders::error + ) + ); + } else { + // if we are in an inturrupted state and had nothing else to write + // it is safe to terminate the connection. + if (m_write_state == write_state::INTURRUPT) { + terminate_connection(false); + } + } + } + + void handle_write(const boost::system::error_code& error) { + if (error) { + if (error == boost::asio::error::operation_aborted) { + // previous write was aborted + } else { + log_error("Error writing frame data",error); + terminate_connection(false); + return; + } + } + + m_write_buffer -= m_write_queue.front()->size(); + m_write_queue.pop(); + + if (m_write_state == write_state::WRITING) { + m_write_state = write_state::IDLE; + } + + write(); } // end conditions // - tcp connection is closed // - session state is CLOSED // - session end flags are set + // - application is notified void terminate_connection(bool failed_by_me) { + m_server->alog().at(log::alevel::DEBUG_CLOSE) << "terminate called" << log::endl; + if (m_state == session::state::CLOSED) { // shouldn't be here } @@ -403,10 +751,10 @@ public: // If we called terminate from the connecting state call on_fail if (old_state == session::state::CONNECTING) { - m_handler->on_fail(boost::static_pointer_cast(session_type::shared_from_this())); + m_handler->on_fail(boost::static_pointer_cast(connection_type::shared_from_this())); } else if (old_state == session::state::OPEN || old_state == session::state::CLOSING) { - m_handler->on_close(boost::static_pointer_cast(session_type::shared_from_this())); + m_handler->on_close(boost::static_pointer_cast(connection_type::shared_from_this())); } else { // if we were already closed something is wrong } @@ -417,6 +765,23 @@ public: m_server->elog().at(log::elevel::ERROR) << msg << "(" << e << ")" << log::endl; } + void log_close_result() { + m_server->alog().at(log::alevel::DISCONNECT) + //<< "Disconnect " << (m_was_clean ? "Clean" : "Unclean") + << "Disconnect " + << " close local:[" << m_local_close_code + << (m_local_close_reason == "" ? "" : ","+m_local_close_reason) + << "] remote:[" << m_remote_close_code + << (m_remote_close_reason == "" ? "" : ","+m_remote_close_reason) << "]" + << log::endl; + } + void log_open_result() { + m_server->alog().at(log::alevel::CONNECT) << "Connection " + << m_socket.remote_endpoint() << " v" << m_version << " " + << (get_request_header("User-Agent") == "" ? "NULL" : get_request_header("User-Agent")) + << " " << m_uri.resource << " " << m_response.status_code() + << log::endl; + } void fail_on_expire(const boost::system::error_code& error) { if (error) { @@ -433,6 +798,8 @@ public: } private: + + server_ptr m_server; boost::asio::io_service& m_io_service; @@ -442,13 +809,33 @@ private: server_handler_ptr m_handler; processor_ptr m_processor; - + blank_rng m_rng; + + // Connection state http::parser::request m_request; http::parser::response m_response; - session::state::value m_state; - int m_version; + std::vector m_requested_subprotocols; + std::vector m_requested_extensions; + std::string m_subprotocol; + std::vector m_extensions; + std::string m_origin; + unsigned int m_version; + bool m_secure; + ws_uri m_uri; + session::state::value m_state; + + // Write queue + std::queue m_write_queue; + size_t m_write_buffer; + write_state::value m_write_state; + + // Close state + close::status::value m_local_close_code; + std::string m_local_close_reason; + close::status::value m_remote_close_code; + std::string m_remote_close_reason; bool m_closed_by_me; bool m_failed_by_me; bool m_dropped_by_me; @@ -461,11 +848,11 @@ template