commit fa73c478e95649792403a65ba8d427964228551d Author: Peter Thorson Date: Wed Sep 7 20:27:29 2011 -0500 initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000..0ce170c0c7 --- /dev/null +++ b/.gitignore @@ -0,0 +1,13 @@ +.DS_Store + +#vim stuff +*~ +*.swp + +*.o +*.so +*.a +*.dylib + +objs_shared/ +objs_static/ diff --git a/Makefile b/Makefile new file mode 100644 index 0000000000..d29fe6ed34 --- /dev/null +++ b/Makefile @@ -0,0 +1,169 @@ +objects = websocket_session.o websocket_server.o websocket_frame.o \ + network_utilities.o sha1.o base64.o + +OS=$(shell uname) + +# Defaults +ifeq ($(OS), Darwin) + cxxflags_default = -c -O2 -DNDEBUG +else + cxxflags_default = -c -O2 -DNDEBUG +endif +cxxflags_small = -c +cxxflags_debug = -c -g +cxxflags_shared = -f$(PIC) +libname = libwebsocketpp +libname_hdr = websocketpp +libname_debug = $(libname)_dbg +suffix_shared = so +suffix_shared_darwin = dylib +suffix_static = a +major_version = 0 +minor_version = 1.0 +objdir = objs + +# Variables +prefix ?= /usr/local +exec_prefix ?= $(prefix) +libdir ?= lib +includedir ?= include +srcdir ?= src +CXX ?= c++ +AR ?= ar +PIC ?= PIC +BUILD_TYPE ?= "default" +SHARED ?= "1" + +# Internal Variables +inst_path = $(exec_prefix)/$(libdir) +include_path = $(prefix)/$(includedir) + +# BUILD_TYPE specific settings +ifeq ($(BUILD_TYPE), debug) + CXXFLAGS = $(cxxflags_debug) + libname := $(libname_debug) +else + CXXFLAGS ?= $(cxxflags_default) +endif + +# SHARED specific settings +ifeq ($(SHARED), 1) + ifeq ($(OS), Darwin) + libname_shared = $(libname).$(suffix_shared_darwin) + else + libname_shared = $(libname).$(suffix_shared) + endif + libname_shared_major_version = $(libname_shared).$(major_version) + lib_target = $(libname_shared_major_version).$(minor_version) + objdir := $(objdir)_shared + CXXFLAGS := $(CXXFLAGS) $(cxxflags_shared) +else + lib_target = $(libname).$(suffix_static) + objdir := $(objdir)_static +endif + +# Phony targets +.PHONY: all banner installdirs install install_headers clean uninstall \ + uninstall_headers + +# Targets +all: $(lib_target) + @echo "============================================================" + @echo "Done" + @echo "============================================================" + +banner: + @echo "============================================================" + @echo "libwebsocketpp version: "$(major_version).$(minor_version) "target: "$(target) "OS: "$(OS) + @echo "============================================================" + +installdirs: banner + mkdir -p $(objdir) + +# Libraries +ifeq ($(SHARED),1) +$(lib_target): banner installdirs $(addprefix $(objdir)/, $(objects)) + @echo "Link " + cd $(objdir) ; \ + if test "$(OS)" = "Darwin" ; then \ + $(CXX) -dynamiclib -lboost_system -Wl,-dylib_install_name -Wl,$(libname_shared_major_version) -o $@ $(objects) ; \ + else \ + $(CXX) -shared -lboost_system -Wl,-soname,$(libname_shared_major_version) -o $@ $(objects) ; \ + fi ; \ + mv -f $@ ../ + @echo "Link: Done" +else +$(lib_target): banner installdirs $(addprefix $(objdir)/, $(objects)) + @echo "Archive" + cd $(objdir) ; \ + $(AR) -cvq $@ $(objects) ; \ + mv -f $@ ../ + @echo "Archive: Done" +endif + +# Compile object files +$(objdir)/sha1.o: $(srcdir)/sha1/sha1.cpp + $(CXX) $< -o $@ $(CXXFLAGS) + +$(objdir)/base64.o: $(srcdir)/base64/base64.cpp + $(CXX) $< -o $@ $(CXXFLAGS) + +$(objdir)/%.o: $(srcdir)/%.cpp + $(CXX) $< -o $@ $(CXXFLAGS) + +ifeq ($(SHARED),1) +install: banner install_headers $(lib_target) + @echo "Install shared library" + cp -f ./$(lib_target) $(inst_path) + cd $(inst_path) ; \ + ln -sf $(lib_target) $(libname_shared_major_version) ; \ + ln -sf $(libname_shared_major_version) $(libname_shared) + ifneq ($(OS),Darwin) + ldconfig + endif + @echo "Install shared library: Done." +else +install: banner install_headers $(lib_target) + @echo "Install static library" + cp -f ./$(lib_target) $(inst_path) + @echo "Install static library: Done." +endif + +install_headers: banner + @echo "Install header files" + mkdir -p $(include_path)/$(libname_hdr) +# cp -f ./*.hpp $(include_path)/$(libname_hdr) + cp -f ./$(srcdir)/*.hpp $(include_path)/$(libname_hdr) + mkdir -p $(include_path)/$(libname_hdr)/base64 + cp -f ./$(srcdir)/base64/base64.h $(include_path)/$(libname_hdr)/base64 + mkdir -p $(include_path)/$(libname_hdr)/sha1 + cp -f ./$(srcdir)/sha1/sha1.h $(include_path)/$(libname_hdr)/sha1 + chmod -R a+r $(include_path)/$(libname_hdr) + find $(include_path)/$(libname_hdr) -type d -exec chmod a+x {} \; + @echo "Install header files: Done." + +clean: banner + @echo "Clean library and object folder" + rm -rf $(objdir) + rm -f $(lib_target) + @echo "Clean library and object folder: Done" + +ifeq ($(SHARED),1) +uninstall: banner uninstall_headers + @echo "Uninstall shared library" + rm -f $(inst_path)/$(libname_shared) + rm -f $(inst_path)/$(libname_shared_major_version) + rm -f $(inst_path)/$(lib_target) + ldconfig + @echo "Uninstall shared library: Done" +else +uninstall: banner uninstall_headers + @echo "Uninstall static library" + rm -f $(inst_path)/$(lib_target) + @echo "Uninstall static library: Done" +endif + +uninstall_headers: banner + @echo "Uninstall header files" + rm -rf $(include_path)/$(libname) + @echo "Uninstall header files: Done" diff --git a/examples/chat_server/Makefile b/examples/chat_server/Makefile new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/chat_server/chat.cpp b/examples/chat_server/chat.cpp new file mode 100644 index 0000000000..7553297655 --- /dev/null +++ b/examples/chat_server/chat.cpp @@ -0,0 +1,30 @@ +#include "chat.hpp" + +using websocketchat::chat_handler; + + +bool chat_handler::validate(websocketpp::session_ptr client) { + // We only know about the chat resource + if (client->get_request() != "/chat") { + client->set_http_error(404); + return false; + } + + // Require specific origin example + if (client->get_header("Sec-WebSocket-Origin") != "http://zaphoyd.com") { + client->set_http_error(403); + return false; + } + + return true; +} + + +void connect(session_ptr client) { + std::cout << "client " << client << " joined the lobby." << std::endl; + m_connections.insert(client); + + // send user list to all clients + // send signon message from server to all other clients. + +} diff --git a/examples/chat_server/chat.hpp b/examples/chat_server/chat.hpp new file mode 100644 index 0000000000..98047a014e --- /dev/null +++ b/examples/chat_server/chat.hpp @@ -0,0 +1,85 @@ +#ifndef CHAT_HPP +#define CHAT_HPP + +// com.zaphoyd.websocketpp.chat protocol +// +// client messages: +// msg [UTF8 text] +// +// server messages: +// {"type":"msg","sender":"","value":"" } +// {"type":"participants","value":[,...]} + +#include + +#include + +#include +#include +#include +#include + +namespace websocketchat { + +class chat_handler : public websocketpp::connection_handler { +public: + chat_handler() {} + virtual ~chat_handler() {} + + bool validate(websocketpp::session_ptr client); + + // add new connection to the lobby + void connect(session_ptr client) { + std::cout << "client " << client << " joined the lobby." << std::endl; + m_connections.insert(client); + + // send user list to all clients + // send signon message from server to all other clients. + + } + + // someone disconnected from the lobby, remove them + void disconnect(websocketpp::session_ptr client,const std::string &reason) { + std::set::iterator it = m_connections.find(client); + + if (it == m_connections.end()) { + // this client has already disconnected, we can ignore this. + // this happens during certain types of disconnect where there is a + // deliberate "soft" disconnection preceeding the "hard" socket read + // fail or disconnect ack message. + return; + } + + std::cout << "client " << client << " left the lobby." << std::endl; + + m_connections.erase(it); + + // send signoff message from server to all clients + // send user list to all remaining clients + } + + void message(websocketpp::session_ptr client,const std::string &msg) { + std::cout << "message from client " << client << ": " << msg << std::endl; + + // create JSON message to send based on msg + + for (std::set::iterator it = m_connections.begin(); + it != m_connections.end(); it++) { + (*it)->send(msg); + } + } + + // lobby will ignore binary messages + void message(websocketpp::session_ptr client, + const std::vector &data) {} +private: + std::string serialize_state(); + + // list of outstanding connections + std::set m_connections; +}; + +typedef boost::shared_ptr chat_handler_ptr; + +} +#endif // CHAT_HPP diff --git a/examples/chat_server/chat_server.cpp b/examples/chat_server/chat_server.cpp new file mode 100644 index 0000000000..cbce304b55 --- /dev/null +++ b/examples/chat_server/chat_server.cpp @@ -0,0 +1,41 @@ +#include "chat.hpp" + +#include +#include + +#include + +using boost::asio::ip::tcp; + +int main(int argc, char* argv[]) { + std::string host = "localhost:5000"; + short port = 5000; + + if (argc == 3) { + port = atoi(argv[2]); + + std::stringstream temp; + temp << argv[1] << ":" << port; + + host = temp.str(); + } + + websocketchat::chat_handler_ptr chat_handler(new websocketchat::chat_handler()); + + try { + boost::asio::io_service io_service; + tcp::endpoint endpoint(tcp::v6(), port); + + websocketpp::server_ptr server( + new websocketpp::server(io_service,endpoint,host,chat_handler) + ); + + std::cout << "Starting chat server on " << host << std::endl; + + io_service.run(); + } catch (std::exception& e) { + std::cerr << "Exception: " << e.what() << std::endl; + } + + return 0; +} diff --git a/examples/echo_server/Makefile b/examples/echo_server/Makefile new file mode 100644 index 0000000000..dcdfdeb85e --- /dev/null +++ b/examples/echo_server/Makefile @@ -0,0 +1,23 @@ +CFLAGS = -O2 +LDFLAGS = + +CXX ?= c++ +SHARED ?= "1" + +ifeq ($(SHARED), 1) + LDFLAGS := $(LDFLAGS) -lboost_system -lwebsocketpp +else + LDFLAGS := $(LDFLAGS) -lboost_system ../../libwebsocketpp.a +endif + +echo_server: echo_server.o echo.o + $(CXX) $(CFLAGS) $^ -o $@ $(LDFLAGS) + +%.o: %.cpp + $(CXX) -c $(CFLAGS) -o $@ $^ + +# cleanup by removing generated files +# +.PHONY: clean +clean: + rm -f *.o echo_server diff --git a/examples/echo_server/echo.cpp b/examples/echo_server/echo.cpp new file mode 100644 index 0000000000..ab677a0772 --- /dev/null +++ b/examples/echo_server/echo.cpp @@ -0,0 +1,16 @@ +#include "echo.hpp" + +using websocketecho::echo_handler; + +bool echo_handler::validate(websocketpp::session_ptr client) { + return true; +} + +void echo_handler::message(websocketpp::session_ptr client, const std::string &msg) { + client->send(msg); +} + +void echo_handler::message(websocketpp::session_ptr client, + const std::vector &data) { + client->send(data); +} diff --git a/examples/echo_server/echo.hpp b/examples/echo_server/echo.hpp new file mode 100644 index 0000000000..221f305ebb --- /dev/null +++ b/examples/echo_server/echo.hpp @@ -0,0 +1,35 @@ +#ifndef ECHO_HANDLER_HPP +#define ECHO_HANDLER_HPP + +#include "../../src/websocket_connection_handler.hpp" +#include + +#include +#include + +namespace websocketecho { + +class echo_handler : public websocketpp::connection_handler { +public: + echo_handler() {} + virtual ~echo_handler() {} + + // The echo server allows all domains is protocol free. + bool validate(websocketpp::session_ptr client); + + // an echo server is stateless. The handler has no need to keep track of connected + // clients. + void connect(websocketpp::session_ptr client) {} + void disconnect(websocketpp::session_ptr client,const std::string &reason) {} + + // both text and binary messages are echoed back to the sending client. + void message(websocketpp::session_ptr client,const std::string &msg); + void message(websocketpp::session_ptr client, + const std::vector &data); +}; + +typedef boost::shared_ptr echo_handler_ptr; + +} + +#endif // ECHO_HANDLER_HPP diff --git a/examples/echo_server/echo_client.html b/examples/echo_server/echo_client.html new file mode 100644 index 0000000000..8e6bc9bbce --- /dev/null +++ b/examples/echo_server/echo_client.html @@ -0,0 +1,94 @@ + + + + + + + + + + +
+
+ + +
+ +
+
+
+
+ + + \ No newline at end of file diff --git a/examples/echo_server/echo_server b/examples/echo_server/echo_server new file mode 100755 index 0000000000..51987c9609 Binary files /dev/null and b/examples/echo_server/echo_server differ diff --git a/examples/echo_server/echo_server.cpp b/examples/echo_server/echo_server.cpp new file mode 100644 index 0000000000..f8200463e6 --- /dev/null +++ b/examples/echo_server/echo_server.cpp @@ -0,0 +1,43 @@ +#include "echo.hpp" + +#include "../../src/websocketpp.hpp" +#include + +#include + +using boost::asio::ip::tcp; + +int main(int argc, char* argv[]) { + std::string host = "localhost:5000"; + short port = 5000; + + if (argc == 3) { + // TODO: input validation? + port = atoi(argv[2]); + + + std::stringstream temp; + temp << argv[1] << ":" << port; + + host = temp.str(); + } + + websocketecho::echo_handler_ptr echo_handler(new websocketecho::echo_handler()); + + try { + boost::asio::io_service io_service; + tcp::endpoint endpoint(tcp::v6(), port); + + websocketpp::server_ptr server( + new websocketpp::server(io_service,endpoint,host,echo_handler) + ); + + std::cout << "Starting echo server on " << host << std::endl; + + io_service.run(); + } catch (std::exception& e) { + std::cerr << "Exception: " << e.what() << std::endl; + } + + return 0; +} diff --git a/readme.txt b/readme.txt new file mode 100644 index 0000000000..51331c39de --- /dev/null +++ b/readme.txt @@ -0,0 +1,97 @@ +How to use this library + +WebSocket++ is a C++ websocket server library implimented using the Boost Asio networking stack. It is designed to provide a simple interface + +Built on Asio's proactor asyncronious event loop. + +Building a program using WebSocket++ has two parts +1. Impliment a connection handler. +This is done by subclassing websocketpp::connection_handler. Each websocket connection is attached to a connection handler. The handler impliments the following methods: + +validate: +Called after the client handshake is recieved but before the connection is accepted. Allows cookie authentication, origin checking, subprotocol negotiation, etc + +connect: +Called when the connection has been established and writes are allowed. + +disconnect: +Called when the connection has been disconnected + +message: text and binary variants +Called when a new websocket message is recieved. + +The handler has access to the following websocket session api: +get_header +returns the value of an HTTP header sent by the client during the handshake + +get_request +returns the resource requested by the client in the handshake + +set_handler: +pass responsibility for this connection to another connection handler + +set_http_error: +reject the connection with a specific HTTP error + +add_header +adds an HTTP header to the server handshake + +set_subprotocol +selects a subprotocol for the connection + +send: text and binary varients +send a websocket message + +ping: +send a ping + +2. Start Asio's event loop with a TCP endpoint and your connection handler + +There are two example programs in the examples directory that demonstrate this use pattern. One is a trivial stateless echo server, the other is a simple web based chat client. Both include example javascript clients. The echo server is suitable for use with automated testing suites such as the Autobahn test suite. + +By default, a single connection handler object is used for all connections. If needs require, that default handler can either store per-connection state itself or create new handlers and pass off responsibility for the connection to them. + +How to build this library + +Build static library +make + +Build and install in system include directories +make install + +Avaliable flags: +- SHARED=1: build a shared instead of static library. +- DEBUG=1: build library with no optimizations, suitable for debugging. Debug library + is called libwebsocketpp_dbg +- CXX=*: uses * as the c++ compiler instead of system default + +Build tested on +- Mac OS X 10.7 with apple gcc 4.2, macports gcc 4.6, apple llvm/clang + + + + +Outstanding issues +- Acknowledgement details +- Add license information to each file +- Add license.txt file +- Subprotocol negotiation interface +- check draft 13 issues +- make website + +Unimplimented features +- SSL +- frame or streaming based api +- client features +- extension negotiation interface + +Acknowledgements +- Boost Asio and other libraries +- base64 library +- sha1 library +- htonll discussion +- build/makefile from libjson + +- Autobahn test suite +- testing by Keith Brisson + diff --git a/src/base64/base64.cpp b/src/base64/base64.cpp new file mode 100644 index 0000000000..071b05cd38 --- /dev/null +++ b/src/base64/base64.cpp @@ -0,0 +1,123 @@ +/* + base64.cpp and base64.h + + Copyright (C) 2004-2008 René Nyffenegger + + This source code is provided 'as-is', without any express or implied + warranty. In no event will the author be held liable for any damages + arising from the use of this software. + + Permission is granted to anyone to use this software for any purpose, + including commercial applications, and to alter it and redistribute it + freely, subject to the following restrictions: + + 1. The origin of this source code must not be misrepresented; you must not + claim that you wrote the original source code. If you use this source code + in a product, an acknowledgment in the product documentation would be + appreciated but is not required. + + 2. Altered source versions must be plainly marked as such, and must not be + misrepresented as being the original source code. + + 3. This notice may not be removed or altered from any source distribution. + + René Nyffenegger rene.nyffenegger@adp-gmbh.ch + +*/ + +#include "base64.h" +#include + +static const std::string base64_chars = + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz" + "0123456789+/"; + + +static inline bool is_base64(unsigned char c) { + return (isalnum(c) || (c == '+') || (c == '/')); +} + +std::string base64_encode(unsigned char const* bytes_to_encode, unsigned int in_len) { + std::string ret; + int i = 0; + int j = 0; + unsigned char char_array_3[3]; + unsigned char char_array_4[4]; + + while (in_len--) { + char_array_3[i++] = *(bytes_to_encode++); + if (i == 3) { + char_array_4[0] = (char_array_3[0] & 0xfc) >> 2; + char_array_4[1] = ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4); + char_array_4[2] = ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6); + char_array_4[3] = char_array_3[2] & 0x3f; + + for(i = 0; (i <4) ; i++) + ret += base64_chars[char_array_4[i]]; + i = 0; + } + } + + if (i) + { + for(j = i; j < 3; j++) + char_array_3[j] = '\0'; + + char_array_4[0] = (char_array_3[0] & 0xfc) >> 2; + char_array_4[1] = ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4); + char_array_4[2] = ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6); + char_array_4[3] = char_array_3[2] & 0x3f; + + for (j = 0; (j < i + 1); j++) + ret += base64_chars[char_array_4[j]]; + + while((i++ < 3)) + ret += '='; + + } + + return ret; + +} + +std::string base64_decode(std::string const& encoded_string) { + int in_len = encoded_string.size(); + int i = 0; + int j = 0; + int in_ = 0; + unsigned char char_array_4[4], char_array_3[3]; + std::string ret; + + while (in_len-- && ( encoded_string[in_] != '=') && is_base64(encoded_string[in_])) { + char_array_4[i++] = encoded_string[in_]; in_++; + if (i ==4) { + for (i = 0; i <4; i++) + char_array_4[i] = base64_chars.find(char_array_4[i]); + + char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); + char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + + for (i = 0; (i < 3); i++) + ret += char_array_3[i]; + i = 0; + } + } + + if (i) { + for (j = i; j <4; j++) + char_array_4[j] = 0; + + for (j = 0; j <4; j++) + char_array_4[j] = base64_chars.find(char_array_4[j]); + + char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); + char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + + for (j = 0; (j < i - 1); j++) ret += char_array_3[j]; + } + + return ret; +} \ No newline at end of file diff --git a/src/base64/base64.h b/src/base64/base64.h new file mode 100644 index 0000000000..ceb13579ce --- /dev/null +++ b/src/base64/base64.h @@ -0,0 +1,4 @@ +#include + +std::string base64_encode(unsigned char const* , unsigned int len); +std::string base64_decode(std::string const& s); \ No newline at end of file diff --git a/src/network_utilities.cpp b/src/network_utilities.cpp new file mode 100644 index 0000000000..9850a6de7c --- /dev/null +++ b/src/network_utilities.cpp @@ -0,0 +1,26 @@ +#include "network_utilities.hpp" + +uint64_t htonll(uint64_t src) { + static int typ = TYP_INIT; + unsigned char c; + union { + uint64_t ull; + unsigned char c[8]; + } x; + if (typ == TYP_INIT) { + x.ull = 0x01; + typ = (x.c[7] == 0x01ULL) ? TYP_BIGE : TYP_SMLE; + } + if (typ == TYP_BIGE) + return src; + x.ull = src; + c = x.c[0]; x.c[0] = x.c[7]; x.c[7] = c; + c = x.c[1]; x.c[1] = x.c[6]; x.c[6] = c; + c = x.c[2]; x.c[2] = x.c[5]; x.c[5] = c; + c = x.c[3]; x.c[3] = x.c[4]; x.c[4] = c; + return x.ull; +} + +uint64_t ntohll(uint64_t src) { + return htonll(src); +} \ No newline at end of file diff --git a/src/network_utilities.hpp b/src/network_utilities.hpp new file mode 100644 index 0000000000..154e1c5ad8 --- /dev/null +++ b/src/network_utilities.hpp @@ -0,0 +1,17 @@ +#ifndef NETWORK_UTILITIES_HPP +#define NETWORK_UTILITIES_HPP + +#include + +// http://www.viva64.com/en/k/0018/ +// TODO: impliment stuff from here: +// http://stackoverflow.com/questions/809902/64-bit-ntohl-in-c + +#define TYP_INIT 0 +#define TYP_SMLE 1 +#define TYP_BIGE 2 + +uint64_t htonll(uint64_t src); +uint64_t ntohll(uint64_t src); + +#endif // NETWORK_UTILITIES_HPP \ No newline at end of file diff --git a/src/sha1/Makefile b/src/sha1/Makefile new file mode 100755 index 0000000000..66aaf2b56f --- /dev/null +++ b/src/sha1/Makefile @@ -0,0 +1,41 @@ +# +# Makefile +# +# Copyright (C) 1998, 2009 +# Paul E. Jones +# All Rights Reserved. +# +############################################################################# +# $Id: Makefile 12 2009-06-22 19:34:25Z paulej $ +############################################################################# +# +# Description: +# This is a makefile for UNIX to build the programs sha, shacmp, and +# shatest +# +# + +CC = g++ + +CFLAGS = -c -O2 -Wall -D_FILE_OFFSET_BITS=64 + +LIBS = + +OBJS = sha1.o + +all: sha shacmp shatest + +sha: sha.o $(OBJS) + $(CC) -o $@ sha.o $(OBJS) $(LIBS) + +shacmp: shacmp.o $(OBJS) + $(CC) -o $@ shacmp.o $(OBJS) $(LIBS) + +shatest: shatest.o $(OBJS) + $(CC) -o $@ shatest.o $(OBJS) $(LIBS) + +%.o: %.cpp + $(CC) $(CFLAGS) -o $@ $< + +clean: + $(RM) *.o sha shacmp shatest diff --git a/src/sha1/Makefile.nt b/src/sha1/Makefile.nt new file mode 100755 index 0000000000..9bacf64a1d --- /dev/null +++ b/src/sha1/Makefile.nt @@ -0,0 +1,48 @@ +# +# Makefile.nt +# +# Copyright (C) 1998, 2009 +# Paul E. Jones +# All Rights Reserved. +# +############################################################################# +# $Id: Makefile.nt 13 2009-06-22 20:20:32Z paulej $ +############################################################################# +# +# Description: +# This is a makefile for Win32 to build the programs sha, shacmp, and +# shatest +# +# Portability Issues: +# Designed to work with Visual C++ +# +# + +.silent: + +!include + +RM = del /q + +LIBS = $(conlibs) setargv.obj + +CFLAGS = -D _CRT_SECURE_NO_WARNINGS /EHsc /O2 /W3 + +OBJS = sha1.obj + +all: sha.exe shacmp.exe shatest.exe + +sha.exe: sha.obj $(OBJS) + $(link) $(conflags) -out:$@ sha.obj $(OBJS) $(LIBS) + +shacmp.exe: shacmp.obj $(OBJS) + $(link) $(conflags) -out:$@ shacmp.obj $(OBJS) $(LIBS) + +shatest.exe: shatest.obj $(OBJS) + $(link) $(conflags) -out:$@ shatest.obj $(OBJS) $(LIBS) + +.cpp.obj: + $(cc) $(CFLAGS) $(cflags) $(cvars) $< + +clean: + $(RM) *.obj sha.exe shacmp.exe shatest.exe diff --git a/src/sha1/license.txt b/src/sha1/license.txt new file mode 100755 index 0000000000..8d7f3941cf --- /dev/null +++ b/src/sha1/license.txt @@ -0,0 +1,14 @@ +Copyright (C) 1998, 2009 +Paul E. Jones + +Freeware Public License (FPL) + +This software is licensed as "freeware." Permission to distribute +this software in source and binary forms, including incorporation +into other products, is hereby granted without a fee. THIS SOFTWARE +IS PROVIDED 'AS IS' AND WITHOUT ANY EXPRESSED OR IMPLIED WARRANTIES, +INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY +AND FITNESS FOR A PARTICULAR PURPOSE. THE AUTHOR SHALL NOT BE HELD +LIABLE FOR ANY DAMAGES RESULTING FROM THE USE OF THIS SOFTWARE, EITHER +DIRECTLY OR INDIRECTLY, INCLUDING, BUT NOT LIMITED TO, LOSS OF DATA +OR DATA BEING RENDERED INACCURATE. diff --git a/src/sha1/sha.cpp b/src/sha1/sha.cpp new file mode 100755 index 0000000000..ad860863af --- /dev/null +++ b/src/sha1/sha.cpp @@ -0,0 +1,176 @@ +/* + * sha.cpp + * + * Copyright (C) 1998, 2009 + * Paul E. Jones + * All Rights Reserved + * + ***************************************************************************** + * $Id: sha.cpp 13 2009-06-22 20:20:32Z paulej $ + ***************************************************************************** + * + * Description: + * This utility will display the message digest (fingerprint) for + * the specified file(s). + * + * Portability Issues: + * None. + */ + +#include +#include +#ifdef WIN32 +#include +#endif +#include +#include "sha1.h" + +/* + * Function prototype + */ +void usage(); + + +/* + * main + * + * Description: + * This is the entry point for the program + * + * Parameters: + * argc: [in] + * This is the count of arguments in the argv array + * argv: [in] + * This is an array of filenames for which to compute message digests + * + * Returns: + * Nothing. + * + * Comments: + * + */ +int main(int argc, char *argv[]) +{ + SHA1 sha; // SHA-1 class + FILE *fp; // File pointer for reading files + char c; // Character read from file + unsigned message_digest[5]; // Message digest from "sha" + int i; // Counter + bool reading_stdin; // Are we reading standard in? + bool read_stdin = false; // Have we read stdin? + + /* + * Check the program arguments and print usage information if -? + * or --help is passed as the first argument. + */ + if (argc > 1 && (!strcmp(argv[1],"-?") || !strcmp(argv[1],"--help"))) + { + usage(); + return 1; + } + + /* + * For each filename passed in on the command line, calculate the + * SHA-1 value and display it. + */ + for(i = 0; i < argc; i++) + { + /* + * We start the counter at 0 to guarantee entry into the for loop. + * So if 'i' is zero, we will increment it now. If there is no + * argv[1], we will use STDIN below. + */ + if (i == 0) + { + i++; + } + + if (argc == 1 || !strcmp(argv[i],"-")) + { +#ifdef WIN32 + _setmode(_fileno(stdin), _O_BINARY); +#endif + fp = stdin; + reading_stdin = true; + } + else + { + if (!(fp = fopen(argv[i],"rb"))) + { + fprintf(stderr, "sha: unable to open file %s\n", argv[i]); + return 2; + } + reading_stdin = false; + } + + /* + * We do not want to read STDIN multiple times + */ + if (reading_stdin) + { + if (read_stdin) + { + continue; + } + + read_stdin = true; + } + + /* + * Reset the SHA1 object and process input + */ + sha.Reset(); + + c = fgetc(fp); + while(!feof(fp)) + { + sha.Input(c); + c = fgetc(fp); + } + + if (!reading_stdin) + { + fclose(fp); + } + + if (!sha.Result(message_digest)) + { + fprintf(stderr,"sha: could not compute message digest for %s\n", + reading_stdin?"STDIN":argv[i]); + } + else + { + printf( "%08X %08X %08X %08X %08X - %s\n", + message_digest[0], + message_digest[1], + message_digest[2], + message_digest[3], + message_digest[4], + reading_stdin?"STDIN":argv[i]); + } + } + + return 0; +} + +/* + * usage + * + * Description: + * This function will display program usage information to the user. + * + * Parameters: + * None. + * + * Returns: + * Nothing. + * + * Comments: + * + */ +void usage() +{ + printf("usage: sha [ ...]\n"); + printf("\tThis program will display the message digest (fingerprint)\n"); + printf("\tfor files using the Secure Hashing Algorithm (SHA-1).\n"); +} diff --git a/src/sha1/sha1.cpp b/src/sha1/sha1.cpp new file mode 100755 index 0000000000..fdcbdcc0b4 --- /dev/null +++ b/src/sha1/sha1.cpp @@ -0,0 +1,589 @@ +/* + * sha1.cpp + * + * Copyright (C) 1998, 2009 + * Paul E. Jones + * All Rights Reserved. + * + ***************************************************************************** + * $Id: sha1.cpp 12 2009-06-22 19:34:25Z paulej $ + ***************************************************************************** + * + * Description: + * This class implements the Secure Hashing Standard as defined + * in FIPS PUB 180-1 published April 17, 1995. + * + * The Secure Hashing Standard, which uses the Secure Hashing + * Algorithm (SHA), produces a 160-bit message digest for a + * given data stream. In theory, it is highly improbable that + * two messages will produce the same message digest. Therefore, + * this algorithm can serve as a means of providing a "fingerprint" + * for a message. + * + * Portability Issues: + * SHA-1 is defined in terms of 32-bit "words". This code was + * written with the expectation that the processor has at least + * a 32-bit machine word size. If the machine word size is larger, + * the code should still function properly. One caveat to that + * is that the input functions taking characters and character arrays + * assume that only 8 bits of information are stored in each character. + * + * Caveats: + * SHA-1 is designed to work with messages less than 2^64 bits long. + * Although SHA-1 allows a message digest to be generated for + * messages of any number of bits less than 2^64, this implementation + * only works with messages with a length that is a multiple of 8 + * bits. + * + */ + + +#include "sha1.h" + +/* + * SHA1 + * + * Description: + * This is the constructor for the sha1 class. + * + * Parameters: + * None. + * + * Returns: + * Nothing. + * + * Comments: + * + */ +SHA1::SHA1() +{ + Reset(); +} + +/* + * ~SHA1 + * + * Description: + * This is the destructor for the sha1 class + * + * Parameters: + * None. + * + * Returns: + * Nothing. + * + * Comments: + * + */ +SHA1::~SHA1() +{ + // The destructor does nothing +} + +/* + * Reset + * + * Description: + * This function will initialize the sha1 class member variables + * in preparation for computing a new message digest. + * + * Parameters: + * None. + * + * Returns: + * Nothing. + * + * Comments: + * + */ +void SHA1::Reset() +{ + Length_Low = 0; + Length_High = 0; + Message_Block_Index = 0; + + H[0] = 0x67452301; + H[1] = 0xEFCDAB89; + H[2] = 0x98BADCFE; + H[3] = 0x10325476; + H[4] = 0xC3D2E1F0; + + Computed = false; + Corrupted = false; +} + +/* + * Result + * + * Description: + * This function will return the 160-bit message digest into the + * array provided. + * + * Parameters: + * message_digest_array: [out] + * This is an array of five unsigned integers which will be filled + * with the message digest that has been computed. + * + * Returns: + * True if successful, false if it failed. + * + * Comments: + * + */ +bool SHA1::Result(unsigned *message_digest_array) +{ + int i; // Counter + + if (Corrupted) + { + return false; + } + + if (!Computed) + { + PadMessage(); + Computed = true; + } + + for(i = 0; i < 5; i++) + { + message_digest_array[i] = H[i]; + } + + return true; +} + +/* + * Input + * + * Description: + * This function accepts an array of octets as the next portion of + * the message. + * + * Parameters: + * message_array: [in] + * An array of characters representing the next portion of the + * message. + * + * Returns: + * Nothing. + * + * Comments: + * + */ +void SHA1::Input( const unsigned char *message_array, + unsigned length) +{ + if (!length) + { + return; + } + + if (Computed || Corrupted) + { + Corrupted = true; + return; + } + + while(length-- && !Corrupted) + { + Message_Block[Message_Block_Index++] = (*message_array & 0xFF); + + Length_Low += 8; + Length_Low &= 0xFFFFFFFF; // Force it to 32 bits + if (Length_Low == 0) + { + Length_High++; + Length_High &= 0xFFFFFFFF; // Force it to 32 bits + if (Length_High == 0) + { + Corrupted = true; // Message is too long + } + } + + if (Message_Block_Index == 64) + { + ProcessMessageBlock(); + } + + message_array++; + } +} + +/* + * Input + * + * Description: + * This function accepts an array of octets as the next portion of + * the message. + * + * Parameters: + * message_array: [in] + * An array of characters representing the next portion of the + * message. + * length: [in] + * The length of the message_array + * + * Returns: + * Nothing. + * + * Comments: + * + */ +void SHA1::Input( const char *message_array, + unsigned length) +{ + Input((unsigned char *) message_array, length); +} + +/* + * Input + * + * Description: + * This function accepts a single octets as the next message element. + * + * Parameters: + * message_element: [in] + * The next octet in the message. + * + * Returns: + * Nothing. + * + * Comments: + * + */ +void SHA1::Input(unsigned char message_element) +{ + Input(&message_element, 1); +} + +/* + * Input + * + * Description: + * This function accepts a single octet as the next message element. + * + * Parameters: + * message_element: [in] + * The next octet in the message. + * + * Returns: + * Nothing. + * + * Comments: + * + */ +void SHA1::Input(char message_element) +{ + Input((unsigned char *) &message_element, 1); +} + +/* + * operator<< + * + * Description: + * This operator makes it convenient to provide character strings to + * the SHA1 object for processing. + * + * Parameters: + * message_array: [in] + * The character array to take as input. + * + * Returns: + * A reference to the SHA1 object. + * + * Comments: + * Each character is assumed to hold 8 bits of information. + * + */ +SHA1& SHA1::operator<<(const char *message_array) +{ + const char *p = message_array; + + while(*p) + { + Input(*p); + p++; + } + + return *this; +} + +/* + * operator<< + * + * Description: + * This operator makes it convenient to provide character strings to + * the SHA1 object for processing. + * + * Parameters: + * message_array: [in] + * The character array to take as input. + * + * Returns: + * A reference to the SHA1 object. + * + * Comments: + * Each character is assumed to hold 8 bits of information. + * + */ +SHA1& SHA1::operator<<(const unsigned char *message_array) +{ + const unsigned char *p = message_array; + + while(*p) + { + Input(*p); + p++; + } + + return *this; +} + +/* + * operator<< + * + * Description: + * This function provides the next octet in the message. + * + * Parameters: + * message_element: [in] + * The next octet in the message + * + * Returns: + * A reference to the SHA1 object. + * + * Comments: + * The character is assumed to hold 8 bits of information. + * + */ +SHA1& SHA1::operator<<(const char message_element) +{ + Input((unsigned char *) &message_element, 1); + + return *this; +} + +/* + * operator<< + * + * Description: + * This function provides the next octet in the message. + * + * Parameters: + * message_element: [in] + * The next octet in the message + * + * Returns: + * A reference to the SHA1 object. + * + * Comments: + * The character is assumed to hold 8 bits of information. + * + */ +SHA1& SHA1::operator<<(const unsigned char message_element) +{ + Input(&message_element, 1); + + return *this; +} + +/* + * ProcessMessageBlock + * + * Description: + * This function will process the next 512 bits of the message + * stored in the Message_Block array. + * + * Parameters: + * None. + * + * Returns: + * Nothing. + * + * Comments: + * Many of the variable names in this function, especially the single + * character names, were used because those were the names used + * in the publication. + * + */ +void SHA1::ProcessMessageBlock() +{ + const unsigned K[] = { // Constants defined for SHA-1 + 0x5A827999, + 0x6ED9EBA1, + 0x8F1BBCDC, + 0xCA62C1D6 + }; + int t; // Loop counter + unsigned temp; // Temporary word value + unsigned W[80]; // Word sequence + unsigned A, B, C, D, E; // Word buffers + + /* + * Initialize the first 16 words in the array W + */ + for(t = 0; t < 16; t++) + { + W[t] = ((unsigned) Message_Block[t * 4]) << 24; + W[t] |= ((unsigned) Message_Block[t * 4 + 1]) << 16; + W[t] |= ((unsigned) Message_Block[t * 4 + 2]) << 8; + W[t] |= ((unsigned) Message_Block[t * 4 + 3]); + } + + for(t = 16; t < 80; t++) + { + W[t] = CircularShift(1,W[t-3] ^ W[t-8] ^ W[t-14] ^ W[t-16]); + } + + A = H[0]; + B = H[1]; + C = H[2]; + D = H[3]; + E = H[4]; + + for(t = 0; t < 20; t++) + { + temp = CircularShift(5,A) + ((B & C) | ((~B) & D)) + E + W[t] + K[0]; + temp &= 0xFFFFFFFF; + E = D; + D = C; + C = CircularShift(30,B); + B = A; + A = temp; + } + + for(t = 20; t < 40; t++) + { + temp = CircularShift(5,A) + (B ^ C ^ D) + E + W[t] + K[1]; + temp &= 0xFFFFFFFF; + E = D; + D = C; + C = CircularShift(30,B); + B = A; + A = temp; + } + + for(t = 40; t < 60; t++) + { + temp = CircularShift(5,A) + + ((B & C) | (B & D) | (C & D)) + E + W[t] + K[2]; + temp &= 0xFFFFFFFF; + E = D; + D = C; + C = CircularShift(30,B); + B = A; + A = temp; + } + + for(t = 60; t < 80; t++) + { + temp = CircularShift(5,A) + (B ^ C ^ D) + E + W[t] + K[3]; + temp &= 0xFFFFFFFF; + E = D; + D = C; + C = CircularShift(30,B); + B = A; + A = temp; + } + + H[0] = (H[0] + A) & 0xFFFFFFFF; + H[1] = (H[1] + B) & 0xFFFFFFFF; + H[2] = (H[2] + C) & 0xFFFFFFFF; + H[3] = (H[3] + D) & 0xFFFFFFFF; + H[4] = (H[4] + E) & 0xFFFFFFFF; + + Message_Block_Index = 0; +} + +/* + * PadMessage + * + * Description: + * According to the standard, the message must be padded to an even + * 512 bits. The first padding bit must be a '1'. The last 64 bits + * represent the length of the original message. All bits in between + * should be 0. This function will pad the message according to those + * rules by filling the message_block array accordingly. It will also + * call ProcessMessageBlock() appropriately. When it returns, it + * can be assumed that the message digest has been computed. + * + * Parameters: + * None. + * + * Returns: + * Nothing. + * + * Comments: + * + */ +void SHA1::PadMessage() +{ + /* + * Check to see if the current message block is too small to hold + * the initial padding bits and length. If so, we will pad the + * block, process it, and then continue padding into a second block. + */ + if (Message_Block_Index > 55) + { + Message_Block[Message_Block_Index++] = 0x80; + while(Message_Block_Index < 64) + { + Message_Block[Message_Block_Index++] = 0; + } + + ProcessMessageBlock(); + + while(Message_Block_Index < 56) + { + Message_Block[Message_Block_Index++] = 0; + } + } + else + { + Message_Block[Message_Block_Index++] = 0x80; + while(Message_Block_Index < 56) + { + Message_Block[Message_Block_Index++] = 0; + } + + } + + /* + * Store the message length as the last 8 octets + */ + Message_Block[56] = (Length_High >> 24) & 0xFF; + Message_Block[57] = (Length_High >> 16) & 0xFF; + Message_Block[58] = (Length_High >> 8) & 0xFF; + Message_Block[59] = (Length_High) & 0xFF; + Message_Block[60] = (Length_Low >> 24) & 0xFF; + Message_Block[61] = (Length_Low >> 16) & 0xFF; + Message_Block[62] = (Length_Low >> 8) & 0xFF; + Message_Block[63] = (Length_Low) & 0xFF; + + ProcessMessageBlock(); +} + + +/* + * CircularShift + * + * Description: + * This member function will perform a circular shifting operation. + * + * Parameters: + * bits: [in] + * The number of bits to shift (1-31) + * word: [in] + * The value to shift (assumes a 32-bit integer) + * + * Returns: + * The shifted value. + * + * Comments: + * + */ +unsigned SHA1::CircularShift(int bits, unsigned word) +{ + return ((word << bits) & 0xFFFFFFFF) | ((word & 0xFFFFFFFF) >> (32-bits)); +} diff --git a/src/sha1/sha1.h b/src/sha1/sha1.h new file mode 100755 index 0000000000..c0efa1c944 --- /dev/null +++ b/src/sha1/sha1.h @@ -0,0 +1,89 @@ +/* + * sha1.h + * + * Copyright (C) 1998, 2009 + * Paul E. Jones + * All Rights Reserved. + * + ***************************************************************************** + * $Id: sha1.h 12 2009-06-22 19:34:25Z paulej $ + ***************************************************************************** + * + * Description: + * This class implements the Secure Hashing Standard as defined + * in FIPS PUB 180-1 published April 17, 1995. + * + * Many of the variable names in this class, especially the single + * character names, were used because those were the names used + * in the publication. + * + * Please read the file sha1.cpp for more information. + * + */ + +#ifndef _SHA1_H_ +#define _SHA1_H_ + +class SHA1 +{ + + public: + + SHA1(); + virtual ~SHA1(); + + /* + * Re-initialize the class + */ + void Reset(); + + /* + * Returns the message digest + */ + bool Result(unsigned *message_digest_array); + + /* + * Provide input to SHA1 + */ + void Input( const unsigned char *message_array, + unsigned length); + void Input( const char *message_array, + unsigned length); + void Input(unsigned char message_element); + void Input(char message_element); + SHA1& operator<<(const char *message_array); + SHA1& operator<<(const unsigned char *message_array); + SHA1& operator<<(const char message_element); + SHA1& operator<<(const unsigned char message_element); + + private: + + /* + * Process the next 512 bits of the message + */ + void ProcessMessageBlock(); + + /* + * Pads the current message block to 512 bits + */ + void PadMessage(); + + /* + * Performs a circular left shift operation + */ + inline unsigned CircularShift(int bits, unsigned word); + + unsigned H[5]; // Message digest buffers + + unsigned Length_Low; // Message length in bits + unsigned Length_High; // Message length in bits + + unsigned char Message_Block[64]; // 512-bit message blocks + int Message_Block_Index; // Index into message block array + + bool Computed; // Is the digest computed? + bool Corrupted; // Is the message digest corruped? + +}; + +#endif diff --git a/src/sha1/shacmp.cpp b/src/sha1/shacmp.cpp new file mode 100755 index 0000000000..476ad3f0bd --- /dev/null +++ b/src/sha1/shacmp.cpp @@ -0,0 +1,169 @@ +/* + * shacmp.cpp + * + * Copyright (C) 1998, 2009 + * Paul E. Jones + * All Rights Reserved + * + ***************************************************************************** + * $Id: shacmp.cpp 12 2009-06-22 19:34:25Z paulej $ + ***************************************************************************** + * + * Description: + * This utility will compare two files by producing a message digest + * for each file using the Secure Hashing Algorithm and comparing + * the message digests. This function will return 0 if they + * compare or 1 if they do not or if there is an error. + * Errors result in a return code higher than 1. + * + * Portability Issues: + * none. + * + */ + +#include +#include +#include "sha1.h" + +/* + * Return codes + */ +#define SHA1_COMPARE 0 +#define SHA1_NO_COMPARE 1 +#define SHA1_USAGE_ERROR 2 +#define SHA1_FILE_ERROR 3 + +/* + * Function prototype + */ +void usage(); + +/* + * main + * + * Description: + * This is the entry point for the program + * + * Parameters: + * argc: [in] + * This is the count of arguments in the argv array + * argv: [in] + * This is an array of filenames for which to compute message digests + * + * Returns: + * Nothing. + * + * Comments: + * + */ +int main(int argc, char *argv[]) +{ + SHA1 sha; // SHA-1 class + FILE *fp; // File pointer for reading files + char c; // Character read from file + unsigned message_digest[2][5]; // Message digest for files + int i; // Counter + bool message_match; // Message digest match flag + int returncode; + + /* + * If we have two arguments, we will assume they are filenames. If + * we do not have to arguments, call usage() and exit. + */ + if (argc != 3) + { + usage(); + return SHA1_USAGE_ERROR; + } + + /* + * Get the message digests for each file + */ + for(i = 1; i <= 2; i++) + { + sha.Reset(); + + if (!(fp = fopen(argv[i],"rb"))) + { + fprintf(stderr, "sha: unable to open file %s\n", argv[i]); + return SHA1_FILE_ERROR; + } + + c = fgetc(fp); + while(!feof(fp)) + { + sha.Input(c); + c = fgetc(fp); + } + + fclose(fp); + + if (!sha.Result(message_digest[i-1])) + { + fprintf(stderr,"shacmp: could not compute message digest for %s\n", + argv[i]); + return SHA1_FILE_ERROR; + } + } + + /* + * Compare the message digest values + */ + message_match = true; + for(i = 0; i < 5; i++) + { + if (message_digest[0][i] != message_digest[1][i]) + { + message_match = false; + break; + } + } + + if (message_match) + { + printf("Fingerprints match:\n"); + returncode = SHA1_COMPARE; + } + else + { + printf("Fingerprints do not match:\n"); + returncode = SHA1_NO_COMPARE; + } + + printf( "\t%08X %08X %08X %08X %08X\n", + message_digest[0][0], + message_digest[0][1], + message_digest[0][2], + message_digest[0][3], + message_digest[0][4]); + printf( "\t%08X %08X %08X %08X %08X\n", + message_digest[1][0], + message_digest[1][1], + message_digest[1][2], + message_digest[1][3], + message_digest[1][4]); + + return returncode; +} + +/* + * usage + * + * Description: + * This function will display program usage information to the user. + * + * Parameters: + * None. + * + * Returns: + * Nothing. + * + * Comments: + * + */ +void usage() +{ + printf("usage: shacmp \n"); + printf("\tThis program will compare the message digests (fingerprints)\n"); + printf("\tfor two files using the Secure Hashing Algorithm (SHA-1).\n"); +} diff --git a/src/sha1/shatest.cpp b/src/sha1/shatest.cpp new file mode 100755 index 0000000000..4d29ef9e24 --- /dev/null +++ b/src/sha1/shatest.cpp @@ -0,0 +1,149 @@ +/* + * shatest.cpp + * + * Copyright (C) 1998, 2009 + * Paul E. Jones + * All Rights Reserved + * + ***************************************************************************** + * $Id: shatest.cpp 12 2009-06-22 19:34:25Z paulej $ + ***************************************************************************** + * + * Description: + * This file will exercise the SHA1 class and perform the three + * tests documented in FIPS PUB 180-1. + * + * Portability Issues: + * None. + * + */ + +#include +#include "sha1.h" + +using namespace std; + +/* + * Define patterns for testing + */ +#define TESTA "abc" +#define TESTB "abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq" + +/* + * Function prototype + */ +void DisplayMessageDigest(unsigned *message_digest); + +/* + * main + * + * Description: + * This is the entry point for the program + * + * Parameters: + * None. + * + * Returns: + * Nothing. + * + * Comments: + * + */ +int main() +{ + SHA1 sha; + unsigned message_digest[5]; + + /* + * Perform test A + */ + cout << endl << "Test A: 'abc'" << endl; + + sha.Reset(); + sha << TESTA; + + if (!sha.Result(message_digest)) + { + cerr << "ERROR-- could not compute message digest" << endl; + } + else + { + DisplayMessageDigest(message_digest); + cout << "Should match:" << endl; + cout << '\t' << "A9993E36 4706816A BA3E2571 7850C26C 9CD0D89D" << endl; + } + + /* + * Perform test B + */ + cout << endl << "Test B: " << TESTB << endl; + + sha.Reset(); + sha << TESTB; + + if (!sha.Result(message_digest)) + { + cerr << "ERROR-- could not compute message digest" << endl; + } + else + { + DisplayMessageDigest(message_digest); + cout << "Should match:" << endl; + cout << '\t' << "84983E44 1C3BD26E BAAE4AA1 F95129E5 E54670F1" << endl; + } + + /* + * Perform test C + */ + cout << endl << "Test C: One million 'a' characters" << endl; + + sha.Reset(); + for(int i = 1; i <= 1000000; i++) sha.Input('a'); + + if (!sha.Result(message_digest)) + { + cerr << "ERROR-- could not compute message digest" << endl; + } + else + { + DisplayMessageDigest(message_digest); + cout << "Should match:" << endl; + cout << '\t' << "34AA973C D4C4DAA4 F61EEB2B DBAD2731 6534016F" << endl; + } + + return 0; +} + +/* + * DisplayMessageDigest + * + * Description: + * Display Message Digest array + * + * Parameters: + * None. + * + * Returns: + * Nothing. + * + * Comments: + * + */ +void DisplayMessageDigest(unsigned *message_digest) +{ + ios::fmtflags flags; + + cout << '\t'; + + flags = cout.setf(ios::hex|ios::uppercase,ios::basefield); + cout.setf(ios::uppercase); + + for(int i = 0; i < 5 ; i++) + { + cout << message_digest[i] << ' '; + } + + cout << endl; + + cout.setf(flags); +} diff --git a/src/websocket_connection_handler.hpp b/src/websocket_connection_handler.hpp new file mode 100644 index 0000000000..33ec3b30ec --- /dev/null +++ b/src/websocket_connection_handler.hpp @@ -0,0 +1,61 @@ +#ifndef WEBSOCKET_CONNECTION_HANDLER_HPP +#define WEBSOCKET_CONNECTION_HANDLER_HPP + +#include + +#include +#include + +namespace websocketpp { + class connection_handler; + typedef boost::shared_ptr connection_handler_ptr; +} + +#include "websocket_session.hpp" + +namespace websocketpp { + +class connection_handler { +public: + // validate will be called after a websocket handshake has been received and + // before it is accepted. It provides a handler the ability to refuse a + // connection based on application specific logic (ex: restrict domains or + // negotiate subprotocols) + // + // resource is the resource requested by the connection + // headers is a map containing the HTTP headers included in the client + // handshake as key/value strings + // + // if validate returns true the connection is allowed, otherwise the + // connection is closed. + virtual bool validate(session_ptr client) = 0; + + // this will be called once the connected websocket is avaliable for + // writing messages. client may be a new websocket session or an existing + // session that was recently passed to this handler. + virtual void connect(session_ptr client) = 0; + + // this will be called when the connected websocket is no longer avaliable + // for writing messages. This occurs under the following conditions: + // - Disconnect message recieved from the remote endpoint + // - Someone (usually this object) calls the disconnect method of session + // - A disconnect acknowledgement is recieved (in case another object + // calls the disconnect method of session + // - The connection handler assigned to this client was set to another + // handler + virtual void disconnect(session_ptr client,const std::string &reason) = 0; + + // this will be called when a text message is recieved. Text will be + // encoded as UTF-8. + virtual void message(session_ptr client,const std::string &msg) = 0; + + // this will be called when a binary message is recieved. Argument is a + // vector of the raw bytes in the message body. + virtual void message(session_ptr client, + const std::vector &data) = 0; +}; + + + +} +#endif // WEBSOCKET_CONNECTION_HANDLER_HPP \ No newline at end of file diff --git a/src/websocket_frame.cpp b/src/websocket_frame.cpp new file mode 100644 index 0000000000..b71b67d39f --- /dev/null +++ b/src/websocket_frame.cpp @@ -0,0 +1,340 @@ +#include "websocket_frame.hpp" + +#include +#include + +#include + +using websocketpp::frame; + +char* frame::get_header() { + return m_header; +} + +char* frame::get_extended_header() { + return m_header+BASIC_HEADER_LENGTH; +} + +unsigned int frame::get_header_len() const { + unsigned int temp = 2; + + if (get_masked()) { + temp += 4; + } + + if (get_basic_size() == 126) { + temp += 2; + } else if (get_basic_size() == 127) { + temp += 8; + } + + return temp; +} + +char* frame::get_masking_key() { + if (m_extended_header_bytes_needed > 0) { + throw "attempted to get masking_key before reading full header"; + } + return m_masking_key; +} + +// get and set header bits +bool frame::get_fin() const { + return ((m_header[0] & BPB0_FIN) == BPB0_FIN); +} + +void frame::set_fin(bool fin) { + if (fin) { + m_header[0] |= BPB0_FIN; + } else { + m_header[0] &= (0xFF ^ BPB0_FIN); + } +} + +// get and set reserved bits +bool frame::get_rsv1() const { + return ((m_header[0] & BPB0_RSV1) == BPB0_RSV1); +} + +void frame::set_rsv1(bool b) { + if (b) { + m_header[0] |= BPB0_RSV1; + } else { + m_header[0] &= (0xFF ^ BPB0_RSV1); + } +} + +bool frame::get_rsv2() const { + return ((m_header[0] & BPB0_RSV2) == BPB0_RSV2); +} + +void frame::set_rsv2(bool b) { + if (b) { + m_header[0] |= BPB0_RSV2; + } else { + m_header[0] &= (0xFF ^ BPB0_RSV2); + } +} + +bool frame::get_rsv3() const { + return ((m_header[0] & BPB0_RSV3) == BPB0_RSV3); +} + +void frame::set_rsv3(bool b) { + if (b) { + m_header[0] |= BPB0_RSV3; + } else { + m_header[0] &= (0xFF ^ BPB0_RSV3); + } +} + +frame::opcode frame::get_opcode() const { + return frame::opcode(m_header[0] & BPB0_OPCODE); +} + +void frame::set_opcode(frame::opcode op) { + if (op > 0x0F) { + throw "invalid opcode"; + } + + if (get_basic_size() > BASIC_PAYLOAD_LIMIT && + is_control()) { + throw "control frames can't have large payloads"; + } + + m_header[0] &= (0xFF ^ BPB0_OPCODE); // clear op bits + m_header[0] |= op; // set op bits +} + +bool frame::get_masked() const { + return ((m_header[1] & BPB1_MASK) == BPB1_MASK); +} + +void frame::set_masked(bool masked) { + if (masked) { + m_header[1] |= BPB1_MASK; + generate_masking_key(); + } else { + m_header[1] &= (0xFF ^ BPB1_MASK); + clear_masking_key(); + } +} + +uint8_t frame::get_basic_size() const { + return m_header[1] & BPB1_PAYLOAD; +} + +size_t frame::get_payload_size() const { + if (m_extended_header_bytes_needed > 0) { + // problem + throw "attempted to get payload size before reading full header"; + } + + return m_payload.size(); +} + +std::vector &frame::get_payload() { + return m_payload; +} + +void frame::set_payload(const std::vector source) { + set_payload_helper(source.size()); + + std::copy(source.begin(),source.end(),m_payload.begin()); +} +void frame::set_payload(const std::string source) { + set_payload_helper(source.size()); + + std::copy(source.begin(),source.end(),m_payload.begin()); +} + +bool frame::is_control() const { + return (get_opcode() > MAX_FRAME_OPCODE); +} + +void frame::set_payload_helper(size_t s) { + if (s > max_payload_size) { + throw "requested payload is over implimentation defined limit"; + } + + // limits imposed by the websocket spec + if (s > BASIC_PAYLOAD_LIMIT && + get_opcode() > MAX_FRAME_OPCODE) { + throw "control frames can't have large payloads"; + } + + if (s <= BASIC_PAYLOAD_LIMIT) { + m_header[1] = s; + } else if (s <= PAYLOAD_16BIT_LIMIT) { + m_header[1] = BASIC_PAYLOAD_16BIT_CODE; + + // this reinterprets the second pair of bytes in m_header as a + // 16 bit int and writes the payload size there as an integer + // in network byte order + *reinterpret_cast(&m_header[BASIC_HEADER_LENGTH]) = htons(s); + } else if (s <= PAYLOAD_64BIT_LIMIT) { + m_header[1] = BASIC_PAYLOAD_64BIT_CODE; + *reinterpret_cast(&m_header[BASIC_HEADER_LENGTH]) = htonll(s); + } else { + throw "payload size limit is 63 bits"; + } + + m_payload.resize(s); +} + +void frame::print_frame() const { + /*unsigned int len = get_header_len(); + + std::cout << "frame: "; + // print header + for (unsigned int i = 0; i < len; i++) { + std::cout << std::hex << (unsigned short)m_header[i] << " "; + } + // print message + for (auto &i: m_payload) { + std::cout << i; + } + + std::cout << std::endl;*/ +} + +unsigned int frame::process_basic_header() { + m_extended_header_bytes_needed = 0; + m_payload.empty(); + + m_extended_header_bytes_needed = get_header_len() - BASIC_HEADER_LENGTH; + + return m_extended_header_bytes_needed; +} + +void frame::process_extended_header() { + m_extended_header_bytes_needed = 0; + + uint8_t s = get_basic_size(); + uint64_t payload_size; + int mask_index = BASIC_HEADER_LENGTH; + + if (s <= BASIC_PAYLOAD_LIMIT) { + payload_size = s; + } else if (s == BASIC_PAYLOAD_16BIT_CODE) { + // reinterpret the second two bytes as a 16 bit integer in network + // byte order. Convert to host byte order and store locally. + payload_size = ntohs(*( + reinterpret_cast(&m_header[BASIC_HEADER_LENGTH]) + )); + + mask_index += 2; + } else if (s == BASIC_PAYLOAD_64BIT_CODE) { + // reinterpret the second eight bytes as a 16 bit integer in + // network byte order. Convert to host byte order and store. + payload_size = ntohll(*( + reinterpret_cast(&m_header[BASIC_HEADER_LENGTH]) + )); + + mask_index += 8; + } else { + // shouldn't be here + throw "invalid get_basic_size in process_extended_header"; + } + + if (payload_size < s) { + throw "payload size error"; + } + + if (get_masked() == 0) { + clear_masking_key(); + } else { + // TODO: use this copy line (needs testing) + // std::copy(m_header[mask_index],m_header[mask_index+4],m_masking_key); + m_masking_key[0] = m_header[mask_index+0]; + m_masking_key[1] = m_header[mask_index+1]; + m_masking_key[2] = m_header[mask_index+2]; + m_masking_key[3] = m_header[mask_index+3]; + } + + if (payload_size > max_payload_size) { + throw "got frame with payload greater than maximum frame buffer size."; + } + m_payload.resize(payload_size); +} + +void frame::process_payload() { + // unmask payload one byte at a time + for (uint64_t i = 0; i < m_payload.size(); i++) { + m_payload[i] = (m_payload[i] ^ m_masking_key[i%4]); + } +} + +void frame::process_payload2() { + // unmask payload one byte at a time + + //uint64_t key = (*((uint32_t*)m_masking_key;)) << 32; + //key += *((uint32_t*)m_masking_key); + + // might need to switch byte order + uint32_t key = *((uint32_t*)m_masking_key); + + // 4 + + uint64_t i = 0; + uint64_t s = (m_payload.size() / 4); + + std::cout << "s: " << s << std::endl; + + // chunks of 4 + for (i = 0; i < s; i+=4) { + ((uint32_t*)(&m_payload[0]))[i] = (((uint32_t*)(&m_payload[0]))[i] ^ key); + } + + // finish the last few + for (i = s; i < m_payload.size(); i++) { + m_payload[i] = (m_payload[i] ^ m_masking_key[i%4]); + } +} + +bool frame::validate_basic_header() const { + // check for control frame size + if (get_basic_size() > BASIC_PAYLOAD_LIMIT && is_control()) { + return false; + } + + // check for reserved opcodes + if (get_rsv1() || get_rsv2() || get_rsv3()) { + return false; + } + + // check for reserved opcodes + opcode op = get_opcode(); + if (op > 0x02 && op < 0x08) { + return false; + } + if (op > 0x0A) { + return false; + } + + // check for fragmented control message + if (is_control() && !get_fin()) { + return false; + } + + return true; +} + +void frame::generate_masking_key() { + throw "masking key generation not implimented"; + /* TODO: test and tune + + boost::random::random_device rng; + boost::random::uniform_int_distribution<> mask(0,UINT32_MAX); + + *(reinterpret_cast(m_masking_key)) = mask(rng); + + */ +} + +void frame::clear_masking_key() { + m_masking_key[0] = 0; + m_masking_key[1] = 0; + m_masking_key[2] = 0; + m_masking_key[3] = 0; +} \ No newline at end of file diff --git a/src/websocket_frame.hpp b/src/websocket_frame.hpp new file mode 100644 index 0000000000..a010117615 --- /dev/null +++ b/src/websocket_frame.hpp @@ -0,0 +1,113 @@ +#ifndef WEBSOCKET_FRAME_HPP +#define WEBSOCKET_FRAME_HPP + +#include "network_utilities.hpp" + +#include +#include +#include + +namespace websocketpp { + +class frame { +public: + enum opcode_s { + CONTINUATION_FRAME = 0x00, + TEXT_FRAME = 0x01, + BINARY_FRAME = 0x02, + CONNECTION_CLOSE = 0x08, + PING = 0x09, + PONG = 0x0A + }; + + typedef enum opcode_s opcode; + + static const uint8_t MAX_FRAME_OPCODE = 0x07; + + // basic payload byte flags + static const uint8_t BPB0_OPCODE = 0x0F; + static const uint8_t BPB0_RSV3 = 0x10; + static const uint8_t BPB0_RSV2 = 0x20; + static const uint8_t BPB0_RSV1 = 0x40; + static const uint8_t BPB0_FIN = 0x80; + static const uint8_t BPB1_PAYLOAD = 0x7F; + static const uint8_t BPB1_MASK = 0x80; + + static const uint8_t BASIC_PAYLOAD_LIMIT = 0x7D; // 125 + static const uint8_t BASIC_PAYLOAD_16BIT_CODE = 0x7E; // 126 + static const uint16_t PAYLOAD_16BIT_LIMIT = 0xFFFF; // 2^16, 65535 + static const uint8_t BASIC_PAYLOAD_64BIT_CODE = 0x7F; // 127 + static const uint64_t PAYLOAD_64BIT_LIMIT = 0x7FFFFFFFFFFFFFFF; // 2^63 + + static const unsigned int BASIC_HEADER_LENGTH = 2; + static const unsigned int MAX_HEADER_LENGTH = 14; + static const uint8_t extended_header_length = 12; + static const uint64_t max_payload_size = 100000000; // 100MB + + // create an empty frame for writing into + frame() { + // not sure if these are necessary with c++ but putting in just in case + memset(m_header,0,MAX_HEADER_LENGTH); + } + + // get pointers to underlying buffers + char* get_header(); + char* get_extended_header(); + unsigned int get_header_len() const; + + char* get_masking_key(); + + // get and set header bits + bool get_fin() const; + void set_fin(bool fin); + + bool get_rsv1() const; + void set_rsv1(bool b); + + bool get_rsv2() const; + void set_rsv2(bool b); + + bool get_rsv3() const; + void set_rsv3(bool b); + + opcode get_opcode() const; + void set_opcode(opcode op); + + bool get_masked() const; + void set_masked(bool masked); + + uint8_t get_basic_size() const; + size_t get_payload_size() const; + + std::vector &get_payload(); + + void set_payload(const std::vector source); + void set_payload(const std::string source); + void set_payload_helper(size_t s); + + bool is_control() const; + + void print_frame() const; + + // reads basic header, sets and returns m_header_bits_needed + unsigned int process_basic_header(); + void process_extended_header(); + void process_payload(); + void process_payload2(); // experiment with more efficient masking code. + + bool validate_basic_header() const; + + void generate_masking_key(); + void clear_masking_key(); + +private: + char m_header[MAX_HEADER_LENGTH]; + std::vector m_payload; + + char m_masking_key[4]; + unsigned int m_extended_header_bytes_needed; +}; + +} + +#endif // WEBSOCKET_FRAME_HPP \ No newline at end of file diff --git a/src/websocket_server.cpp b/src/websocket_server.cpp new file mode 100644 index 0000000000..5f6665ad8b --- /dev/null +++ b/src/websocket_server.cpp @@ -0,0 +1,44 @@ +#include "websocket_server.hpp" + +#include + +#include + +using websocketpp::server; + +server::server(boost::asio::io_service& io_service, + const tcp::endpoint& endpoint, + const std::string& host, + connection_handler_ptr defc) + : m_host(host), + m_io_service(io_service), + m_acceptor(io_service, endpoint), + m_def_con_handler(defc) { + this->start_accept(); +} + +void server::start_accept() { + session_ptr new_ws(new session(m_io_service,m_host,m_def_con_handler)); + + m_acceptor.async_accept( + new_ws->socket(), + boost::bind( + &server::handle_accept, + this, + new_ws, + boost::asio::placeholders::error + ) + ); +} + +void server::handle_accept(session_ptr session, + const boost::system::error_code& error) { + + if (!error) { + session->start(); + } else { + std::cout << "Error" << std::endl; + } + + this->start_accept(); +} \ No newline at end of file diff --git a/src/websocket_server.hpp b/src/websocket_server.hpp new file mode 100644 index 0000000000..6818ee1570 --- /dev/null +++ b/src/websocket_server.hpp @@ -0,0 +1,40 @@ +#ifndef WEBSOCKET_SERVER_HPP +#define WEBSOCKET_SERVER_HPP + +#include "websocket_session.hpp" + +#include +#include + +using boost::asio::ip::tcp; + +namespace websocketpp { + +class server { + public: + server(boost::asio::io_service& io_service, + const tcp::endpoint& endpoint, + const std::string& host, + connection_handler_ptr defc); + + // creates a new session object and connects the next websocket + // connection to it. + void start_accept(); + + // if no errors starts the session's read loop and returns to the + // start_accept phase. + void handle_accept(session_ptr session, + const boost::system::error_code& error); + + private: + std::string m_host; + boost::asio::io_service& m_io_service; + tcp::acceptor m_acceptor; + connection_handler_ptr m_def_con_handler; +}; + +typedef boost::shared_ptr server_ptr; + +} + +#endif // WEBSOCKET_SERVER_HPP \ No newline at end of file diff --git a/src/websocket_session.cpp b/src/websocket_session.cpp new file mode 100644 index 0000000000..8d68e65898 --- /dev/null +++ b/src/websocket_session.cpp @@ -0,0 +1,666 @@ +#include "websocket_session.hpp" + +#include "websocket_frame.hpp" + +#include +#include +#include + +#include +#include +#include + +using websocketpp::session; + +void session::start() { + //std::cout << "[Connection " << this << "] WebSocket Connection request from " << m_socket.remote_endpoint() << std::endl; + + // async read to handle_read_handshake + boost::asio::async_read_until( + m_socket, + m_buf, + "\r\n\r\n", + boost::bind( + &session::handle_read_handshake, + shared_from_this(), + boost::asio::placeholders::error, + boost::asio::placeholders::bytes_transferred + ) + ); +} + +void session::set_handler(connection_handler_ptr new_con) { + if (m_local_interface) { + m_local_interface->disconnect(shared_from_this(),"Setting new connection handler"); + } + m_local_interface = new_con; + m_local_interface->connect(shared_from_this()); +} + + +std::string session::get_header(const std::string& key) const { + std::map::const_iterator h = m_headers.find(key); + + if (h == m_headers.end()) { + return std::string(); + } else { + return h->second; + } +} + +std::string session::get_request() const { + return m_request; +} + +void session::set_http_error(int code, std::string msg) { + m_http_error_code = code; + m_http_error_string = (msg != "" ? msg : lookup_http_error_string(code)); +} + +std::string session::lookup_http_error_string(int code) { + switch (code) { + case 400: + return "Bad Request"; + case 401: + return "Unauthorized"; + case 403: + return "Forbidden"; + case 404: + return "Not Found"; + case 405: + return "Method Not Allowed"; + case 406: + return "Not Acceptable"; + case 407: + return "Proxy Authentication Required"; + case 408: + return "Request Timeout"; + case 409: + return "Conflict"; + case 410: + return "Gone"; + case 411: + return "Length Required"; + case 412: + return "Precondition Failed"; + case 413: + return "Request Entity Too Large"; + case 414: + return "Request-URI Too Long"; + case 415: + return "Unsupported Media Type"; + case 416: + return "Requested Range Not Satisfiable"; + case 417: + return "Expectation Failed"; + case 500: + return "Internal Server Error"; + case 501: + return "Not Implimented"; + case 502: + return "Bad Gateway"; + case 503: + return "Service Unavailable"; + case 504: + return "Gateway Timeout"; + case 505: + return "HTTP Version Not Supported"; + default: + return "Unknown"; + } +} + +void session::send(const std::string &msg) { + m_write_frame.set_fin(true); + m_write_frame.set_opcode(frame::TEXT_FRAME); + m_write_frame.set_payload(msg); + + write_frame(); +} + +// send binary frame +void session::send(const std::vector &data) { + m_write_frame.set_fin(true); + m_write_frame.set_opcode(frame::BINARY_FRAME); + m_write_frame.set_payload(data); + + write_frame(); +} + +// send close frame +void session::disconnect(const std::string &reason) { + m_write_frame.set_fin(true); + m_write_frame.set_opcode(frame::CONNECTION_CLOSE); + m_write_frame.set_payload(reason); + + write_frame(); + + if (m_local_interface) { + m_local_interface->disconnect(shared_from_this(),reason); + } +} + +void session::ping(const std::string &msg) { + m_write_frame.set_fin(true); + m_write_frame.set_opcode(frame::PING); + m_write_frame.set_payload(msg); + + write_frame(); +} + +void session::pong(const std::string &msg) { + m_write_frame.set_fin(true); + m_write_frame.set_opcode(frame::PONG); + m_write_frame.set_payload(msg); + + write_frame(); +} + +void session::handle_read_handshake(const boost::system::error_code& e, + std::size_t bytes_transferred) { + // read handshake and set local state (or pass to write_handshake) + std::ostringstream line; + line << &m_buf; + m_handshake += line.str(); + + //std::cout << "=== Raw Message ===" << std::endl; + //std::cout << m_handshake << std::endl; + //std::cout << "=== Raw Message end ===" << std::endl; + + std::vector tokens; + std::string::size_type start = 0; + std::string::size_type end; + + // Get request and parse headers + end = m_handshake.find("\r\n",start); + + while(end != std::string::npos) { + tokens.push_back(m_handshake.substr(start, end - start)); + + start = end + 2; + + end = m_handshake.find("\r\n",start); + } + + for (size_t i = 0; i < tokens.size(); i++) { + if (i == 0) { + m_request = tokens[i]; + } + + end = tokens[i].find(": ",0); + + if (end != std::string::npos) { + m_headers[tokens[i].substr(0,end)] = tokens[i].substr(end+2); + } + } + + // error checking + + // check the method + if (m_request.substr(0,4) != "GET ") { + // invalid method + std::cout << "Websocket handshake has invalid method: " << m_request.substr(0,4) << ", killing connection." << std::endl; + // TODO: exception + this->set_http_error(400); + this->write_http_error(); + return; + } + + // check the HTTP version + // TODO: allow versions greater than 1.1 + end = m_request.find(" HTTP/1.1",4); + if (end == std::string::npos) { + std::cout << "Websocket handshake has invalid HTTP version, killing connection." << std::endl; + this->set_http_error(400); // or error 505 HTTP Version Not Supported + this->write_http_error(); + return; + } + + m_request = m_request.substr(4,end-4); + + // TODO: use exceptions or a helper function or something better here + + // verify the presence of required headers + if (get_header("Host") != m_host) { + std::cerr << "Invalid or missing Host header." << std::endl; + this->set_http_error(400); + } + if (!boost::iequals(get_header("Upgrade"),"websocket")) { + std::cerr << "Invalid or missing Upgrade header." << std::endl; + this->set_http_error(400); + } + if (get_header("Connection").find("Upgrade") == std::string::npos) { + // TODO: case insensitive? + std::cerr << "Invalid or missing Connection header." << std::endl; + this->set_http_error(400); + } + if (get_header("Sec-WebSocket-Key") == "") { + std::cerr << "Invalid or missing Sec-Websocket-Key header." << std::endl; + this->set_http_error(400); + } + + if (get_header("Sec-WebSocket-Version") != "8" && + get_header("Sec-WebSocket-Version") != "7") { + std::cerr << "Invalid or missing Sec-Websocket-Version header." << std::endl; + this->set_http_error(400); + } + + if (m_http_error_code != 0) { + this->write_http_error(); + return; + } + + // optional headers (delegated to the local interface) + if (m_local_interface && !m_local_interface->validate(shared_from_this())) { + std::cerr << "Local interface rejected the connection." << std::endl; + if (m_http_error_code == 0) { + this->set_http_error(400); + } + } + + if (m_http_error_code != 0) { + this->write_http_error(); + return; + } + + this->write_handshake(); +} + +void session::write_handshake() { + std::string server_handshake = ""; + std::string server_key = m_headers["Sec-WebSocket-Key"]; + server_key += "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + + SHA1 sha; + uint32_t message_digest[5]; + + sha.Reset(); + sha << server_key.c_str(); + + if (!sha.Result(message_digest)) { + std::cerr << "Error computing sha1 hash, killing connection." << std::endl; + + return; + } + + // convert sha1 hash bytes to network byte order because this sha1 + // library works on ints rather than bytes + for (int i = 0; i < 5; i++) { + message_digest[i] = htonl(message_digest[i]); + } + + server_key = base64_encode( + reinterpret_cast(message_digest),20); + + server_handshake += "HTTP/1.1 101 Switching Protocols\r\n"; + server_handshake += "Upgrade: websocket\r\n"; + server_handshake += "Connection: Upgrade\r\n"; + server_handshake += "Sec-WebSocket-Accept: "+server_key+"\r\n\r\n"; + + // TODO: handler requested headers + + //std::cout << server_handshake << std::endl; + + // start async write to handle_write_handshake + boost::asio::async_write( + m_socket, + boost::asio::buffer(server_handshake), + boost::bind( + &session::handle_write_handshake, + shared_from_this(), + boost::asio::placeholders::error + ) + ); +} + +void session::handle_write_handshake(const boost::system::error_code& error) { + if (error) { + handle_error("Error writing handshake",error); + return; + } + + //std::cout << "WebSocket Version 8 connection opened." << std::endl; + + m_status = OPEN; + + if (m_local_interface) { + m_local_interface->connect(shared_from_this()); + } + + reset_message(); + this->read_frame(); +} + +void session::write_http_error() { + std::stringstream server_handshake; + + server_handshake << "HTTP/1.1 " << m_http_error_code << " " + << m_http_error_string << "\r\n"; + + // additional headers? + + server_handshake << "\r\n"; + + // start async write to handle_write_handshake + boost::asio::async_write( + m_socket, + boost::asio::buffer(server_handshake.str()), + boost::bind( + &session::handle_write_http_error, + shared_from_this(), + boost::asio::placeholders::error + ) + ); +} + +void session::handle_write_http_error(const boost::system::error_code& error) { + if (error) { + handle_error("Error writing http response",error); + return; + } +} + +void session::read_frame() { + boost::asio::async_read( + m_socket, + boost::asio::buffer(m_read_frame.get_header(), + frame::BASIC_HEADER_LENGTH), + boost::bind( + &session::handle_frame_header, + shared_from_this(), + boost::asio::placeholders::error + ) + ); +} + + + +void session::handle_frame_header(const boost::system::error_code& error) { + if (error) { + handle_error("Error reading basic frame header",error); + return; + } + + uint16_t extended_header_bytes = m_read_frame.process_basic_header(); + + if (!m_read_frame.validate_basic_header()) { + handle_error("Basic header validation failed",boost::system::error_code()); + return; + } + + if (extended_header_bytes == 0) { + read_payload(); + } else { + boost::asio::async_read( + m_socket, + boost::asio::buffer(m_read_frame.get_extended_header(), + extended_header_bytes), + boost::bind( + &session::handle_extended_frame_header, + shared_from_this(), + boost::asio::placeholders::error + ) + ); + } +} + +void session::handle_extended_frame_header( + const boost::system::error_code& error) { + if (error) { + handle_error("Error reading extended frame header",error); + return; + } + + // this sets up the buffer we are about to read into. + m_read_frame.process_extended_header(); + + this->read_payload(); +} + +void session::read_payload() { + /*char * foo = m_read_frame.get_header(); + + std::cout << std::hex << ((uint16_t*)foo)[0] << std::endl; + + std::cout << "opcode: " << m_read_frame.get_opcode() << std::endl; + std::cout << "fin: " << m_read_frame.get_fin() << std::endl; + std::cout << "mask: " << m_read_frame.get_masked() << std::endl; + std::cout << "size: " << (uint16_t)m_read_frame.get_basic_size() << std::endl; + std::cout << "payload_size: " << m_read_frame.get_payload_size() << std::endl;*/ + + boost::asio::async_read( + m_socket, + boost::asio::buffer(m_read_frame.get_payload()), + boost::bind( + &session::handle_read_payload, + shared_from_this(), + boost::asio::placeholders::error + ) + ); +} + +void session::handle_read_payload (const boost::system::error_code& error) { + if (error) { + handle_error("Error reading payload data frame header",error); + return; + } + + m_read_frame.process_payload(); + + switch (m_read_frame.get_opcode()) { + case frame::CONTINUATION_FRAME: + process_continuation(); + break; + case frame::TEXT_FRAME: + process_text(); + break; + case frame::BINARY_FRAME: + process_binary(); + break; + case frame::CONNECTION_CLOSE: + process_close(); + break; + case frame::PING: + process_ping(); + break; + case frame::PONG: + process_pong(); + break; + default: + // TODO: unknown frame type + // ignore or close? + break; + } + + // check if there was an error processing this frame and fail the connection + if (m_error) { + return; + } + + this->read_frame(); +} + +void session::handle_write_frame (const boost::system::error_code& error) { + if (error) { + handle_error("Error writing frame data",error); + } + + //std::cout << "Successfully wrote frame." << std::endl; +} + +void session::process_ping() { + std::cout << "Got ping" << std::endl; + + // send pong + m_write_frame.set_fin(true); + m_write_frame.set_opcode(frame::PONG); + m_write_frame.set_payload(m_read_frame.get_payload()); + + write_frame(); +} + +void session::process_pong() { + std::cout << "Got pong" << std::endl; +} + +void session::process_text() { + // text is binary. + process_binary(); +} + +void session::process_binary() { + if (m_fragmented) { + handle_error("Got a new message before the previous was finished.", + boost::system::error_code()); + return; + } + + m_current_opcode = m_read_frame.get_opcode(); + + if (m_read_frame.get_fin()) { + deliver_message(); + reset_message(); + } else { + m_fragmented = true; + extract_payload(); + } +} + +void session::process_continuation() { + if (!m_fragmented) { + handle_error("Got a continuation frame without an outstanding message.", + boost::system::error_code()); + return; + } + + extract_payload(); + + // check if we are done + if (m_read_frame.get_fin()) { + deliver_message(); + reset_message(); + } +} + +void session::process_close() { + if (m_status == OPEN) { + // send response and set to closed + std::string msg(m_read_frame.get_payload().begin(), + m_read_frame.get_payload().end()); + + std::cout << "Got connection close message, acking and closing the connection. Reason was: " << msg << std::endl; + + m_status = CLOSED; + + // send acknowledgement + m_write_frame.set_fin(true); + m_write_frame.set_opcode(frame::CONNECTION_CLOSE); + m_write_frame.set_payload(""); + + write_frame(); + + // let our local interface know that the remote client has + // disconnected + if (m_local_interface) { + m_local_interface->disconnect(shared_from_this(),msg); + } + } else if (m_status == CLOSING) { + // this is an ack of my close message + // close cleanly + std::cout << "Got ack for my close message, closing the connection" << std::endl; + m_status = CLOSED; + + // let our local interface know that the remote client has + // disconnected and the reason (if any) + if (m_local_interface) { + m_local_interface->disconnect(shared_from_this(),""); + } + } else { + // ignore + } +} + +void session::deliver_message() { + if (!m_local_interface) { + return; + } + + if (m_current_opcode == frame::BINARY_FRAME) { + if (m_fragmented) { + m_local_interface->message(shared_from_this(),m_current_message); + } else { + m_local_interface->message(shared_from_this(), + m_read_frame.get_payload()); + } + } else if (m_current_opcode == frame::TEXT_FRAME) { + std::string msg; + + if (m_fragmented) { + msg.append(m_current_message.begin(),m_current_message.end()); + } else { + msg.append( + m_read_frame.get_payload().begin(), + m_read_frame.get_payload().end() + ); + } + + m_local_interface->message(shared_from_this(),msg); + } else { + // fail + } + +} + +void session::extract_payload() { + std::vector &msg = m_read_frame.get_payload(); + m_current_message.resize(m_current_message.size()+msg.size()); + std::copy(msg.begin(),msg.end(),m_current_message.end()-msg.size()); +} + +void session::write_frame() { + // print debug info + m_write_frame.print_frame(); + + std::vector data; + + data.push_back( + boost::asio::buffer( + m_write_frame.get_header(), + m_write_frame.get_header_len() + ) + ); + data.push_back( + boost::asio::buffer(m_write_frame.get_payload()) + ); + + boost::asio::async_write( + m_socket, + data, + boost::bind( + &session::handle_write_frame, + shared_from_this(), + boost::asio::placeholders::error + ) + ); +} + +void session::reset_message() { + m_error = false; + m_fragmented = false; + m_current_message.clear(); +} + +void session::handle_error(std::string msg, + const boost::system::error_code& error) { + std::stringstream e; + + e << msg << " (" << error << ")"; + + std::cerr << e.str() << std::endl; + + if (m_local_interface) { + m_local_interface->disconnect(shared_from_this(),e.str()); + } + + m_error = true; +} \ No newline at end of file diff --git a/src/websocket_session.hpp b/src/websocket_session.hpp new file mode 100644 index 0000000000..479337ff93 --- /dev/null +++ b/src/websocket_session.hpp @@ -0,0 +1,209 @@ +#ifndef WEBSOCKET_SESSION_HPP +#define WEBSOCKET_SESSION_HPP + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include + +#include + +namespace websocketpp { + class session; + typedef boost::shared_ptr session_ptr; +} + +#include "websocket_frame.hpp" +#include "websocket_connection_handler.hpp" + +#include "base64/base64.h" +#include "sha1/sha1.h" + +using boost::asio::ip::tcp; + +namespace websocketpp { + +class session : public boost::enable_shared_from_this { +public: + enum ws_status { + CONNECTING, + OPEN, + CLOSING, + CLOSED + }; + + typedef enum ws_status status_code; + + session (boost::asio::io_service& io_service, + const std::string& host, + connection_handler_ptr defc) + : m_host(host), + m_socket(io_service), + m_status(CONNECTING), + m_http_error_code(0), + m_local_interface(defc) {} + + tcp::socket& socket() { + return m_socket; + } + + // This function is called to begin the session loop. This method and all + // that come after it are called as a result of an async event completing. + // if any method in this chain returns before adding a new async event the + // session will end. + void start(); + + // sets the internal connection handler of this connection to new_con. + // This is useful if you want to switch handler objects during a connection + // Example: a generic lobby handler could validate the handshake negotiate a + // sub protocol to talk to and then pass the connection off to a handler for + // that sub protocol. + void set_handler(connection_handler_ptr new_con); + + + // Public handshake interface + + // By default a failed handshake validation will return an HTTP 400 Bad + // Request error. If your application wants to reject the connection for + // another reason it can be set here. Example: 404 if the resource request + // is not recognized or 403 forbidden if the origin does not match. If only + // a code is supplied a msg will be generated automatically. + void set_http_error(int code,std::string msg = ""); + + // gets the value of a header or the empty string if not present. + std::string get_header(const std::string &key) const; + // adds an arbitrary header to the server handshake HTTP response. + void add_header(const std::string &key,const std::string &value); + + std::string get_request() const; + + // sets the subprotocol being used. This will result in the appropriate + // Sec-WebSocket-Protocol header being sent back to the client. The value + // here must have been present in the client's opening handshake. + void set_subprotocol(const std::string &protocol); + + + //int get_version(); + + //void add_extension(); + + // Public session interface + + // send basic frame types + void send(const std::string &msg); // text + void send(const std::vector &data); // binary + void ping(const std::string &msg); + void pong(const std::string &msg); + + void disconnect(const std::string &reason); +private: + // handle_read_handshake reads the HTTP headers of the initial websocket + // handshake, parses out the request and headers, and does error checking + // TODO: Generalize a lot of the hard coded things in this method. + void handle_read_handshake(const boost::system::error_code& e, + std::size_t bytes_transferred); + + // write_handshake calculates the server portion of the handshake and + // sends it back. + // TODO: Generalize this to include things like protocols, cookies, etc + void write_handshake(); + // handle_write_handshake checks for errors writing the server handshake, + // officially declares a connection open, notifies the local interface, + // and starts the frame reading loop. + void handle_write_handshake(const boost::system::error_code& error); + + // construct and write an HTTP error in the case the handshake goes poorly + void write_http_error(); + void handle_write_http_error(const boost::system::error_code& error); + + // start async read for a websocket frame (2 bytes) to handle_frame_header + void read_frame(); + + // reads frame header and devices if it needs to read more header or go + // straight to the payload. + void handle_frame_header(const boost::system::error_code& error); + + // process extra headers and start payload read + void handle_extended_frame_header(const boost::system::error_code& error); + + // initiate payload read + void read_payload(); + + // now the frame object should be complete. Process and send it on then + // reset for new frame + void handle_read_payload (const boost::system::error_code& error); + + // checks for errors writing frames + void handle_write_frame (const boost::system::error_code& error); + + // helper functions for processing each opcode + void process_ping(); + void process_pong(); + void process_text(); + void process_binary(); + void process_continuation(); + void process_close(); + + // deliver message if we have a local interface attached + void deliver_message(); + + // copies the current read frame payload into the session so that the read + // frame can be cleared for the next read. This is done when fragmented + // messages are recieved. + void extract_payload(); + + // write m_write_frame out to the socket. + void write_frame(); + + // reset session for a new message + void reset_message(); + + // prints a diagnostic message and disconnects the local interface + void handle_error(std::string msg,const boost::system::error_code& error); + + std::string lookup_http_error_string(int code); +private: + std::string m_host; + tcp::socket m_socket; + status_code m_status; + int m_http_error_code; + std::string m_http_error_string; + + boost::asio::streambuf m_buf; + std::string m_handshake; + + std::string m_request; + std::map m_headers; + + frame m_read_frame; + frame m_write_frame; + std::vector m_current_message; + + bool m_error; + bool m_fragmented; + frame::opcode m_current_opcode; + + connection_handler_ptr m_local_interface; +}; + +} + +#endif // WEBSOCKET_SESSION_HPP + + + +// better debug printing system +// set acceptible origin and host headers +// case sensitive header values? e.g. websocket + + +// double check bugs in autobahn (sending wrong localhost:9000 header) not +// checking masking in the 9.x tests diff --git a/src/websocketpp.hpp b/src/websocketpp.hpp new file mode 100644 index 0000000000..d798ef11db --- /dev/null +++ b/src/websocketpp.hpp @@ -0,0 +1,6 @@ +#ifndef WEBSOCKETPP_HPP +#define WEBSOCKETPP_HPP + +#include + +#endif // WEBSOCKETPP_HPP diff --git a/todo.txt b/todo.txt new file mode 100644 index 0000000000..809a2d1860 --- /dev/null +++ b/todo.txt @@ -0,0 +1,11 @@ +// TODO: impliment stuff from here: +// http://stackoverflow.com/questions/809902/64-bit-ntohl-in-c + +boost unit test suite? + +figure out exception handling setup + + +nagle? + +extensions \ No newline at end of file